careamics 0.0.1__py3-none-any.whl → 0.1.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +7 -1
- careamics/bioimage/__init__.py +15 -0
- careamics/bioimage/docs/Noise2Void.md +5 -0
- careamics/bioimage/docs/__init__.py +1 -0
- careamics/bioimage/io.py +182 -0
- careamics/bioimage/rdf.py +105 -0
- careamics/config/__init__.py +11 -0
- careamics/config/algorithm.py +231 -0
- careamics/config/config.py +297 -0
- careamics/config/config_filter.py +44 -0
- careamics/config/data.py +194 -0
- careamics/config/torch_optim.py +118 -0
- careamics/config/training.py +534 -0
- careamics/dataset/__init__.py +1 -0
- careamics/dataset/dataset_utils.py +111 -0
- careamics/dataset/extraction_strategy.py +21 -0
- careamics/dataset/in_memory_dataset.py +202 -0
- careamics/dataset/patching.py +492 -0
- careamics/dataset/prepare_dataset.py +175 -0
- careamics/dataset/tiff_dataset.py +212 -0
- careamics/engine.py +1014 -0
- careamics/losses/__init__.py +4 -0
- careamics/losses/loss_factory.py +38 -0
- careamics/losses/losses.py +34 -0
- careamics/manipulation/__init__.py +4 -0
- careamics/manipulation/pixel_manipulation.py +158 -0
- careamics/models/__init__.py +4 -0
- careamics/models/layers.py +152 -0
- careamics/models/model_factory.py +251 -0
- careamics/models/unet.py +322 -0
- careamics/prediction/__init__.py +9 -0
- careamics/prediction/prediction_utils.py +106 -0
- careamics/utils/__init__.py +20 -0
- careamics/utils/ascii_logo.txt +9 -0
- careamics/utils/augment.py +65 -0
- careamics/utils/context.py +45 -0
- careamics/utils/logging.py +321 -0
- careamics/utils/metrics.py +160 -0
- careamics/utils/normalization.py +55 -0
- careamics/utils/torch_utils.py +89 -0
- careamics/utils/validators.py +170 -0
- careamics/utils/wandb.py +121 -0
- careamics-0.1.0rc2.dist-info/METADATA +81 -0
- careamics-0.1.0rc2.dist-info/RECORD +47 -0
- {careamics-0.0.1.dist-info → careamics-0.1.0rc2.dist-info}/WHEEL +1 -1
- {careamics-0.0.1.dist-info → careamics-0.1.0rc2.dist-info}/licenses/LICENSE +1 -1
- careamics-0.0.1.dist-info/METADATA +0 -46
- careamics-0.0.1.dist-info/RECORD +0 -6
careamics/models/unet.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
1
|
+
"""
|
|
2
|
+
UNet model.
|
|
3
|
+
|
|
4
|
+
A UNet encoder, decoder and complete model.
|
|
5
|
+
"""
|
|
6
|
+
from typing import Callable, List, Optional
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn as nn
|
|
10
|
+
|
|
11
|
+
from .layers import Conv_Block
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class UnetEncoder(nn.Module):
|
|
15
|
+
"""
|
|
16
|
+
Unet encoder pathway.
|
|
17
|
+
|
|
18
|
+
Parameters
|
|
19
|
+
----------
|
|
20
|
+
conv_dim : int
|
|
21
|
+
Number of dimension of the convolution layers, 2 for 2D or 3 for 3D.
|
|
22
|
+
in_channels : int, optional
|
|
23
|
+
Number of input channels, by default 1.
|
|
24
|
+
depth : int, optional
|
|
25
|
+
Number of encoder blocks, by default 3.
|
|
26
|
+
num_channels_init : int, optional
|
|
27
|
+
Number of channels in the first encoder block, by default 64.
|
|
28
|
+
use_batch_norm : bool, optional
|
|
29
|
+
Whether to use batch normalization, by default True.
|
|
30
|
+
dropout : float, optional
|
|
31
|
+
Dropout probability, by default 0.0.
|
|
32
|
+
pool_kernel : int, optional
|
|
33
|
+
Kernel size for the max pooling layers, by default 2.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
conv_dim: int,
|
|
39
|
+
in_channels: int = 1,
|
|
40
|
+
depth: int = 3,
|
|
41
|
+
num_channels_init: int = 64,
|
|
42
|
+
use_batch_norm: bool = True,
|
|
43
|
+
dropout: float = 0.0,
|
|
44
|
+
pool_kernel: int = 2,
|
|
45
|
+
) -> None:
|
|
46
|
+
"""
|
|
47
|
+
Constructor.
|
|
48
|
+
|
|
49
|
+
Parameters
|
|
50
|
+
----------
|
|
51
|
+
conv_dim : int
|
|
52
|
+
Number of dimension of the convolution layers, 2 for 2D or 3 for 3D.
|
|
53
|
+
in_channels : int, optional
|
|
54
|
+
Number of input channels, by default 1.
|
|
55
|
+
depth : int, optional
|
|
56
|
+
Number of encoder blocks, by default 3.
|
|
57
|
+
num_channels_init : int, optional
|
|
58
|
+
Number of channels in the first encoder block, by default 64.
|
|
59
|
+
use_batch_norm : bool, optional
|
|
60
|
+
Whether to use batch normalization, by default True.
|
|
61
|
+
dropout : float, optional
|
|
62
|
+
Dropout probability, by default 0.0.
|
|
63
|
+
pool_kernel : int, optional
|
|
64
|
+
Kernel size for the max pooling layers, by default 2.
|
|
65
|
+
"""
|
|
66
|
+
super().__init__()
|
|
67
|
+
|
|
68
|
+
self.pooling = getattr(nn, f"MaxPool{conv_dim}d")(kernel_size=pool_kernel)
|
|
69
|
+
|
|
70
|
+
encoder_blocks = []
|
|
71
|
+
|
|
72
|
+
for n in range(depth):
|
|
73
|
+
out_channels = num_channels_init * (2**n)
|
|
74
|
+
in_channels = in_channels if n == 0 else out_channels // 2
|
|
75
|
+
encoder_blocks.append(
|
|
76
|
+
Conv_Block(
|
|
77
|
+
conv_dim,
|
|
78
|
+
in_channels=in_channels,
|
|
79
|
+
out_channels=out_channels,
|
|
80
|
+
dropout_perc=dropout,
|
|
81
|
+
use_batch_norm=use_batch_norm,
|
|
82
|
+
)
|
|
83
|
+
)
|
|
84
|
+
encoder_blocks.append(self.pooling)
|
|
85
|
+
|
|
86
|
+
self.encoder_blocks = nn.ModuleList(encoder_blocks)
|
|
87
|
+
|
|
88
|
+
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
|
89
|
+
"""
|
|
90
|
+
Forward pass.
|
|
91
|
+
|
|
92
|
+
Parameters
|
|
93
|
+
----------
|
|
94
|
+
x : torch.Tensor
|
|
95
|
+
Input tensor.
|
|
96
|
+
|
|
97
|
+
Returns
|
|
98
|
+
-------
|
|
99
|
+
List[torch.Tensor]
|
|
100
|
+
Output of each encoder block (skip connections) and final output of the
|
|
101
|
+
encoder.
|
|
102
|
+
"""
|
|
103
|
+
encoder_features = []
|
|
104
|
+
for module in self.encoder_blocks:
|
|
105
|
+
x = module(x)
|
|
106
|
+
if isinstance(module, Conv_Block):
|
|
107
|
+
encoder_features.append(x)
|
|
108
|
+
features = [x, *encoder_features]
|
|
109
|
+
return features
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class UnetDecoder(nn.Module):
|
|
113
|
+
"""
|
|
114
|
+
Unet decoder pathway.
|
|
115
|
+
|
|
116
|
+
Parameters
|
|
117
|
+
----------
|
|
118
|
+
conv_dim : int
|
|
119
|
+
Number of dimension of the convolution layers, 2 for 2D or 3 for 3D.
|
|
120
|
+
depth : int, optional
|
|
121
|
+
Number of decoder blocks, by default 3.
|
|
122
|
+
num_channels_init : int, optional
|
|
123
|
+
Number of channels in the first encoder block, by default 64.
|
|
124
|
+
use_batch_norm : bool, optional
|
|
125
|
+
Whether to use batch normalization, by default True.
|
|
126
|
+
dropout : float, optional
|
|
127
|
+
Dropout probability, by default 0.0.
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
def __init__(
|
|
131
|
+
self,
|
|
132
|
+
conv_dim: int,
|
|
133
|
+
depth: int = 3,
|
|
134
|
+
num_channels_init: int = 64,
|
|
135
|
+
use_batch_norm: bool = True,
|
|
136
|
+
dropout: float = 0.0,
|
|
137
|
+
) -> None:
|
|
138
|
+
"""
|
|
139
|
+
Constructor.
|
|
140
|
+
|
|
141
|
+
Parameters
|
|
142
|
+
----------
|
|
143
|
+
conv_dim : int
|
|
144
|
+
Number of dimension of the convolution layers, 2 for 2D or 3 for 3D.
|
|
145
|
+
depth : int, optional
|
|
146
|
+
Number of decoder blocks, by default 3.
|
|
147
|
+
num_channels_init : int, optional
|
|
148
|
+
Number of channels in the first encoder block, by default 64.
|
|
149
|
+
use_batch_norm : bool, optional
|
|
150
|
+
Whether to use batch normalization, by default True.
|
|
151
|
+
dropout : float, optional
|
|
152
|
+
Dropout probability, by default 0.0.
|
|
153
|
+
"""
|
|
154
|
+
super().__init__()
|
|
155
|
+
|
|
156
|
+
upsampling = nn.Upsample(
|
|
157
|
+
scale_factor=2, mode="bilinear" if conv_dim == 2 else "trilinear"
|
|
158
|
+
)
|
|
159
|
+
in_channels = out_channels = num_channels_init * 2 ** (depth - 1)
|
|
160
|
+
self.bottleneck = Conv_Block(
|
|
161
|
+
conv_dim,
|
|
162
|
+
in_channels=in_channels,
|
|
163
|
+
out_channels=out_channels,
|
|
164
|
+
intermediate_channel_multiplier=2,
|
|
165
|
+
use_batch_norm=use_batch_norm,
|
|
166
|
+
dropout_perc=dropout,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
decoder_blocks = []
|
|
170
|
+
for n in range(depth):
|
|
171
|
+
decoder_blocks.append(upsampling)
|
|
172
|
+
in_channels = num_channels_init * 2 ** (depth - n)
|
|
173
|
+
out_channels = num_channels_init
|
|
174
|
+
decoder_blocks.append(
|
|
175
|
+
Conv_Block(
|
|
176
|
+
conv_dim,
|
|
177
|
+
in_channels=in_channels,
|
|
178
|
+
out_channels=out_channels,
|
|
179
|
+
intermediate_channel_multiplier=2,
|
|
180
|
+
dropout_perc=dropout,
|
|
181
|
+
activation="ReLU",
|
|
182
|
+
use_batch_norm=use_batch_norm,
|
|
183
|
+
)
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
self.decoder_blocks = nn.ModuleList(decoder_blocks)
|
|
187
|
+
|
|
188
|
+
def forward(self, *features: List[torch.Tensor]) -> torch.Tensor:
|
|
189
|
+
"""
|
|
190
|
+
Forward pass.
|
|
191
|
+
|
|
192
|
+
Parameters
|
|
193
|
+
----------
|
|
194
|
+
*features : List[torch.Tensor]
|
|
195
|
+
List containing the output of each encoder block(skip connections) and final
|
|
196
|
+
output of the encoder.
|
|
197
|
+
|
|
198
|
+
Returns
|
|
199
|
+
-------
|
|
200
|
+
torch.Tensor
|
|
201
|
+
Output of the decoder.
|
|
202
|
+
"""
|
|
203
|
+
x = features[0]
|
|
204
|
+
skip_connections = features[1:][::-1]
|
|
205
|
+
x = self.bottleneck(x)
|
|
206
|
+
for i, module in enumerate(self.decoder_blocks):
|
|
207
|
+
x = module(x)
|
|
208
|
+
if isinstance(module, nn.Upsample):
|
|
209
|
+
x = torch.cat([x, skip_connections[i // 2]], axis=1)
|
|
210
|
+
return x
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
class UNet(nn.Module):
|
|
214
|
+
"""
|
|
215
|
+
UNet model.
|
|
216
|
+
|
|
217
|
+
Adapted for PyTorch from
|
|
218
|
+
https://github.com/juglab/n2v/blob/main/n2v/nets/unet_blocks.py.
|
|
219
|
+
|
|
220
|
+
Parameters
|
|
221
|
+
----------
|
|
222
|
+
conv_dim : int
|
|
223
|
+
Number of dimensions of the convolution layers (2 or 3).
|
|
224
|
+
num_classes : int, optional
|
|
225
|
+
Number of classes to predict, by default 1.
|
|
226
|
+
in_channels : int, optional
|
|
227
|
+
Number of input channels, by default 1.
|
|
228
|
+
depth : int, optional
|
|
229
|
+
Number of downsamplings, by default 3.
|
|
230
|
+
num_channels_init : int, optional
|
|
231
|
+
Number of filters in the first convolution layer, by default 64.
|
|
232
|
+
use_batch_norm : bool, optional
|
|
233
|
+
Whether to use batch normalization, by default True.
|
|
234
|
+
dropout : float, optional
|
|
235
|
+
Dropout probability, by default 0.0.
|
|
236
|
+
pool_kernel : int, optional
|
|
237
|
+
Kernel size of the pooling layers, by default 2.
|
|
238
|
+
last_activation : Optional[Callable], optional
|
|
239
|
+
Activation function to use for the last layer, by default None.
|
|
240
|
+
"""
|
|
241
|
+
|
|
242
|
+
def __init__(
|
|
243
|
+
self,
|
|
244
|
+
conv_dim: int,
|
|
245
|
+
num_classes: int = 1,
|
|
246
|
+
in_channels: int = 1,
|
|
247
|
+
depth: int = 3,
|
|
248
|
+
num_channels_init: int = 64,
|
|
249
|
+
use_batch_norm: bool = True,
|
|
250
|
+
dropout: float = 0.0,
|
|
251
|
+
pool_kernel: int = 2,
|
|
252
|
+
last_activation: Optional[Callable] = None,
|
|
253
|
+
) -> None:
|
|
254
|
+
"""
|
|
255
|
+
Constructor.
|
|
256
|
+
|
|
257
|
+
Parameters
|
|
258
|
+
----------
|
|
259
|
+
conv_dim : int
|
|
260
|
+
Number of dimensions of the convolution layers (2 or 3).
|
|
261
|
+
num_classes : int, optional
|
|
262
|
+
Number of classes to predict, by default 1.
|
|
263
|
+
in_channels : int, optional
|
|
264
|
+
Number of input channels, by default 1.
|
|
265
|
+
depth : int, optional
|
|
266
|
+
Number of downsamplings, by default 3.
|
|
267
|
+
num_channels_init : int, optional
|
|
268
|
+
Number of filters in the first convolution layer, by default 64.
|
|
269
|
+
use_batch_norm : bool, optional
|
|
270
|
+
Whether to use batch normalization, by default True.
|
|
271
|
+
dropout : float, optional
|
|
272
|
+
Dropout probability, by default 0.0.
|
|
273
|
+
pool_kernel : int, optional
|
|
274
|
+
Kernel size of the pooling layers, by default 2.
|
|
275
|
+
last_activation : Optional[Callable], optional
|
|
276
|
+
Activation function to use for the last layer, by default None.
|
|
277
|
+
"""
|
|
278
|
+
super().__init__()
|
|
279
|
+
|
|
280
|
+
self.encoder = UnetEncoder(
|
|
281
|
+
conv_dim,
|
|
282
|
+
in_channels=in_channels,
|
|
283
|
+
depth=depth,
|
|
284
|
+
num_channels_init=num_channels_init,
|
|
285
|
+
use_batch_norm=use_batch_norm,
|
|
286
|
+
dropout=dropout,
|
|
287
|
+
pool_kernel=pool_kernel,
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
self.decoder = UnetDecoder(
|
|
291
|
+
conv_dim,
|
|
292
|
+
depth=depth,
|
|
293
|
+
num_channels_init=num_channels_init,
|
|
294
|
+
use_batch_norm=use_batch_norm,
|
|
295
|
+
dropout=dropout,
|
|
296
|
+
)
|
|
297
|
+
self.final_conv = getattr(nn, f"Conv{conv_dim}d")(
|
|
298
|
+
in_channels=num_channels_init,
|
|
299
|
+
out_channels=num_classes,
|
|
300
|
+
kernel_size=1,
|
|
301
|
+
)
|
|
302
|
+
self.last_activation = last_activation if last_activation else nn.Identity()
|
|
303
|
+
|
|
304
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
305
|
+
"""
|
|
306
|
+
Forward pass.
|
|
307
|
+
|
|
308
|
+
Parameters
|
|
309
|
+
----------
|
|
310
|
+
x : torch.Tensor
|
|
311
|
+
Input tensor.
|
|
312
|
+
|
|
313
|
+
Returns
|
|
314
|
+
-------
|
|
315
|
+
torch.Tensor
|
|
316
|
+
Output of the model.
|
|
317
|
+
"""
|
|
318
|
+
encoder_features = self.encoder(x)
|
|
319
|
+
x = self.decoder(*encoder_features)
|
|
320
|
+
x = self.final_conv(x)
|
|
321
|
+
x = self.last_activation(x)
|
|
322
|
+
return x
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Prediction convenience functions.
|
|
3
|
+
|
|
4
|
+
These functions are used during prediction.
|
|
5
|
+
"""
|
|
6
|
+
from typing import List
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def stitch_prediction(
|
|
13
|
+
tiles: List[np.ndarray],
|
|
14
|
+
stitching_data: List,
|
|
15
|
+
) -> np.ndarray:
|
|
16
|
+
"""
|
|
17
|
+
Stitch tiles back together to form a full image.
|
|
18
|
+
|
|
19
|
+
Parameters
|
|
20
|
+
----------
|
|
21
|
+
tiles : List[Tuple[np.ndarray, List[int]]]
|
|
22
|
+
Cropped tiles and their respective stitching coordinates.
|
|
23
|
+
stitching_data : List
|
|
24
|
+
List of coordinates obtained from
|
|
25
|
+
dataset.tiling.compute_crop_and_stitch_coords_1d.
|
|
26
|
+
|
|
27
|
+
Returns
|
|
28
|
+
-------
|
|
29
|
+
np.ndarray
|
|
30
|
+
Full image.
|
|
31
|
+
"""
|
|
32
|
+
# Get whole sample shape
|
|
33
|
+
input_shape = stitching_data[0][0]
|
|
34
|
+
predicted_image = np.zeros(input_shape, dtype=np.float32)
|
|
35
|
+
for tile, (_, overlap_crop_coords, stitch_coords) in zip(tiles, stitching_data):
|
|
36
|
+
# Compute coordinates for cropping predicted tile
|
|
37
|
+
slices = tuple([slice(c[0], c[1]) for c in overlap_crop_coords])
|
|
38
|
+
|
|
39
|
+
# Crop predited tile according to overlap coordinates
|
|
40
|
+
cropped_tile = tile.squeeze()[slices]
|
|
41
|
+
|
|
42
|
+
# Insert cropped tile into predicted image using stitch coordinates
|
|
43
|
+
predicted_image[
|
|
44
|
+
(..., *[slice(c[0], c[1]) for c in stitch_coords])
|
|
45
|
+
] = cropped_tile
|
|
46
|
+
return predicted_image
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def tta_forward(x: torch.Tensor) -> List[torch.Tensor]:
|
|
50
|
+
"""
|
|
51
|
+
Augment 8-fold an array.
|
|
52
|
+
|
|
53
|
+
The augmentation is performed using all 90 deg rotations and their flipped version,
|
|
54
|
+
as well as the original image flipped.
|
|
55
|
+
|
|
56
|
+
Tensors should be of shape SC(Z)YX, with S and C potentially singleton dimensions.
|
|
57
|
+
|
|
58
|
+
Parameters
|
|
59
|
+
----------
|
|
60
|
+
x : torch.Tensor
|
|
61
|
+
Data to augment.
|
|
62
|
+
|
|
63
|
+
Returns
|
|
64
|
+
-------
|
|
65
|
+
List
|
|
66
|
+
Stack of augmented images.
|
|
67
|
+
"""
|
|
68
|
+
x_aug = [
|
|
69
|
+
x,
|
|
70
|
+
torch.rot90(x, 1, dims=(2, 3)),
|
|
71
|
+
torch.rot90(x, 2, dims=(2, 3)),
|
|
72
|
+
torch.rot90(x, 3, dims=(2, 3)),
|
|
73
|
+
]
|
|
74
|
+
x_aug_flip = x_aug.copy()
|
|
75
|
+
for x_ in x_aug:
|
|
76
|
+
x_aug_flip.append(torch.flip(x_, dims=(1, 3)))
|
|
77
|
+
return x_aug_flip
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def tta_backward(x_aug: List[torch.Tensor]) -> np.ndarray:
|
|
81
|
+
"""
|
|
82
|
+
Invert `tta_forward` and average the 8 images.
|
|
83
|
+
|
|
84
|
+
The function takes a list of torch tensors and returns a numpy array.
|
|
85
|
+
|
|
86
|
+
Parameters
|
|
87
|
+
----------
|
|
88
|
+
x_aug : List[torch.Tensor]
|
|
89
|
+
Stack of 8-fold augmented images.
|
|
90
|
+
|
|
91
|
+
Returns
|
|
92
|
+
-------
|
|
93
|
+
np.ndarray
|
|
94
|
+
Average of de-augmented x_aug.
|
|
95
|
+
"""
|
|
96
|
+
x_deaug = [
|
|
97
|
+
x_aug[0].numpy(),
|
|
98
|
+
np.rot90(x_aug[1], -1, axes=(2, 3)),
|
|
99
|
+
np.rot90(x_aug[2], -2, axes=(2, 3)),
|
|
100
|
+
np.rot90(x_aug[3], -3, axes=(2, 3)),
|
|
101
|
+
np.flip(x_aug[4].numpy(), axis=(1, 3)),
|
|
102
|
+
np.rot90(np.flip(x_aug[5].numpy(), axis=(1, 3)), -1, axes=(2, 3)),
|
|
103
|
+
np.rot90(np.flip(x_aug[6].numpy(), axis=(1, 3)), -2, axes=(2, 3)),
|
|
104
|
+
np.rot90(np.flip(x_aug[7].numpy(), axis=(1, 3)), -3, axes=(2, 3)),
|
|
105
|
+
]
|
|
106
|
+
return np.mean(x_deaug, 0)
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""Utils module."""
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
"denormalize",
|
|
6
|
+
"normalize",
|
|
7
|
+
"get_device",
|
|
8
|
+
"check_axes_validity",
|
|
9
|
+
"add_axes",
|
|
10
|
+
"check_tiling_validity",
|
|
11
|
+
"cwd",
|
|
12
|
+
"MetricTracker",
|
|
13
|
+
]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
from .context import cwd
|
|
17
|
+
from .metrics import MetricTracker
|
|
18
|
+
from .normalization import denormalize, normalize
|
|
19
|
+
from .torch_utils import get_device
|
|
20
|
+
from .validators import add_axes, check_axes_validity, check_tiling_validity
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
...... ...... ........ ........ ....
|
|
2
|
+
-+++----+- -+++--+++- :+++---+++: :+++----- .--:
|
|
3
|
+
.+++ .: +++. .+++. :+++ :+++ :+++ :------. .---:----..:----. :--- :----: :----:.
|
|
4
|
+
.+++ .+++. .+++. :+++ -++= :+++ +=....=+++ :+++-..=+++-..=++= -+++ .+++-..++ +++-..=+.
|
|
5
|
+
.+++ .++++++++++. :++++++++=. :++++++: .+++. :+++ :+++ -+++ -+++ :+++ .+++=.
|
|
6
|
+
.+++ .+++. .+++. :+++ -+++ :+++ :=++==++++. :+++ :+++ -+++ -+++ :+++ .-=+++=:
|
|
7
|
+
.+++ .. .+++. .+++. :+++ :+++ :+++ .+++. .+++. :+++ :+++ -+++ -+++ :+++ .. .. :+++.
|
|
8
|
+
-++=-::-+= .+++. .+++. :+++ :+++ :+++-:::: =++=--=+++. :+++ :+++ -+++ -+++ =++=:-+= =+-:=++=
|
|
9
|
+
...... ... ... ... ... ........ .... ... ... ... .... .... .... .....
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
"""Augmentation module."""
|
|
2
|
+
from typing import Tuple
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
# TODO: unused?
|
|
8
|
+
def _flip_and_rotate(
|
|
9
|
+
image: np.ndarray, rotate_state: int, flip_state: int
|
|
10
|
+
) -> np.ndarray:
|
|
11
|
+
"""
|
|
12
|
+
Apply the given number of 90 degrees rotations and flip to an array.
|
|
13
|
+
|
|
14
|
+
Parameters
|
|
15
|
+
----------
|
|
16
|
+
image : np.ndarray
|
|
17
|
+
Array containing single image or patch, 2D or 3D.
|
|
18
|
+
rotate_state : int
|
|
19
|
+
Number of 90 degree rotations to apply.
|
|
20
|
+
flip_state : int
|
|
21
|
+
0 or 1, whether to flip the array or not.
|
|
22
|
+
|
|
23
|
+
Returns
|
|
24
|
+
-------
|
|
25
|
+
np.ndarray
|
|
26
|
+
Flipped and rotated array.
|
|
27
|
+
"""
|
|
28
|
+
rotated = np.rot90(image, k=rotate_state, axes=(-2, -1))
|
|
29
|
+
flipped = np.flip(rotated, axis=-1) if flip_state == 1 else rotated
|
|
30
|
+
return flipped.copy()
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def augment_batch(
|
|
34
|
+
patch: np.ndarray,
|
|
35
|
+
original_image: np.ndarray,
|
|
36
|
+
mask: np.ndarray,
|
|
37
|
+
seed: int = 42,
|
|
38
|
+
) -> Tuple[np.ndarray, ...]:
|
|
39
|
+
"""
|
|
40
|
+
Apply augmentation function to patches and masks.
|
|
41
|
+
|
|
42
|
+
Parameters
|
|
43
|
+
----------
|
|
44
|
+
patch : np.ndarray
|
|
45
|
+
Array containing single image or patch, 2D or 3D with masked pixels.
|
|
46
|
+
original_image : np.ndarray
|
|
47
|
+
Array containing original image or patch, 2D or 3D.
|
|
48
|
+
mask : np.ndarray
|
|
49
|
+
Array containing only masked pixels, 2D or 3D.
|
|
50
|
+
seed : int, optional
|
|
51
|
+
Seed for random number generator, controls the rotation and falipping.
|
|
52
|
+
|
|
53
|
+
Returns
|
|
54
|
+
-------
|
|
55
|
+
Tuple[np.ndarray, ...]
|
|
56
|
+
Tuple of augmented arrays.
|
|
57
|
+
"""
|
|
58
|
+
rng = np.random.default_rng(seed=seed)
|
|
59
|
+
rotate_state = rng.integers(0, 4)
|
|
60
|
+
flip_state = rng.integers(0, 2)
|
|
61
|
+
return (
|
|
62
|
+
_flip_and_rotate(patch, rotate_state, flip_state),
|
|
63
|
+
_flip_and_rotate(original_image, rotate_state, flip_state),
|
|
64
|
+
_flip_and_rotate(mask, rotate_state, flip_state),
|
|
65
|
+
)
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Context submodule.
|
|
3
|
+
|
|
4
|
+
A convenience function to change the working directory in order to save data.
|
|
5
|
+
"""
|
|
6
|
+
import os
|
|
7
|
+
from contextlib import contextmanager
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Iterator, Union
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@contextmanager
|
|
13
|
+
def cwd(path: Union[str, Path]) -> Iterator[None]:
|
|
14
|
+
"""
|
|
15
|
+
Change the current working directory to the given path.
|
|
16
|
+
|
|
17
|
+
This method can be used to generate files in a specific directory, once out of the
|
|
18
|
+
context, the working directory is set back to the original one.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
path : Union[str,Path]
|
|
23
|
+
New working directory path.
|
|
24
|
+
|
|
25
|
+
Returns
|
|
26
|
+
-------
|
|
27
|
+
Iterator[None]
|
|
28
|
+
None values.
|
|
29
|
+
|
|
30
|
+
Examples
|
|
31
|
+
--------
|
|
32
|
+
>>> with cwd(path):
|
|
33
|
+
... pass
|
|
34
|
+
"""
|
|
35
|
+
path = Path(path)
|
|
36
|
+
|
|
37
|
+
if not path.exists():
|
|
38
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
39
|
+
|
|
40
|
+
old_pwd = Path(".").absolute()
|
|
41
|
+
os.chdir(path)
|
|
42
|
+
try:
|
|
43
|
+
yield
|
|
44
|
+
finally:
|
|
45
|
+
os.chdir(old_pwd)
|