careamics 0.1.0rc1__py3-none-any.whl → 0.1.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +14 -4
- careamics/callbacks/__init__.py +6 -0
- careamics/callbacks/hyperparameters_callback.py +42 -0
- careamics/callbacks/progress_bar_callback.py +57 -0
- careamics/careamist.py +761 -0
- careamics/config/__init__.py +27 -3
- careamics/config/algorithm_model.py +167 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +29 -0
- careamics/config/architectures/custom_model.py +150 -0
- careamics/config/architectures/register_model.py +101 -0
- careamics/config/architectures/unet_model.py +96 -0
- careamics/config/architectures/vae_model.py +39 -0
- careamics/config/callback_model.py +92 -0
- careamics/config/configuration_factory.py +460 -0
- careamics/config/configuration_model.py +596 -0
- careamics/config/data_model.py +555 -0
- careamics/config/inference_model.py +283 -0
- careamics/config/noise_models.py +162 -0
- careamics/config/optimizer_models.py +181 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +131 -0
- careamics/config/references/references.py +38 -0
- careamics/config/support/__init__.py +33 -0
- careamics/config/support/supported_activations.py +24 -0
- careamics/config/support/supported_algorithms.py +18 -0
- careamics/config/support/supported_architectures.py +18 -0
- careamics/config/support/supported_data.py +82 -0
- careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
- careamics/config/support/supported_loggers.py +8 -0
- careamics/config/support/supported_losses.py +25 -0
- careamics/config/support/supported_optimizers.py +55 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +19 -0
- careamics/config/support/supported_transforms.py +23 -0
- careamics/config/tile_information.py +104 -0
- careamics/config/training_model.py +65 -0
- careamics/config/transformations/__init__.py +14 -0
- careamics/config/transformations/n2v_manipulate_model.py +63 -0
- careamics/config/transformations/nd_flip_model.py +32 -0
- careamics/config/transformations/normalize_model.py +31 -0
- careamics/config/transformations/transform_model.py +44 -0
- careamics/config/transformations/xy_random_rotate90_model.py +29 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +100 -0
- careamics/conftest.py +26 -0
- careamics/dataset/__init__.py +5 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +100 -0
- careamics/dataset/dataset_utils/file_utils.py +140 -0
- careamics/dataset/dataset_utils/read_tiff.py +61 -0
- careamics/dataset/dataset_utils/read_utils.py +25 -0
- careamics/dataset/dataset_utils/read_zarr.py +56 -0
- careamics/dataset/in_memory_dataset.py +321 -131
- careamics/dataset/iterable_dataset.py +416 -0
- careamics/dataset/patching/__init__.py +8 -0
- careamics/dataset/patching/patch_transform.py +44 -0
- careamics/dataset/patching/patching.py +212 -0
- careamics/dataset/patching/random_patching.py +190 -0
- careamics/dataset/patching/sequential_patching.py +206 -0
- careamics/dataset/patching/tiled_patching.py +158 -0
- careamics/dataset/patching/validate_patch_dimension.py +60 -0
- careamics/dataset/zarr_dataset.py +149 -0
- careamics/lightning_datamodule.py +665 -0
- careamics/lightning_module.py +292 -0
- careamics/lightning_prediction_datamodule.py +390 -0
- careamics/lightning_prediction_loop.py +116 -0
- careamics/losses/__init__.py +4 -1
- careamics/losses/loss_factory.py +24 -13
- careamics/losses/losses.py +65 -5
- careamics/losses/noise_model_factory.py +40 -0
- careamics/losses/noise_models.py +524 -0
- careamics/model_io/__init__.py +8 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +120 -0
- careamics/model_io/bioimage/bioimage_utils.py +48 -0
- careamics/model_io/bioimage/model_description.py +318 -0
- careamics/model_io/bmz_io.py +231 -0
- careamics/model_io/model_io_utils.py +80 -0
- careamics/models/__init__.py +4 -1
- careamics/models/activation.py +35 -0
- careamics/models/layers.py +244 -0
- careamics/models/model_factory.py +21 -202
- careamics/models/unet.py +46 -20
- careamics/prediction/__init__.py +1 -3
- careamics/prediction/stitch_prediction.py +73 -0
- careamics/transforms/__init__.py +41 -0
- careamics/transforms/n2v_manipulate.py +113 -0
- careamics/transforms/nd_flip.py +93 -0
- careamics/transforms/normalize.py +109 -0
- careamics/transforms/pixel_manipulation.py +383 -0
- careamics/transforms/struct_mask_parameters.py +18 -0
- careamics/transforms/tta.py +74 -0
- careamics/transforms/xy_random_rotate90.py +95 -0
- careamics/utils/__init__.py +10 -13
- careamics/utils/base_enum.py +32 -0
- careamics/utils/context.py +22 -2
- careamics/utils/metrics.py +0 -46
- careamics/utils/path_utils.py +24 -0
- careamics/utils/ram.py +13 -0
- careamics/utils/receptive_field.py +102 -0
- careamics/utils/running_stats.py +43 -0
- careamics/utils/torch_utils.py +89 -56
- careamics-0.1.0rc3.dist-info/METADATA +122 -0
- careamics-0.1.0rc3.dist-info/RECORD +109 -0
- {careamics-0.1.0rc1.dist-info → careamics-0.1.0rc3.dist-info}/WHEEL +1 -1
- careamics/bioimage/__init__.py +0 -15
- careamics/bioimage/docs/Noise2Void.md +0 -5
- careamics/bioimage/docs/__init__.py +0 -1
- careamics/bioimage/io.py +0 -271
- careamics/config/algorithm.py +0 -231
- careamics/config/config.py +0 -296
- careamics/config/config_filter.py +0 -44
- careamics/config/data.py +0 -194
- careamics/config/torch_optim.py +0 -118
- careamics/config/training.py +0 -534
- careamics/dataset/dataset_utils.py +0 -115
- careamics/dataset/patching.py +0 -493
- careamics/dataset/prepare_dataset.py +0 -174
- careamics/dataset/tiff_dataset.py +0 -211
- careamics/engine.py +0 -954
- careamics/manipulation/__init__.py +0 -4
- careamics/manipulation/pixel_manipulation.py +0 -158
- careamics/prediction/prediction_utils.py +0 -102
- careamics/utils/ascii_logo.txt +0 -9
- careamics/utils/augment.py +0 -65
- careamics/utils/normalization.py +0 -55
- careamics/utils/validators.py +0 -156
- careamics/utils/wandb.py +0 -121
- careamics-0.1.0rc1.dist-info/METADATA +0 -80
- careamics-0.1.0rc1.dist-info/RECORD +0 -46
- {careamics-0.1.0rc1.dist-info → careamics-0.1.0rc3.dist-info}/licenses/LICENSE +0 -0
careamics/models/unet.py
CHANGED
|
@@ -3,12 +3,14 @@ UNet model.
|
|
|
3
3
|
|
|
4
4
|
A UNet encoder, decoder and complete model.
|
|
5
5
|
"""
|
|
6
|
-
from typing import
|
|
6
|
+
from typing import Any, List, Union
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
9
|
import torch.nn as nn
|
|
10
10
|
|
|
11
|
-
from .
|
|
11
|
+
from ..config.support import SupportedActivation
|
|
12
|
+
from .activation import get_activation
|
|
13
|
+
from .layers import Conv_Block, MaxBlurPool
|
|
12
14
|
|
|
13
15
|
|
|
14
16
|
class UnetEncoder(nn.Module):
|
|
@@ -42,6 +44,7 @@ class UnetEncoder(nn.Module):
|
|
|
42
44
|
use_batch_norm: bool = True,
|
|
43
45
|
dropout: float = 0.0,
|
|
44
46
|
pool_kernel: int = 2,
|
|
47
|
+
n2v2: bool = False,
|
|
45
48
|
) -> None:
|
|
46
49
|
"""
|
|
47
50
|
Constructor.
|
|
@@ -65,7 +68,11 @@ class UnetEncoder(nn.Module):
|
|
|
65
68
|
"""
|
|
66
69
|
super().__init__()
|
|
67
70
|
|
|
68
|
-
self.pooling =
|
|
71
|
+
self.pooling = (
|
|
72
|
+
getattr(nn, f"MaxPool{conv_dim}d")(kernel_size=pool_kernel)
|
|
73
|
+
if not n2v2
|
|
74
|
+
else MaxBlurPool(dim=conv_dim, kernel_size=3, max_pool_size=pool_kernel)
|
|
75
|
+
)
|
|
69
76
|
|
|
70
77
|
encoder_blocks = []
|
|
71
78
|
|
|
@@ -82,7 +89,6 @@ class UnetEncoder(nn.Module):
|
|
|
82
89
|
)
|
|
83
90
|
)
|
|
84
91
|
encoder_blocks.append(self.pooling)
|
|
85
|
-
|
|
86
92
|
self.encoder_blocks = nn.ModuleList(encoder_blocks)
|
|
87
93
|
|
|
88
94
|
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
|
@@ -134,6 +140,7 @@ class UnetDecoder(nn.Module):
|
|
|
134
140
|
num_channels_init: int = 64,
|
|
135
141
|
use_batch_norm: bool = True,
|
|
136
142
|
dropout: float = 0.0,
|
|
143
|
+
n2v2: bool = False,
|
|
137
144
|
) -> None:
|
|
138
145
|
"""
|
|
139
146
|
Constructor.
|
|
@@ -157,6 +164,9 @@ class UnetDecoder(nn.Module):
|
|
|
157
164
|
scale_factor=2, mode="bilinear" if conv_dim == 2 else "trilinear"
|
|
158
165
|
)
|
|
159
166
|
in_channels = out_channels = num_channels_init * 2 ** (depth - 1)
|
|
167
|
+
|
|
168
|
+
self.n2v2 = n2v2
|
|
169
|
+
|
|
160
170
|
self.bottleneck = Conv_Block(
|
|
161
171
|
conv_dim,
|
|
162
172
|
in_channels=in_channels,
|
|
@@ -169,12 +179,18 @@ class UnetDecoder(nn.Module):
|
|
|
169
179
|
decoder_blocks = []
|
|
170
180
|
for n in range(depth):
|
|
171
181
|
decoder_blocks.append(upsampling)
|
|
172
|
-
in_channels =
|
|
173
|
-
|
|
182
|
+
in_channels = (
|
|
183
|
+
num_channels_init ** (depth - n)
|
|
184
|
+
if (self.n2v2 and n == depth - 1)
|
|
185
|
+
else num_channels_init * 2 ** (depth - n)
|
|
186
|
+
)
|
|
187
|
+
out_channels = in_channels // 2
|
|
174
188
|
decoder_blocks.append(
|
|
175
189
|
Conv_Block(
|
|
176
190
|
conv_dim,
|
|
177
|
-
in_channels=in_channels
|
|
191
|
+
in_channels=in_channels + in_channels // 2
|
|
192
|
+
if n > 0
|
|
193
|
+
else in_channels,
|
|
178
194
|
out_channels=out_channels,
|
|
179
195
|
intermediate_channel_multiplier=2,
|
|
180
196
|
dropout_perc=dropout,
|
|
@@ -200,13 +216,19 @@ class UnetDecoder(nn.Module):
|
|
|
200
216
|
torch.Tensor
|
|
201
217
|
Output of the decoder.
|
|
202
218
|
"""
|
|
203
|
-
x = features[0]
|
|
204
|
-
skip_connections = features[1:][::-1]
|
|
219
|
+
x: torch.Tensor = features[0]
|
|
220
|
+
skip_connections: torch.Tensor = features[1:][::-1]
|
|
221
|
+
|
|
205
222
|
x = self.bottleneck(x)
|
|
223
|
+
|
|
206
224
|
for i, module in enumerate(self.decoder_blocks):
|
|
207
225
|
x = module(x)
|
|
208
226
|
if isinstance(module, nn.Upsample):
|
|
209
|
-
|
|
227
|
+
if self.n2v2:
|
|
228
|
+
if x.shape != skip_connections[-1].shape:
|
|
229
|
+
x = torch.cat([x, skip_connections[i // 2]], axis=1)
|
|
230
|
+
else:
|
|
231
|
+
x = torch.cat([x, skip_connections[i // 2]], axis=1)
|
|
210
232
|
return x
|
|
211
233
|
|
|
212
234
|
|
|
@@ -214,12 +236,12 @@ class UNet(nn.Module):
|
|
|
214
236
|
"""
|
|
215
237
|
UNet model.
|
|
216
238
|
|
|
217
|
-
Adapted for PyTorch from
|
|
239
|
+
Adapted for PyTorch from:
|
|
218
240
|
https://github.com/juglab/n2v/blob/main/n2v/nets/unet_blocks.py.
|
|
219
241
|
|
|
220
242
|
Parameters
|
|
221
243
|
----------
|
|
222
|
-
|
|
244
|
+
conv_dims : int
|
|
223
245
|
Number of dimensions of the convolution layers (2 or 3).
|
|
224
246
|
num_classes : int, optional
|
|
225
247
|
Number of classes to predict, by default 1.
|
|
@@ -241,7 +263,7 @@ class UNet(nn.Module):
|
|
|
241
263
|
|
|
242
264
|
def __init__(
|
|
243
265
|
self,
|
|
244
|
-
|
|
266
|
+
conv_dims: int,
|
|
245
267
|
num_classes: int = 1,
|
|
246
268
|
in_channels: int = 1,
|
|
247
269
|
depth: int = 3,
|
|
@@ -249,14 +271,16 @@ class UNet(nn.Module):
|
|
|
249
271
|
use_batch_norm: bool = True,
|
|
250
272
|
dropout: float = 0.0,
|
|
251
273
|
pool_kernel: int = 2,
|
|
252
|
-
|
|
274
|
+
final_activation: Union[SupportedActivation, str] = SupportedActivation.NONE,
|
|
275
|
+
n2v2: bool = False,
|
|
276
|
+
**kwargs: Any,
|
|
253
277
|
) -> None:
|
|
254
278
|
"""
|
|
255
279
|
Constructor.
|
|
256
280
|
|
|
257
281
|
Parameters
|
|
258
282
|
----------
|
|
259
|
-
|
|
283
|
+
conv_dims : int
|
|
260
284
|
Number of dimensions of the convolution layers (2 or 3).
|
|
261
285
|
num_classes : int, optional
|
|
262
286
|
Number of classes to predict, by default 1.
|
|
@@ -278,28 +302,30 @@ class UNet(nn.Module):
|
|
|
278
302
|
super().__init__()
|
|
279
303
|
|
|
280
304
|
self.encoder = UnetEncoder(
|
|
281
|
-
|
|
305
|
+
conv_dims,
|
|
282
306
|
in_channels=in_channels,
|
|
283
307
|
depth=depth,
|
|
284
308
|
num_channels_init=num_channels_init,
|
|
285
309
|
use_batch_norm=use_batch_norm,
|
|
286
310
|
dropout=dropout,
|
|
287
311
|
pool_kernel=pool_kernel,
|
|
312
|
+
n2v2=n2v2,
|
|
288
313
|
)
|
|
289
314
|
|
|
290
315
|
self.decoder = UnetDecoder(
|
|
291
|
-
|
|
316
|
+
conv_dims,
|
|
292
317
|
depth=depth,
|
|
293
318
|
num_channels_init=num_channels_init,
|
|
294
319
|
use_batch_norm=use_batch_norm,
|
|
295
320
|
dropout=dropout,
|
|
321
|
+
n2v2=n2v2,
|
|
296
322
|
)
|
|
297
|
-
self.final_conv = getattr(nn, f"Conv{
|
|
323
|
+
self.final_conv = getattr(nn, f"Conv{conv_dims}d")(
|
|
298
324
|
in_channels=num_channels_init,
|
|
299
325
|
out_channels=num_classes,
|
|
300
326
|
kernel_size=1,
|
|
301
327
|
)
|
|
302
|
-
self.
|
|
328
|
+
self.final_activation = get_activation(final_activation)
|
|
303
329
|
|
|
304
330
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
305
331
|
"""
|
|
@@ -318,5 +344,5 @@ class UNet(nn.Module):
|
|
|
318
344
|
encoder_features = self.encoder(x)
|
|
319
345
|
x = self.decoder(*encoder_features)
|
|
320
346
|
x = self.final_conv(x)
|
|
321
|
-
x = self.
|
|
347
|
+
x = self.final_activation(x)
|
|
322
348
|
return x
|
careamics/prediction/__init__.py
CHANGED
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Prediction convenience functions.
|
|
3
|
+
|
|
4
|
+
These functions are used during prediction.
|
|
5
|
+
"""
|
|
6
|
+
from typing import List
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def stitch_prediction(
|
|
13
|
+
tiles: List[torch.Tensor],
|
|
14
|
+
stitching_data: List[List[torch.Tensor]],
|
|
15
|
+
) -> torch.Tensor:
|
|
16
|
+
"""
|
|
17
|
+
Stitch tiles back together to form a full image.
|
|
18
|
+
|
|
19
|
+
Parameters
|
|
20
|
+
----------
|
|
21
|
+
tiles : List[torch.Tensor]
|
|
22
|
+
Cropped tiles and their respective stitching coordinates.
|
|
23
|
+
stitching_coords : List
|
|
24
|
+
List of information and coordinates obtained from
|
|
25
|
+
`dataset.tiled_patching.extract_tiles`.
|
|
26
|
+
|
|
27
|
+
Returns
|
|
28
|
+
-------
|
|
29
|
+
np.ndarray
|
|
30
|
+
Full image.
|
|
31
|
+
"""
|
|
32
|
+
# retrieve whole array size, there is two cases to consider:
|
|
33
|
+
# 1. the tiles are stored in a list
|
|
34
|
+
# 2. the tiles are stored in a list with batches along the first dim
|
|
35
|
+
if tiles[0].shape[0] > 1:
|
|
36
|
+
input_shape = np.array(
|
|
37
|
+
[el.numpy() for el in stitching_data[0][0][0]], dtype=int
|
|
38
|
+
).squeeze()
|
|
39
|
+
else:
|
|
40
|
+
input_shape = np.array(
|
|
41
|
+
[el.numpy() for el in stitching_data[0][0]], dtype=int
|
|
42
|
+
).squeeze()
|
|
43
|
+
|
|
44
|
+
# TODO should use torch.zeros instead of np.zeros
|
|
45
|
+
predicted_image = torch.Tensor(np.zeros(input_shape, dtype=np.float32))
|
|
46
|
+
|
|
47
|
+
for tile_batch, (_, overlap_crop_coords_batch, stitch_coords_batch) in zip(
|
|
48
|
+
tiles, stitching_data
|
|
49
|
+
):
|
|
50
|
+
for batch_idx in range(tile_batch.shape[0]):
|
|
51
|
+
# Compute coordinates for cropping predicted tile
|
|
52
|
+
slices = tuple(
|
|
53
|
+
[
|
|
54
|
+
slice(c[0][batch_idx], c[1][batch_idx])
|
|
55
|
+
for c in overlap_crop_coords_batch
|
|
56
|
+
]
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
# Crop predited tile according to overlap coordinates
|
|
60
|
+
cropped_tile = tile_batch[batch_idx].squeeze()[slices]
|
|
61
|
+
|
|
62
|
+
# Insert cropped tile into predicted image using stitch coordinates
|
|
63
|
+
predicted_image[
|
|
64
|
+
(
|
|
65
|
+
...,
|
|
66
|
+
*[
|
|
67
|
+
slice(c[0][batch_idx], c[1][batch_idx])
|
|
68
|
+
for c in stitch_coords_batch
|
|
69
|
+
],
|
|
70
|
+
)
|
|
71
|
+
] = cropped_tile.to(torch.float32)
|
|
72
|
+
|
|
73
|
+
return predicted_image
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""Transforms that are used to augment the data."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"get_all_transforms",
|
|
5
|
+
"N2VManipulate",
|
|
6
|
+
"NDFlip",
|
|
7
|
+
"XYRandomRotate90",
|
|
8
|
+
"ImageRestorationTTA",
|
|
9
|
+
"Denormalize",
|
|
10
|
+
"Normalize",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
from .n2v_manipulate import N2VManipulate
|
|
15
|
+
from .nd_flip import NDFlip
|
|
16
|
+
from .normalize import Denormalize, Normalize
|
|
17
|
+
from .tta import ImageRestorationTTA
|
|
18
|
+
from .xy_random_rotate90 import XYRandomRotate90
|
|
19
|
+
|
|
20
|
+
ALL_TRANSFORMS = {
|
|
21
|
+
"Normalize": Normalize,
|
|
22
|
+
"N2VManipulate": N2VManipulate,
|
|
23
|
+
"NDFlip": NDFlip,
|
|
24
|
+
"XYRandomRotate90": XYRandomRotate90,
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_all_transforms() -> dict:
|
|
29
|
+
"""Return all the transforms accepted by CAREamics.
|
|
30
|
+
|
|
31
|
+
Note that while CAREamics accepts any `Compose` transforms from Albumentations (see
|
|
32
|
+
https://albumentations.ai/), only a few transformations are explicitely supported
|
|
33
|
+
(see `SupportedTransform`).
|
|
34
|
+
|
|
35
|
+
Returns
|
|
36
|
+
-------
|
|
37
|
+
dict
|
|
38
|
+
A dictionary with all the transforms accepted by CAREamics, where the keys are
|
|
39
|
+
the transform names and the values are the transform classes.
|
|
40
|
+
"""
|
|
41
|
+
return ALL_TRANSFORMS
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
from typing import Any, Literal, Optional, Tuple
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from albumentations import ImageOnlyTransform
|
|
5
|
+
|
|
6
|
+
from careamics.config.support import SupportedPixelManipulation, SupportedStructAxis
|
|
7
|
+
|
|
8
|
+
from .pixel_manipulation import median_manipulate, uniform_manipulate
|
|
9
|
+
from .struct_mask_parameters import StructMaskParameters
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class N2VManipulate(ImageOnlyTransform):
|
|
13
|
+
"""
|
|
14
|
+
Default augmentation for the N2V model.
|
|
15
|
+
|
|
16
|
+
This transform expects (Z)YXC dimensions.
|
|
17
|
+
|
|
18
|
+
Parameters
|
|
19
|
+
----------
|
|
20
|
+
mask_pixel_percentage : float
|
|
21
|
+
Approximate percentage of pixels to be masked.
|
|
22
|
+
roi_size : int
|
|
23
|
+
Size of the ROI the new pixel value is sampled from, by default 11.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
roi_size: int = 11,
|
|
29
|
+
masked_pixel_percentage: float = 0.2,
|
|
30
|
+
strategy: Literal[
|
|
31
|
+
"uniform", "median"
|
|
32
|
+
] = SupportedPixelManipulation.UNIFORM.value,
|
|
33
|
+
remove_center: bool = True,
|
|
34
|
+
struct_mask_axis: Literal["horizontal", "vertical", "none"] = "none",
|
|
35
|
+
struct_mask_span: int = 5,
|
|
36
|
+
):
|
|
37
|
+
"""Constructor.
|
|
38
|
+
|
|
39
|
+
Parameters
|
|
40
|
+
----------
|
|
41
|
+
roi_size : int, optional
|
|
42
|
+
Size of the replacement area, by default 11
|
|
43
|
+
masked_pixel_percentage : float, optional
|
|
44
|
+
Percentage of pixels to mask, by default 0.2
|
|
45
|
+
strategy : Literal[ "uniform", "median" ], optional
|
|
46
|
+
Replaccement strategy, uniform or median, by default uniform
|
|
47
|
+
remove_center : bool, optional
|
|
48
|
+
Whether to remove central pixel from patch, by default True
|
|
49
|
+
struct_mask_axis : Literal["horizontal", "vertical", "none"], optional
|
|
50
|
+
StructN2V mask axis, by default "none"
|
|
51
|
+
struct_mask_span : int, optional
|
|
52
|
+
StructN2V mask span, by default 5
|
|
53
|
+
"""
|
|
54
|
+
super().__init__(p=1)
|
|
55
|
+
self.masked_pixel_percentage = masked_pixel_percentage
|
|
56
|
+
self.roi_size = roi_size
|
|
57
|
+
self.strategy = strategy
|
|
58
|
+
self.remove_center = remove_center
|
|
59
|
+
|
|
60
|
+
if struct_mask_axis == SupportedStructAxis.NONE:
|
|
61
|
+
self.struct_mask: Optional[StructMaskParameters] = None
|
|
62
|
+
else:
|
|
63
|
+
self.struct_mask = StructMaskParameters(
|
|
64
|
+
axis=0 if struct_mask_axis == SupportedStructAxis.HORIZONTAL else 1,
|
|
65
|
+
span=struct_mask_span,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
def apply(
|
|
69
|
+
self, patch: np.ndarray, **kwargs: Any
|
|
70
|
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
71
|
+
"""Apply the transform to the image.
|
|
72
|
+
|
|
73
|
+
Parameters
|
|
74
|
+
----------
|
|
75
|
+
image : np.ndarray
|
|
76
|
+
Image or image patch, 2D or 3D, shape (y, x, c) or (z, y, x, c).
|
|
77
|
+
"""
|
|
78
|
+
masked = np.zeros_like(patch)
|
|
79
|
+
mask = np.zeros_like(patch)
|
|
80
|
+
if self.strategy == SupportedPixelManipulation.UNIFORM:
|
|
81
|
+
# Iterate over the channels to apply manipulation separately
|
|
82
|
+
for c in range(patch.shape[-1]):
|
|
83
|
+
masked[..., c], mask[..., c] = uniform_manipulate(
|
|
84
|
+
patch=patch[..., c],
|
|
85
|
+
mask_pixel_percentage=self.masked_pixel_percentage,
|
|
86
|
+
subpatch_size=self.roi_size,
|
|
87
|
+
remove_center=self.remove_center,
|
|
88
|
+
struct_params=self.struct_mask,
|
|
89
|
+
)
|
|
90
|
+
elif self.strategy == SupportedPixelManipulation.MEDIAN:
|
|
91
|
+
# Iterate over the channels to apply manipulation separately
|
|
92
|
+
for c in range(patch.shape[-1]):
|
|
93
|
+
masked[..., c], mask[..., c] = median_manipulate(
|
|
94
|
+
patch=patch[..., c],
|
|
95
|
+
mask_pixel_percentage=self.masked_pixel_percentage,
|
|
96
|
+
subpatch_size=self.roi_size,
|
|
97
|
+
struct_params=self.struct_mask,
|
|
98
|
+
)
|
|
99
|
+
else:
|
|
100
|
+
raise ValueError(f"Unknown masking strategy ({self.strategy}).")
|
|
101
|
+
|
|
102
|
+
# TODO why return patch?
|
|
103
|
+
return masked, patch, mask
|
|
104
|
+
|
|
105
|
+
def get_transform_init_args_names(self) -> Tuple[str, ...]:
|
|
106
|
+
"""Get the transform parameters.
|
|
107
|
+
|
|
108
|
+
Returns
|
|
109
|
+
-------
|
|
110
|
+
Tuple[str, ...]
|
|
111
|
+
Transform parameters.
|
|
112
|
+
"""
|
|
113
|
+
return ("roi_size", "masked_pixel_percentage", "strategy", "struct_mask")
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
from typing import Any, Dict, Tuple
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from albumentations import DualTransform
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class NDFlip(DualTransform):
|
|
8
|
+
"""Flip ND arrays on a single axis.
|
|
9
|
+
|
|
10
|
+
This transform ignores singleton axes and randomly flips one of the other
|
|
11
|
+
axes, to the exception of the first and last axes (sample and channels).
|
|
12
|
+
|
|
13
|
+
This transform expects (Z)YXC dimensions.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(self, p: float = 0.5, is_3D: bool = False, flip_z: bool = True):
|
|
17
|
+
"""Constructor.
|
|
18
|
+
|
|
19
|
+
Parameters
|
|
20
|
+
----------
|
|
21
|
+
p : float, optional
|
|
22
|
+
Probability to apply the transform, by default 0.5
|
|
23
|
+
is_3D : bool, optional
|
|
24
|
+
Whether the data is 3D, by default False
|
|
25
|
+
flip_z : bool, optional
|
|
26
|
+
Whether to flip Z dimension, by default True
|
|
27
|
+
"""
|
|
28
|
+
super().__init__(p=p)
|
|
29
|
+
|
|
30
|
+
self.is_3D = is_3D
|
|
31
|
+
self.flip_z = flip_z
|
|
32
|
+
|
|
33
|
+
# "flippable" axes
|
|
34
|
+
if is_3D:
|
|
35
|
+
self.axis_indices = [0, 1, 2] if flip_z else [1, 2]
|
|
36
|
+
else:
|
|
37
|
+
self.axis_indices = [0, 1]
|
|
38
|
+
|
|
39
|
+
def get_params(self, **kwargs: Any) -> Dict[str, int]:
|
|
40
|
+
"""Get the transform parameters.
|
|
41
|
+
|
|
42
|
+
Returns
|
|
43
|
+
-------
|
|
44
|
+
Dict[str, int]
|
|
45
|
+
Transform parameters.
|
|
46
|
+
"""
|
|
47
|
+
return {"flip_axis": np.random.choice(self.axis_indices)}
|
|
48
|
+
|
|
49
|
+
def apply(self, patch: np.ndarray, flip_axis: int, **kwargs: Any) -> np.ndarray:
|
|
50
|
+
"""Apply the transform to the image.
|
|
51
|
+
|
|
52
|
+
Parameters
|
|
53
|
+
----------
|
|
54
|
+
patch : np.ndarray
|
|
55
|
+
Image or image patch, 2D or 3D, shape (y, x, c) or (z, y, x, c).
|
|
56
|
+
flip_axis : int
|
|
57
|
+
Axis along which to flip the patch.
|
|
58
|
+
"""
|
|
59
|
+
if len(patch.shape) == 3 and self.is_3D:
|
|
60
|
+
raise ValueError(
|
|
61
|
+
"Incompatible patch shape and dimensionality. ZYXC patch shape "
|
|
62
|
+
"expected, but got YXC shape."
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
return np.ascontiguousarray(np.flip(patch, axis=flip_axis))
|
|
66
|
+
|
|
67
|
+
def apply_to_mask(
|
|
68
|
+
self, mask: np.ndarray, flip_axis: int, **kwargs: Any
|
|
69
|
+
) -> np.ndarray:
|
|
70
|
+
"""Apply the transform to the mask.
|
|
71
|
+
|
|
72
|
+
Parameters
|
|
73
|
+
----------
|
|
74
|
+
mask : np.ndarray
|
|
75
|
+
Mask or mask patch, 2D or 3D, shape (y, x, c) or (z, y, x, c).
|
|
76
|
+
"""
|
|
77
|
+
if len(mask.shape) == 3 and self.is_3D:
|
|
78
|
+
raise ValueError(
|
|
79
|
+
"Incompatible mask shape and dimensionality. ZYXC patch shape "
|
|
80
|
+
"expected, but got YXC shape."
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
return np.ascontiguousarray(np.flip(mask, axis=flip_axis))
|
|
84
|
+
|
|
85
|
+
def get_transform_init_args_names(self, **kwargs: Any) -> Tuple[str, ...]:
|
|
86
|
+
"""Get the transform arguments names.
|
|
87
|
+
|
|
88
|
+
Returns
|
|
89
|
+
-------
|
|
90
|
+
Tuple[str, ...]
|
|
91
|
+
Transform arguments names.
|
|
92
|
+
"""
|
|
93
|
+
return ("is_3D", "flip_z")
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from albumentations import DualTransform
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Normalize(DualTransform):
|
|
8
|
+
"""
|
|
9
|
+
Normalize an image or image patch.
|
|
10
|
+
|
|
11
|
+
Normalization is a zero mean and unit variance. This transform expects (Z)YXC
|
|
12
|
+
dimensions.
|
|
13
|
+
|
|
14
|
+
Not that an epsilon value of 1e-6 is added to the standard deviation to avoid
|
|
15
|
+
division by zero and that it returns a float32 image.
|
|
16
|
+
|
|
17
|
+
Attributes
|
|
18
|
+
----------
|
|
19
|
+
mean : float
|
|
20
|
+
Mean value.
|
|
21
|
+
std : float
|
|
22
|
+
Standard deviation value.
|
|
23
|
+
eps : float
|
|
24
|
+
Epsilon value to avoid division by zero.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
mean: float,
|
|
30
|
+
std: float,
|
|
31
|
+
):
|
|
32
|
+
super().__init__(always_apply=True, p=1)
|
|
33
|
+
|
|
34
|
+
self.mean = mean
|
|
35
|
+
self.std = std
|
|
36
|
+
self.eps = 1e-6
|
|
37
|
+
|
|
38
|
+
def apply(self, patch: np.ndarray, **kwargs: Any) -> np.ndarray:
|
|
39
|
+
"""
|
|
40
|
+
Apply the transform to the image.
|
|
41
|
+
|
|
42
|
+
Parameters
|
|
43
|
+
----------
|
|
44
|
+
patch : np.ndarray
|
|
45
|
+
Image or image patch, 2D or 3D, shape (y, x, c) or (z, y, x, c).
|
|
46
|
+
|
|
47
|
+
Returns
|
|
48
|
+
-------
|
|
49
|
+
np.ndarray
|
|
50
|
+
Normalized image or image patch.
|
|
51
|
+
"""
|
|
52
|
+
return ((patch - self.mean) / (self.std + self.eps)).astype(np.float32)
|
|
53
|
+
|
|
54
|
+
def apply_to_mask(self, mask: np.ndarray, **kwargs: Any) -> np.ndarray:
|
|
55
|
+
"""
|
|
56
|
+
Apply the transform to the mask.
|
|
57
|
+
|
|
58
|
+
The mask is returned as is.
|
|
59
|
+
|
|
60
|
+
Parameters
|
|
61
|
+
----------
|
|
62
|
+
mask : np.ndarray
|
|
63
|
+
Mask or mask patch, 2D or 3D, shape (y, x, c) or (z, y, x, c).
|
|
64
|
+
"""
|
|
65
|
+
return mask
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class Denormalize(DualTransform):
|
|
69
|
+
"""
|
|
70
|
+
Denormalize an image or image patch.
|
|
71
|
+
|
|
72
|
+
Denormalization is performed expecting a zero mean and unit variance input. This
|
|
73
|
+
transform expects (Z)YXC dimensions.
|
|
74
|
+
|
|
75
|
+
Not that an epsilon value of 1e-6 is added to the standard deviation to avoid
|
|
76
|
+
division by zero during the normalization step, which is taken into account during
|
|
77
|
+
denormalization.
|
|
78
|
+
|
|
79
|
+
Attributes
|
|
80
|
+
----------
|
|
81
|
+
mean : float
|
|
82
|
+
Mean value.
|
|
83
|
+
std : float
|
|
84
|
+
Standard deviation value.
|
|
85
|
+
eps : float
|
|
86
|
+
Epsilon value to avoid division by zero.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
def __init__(
|
|
90
|
+
self,
|
|
91
|
+
mean: float,
|
|
92
|
+
std: float,
|
|
93
|
+
):
|
|
94
|
+
super().__init__(always_apply=True, p=1)
|
|
95
|
+
|
|
96
|
+
self.mean = mean
|
|
97
|
+
self.std = std
|
|
98
|
+
self.eps = 1e-6
|
|
99
|
+
|
|
100
|
+
def apply(self, patch: np.ndarray, **kwargs: Any) -> np.ndarray:
|
|
101
|
+
"""
|
|
102
|
+
Apply the transform to the image.
|
|
103
|
+
|
|
104
|
+
Parameters
|
|
105
|
+
----------
|
|
106
|
+
patch : np.ndarray
|
|
107
|
+
Image or image patch, 2D or 3D, shape (y, x, c) or (z, y, x, c).
|
|
108
|
+
"""
|
|
109
|
+
return patch * (self.std + self.eps) + self.mean
|