careamics 0.1.0rc4__py3-none-any.whl → 0.1.0rc6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/callbacks/hyperparameters_callback.py +10 -3
- careamics/callbacks/progress_bar_callback.py +37 -4
- careamics/careamist.py +92 -55
- careamics/config/__init__.py +0 -1
- careamics/config/algorithm_model.py +5 -3
- careamics/config/architectures/architecture_model.py +7 -0
- careamics/config/architectures/custom_model.py +8 -1
- careamics/config/architectures/register_model.py +3 -1
- careamics/config/architectures/unet_model.py +3 -0
- careamics/config/architectures/vae_model.py +2 -0
- careamics/config/callback_model.py +4 -15
- careamics/config/configuration_example.py +4 -4
- careamics/config/configuration_factory.py +113 -55
- careamics/config/configuration_model.py +14 -16
- careamics/config/data_model.py +63 -165
- careamics/config/inference_model.py +9 -75
- careamics/config/optimizer_models.py +4 -4
- careamics/config/references/algorithm_descriptions.py +1 -0
- careamics/config/references/references.py +1 -0
- careamics/config/support/__init__.py +0 -2
- careamics/config/support/supported_activations.py +2 -0
- careamics/config/support/supported_algorithms.py +3 -1
- careamics/config/support/supported_architectures.py +2 -0
- careamics/config/support/supported_data.py +2 -0
- careamics/config/support/supported_loggers.py +2 -0
- careamics/config/support/supported_losses.py +2 -0
- careamics/config/support/supported_optimizers.py +2 -0
- careamics/config/support/supported_pixel_manipulations.py +3 -3
- careamics/config/support/supported_struct_axis.py +2 -0
- careamics/config/support/supported_transforms.py +4 -15
- careamics/config/tile_information.py +2 -0
- careamics/config/training_model.py +1 -0
- careamics/config/transformations/__init__.py +3 -2
- careamics/config/transformations/n2v_manipulate_model.py +1 -0
- careamics/config/transformations/normalize_model.py +1 -0
- careamics/config/transformations/transform_model.py +1 -0
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +13 -7
- careamics/config/validators/validator_utils.py +1 -0
- careamics/conftest.py +13 -0
- careamics/dataset/dataset_utils/__init__.py +0 -1
- careamics/dataset/dataset_utils/dataset_utils.py +5 -4
- careamics/dataset/dataset_utils/file_utils.py +4 -3
- careamics/dataset/dataset_utils/read_tiff.py +6 -2
- careamics/dataset/dataset_utils/read_utils.py +2 -0
- careamics/dataset/dataset_utils/read_zarr.py +11 -7
- careamics/dataset/in_memory_dataset.py +84 -76
- careamics/dataset/iterable_dataset.py +166 -134
- careamics/dataset/patching/__init__.py +0 -7
- careamics/dataset/patching/patching.py +56 -14
- careamics/dataset/patching/random_patching.py +8 -2
- careamics/dataset/patching/sequential_patching.py +20 -14
- careamics/dataset/patching/tiled_patching.py +13 -7
- careamics/dataset/patching/validate_patch_dimension.py +2 -0
- careamics/dataset/zarr_dataset.py +2 -0
- careamics/lightning_datamodule.py +63 -41
- careamics/lightning_module.py +9 -3
- careamics/lightning_prediction_datamodule.py +15 -20
- careamics/lightning_prediction_loop.py +8 -6
- careamics/losses/__init__.py +1 -3
- careamics/losses/loss_factory.py +2 -1
- careamics/losses/losses.py +11 -7
- careamics/model_io/__init__.py +0 -1
- careamics/model_io/bioimage/_readme_factory.py +2 -1
- careamics/model_io/bioimage/bioimage_utils.py +1 -0
- careamics/model_io/bioimage/model_description.py +1 -0
- careamics/model_io/bmz_io.py +4 -3
- careamics/models/activation.py +2 -0
- careamics/models/layers.py +122 -25
- careamics/models/model_factory.py +2 -1
- careamics/models/unet.py +114 -19
- careamics/prediction/stitch_prediction.py +2 -5
- careamics/transforms/__init__.py +4 -25
- careamics/transforms/compose.py +124 -0
- careamics/transforms/n2v_manipulate.py +65 -34
- careamics/transforms/normalize.py +91 -28
- careamics/transforms/pixel_manipulation.py +7 -7
- careamics/transforms/struct_mask_parameters.py +3 -1
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +2 -2
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +66 -60
- careamics/utils/__init__.py +0 -1
- careamics/utils/base_enum.py +28 -0
- careamics/utils/context.py +1 -0
- careamics/utils/logging.py +1 -0
- careamics/utils/metrics.py +1 -0
- careamics/utils/path_utils.py +2 -0
- careamics/utils/ram.py +2 -0
- careamics/utils/receptive_field.py +93 -87
- careamics/utils/torch_utils.py +1 -0
- {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/METADATA +17 -61
- careamics-0.1.0rc6.dist-info/RECORD +107 -0
- careamics/config/noise_models.py +0 -162
- careamics/config/support/supported_extraction_strategies.py +0 -24
- careamics/config/transformations/nd_flip_model.py +0 -32
- careamics/dataset/patching/patch_transform.py +0 -44
- careamics/losses/noise_model_factory.py +0 -40
- careamics/losses/noise_models.py +0 -524
- careamics/transforms/nd_flip.py +0 -93
- careamics-0.1.0rc4.dist-info/RECORD +0 -110
- {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/WHEEL +0 -0
- {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/licenses/LICENSE +0 -0
careamics/models/layers.py
CHANGED
|
@@ -3,6 +3,7 @@ Layer module.
|
|
|
3
3
|
|
|
4
4
|
This submodule contains layers used in the CAREamics models.
|
|
5
5
|
"""
|
|
6
|
+
|
|
6
7
|
from typing import List, Optional, Tuple, Union
|
|
7
8
|
|
|
8
9
|
import torch
|
|
@@ -161,6 +162,18 @@ def _unpack_kernel_size(
|
|
|
161
162
|
"""Unpack kernel_size to a tuple of ints.
|
|
162
163
|
|
|
163
164
|
Inspired by Kornia implementation. TODO: link
|
|
165
|
+
|
|
166
|
+
Parameters
|
|
167
|
+
----------
|
|
168
|
+
kernel_size : Union[Tuple[int, ...], int]
|
|
169
|
+
Kernel size.
|
|
170
|
+
dim : int
|
|
171
|
+
Number of dimensions.
|
|
172
|
+
|
|
173
|
+
Returns
|
|
174
|
+
-------
|
|
175
|
+
Tuple[int, ...]
|
|
176
|
+
Kernel size tuple.
|
|
164
177
|
"""
|
|
165
178
|
if isinstance(kernel_size, int):
|
|
166
179
|
kernel_dims = tuple([kernel_size for _ in range(dim)])
|
|
@@ -172,7 +185,20 @@ def _unpack_kernel_size(
|
|
|
172
185
|
def _compute_zero_padding(
|
|
173
186
|
kernel_size: Union[Tuple[int, ...], int], dim: int
|
|
174
187
|
) -> Tuple[int, ...]:
|
|
175
|
-
"""Utility function that computes zero padding tuple.
|
|
188
|
+
"""Utility function that computes zero padding tuple.
|
|
189
|
+
|
|
190
|
+
Parameters
|
|
191
|
+
----------
|
|
192
|
+
kernel_size : Union[Tuple[int, ...], int]
|
|
193
|
+
Kernel size.
|
|
194
|
+
dim : int
|
|
195
|
+
Number of dimensions.
|
|
196
|
+
|
|
197
|
+
Returns
|
|
198
|
+
-------
|
|
199
|
+
Tuple[int, ...]
|
|
200
|
+
Zero padding tuple.
|
|
201
|
+
"""
|
|
176
202
|
kernel_dims = _unpack_kernel_size(kernel_size, dim)
|
|
177
203
|
return tuple([(kd - 1) // 2 for kd in kernel_dims])
|
|
178
204
|
|
|
@@ -190,14 +216,19 @@ def get_pascal_kernel_1d(
|
|
|
190
216
|
|
|
191
217
|
Parameters
|
|
192
218
|
----------
|
|
193
|
-
kernel_size:
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
219
|
+
kernel_size : int
|
|
220
|
+
Kernel size.
|
|
221
|
+
norm : bool
|
|
222
|
+
Normalize the kernel, by default False.
|
|
223
|
+
device : Optional[torch.device]
|
|
224
|
+
Device of the tensor, by default None.
|
|
225
|
+
dtype : Optional[torch.dtype]
|
|
226
|
+
Data type of the tensor, by default None.
|
|
197
227
|
|
|
198
228
|
Returns
|
|
199
229
|
-------
|
|
200
|
-
|
|
230
|
+
torch.Tensor
|
|
231
|
+
Pascal kernel.
|
|
201
232
|
|
|
202
233
|
Examples
|
|
203
234
|
--------
|
|
@@ -244,19 +275,28 @@ def _get_pascal_kernel_nd(
|
|
|
244
275
|
) -> torch.Tensor:
|
|
245
276
|
"""Generate pascal filter kernel by kernel size.
|
|
246
277
|
|
|
278
|
+
If kernel_size is an integer the kernel will be shaped as (kernel_size, kernel_size)
|
|
279
|
+
otherwise the kernel will be shaped as kernel_size
|
|
280
|
+
|
|
247
281
|
Inspired by Kornia implementation.
|
|
248
282
|
|
|
249
283
|
Parameters
|
|
250
284
|
----------
|
|
251
|
-
kernel_size:
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
285
|
+
kernel_size : Union[Tuple[int, int], int]
|
|
286
|
+
Kernel size for the pascal kernel.
|
|
287
|
+
norm : bool
|
|
288
|
+
Normalize the kernel, by default True.
|
|
289
|
+
dim : int
|
|
290
|
+
Number of dimensions, by default 2.
|
|
291
|
+
device : Optional[torch.device]
|
|
292
|
+
Device of the tensor, by default None.
|
|
293
|
+
dtype : Optional[torch.dtype]
|
|
294
|
+
Data type of the tensor, by default None.
|
|
255
295
|
|
|
256
296
|
Returns
|
|
257
297
|
-------
|
|
258
|
-
|
|
259
|
-
|
|
298
|
+
torch.Tensor
|
|
299
|
+
Pascal kernel.
|
|
260
300
|
|
|
261
301
|
Examples
|
|
262
302
|
--------
|
|
@@ -302,6 +342,24 @@ def _max_blur_pool_by_kernel2d(
|
|
|
302
342
|
"""Compute max_blur_pool by a given :math:`CxC_(out, None)xNxN` kernel.
|
|
303
343
|
|
|
304
344
|
Inspired by Kornia implementation.
|
|
345
|
+
|
|
346
|
+
Parameters
|
|
347
|
+
----------
|
|
348
|
+
x : torch.Tensor
|
|
349
|
+
Input tensor.
|
|
350
|
+
kernel : torch.Tensor
|
|
351
|
+
Kernel tensor.
|
|
352
|
+
stride : int
|
|
353
|
+
Stride.
|
|
354
|
+
max_pool_size : int
|
|
355
|
+
Maximum pool size.
|
|
356
|
+
ceil_mode : bool
|
|
357
|
+
Ceil mode, by default False. Set to True to match output size of conv2d.
|
|
358
|
+
|
|
359
|
+
Returns
|
|
360
|
+
-------
|
|
361
|
+
torch.Tensor
|
|
362
|
+
Output tensor.
|
|
305
363
|
"""
|
|
306
364
|
# compute local maxima
|
|
307
365
|
x = F.max_pool2d(
|
|
@@ -322,6 +380,24 @@ def _max_blur_pool_by_kernel3d(
|
|
|
322
380
|
"""Compute max_blur_pool by a given :math:`CxC_(out, None)xNxNxN` kernel.
|
|
323
381
|
|
|
324
382
|
Inspired by Kornia implementation.
|
|
383
|
+
|
|
384
|
+
Parameters
|
|
385
|
+
----------
|
|
386
|
+
x : torch.Tensor
|
|
387
|
+
Input tensor.
|
|
388
|
+
kernel : torch.Tensor
|
|
389
|
+
Kernel tensor.
|
|
390
|
+
stride : int
|
|
391
|
+
Stride.
|
|
392
|
+
max_pool_size : int
|
|
393
|
+
Maximum pool size.
|
|
394
|
+
ceil_mode : bool
|
|
395
|
+
Ceil mode, by default False. Set to True to match output size of conv2d.
|
|
396
|
+
|
|
397
|
+
Returns
|
|
398
|
+
-------
|
|
399
|
+
torch.Tensor
|
|
400
|
+
Output tensor.
|
|
325
401
|
"""
|
|
326
402
|
# compute local maxima
|
|
327
403
|
x = F.max_pool3d(
|
|
@@ -342,21 +418,16 @@ class MaxBlurPool(nn.Module):
|
|
|
342
418
|
|
|
343
419
|
Parameters
|
|
344
420
|
----------
|
|
345
|
-
dim: int
|
|
346
|
-
Toggles between 2D and 3D
|
|
347
|
-
kernel_size: Union[Tuple[int, int], int]
|
|
421
|
+
dim : int
|
|
422
|
+
Toggles between 2D and 3D.
|
|
423
|
+
kernel_size : Union[Tuple[int, int], int]
|
|
348
424
|
Kernel size for max pooling.
|
|
349
|
-
stride: int
|
|
425
|
+
stride : int
|
|
350
426
|
Stride for pooling.
|
|
351
|
-
max_pool_size: int
|
|
427
|
+
max_pool_size : int
|
|
352
428
|
Max kernel size for max pooling.
|
|
353
|
-
ceil_mode: bool
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
Returns
|
|
357
|
-
-------
|
|
358
|
-
torch.Tensor
|
|
359
|
-
The pooled and blurred tensor.
|
|
429
|
+
ceil_mode : bool
|
|
430
|
+
Ceil mode, by default False. Set to True to match output size of conv2d.
|
|
360
431
|
"""
|
|
361
432
|
|
|
362
433
|
def __init__(
|
|
@@ -367,6 +438,21 @@ class MaxBlurPool(nn.Module):
|
|
|
367
438
|
max_pool_size: int = 2,
|
|
368
439
|
ceil_mode: bool = False,
|
|
369
440
|
) -> None:
|
|
441
|
+
"""Constructor.
|
|
442
|
+
|
|
443
|
+
Parameters
|
|
444
|
+
----------
|
|
445
|
+
dim : int
|
|
446
|
+
Dimension of the convolution.
|
|
447
|
+
kernel_size : Union[Tuple[int, int], int]
|
|
448
|
+
Kernel size for max pooling.
|
|
449
|
+
stride : int, optional
|
|
450
|
+
Stride, by default 2.
|
|
451
|
+
max_pool_size : int, optional
|
|
452
|
+
Maximum pool size, by default 2.
|
|
453
|
+
ceil_mode : bool, optional
|
|
454
|
+
Ceil mode, by default False. Set to True to match output size of conv2d.
|
|
455
|
+
"""
|
|
370
456
|
super().__init__()
|
|
371
457
|
self.dim = dim
|
|
372
458
|
self.kernel_size = kernel_size
|
|
@@ -376,7 +462,18 @@ class MaxBlurPool(nn.Module):
|
|
|
376
462
|
self.kernel = _get_pascal_kernel_nd(kernel_size, norm=True, dim=self.dim)
|
|
377
463
|
|
|
378
464
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
379
|
-
"""Forward pass of the function.
|
|
465
|
+
"""Forward pass of the function.
|
|
466
|
+
|
|
467
|
+
Parameters
|
|
468
|
+
----------
|
|
469
|
+
x : torch.Tensor
|
|
470
|
+
Input tensor.
|
|
471
|
+
|
|
472
|
+
Returns
|
|
473
|
+
-------
|
|
474
|
+
torch.Tensor
|
|
475
|
+
Output tensor.
|
|
476
|
+
"""
|
|
380
477
|
self.kernel = torch.as_tensor(self.kernel, device=x.device, dtype=x.dtype)
|
|
381
478
|
if self.dim == 2:
|
|
382
479
|
return _max_blur_pool_by_kernel2d(
|
|
@@ -3,6 +3,7 @@ Model factory.
|
|
|
3
3
|
|
|
4
4
|
Model creation factory functions.
|
|
5
5
|
"""
|
|
6
|
+
|
|
6
7
|
from typing import Union
|
|
7
8
|
|
|
8
9
|
import torch
|
|
@@ -26,7 +27,7 @@ def model_factory(
|
|
|
26
27
|
Parameters
|
|
27
28
|
----------
|
|
28
29
|
model_configuration : Union[UNetModel, VAEModel]
|
|
29
|
-
Model configuration
|
|
30
|
+
Model configuration.
|
|
30
31
|
|
|
31
32
|
Returns
|
|
32
33
|
-------
|
careamics/models/unet.py
CHANGED
|
@@ -3,7 +3,8 @@ UNet model.
|
|
|
3
3
|
|
|
4
4
|
A UNet encoder, decoder and complete model.
|
|
5
5
|
"""
|
|
6
|
-
|
|
6
|
+
|
|
7
|
+
from typing import Any, List, Tuple, Union
|
|
7
8
|
|
|
8
9
|
import torch
|
|
9
10
|
import torch.nn as nn
|
|
@@ -33,6 +34,11 @@ class UnetEncoder(nn.Module):
|
|
|
33
34
|
Dropout probability, by default 0.0.
|
|
34
35
|
pool_kernel : int, optional
|
|
35
36
|
Kernel size for the max pooling layers, by default 2.
|
|
37
|
+
n2v2 : bool, optional
|
|
38
|
+
Whether to use N2V2 architecture, by default False.
|
|
39
|
+
groups : int, optional
|
|
40
|
+
Number of blocked connections from input channels to output
|
|
41
|
+
channels, by default 1.
|
|
36
42
|
"""
|
|
37
43
|
|
|
38
44
|
def __init__(
|
|
@@ -45,6 +51,7 @@ class UnetEncoder(nn.Module):
|
|
|
45
51
|
dropout: float = 0.0,
|
|
46
52
|
pool_kernel: int = 2,
|
|
47
53
|
n2v2: bool = False,
|
|
54
|
+
groups: int = 1,
|
|
48
55
|
) -> None:
|
|
49
56
|
"""
|
|
50
57
|
Constructor.
|
|
@@ -65,6 +72,11 @@ class UnetEncoder(nn.Module):
|
|
|
65
72
|
Dropout probability, by default 0.0.
|
|
66
73
|
pool_kernel : int, optional
|
|
67
74
|
Kernel size for the max pooling layers, by default 2.
|
|
75
|
+
n2v2 : bool, optional
|
|
76
|
+
Whether to use N2V2 architecture, by default False.
|
|
77
|
+
groups : int, optional
|
|
78
|
+
Number of blocked connections from input channels to output
|
|
79
|
+
channels, by default 1.
|
|
68
80
|
"""
|
|
69
81
|
super().__init__()
|
|
70
82
|
|
|
@@ -77,7 +89,7 @@ class UnetEncoder(nn.Module):
|
|
|
77
89
|
encoder_blocks = []
|
|
78
90
|
|
|
79
91
|
for n in range(depth):
|
|
80
|
-
out_channels = num_channels_init * (2**n)
|
|
92
|
+
out_channels = num_channels_init * (2**n) * groups
|
|
81
93
|
in_channels = in_channels if n == 0 else out_channels // 2
|
|
82
94
|
encoder_blocks.append(
|
|
83
95
|
Conv_Block(
|
|
@@ -86,6 +98,7 @@ class UnetEncoder(nn.Module):
|
|
|
86
98
|
out_channels=out_channels,
|
|
87
99
|
dropout_perc=dropout,
|
|
88
100
|
use_batch_norm=use_batch_norm,
|
|
101
|
+
groups=groups,
|
|
89
102
|
)
|
|
90
103
|
)
|
|
91
104
|
encoder_blocks.append(self.pooling)
|
|
@@ -131,6 +144,11 @@ class UnetDecoder(nn.Module):
|
|
|
131
144
|
Whether to use batch normalization, by default True.
|
|
132
145
|
dropout : float, optional
|
|
133
146
|
Dropout probability, by default 0.0.
|
|
147
|
+
n2v2 : bool, optional
|
|
148
|
+
Whether to use N2V2 architecture, by default False.
|
|
149
|
+
groups : int, optional
|
|
150
|
+
Number of blocked connections from input channels to output
|
|
151
|
+
channels, by default 1.
|
|
134
152
|
"""
|
|
135
153
|
|
|
136
154
|
def __init__(
|
|
@@ -141,6 +159,7 @@ class UnetDecoder(nn.Module):
|
|
|
141
159
|
use_batch_norm: bool = True,
|
|
142
160
|
dropout: float = 0.0,
|
|
143
161
|
n2v2: bool = False,
|
|
162
|
+
groups: int = 1,
|
|
144
163
|
) -> None:
|
|
145
164
|
"""
|
|
146
165
|
Constructor.
|
|
@@ -157,15 +176,21 @@ class UnetDecoder(nn.Module):
|
|
|
157
176
|
Whether to use batch normalization, by default True.
|
|
158
177
|
dropout : float, optional
|
|
159
178
|
Dropout probability, by default 0.0.
|
|
179
|
+
n2v2 : bool, optional
|
|
180
|
+
Whether to use N2V2 architecture, by default False.
|
|
181
|
+
groups : int, optional
|
|
182
|
+
Number of blocked connections from input channels to output
|
|
183
|
+
channels, by default 1.
|
|
160
184
|
"""
|
|
161
185
|
super().__init__()
|
|
162
186
|
|
|
163
187
|
upsampling = nn.Upsample(
|
|
164
188
|
scale_factor=2, mode="bilinear" if conv_dim == 2 else "trilinear"
|
|
165
189
|
)
|
|
166
|
-
in_channels = out_channels = num_channels_init * 2 ** (depth - 1)
|
|
190
|
+
in_channels = out_channels = num_channels_init * groups * (2 ** (depth - 1))
|
|
167
191
|
|
|
168
192
|
self.n2v2 = n2v2
|
|
193
|
+
self.groups = groups
|
|
169
194
|
|
|
170
195
|
self.bottleneck = Conv_Block(
|
|
171
196
|
conv_dim,
|
|
@@ -174,34 +199,32 @@ class UnetDecoder(nn.Module):
|
|
|
174
199
|
intermediate_channel_multiplier=2,
|
|
175
200
|
use_batch_norm=use_batch_norm,
|
|
176
201
|
dropout_perc=dropout,
|
|
202
|
+
groups=self.groups,
|
|
177
203
|
)
|
|
178
204
|
|
|
179
|
-
decoder_blocks = []
|
|
205
|
+
decoder_blocks: List[nn.Module] = []
|
|
180
206
|
for n in range(depth):
|
|
181
207
|
decoder_blocks.append(upsampling)
|
|
182
|
-
in_channels = (
|
|
183
|
-
num_channels_init ** (depth - n)
|
|
184
|
-
if (self.n2v2 and n == depth - 1)
|
|
185
|
-
else num_channels_init * 2 ** (depth - n)
|
|
186
|
-
)
|
|
208
|
+
in_channels = (num_channels_init * 2 ** (depth - n)) * groups
|
|
187
209
|
out_channels = in_channels // 2
|
|
188
210
|
decoder_blocks.append(
|
|
189
211
|
Conv_Block(
|
|
190
212
|
conv_dim,
|
|
191
|
-
in_channels=
|
|
192
|
-
|
|
193
|
-
|
|
213
|
+
in_channels=(
|
|
214
|
+
in_channels + in_channels // 2 if n > 0 else in_channels
|
|
215
|
+
),
|
|
194
216
|
out_channels=out_channels,
|
|
195
217
|
intermediate_channel_multiplier=2,
|
|
196
218
|
dropout_perc=dropout,
|
|
197
219
|
activation="ReLU",
|
|
198
220
|
use_batch_norm=use_batch_norm,
|
|
221
|
+
groups=groups,
|
|
199
222
|
)
|
|
200
223
|
)
|
|
201
224
|
|
|
202
225
|
self.decoder_blocks = nn.ModuleList(decoder_blocks)
|
|
203
226
|
|
|
204
|
-
def forward(self, *features:
|
|
227
|
+
def forward(self, *features: torch.Tensor) -> torch.Tensor:
|
|
205
228
|
"""
|
|
206
229
|
Forward pass.
|
|
207
230
|
|
|
@@ -217,20 +240,73 @@ class UnetDecoder(nn.Module):
|
|
|
217
240
|
Output of the decoder.
|
|
218
241
|
"""
|
|
219
242
|
x: torch.Tensor = features[0]
|
|
220
|
-
skip_connections: torch.Tensor = features[1:
|
|
243
|
+
skip_connections: Tuple[torch.Tensor, ...] = features[-1:0:-1]
|
|
221
244
|
|
|
222
245
|
x = self.bottleneck(x)
|
|
223
246
|
|
|
224
247
|
for i, module in enumerate(self.decoder_blocks):
|
|
225
248
|
x = module(x)
|
|
226
249
|
if isinstance(module, nn.Upsample):
|
|
250
|
+
# divide index by 2 because of upsampling layers
|
|
251
|
+
skip_connection: torch.Tensor = skip_connections[i // 2]
|
|
227
252
|
if self.n2v2:
|
|
228
253
|
if x.shape != skip_connections[-1].shape:
|
|
229
|
-
x =
|
|
254
|
+
x = self._interleave(x, skip_connection, self.groups)
|
|
230
255
|
else:
|
|
231
|
-
x =
|
|
256
|
+
x = self._interleave(x, skip_connection, self.groups)
|
|
232
257
|
return x
|
|
233
258
|
|
|
259
|
+
@staticmethod
|
|
260
|
+
def _interleave(A: torch.Tensor, B: torch.Tensor, groups: int) -> torch.Tensor:
|
|
261
|
+
"""Interleave two tensors.
|
|
262
|
+
|
|
263
|
+
Splits the tensors `A` and `B` into equally sized groups along the channel
|
|
264
|
+
axis (axis=1); then concatenates the groups in alternating order along the
|
|
265
|
+
channel axis, starting with the first group from tensor A.
|
|
266
|
+
|
|
267
|
+
Parameters
|
|
268
|
+
----------
|
|
269
|
+
A : torch.Tensor
|
|
270
|
+
First tensor.
|
|
271
|
+
B : torch.Tensor
|
|
272
|
+
Second tensor.
|
|
273
|
+
groups : int
|
|
274
|
+
The number of groups.
|
|
275
|
+
|
|
276
|
+
Returns
|
|
277
|
+
-------
|
|
278
|
+
torch.Tensor
|
|
279
|
+
Interleaved tensor.
|
|
280
|
+
|
|
281
|
+
Raises
|
|
282
|
+
------
|
|
283
|
+
ValueError:
|
|
284
|
+
If either of `A` or `B`'s channel axis is not divisible by `groups`.
|
|
285
|
+
"""
|
|
286
|
+
if (A.shape[1] % groups != 0) or (B.shape[1] % groups != 0):
|
|
287
|
+
raise ValueError(f"Number of channels not divisible by {groups} groups.")
|
|
288
|
+
|
|
289
|
+
m = A.shape[1] // groups
|
|
290
|
+
n = B.shape[1] // groups
|
|
291
|
+
|
|
292
|
+
A_groups: List[torch.Tensor] = [
|
|
293
|
+
A[:, i * m : (i + 1) * m] for i in range(groups)
|
|
294
|
+
]
|
|
295
|
+
B_groups: List[torch.Tensor] = [
|
|
296
|
+
B[:, i * n : (i + 1) * n] for i in range(groups)
|
|
297
|
+
]
|
|
298
|
+
|
|
299
|
+
interleaved = torch.cat(
|
|
300
|
+
[
|
|
301
|
+
tensor_list[i]
|
|
302
|
+
for i in range(groups)
|
|
303
|
+
for tensor_list in [A_groups, B_groups]
|
|
304
|
+
],
|
|
305
|
+
dim=1,
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
return interleaved
|
|
309
|
+
|
|
234
310
|
|
|
235
311
|
class UNet(nn.Module):
|
|
236
312
|
"""
|
|
@@ -257,8 +333,14 @@ class UNet(nn.Module):
|
|
|
257
333
|
Dropout probability, by default 0.0.
|
|
258
334
|
pool_kernel : int, optional
|
|
259
335
|
Kernel size of the pooling layers, by default 2.
|
|
260
|
-
|
|
336
|
+
final_activation : Optional[Callable], optional
|
|
261
337
|
Activation function to use for the last layer, by default None.
|
|
338
|
+
n2v2 : bool, optional
|
|
339
|
+
Whether to use N2V2 architecture, by default False.
|
|
340
|
+
independent_channels : bool
|
|
341
|
+
Whether to train the channels independently, by default True.
|
|
342
|
+
**kwargs : Any
|
|
343
|
+
Additional keyword arguments, unused.
|
|
262
344
|
"""
|
|
263
345
|
|
|
264
346
|
def __init__(
|
|
@@ -273,6 +355,7 @@ class UNet(nn.Module):
|
|
|
273
355
|
pool_kernel: int = 2,
|
|
274
356
|
final_activation: Union[SupportedActivation, str] = SupportedActivation.NONE,
|
|
275
357
|
n2v2: bool = False,
|
|
358
|
+
independent_channels: bool = True,
|
|
276
359
|
**kwargs: Any,
|
|
277
360
|
) -> None:
|
|
278
361
|
"""
|
|
@@ -296,11 +379,20 @@ class UNet(nn.Module):
|
|
|
296
379
|
Dropout probability, by default 0.0.
|
|
297
380
|
pool_kernel : int, optional
|
|
298
381
|
Kernel size of the pooling layers, by default 2.
|
|
299
|
-
|
|
382
|
+
final_activation : Optional[Callable], optional
|
|
300
383
|
Activation function to use for the last layer, by default None.
|
|
384
|
+
n2v2 : bool, optional
|
|
385
|
+
Whether to use N2V2 architecture, by default False.
|
|
386
|
+
independent_channels : bool
|
|
387
|
+
Whether to train parallel independent networks for each channel, by
|
|
388
|
+
default True.
|
|
389
|
+
**kwargs : Any
|
|
390
|
+
Additional keyword arguments, unused.
|
|
301
391
|
"""
|
|
302
392
|
super().__init__()
|
|
303
393
|
|
|
394
|
+
groups = in_channels if independent_channels else 1
|
|
395
|
+
|
|
304
396
|
self.encoder = UnetEncoder(
|
|
305
397
|
conv_dims,
|
|
306
398
|
in_channels=in_channels,
|
|
@@ -310,6 +402,7 @@ class UNet(nn.Module):
|
|
|
310
402
|
dropout=dropout,
|
|
311
403
|
pool_kernel=pool_kernel,
|
|
312
404
|
n2v2=n2v2,
|
|
405
|
+
groups=groups,
|
|
313
406
|
)
|
|
314
407
|
|
|
315
408
|
self.decoder = UnetDecoder(
|
|
@@ -319,11 +412,13 @@ class UNet(nn.Module):
|
|
|
319
412
|
use_batch_norm=use_batch_norm,
|
|
320
413
|
dropout=dropout,
|
|
321
414
|
n2v2=n2v2,
|
|
415
|
+
groups=groups,
|
|
322
416
|
)
|
|
323
417
|
self.final_conv = getattr(nn, f"Conv{conv_dims}d")(
|
|
324
|
-
in_channels=num_channels_init,
|
|
418
|
+
in_channels=num_channels_init * groups,
|
|
325
419
|
out_channels=num_classes,
|
|
326
420
|
kernel_size=1,
|
|
421
|
+
groups=groups,
|
|
327
422
|
)
|
|
328
423
|
self.final_activation = get_activation(final_activation)
|
|
329
424
|
|
|
@@ -1,8 +1,5 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Prediction convenience functions.
|
|
1
|
+
"""Prediction utility functions."""
|
|
3
2
|
|
|
4
|
-
These functions are used during prediction.
|
|
5
|
-
"""
|
|
6
3
|
from typing import List
|
|
7
4
|
|
|
8
5
|
import numpy as np
|
|
@@ -20,7 +17,7 @@ def stitch_prediction(
|
|
|
20
17
|
----------
|
|
21
18
|
tiles : List[torch.Tensor]
|
|
22
19
|
Cropped tiles and their respective stitching coordinates.
|
|
23
|
-
|
|
20
|
+
stitching_data : List
|
|
24
21
|
List of information and coordinates obtained from
|
|
25
22
|
`dataset.tiled_patching.extract_tiles`.
|
|
26
23
|
|
careamics/transforms/__init__.py
CHANGED
|
@@ -3,39 +3,18 @@
|
|
|
3
3
|
__all__ = [
|
|
4
4
|
"get_all_transforms",
|
|
5
5
|
"N2VManipulate",
|
|
6
|
-
"
|
|
6
|
+
"XYFlip",
|
|
7
7
|
"XYRandomRotate90",
|
|
8
8
|
"ImageRestorationTTA",
|
|
9
9
|
"Denormalize",
|
|
10
10
|
"Normalize",
|
|
11
|
+
"Compose",
|
|
11
12
|
]
|
|
12
13
|
|
|
13
14
|
|
|
15
|
+
from .compose import Compose, get_all_transforms
|
|
14
16
|
from .n2v_manipulate import N2VManipulate
|
|
15
|
-
from .nd_flip import NDFlip
|
|
16
17
|
from .normalize import Denormalize, Normalize
|
|
17
18
|
from .tta import ImageRestorationTTA
|
|
19
|
+
from .xy_flip import XYFlip
|
|
18
20
|
from .xy_random_rotate90 import XYRandomRotate90
|
|
19
|
-
|
|
20
|
-
ALL_TRANSFORMS = {
|
|
21
|
-
"Normalize": Normalize,
|
|
22
|
-
"N2VManipulate": N2VManipulate,
|
|
23
|
-
"NDFlip": NDFlip,
|
|
24
|
-
"XYRandomRotate90": XYRandomRotate90,
|
|
25
|
-
}
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
def get_all_transforms() -> dict:
|
|
29
|
-
"""Return all the transforms accepted by CAREamics.
|
|
30
|
-
|
|
31
|
-
Note that while CAREamics accepts any `Compose` transforms from Albumentations (see
|
|
32
|
-
https://albumentations.ai/), only a few transformations are explicitely supported
|
|
33
|
-
(see `SupportedTransform`).
|
|
34
|
-
|
|
35
|
-
Returns
|
|
36
|
-
-------
|
|
37
|
-
dict
|
|
38
|
-
A dictionary with all the transforms accepted by CAREamics, where the keys are
|
|
39
|
-
the transform names and the values are the transform classes.
|
|
40
|
-
"""
|
|
41
|
-
return ALL_TRANSFORMS
|