careamics 0.0.1__py3-none-any.whl → 0.1.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (48) hide show
  1. careamics/__init__.py +7 -1
  2. careamics/bioimage/__init__.py +15 -0
  3. careamics/bioimage/docs/Noise2Void.md +5 -0
  4. careamics/bioimage/docs/__init__.py +1 -0
  5. careamics/bioimage/io.py +182 -0
  6. careamics/bioimage/rdf.py +105 -0
  7. careamics/config/__init__.py +11 -0
  8. careamics/config/algorithm.py +231 -0
  9. careamics/config/config.py +297 -0
  10. careamics/config/config_filter.py +44 -0
  11. careamics/config/data.py +194 -0
  12. careamics/config/torch_optim.py +118 -0
  13. careamics/config/training.py +534 -0
  14. careamics/dataset/__init__.py +1 -0
  15. careamics/dataset/dataset_utils.py +111 -0
  16. careamics/dataset/extraction_strategy.py +21 -0
  17. careamics/dataset/in_memory_dataset.py +202 -0
  18. careamics/dataset/patching.py +492 -0
  19. careamics/dataset/prepare_dataset.py +175 -0
  20. careamics/dataset/tiff_dataset.py +212 -0
  21. careamics/engine.py +1014 -0
  22. careamics/losses/__init__.py +4 -0
  23. careamics/losses/loss_factory.py +38 -0
  24. careamics/losses/losses.py +34 -0
  25. careamics/manipulation/__init__.py +4 -0
  26. careamics/manipulation/pixel_manipulation.py +158 -0
  27. careamics/models/__init__.py +4 -0
  28. careamics/models/layers.py +152 -0
  29. careamics/models/model_factory.py +251 -0
  30. careamics/models/unet.py +322 -0
  31. careamics/prediction/__init__.py +9 -0
  32. careamics/prediction/prediction_utils.py +106 -0
  33. careamics/utils/__init__.py +20 -0
  34. careamics/utils/ascii_logo.txt +9 -0
  35. careamics/utils/augment.py +65 -0
  36. careamics/utils/context.py +45 -0
  37. careamics/utils/logging.py +321 -0
  38. careamics/utils/metrics.py +160 -0
  39. careamics/utils/normalization.py +55 -0
  40. careamics/utils/torch_utils.py +89 -0
  41. careamics/utils/validators.py +170 -0
  42. careamics/utils/wandb.py +121 -0
  43. careamics-0.1.0rc2.dist-info/METADATA +81 -0
  44. careamics-0.1.0rc2.dist-info/RECORD +47 -0
  45. {careamics-0.0.1.dist-info → careamics-0.1.0rc2.dist-info}/WHEEL +1 -1
  46. {careamics-0.0.1.dist-info → careamics-0.1.0rc2.dist-info}/licenses/LICENSE +1 -1
  47. careamics-0.0.1.dist-info/METADATA +0 -46
  48. careamics-0.0.1.dist-info/RECORD +0 -6
@@ -0,0 +1,322 @@
1
+ """
2
+ UNet model.
3
+
4
+ A UNet encoder, decoder and complete model.
5
+ """
6
+ from typing import Callable, List, Optional
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from .layers import Conv_Block
12
+
13
+
14
+ class UnetEncoder(nn.Module):
15
+ """
16
+ Unet encoder pathway.
17
+
18
+ Parameters
19
+ ----------
20
+ conv_dim : int
21
+ Number of dimension of the convolution layers, 2 for 2D or 3 for 3D.
22
+ in_channels : int, optional
23
+ Number of input channels, by default 1.
24
+ depth : int, optional
25
+ Number of encoder blocks, by default 3.
26
+ num_channels_init : int, optional
27
+ Number of channels in the first encoder block, by default 64.
28
+ use_batch_norm : bool, optional
29
+ Whether to use batch normalization, by default True.
30
+ dropout : float, optional
31
+ Dropout probability, by default 0.0.
32
+ pool_kernel : int, optional
33
+ Kernel size for the max pooling layers, by default 2.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ conv_dim: int,
39
+ in_channels: int = 1,
40
+ depth: int = 3,
41
+ num_channels_init: int = 64,
42
+ use_batch_norm: bool = True,
43
+ dropout: float = 0.0,
44
+ pool_kernel: int = 2,
45
+ ) -> None:
46
+ """
47
+ Constructor.
48
+
49
+ Parameters
50
+ ----------
51
+ conv_dim : int
52
+ Number of dimension of the convolution layers, 2 for 2D or 3 for 3D.
53
+ in_channels : int, optional
54
+ Number of input channels, by default 1.
55
+ depth : int, optional
56
+ Number of encoder blocks, by default 3.
57
+ num_channels_init : int, optional
58
+ Number of channels in the first encoder block, by default 64.
59
+ use_batch_norm : bool, optional
60
+ Whether to use batch normalization, by default True.
61
+ dropout : float, optional
62
+ Dropout probability, by default 0.0.
63
+ pool_kernel : int, optional
64
+ Kernel size for the max pooling layers, by default 2.
65
+ """
66
+ super().__init__()
67
+
68
+ self.pooling = getattr(nn, f"MaxPool{conv_dim}d")(kernel_size=pool_kernel)
69
+
70
+ encoder_blocks = []
71
+
72
+ for n in range(depth):
73
+ out_channels = num_channels_init * (2**n)
74
+ in_channels = in_channels if n == 0 else out_channels // 2
75
+ encoder_blocks.append(
76
+ Conv_Block(
77
+ conv_dim,
78
+ in_channels=in_channels,
79
+ out_channels=out_channels,
80
+ dropout_perc=dropout,
81
+ use_batch_norm=use_batch_norm,
82
+ )
83
+ )
84
+ encoder_blocks.append(self.pooling)
85
+
86
+ self.encoder_blocks = nn.ModuleList(encoder_blocks)
87
+
88
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
89
+ """
90
+ Forward pass.
91
+
92
+ Parameters
93
+ ----------
94
+ x : torch.Tensor
95
+ Input tensor.
96
+
97
+ Returns
98
+ -------
99
+ List[torch.Tensor]
100
+ Output of each encoder block (skip connections) and final output of the
101
+ encoder.
102
+ """
103
+ encoder_features = []
104
+ for module in self.encoder_blocks:
105
+ x = module(x)
106
+ if isinstance(module, Conv_Block):
107
+ encoder_features.append(x)
108
+ features = [x, *encoder_features]
109
+ return features
110
+
111
+
112
+ class UnetDecoder(nn.Module):
113
+ """
114
+ Unet decoder pathway.
115
+
116
+ Parameters
117
+ ----------
118
+ conv_dim : int
119
+ Number of dimension of the convolution layers, 2 for 2D or 3 for 3D.
120
+ depth : int, optional
121
+ Number of decoder blocks, by default 3.
122
+ num_channels_init : int, optional
123
+ Number of channels in the first encoder block, by default 64.
124
+ use_batch_norm : bool, optional
125
+ Whether to use batch normalization, by default True.
126
+ dropout : float, optional
127
+ Dropout probability, by default 0.0.
128
+ """
129
+
130
+ def __init__(
131
+ self,
132
+ conv_dim: int,
133
+ depth: int = 3,
134
+ num_channels_init: int = 64,
135
+ use_batch_norm: bool = True,
136
+ dropout: float = 0.0,
137
+ ) -> None:
138
+ """
139
+ Constructor.
140
+
141
+ Parameters
142
+ ----------
143
+ conv_dim : int
144
+ Number of dimension of the convolution layers, 2 for 2D or 3 for 3D.
145
+ depth : int, optional
146
+ Number of decoder blocks, by default 3.
147
+ num_channels_init : int, optional
148
+ Number of channels in the first encoder block, by default 64.
149
+ use_batch_norm : bool, optional
150
+ Whether to use batch normalization, by default True.
151
+ dropout : float, optional
152
+ Dropout probability, by default 0.0.
153
+ """
154
+ super().__init__()
155
+
156
+ upsampling = nn.Upsample(
157
+ scale_factor=2, mode="bilinear" if conv_dim == 2 else "trilinear"
158
+ )
159
+ in_channels = out_channels = num_channels_init * 2 ** (depth - 1)
160
+ self.bottleneck = Conv_Block(
161
+ conv_dim,
162
+ in_channels=in_channels,
163
+ out_channels=out_channels,
164
+ intermediate_channel_multiplier=2,
165
+ use_batch_norm=use_batch_norm,
166
+ dropout_perc=dropout,
167
+ )
168
+
169
+ decoder_blocks = []
170
+ for n in range(depth):
171
+ decoder_blocks.append(upsampling)
172
+ in_channels = num_channels_init * 2 ** (depth - n)
173
+ out_channels = num_channels_init
174
+ decoder_blocks.append(
175
+ Conv_Block(
176
+ conv_dim,
177
+ in_channels=in_channels,
178
+ out_channels=out_channels,
179
+ intermediate_channel_multiplier=2,
180
+ dropout_perc=dropout,
181
+ activation="ReLU",
182
+ use_batch_norm=use_batch_norm,
183
+ )
184
+ )
185
+
186
+ self.decoder_blocks = nn.ModuleList(decoder_blocks)
187
+
188
+ def forward(self, *features: List[torch.Tensor]) -> torch.Tensor:
189
+ """
190
+ Forward pass.
191
+
192
+ Parameters
193
+ ----------
194
+ *features : List[torch.Tensor]
195
+ List containing the output of each encoder block(skip connections) and final
196
+ output of the encoder.
197
+
198
+ Returns
199
+ -------
200
+ torch.Tensor
201
+ Output of the decoder.
202
+ """
203
+ x = features[0]
204
+ skip_connections = features[1:][::-1]
205
+ x = self.bottleneck(x)
206
+ for i, module in enumerate(self.decoder_blocks):
207
+ x = module(x)
208
+ if isinstance(module, nn.Upsample):
209
+ x = torch.cat([x, skip_connections[i // 2]], axis=1)
210
+ return x
211
+
212
+
213
+ class UNet(nn.Module):
214
+ """
215
+ UNet model.
216
+
217
+ Adapted for PyTorch from
218
+ https://github.com/juglab/n2v/blob/main/n2v/nets/unet_blocks.py.
219
+
220
+ Parameters
221
+ ----------
222
+ conv_dim : int
223
+ Number of dimensions of the convolution layers (2 or 3).
224
+ num_classes : int, optional
225
+ Number of classes to predict, by default 1.
226
+ in_channels : int, optional
227
+ Number of input channels, by default 1.
228
+ depth : int, optional
229
+ Number of downsamplings, by default 3.
230
+ num_channels_init : int, optional
231
+ Number of filters in the first convolution layer, by default 64.
232
+ use_batch_norm : bool, optional
233
+ Whether to use batch normalization, by default True.
234
+ dropout : float, optional
235
+ Dropout probability, by default 0.0.
236
+ pool_kernel : int, optional
237
+ Kernel size of the pooling layers, by default 2.
238
+ last_activation : Optional[Callable], optional
239
+ Activation function to use for the last layer, by default None.
240
+ """
241
+
242
+ def __init__(
243
+ self,
244
+ conv_dim: int,
245
+ num_classes: int = 1,
246
+ in_channels: int = 1,
247
+ depth: int = 3,
248
+ num_channels_init: int = 64,
249
+ use_batch_norm: bool = True,
250
+ dropout: float = 0.0,
251
+ pool_kernel: int = 2,
252
+ last_activation: Optional[Callable] = None,
253
+ ) -> None:
254
+ """
255
+ Constructor.
256
+
257
+ Parameters
258
+ ----------
259
+ conv_dim : int
260
+ Number of dimensions of the convolution layers (2 or 3).
261
+ num_classes : int, optional
262
+ Number of classes to predict, by default 1.
263
+ in_channels : int, optional
264
+ Number of input channels, by default 1.
265
+ depth : int, optional
266
+ Number of downsamplings, by default 3.
267
+ num_channels_init : int, optional
268
+ Number of filters in the first convolution layer, by default 64.
269
+ use_batch_norm : bool, optional
270
+ Whether to use batch normalization, by default True.
271
+ dropout : float, optional
272
+ Dropout probability, by default 0.0.
273
+ pool_kernel : int, optional
274
+ Kernel size of the pooling layers, by default 2.
275
+ last_activation : Optional[Callable], optional
276
+ Activation function to use for the last layer, by default None.
277
+ """
278
+ super().__init__()
279
+
280
+ self.encoder = UnetEncoder(
281
+ conv_dim,
282
+ in_channels=in_channels,
283
+ depth=depth,
284
+ num_channels_init=num_channels_init,
285
+ use_batch_norm=use_batch_norm,
286
+ dropout=dropout,
287
+ pool_kernel=pool_kernel,
288
+ )
289
+
290
+ self.decoder = UnetDecoder(
291
+ conv_dim,
292
+ depth=depth,
293
+ num_channels_init=num_channels_init,
294
+ use_batch_norm=use_batch_norm,
295
+ dropout=dropout,
296
+ )
297
+ self.final_conv = getattr(nn, f"Conv{conv_dim}d")(
298
+ in_channels=num_channels_init,
299
+ out_channels=num_classes,
300
+ kernel_size=1,
301
+ )
302
+ self.last_activation = last_activation if last_activation else nn.Identity()
303
+
304
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
305
+ """
306
+ Forward pass.
307
+
308
+ Parameters
309
+ ----------
310
+ x : torch.Tensor
311
+ Input tensor.
312
+
313
+ Returns
314
+ -------
315
+ torch.Tensor
316
+ Output of the model.
317
+ """
318
+ encoder_features = self.encoder(x)
319
+ x = self.decoder(*encoder_features)
320
+ x = self.final_conv(x)
321
+ x = self.last_activation(x)
322
+ return x
@@ -0,0 +1,9 @@
1
+ """Prediction functions."""
2
+
3
+ __all__ = [
4
+ "stitch_prediction",
5
+ "tta_backward",
6
+ "tta_forward",
7
+ ]
8
+
9
+ from .prediction_utils import stitch_prediction, tta_backward, tta_forward
@@ -0,0 +1,106 @@
1
+ """
2
+ Prediction convenience functions.
3
+
4
+ These functions are used during prediction.
5
+ """
6
+ from typing import List
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+
12
+ def stitch_prediction(
13
+ tiles: List[np.ndarray],
14
+ stitching_data: List,
15
+ ) -> np.ndarray:
16
+ """
17
+ Stitch tiles back together to form a full image.
18
+
19
+ Parameters
20
+ ----------
21
+ tiles : List[Tuple[np.ndarray, List[int]]]
22
+ Cropped tiles and their respective stitching coordinates.
23
+ stitching_data : List
24
+ List of coordinates obtained from
25
+ dataset.tiling.compute_crop_and_stitch_coords_1d.
26
+
27
+ Returns
28
+ -------
29
+ np.ndarray
30
+ Full image.
31
+ """
32
+ # Get whole sample shape
33
+ input_shape = stitching_data[0][0]
34
+ predicted_image = np.zeros(input_shape, dtype=np.float32)
35
+ for tile, (_, overlap_crop_coords, stitch_coords) in zip(tiles, stitching_data):
36
+ # Compute coordinates for cropping predicted tile
37
+ slices = tuple([slice(c[0], c[1]) for c in overlap_crop_coords])
38
+
39
+ # Crop predited tile according to overlap coordinates
40
+ cropped_tile = tile.squeeze()[slices]
41
+
42
+ # Insert cropped tile into predicted image using stitch coordinates
43
+ predicted_image[
44
+ (..., *[slice(c[0], c[1]) for c in stitch_coords])
45
+ ] = cropped_tile
46
+ return predicted_image
47
+
48
+
49
+ def tta_forward(x: torch.Tensor) -> List[torch.Tensor]:
50
+ """
51
+ Augment 8-fold an array.
52
+
53
+ The augmentation is performed using all 90 deg rotations and their flipped version,
54
+ as well as the original image flipped.
55
+
56
+ Tensors should be of shape SC(Z)YX, with S and C potentially singleton dimensions.
57
+
58
+ Parameters
59
+ ----------
60
+ x : torch.Tensor
61
+ Data to augment.
62
+
63
+ Returns
64
+ -------
65
+ List
66
+ Stack of augmented images.
67
+ """
68
+ x_aug = [
69
+ x,
70
+ torch.rot90(x, 1, dims=(2, 3)),
71
+ torch.rot90(x, 2, dims=(2, 3)),
72
+ torch.rot90(x, 3, dims=(2, 3)),
73
+ ]
74
+ x_aug_flip = x_aug.copy()
75
+ for x_ in x_aug:
76
+ x_aug_flip.append(torch.flip(x_, dims=(1, 3)))
77
+ return x_aug_flip
78
+
79
+
80
+ def tta_backward(x_aug: List[torch.Tensor]) -> np.ndarray:
81
+ """
82
+ Invert `tta_forward` and average the 8 images.
83
+
84
+ The function takes a list of torch tensors and returns a numpy array.
85
+
86
+ Parameters
87
+ ----------
88
+ x_aug : List[torch.Tensor]
89
+ Stack of 8-fold augmented images.
90
+
91
+ Returns
92
+ -------
93
+ np.ndarray
94
+ Average of de-augmented x_aug.
95
+ """
96
+ x_deaug = [
97
+ x_aug[0].numpy(),
98
+ np.rot90(x_aug[1], -1, axes=(2, 3)),
99
+ np.rot90(x_aug[2], -2, axes=(2, 3)),
100
+ np.rot90(x_aug[3], -3, axes=(2, 3)),
101
+ np.flip(x_aug[4].numpy(), axis=(1, 3)),
102
+ np.rot90(np.flip(x_aug[5].numpy(), axis=(1, 3)), -1, axes=(2, 3)),
103
+ np.rot90(np.flip(x_aug[6].numpy(), axis=(1, 3)), -2, axes=(2, 3)),
104
+ np.rot90(np.flip(x_aug[7].numpy(), axis=(1, 3)), -3, axes=(2, 3)),
105
+ ]
106
+ return np.mean(x_deaug, 0)
@@ -0,0 +1,20 @@
1
+ """Utils module."""
2
+
3
+
4
+ __all__ = [
5
+ "denormalize",
6
+ "normalize",
7
+ "get_device",
8
+ "check_axes_validity",
9
+ "add_axes",
10
+ "check_tiling_validity",
11
+ "cwd",
12
+ "MetricTracker",
13
+ ]
14
+
15
+
16
+ from .context import cwd
17
+ from .metrics import MetricTracker
18
+ from .normalization import denormalize, normalize
19
+ from .torch_utils import get_device
20
+ from .validators import add_axes, check_axes_validity, check_tiling_validity
@@ -0,0 +1,9 @@
1
+ ...... ...... ........ ........ ....
2
+ -+++----+- -+++--+++- :+++---+++: :+++----- .--:
3
+ .+++ .: +++. .+++. :+++ :+++ :+++ :------. .---:----..:----. :--- :----: :----:.
4
+ .+++ .+++. .+++. :+++ -++= :+++ +=....=+++ :+++-..=+++-..=++= -+++ .+++-..++ +++-..=+.
5
+ .+++ .++++++++++. :++++++++=. :++++++: .+++. :+++ :+++ -+++ -+++ :+++ .+++=.
6
+ .+++ .+++. .+++. :+++ -+++ :+++ :=++==++++. :+++ :+++ -+++ -+++ :+++ .-=+++=:
7
+ .+++ .. .+++. .+++. :+++ :+++ :+++ .+++. .+++. :+++ :+++ -+++ -+++ :+++ .. .. :+++.
8
+ -++=-::-+= .+++. .+++. :+++ :+++ :+++-:::: =++=--=+++. :+++ :+++ -+++ -+++ =++=:-+= =+-:=++=
9
+ ...... ... ... ... ... ........ .... ... ... ... .... .... .... .....
@@ -0,0 +1,65 @@
1
+ """Augmentation module."""
2
+ from typing import Tuple
3
+
4
+ import numpy as np
5
+
6
+
7
+ # TODO: unused?
8
+ def _flip_and_rotate(
9
+ image: np.ndarray, rotate_state: int, flip_state: int
10
+ ) -> np.ndarray:
11
+ """
12
+ Apply the given number of 90 degrees rotations and flip to an array.
13
+
14
+ Parameters
15
+ ----------
16
+ image : np.ndarray
17
+ Array containing single image or patch, 2D or 3D.
18
+ rotate_state : int
19
+ Number of 90 degree rotations to apply.
20
+ flip_state : int
21
+ 0 or 1, whether to flip the array or not.
22
+
23
+ Returns
24
+ -------
25
+ np.ndarray
26
+ Flipped and rotated array.
27
+ """
28
+ rotated = np.rot90(image, k=rotate_state, axes=(-2, -1))
29
+ flipped = np.flip(rotated, axis=-1) if flip_state == 1 else rotated
30
+ return flipped.copy()
31
+
32
+
33
+ def augment_batch(
34
+ patch: np.ndarray,
35
+ original_image: np.ndarray,
36
+ mask: np.ndarray,
37
+ seed: int = 42,
38
+ ) -> Tuple[np.ndarray, ...]:
39
+ """
40
+ Apply augmentation function to patches and masks.
41
+
42
+ Parameters
43
+ ----------
44
+ patch : np.ndarray
45
+ Array containing single image or patch, 2D or 3D with masked pixels.
46
+ original_image : np.ndarray
47
+ Array containing original image or patch, 2D or 3D.
48
+ mask : np.ndarray
49
+ Array containing only masked pixels, 2D or 3D.
50
+ seed : int, optional
51
+ Seed for random number generator, controls the rotation and falipping.
52
+
53
+ Returns
54
+ -------
55
+ Tuple[np.ndarray, ...]
56
+ Tuple of augmented arrays.
57
+ """
58
+ rng = np.random.default_rng(seed=seed)
59
+ rotate_state = rng.integers(0, 4)
60
+ flip_state = rng.integers(0, 2)
61
+ return (
62
+ _flip_and_rotate(patch, rotate_state, flip_state),
63
+ _flip_and_rotate(original_image, rotate_state, flip_state),
64
+ _flip_and_rotate(mask, rotate_state, flip_state),
65
+ )
@@ -0,0 +1,45 @@
1
+ """
2
+ Context submodule.
3
+
4
+ A convenience function to change the working directory in order to save data.
5
+ """
6
+ import os
7
+ from contextlib import contextmanager
8
+ from pathlib import Path
9
+ from typing import Iterator, Union
10
+
11
+
12
+ @contextmanager
13
+ def cwd(path: Union[str, Path]) -> Iterator[None]:
14
+ """
15
+ Change the current working directory to the given path.
16
+
17
+ This method can be used to generate files in a specific directory, once out of the
18
+ context, the working directory is set back to the original one.
19
+
20
+ Parameters
21
+ ----------
22
+ path : Union[str,Path]
23
+ New working directory path.
24
+
25
+ Returns
26
+ -------
27
+ Iterator[None]
28
+ None values.
29
+
30
+ Examples
31
+ --------
32
+ >>> with cwd(path):
33
+ ... pass
34
+ """
35
+ path = Path(path)
36
+
37
+ if not path.exists():
38
+ path.mkdir(parents=True, exist_ok=True)
39
+
40
+ old_pwd = Path(".").absolute()
41
+ os.chdir(path)
42
+ try:
43
+ yield
44
+ finally:
45
+ os.chdir(old_pwd)