careamics 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +6 -1
- careamics/careamist.py +726 -0
- careamics/config/__init__.py +35 -0
- careamics/config/algorithm_model.py +162 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +37 -0
- careamics/config/architectures/custom_model.py +159 -0
- careamics/config/architectures/register_model.py +103 -0
- careamics/config/architectures/unet_model.py +118 -0
- careamics/config/architectures/vae_model.py +42 -0
- careamics/config/callback_model.py +123 -0
- careamics/config/configuration_factory.py +575 -0
- careamics/config/configuration_model.py +600 -0
- careamics/config/data_model.py +502 -0
- careamics/config/inference_model.py +239 -0
- careamics/config/optimizer_models.py +187 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +132 -0
- careamics/config/references/references.py +39 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +26 -0
- careamics/config/support/supported_algorithms.py +20 -0
- careamics/config/support/supported_architectures.py +20 -0
- careamics/config/support/supported_data.py +109 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +27 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +11 -0
- careamics/config/tile_information.py +65 -0
- careamics/config/training_model.py +72 -0
- careamics/config/transformations/__init__.py +15 -0
- careamics/config/transformations/n2v_manipulate_model.py +64 -0
- careamics/config/transformations/normalize_model.py +60 -0
- careamics/config/transformations/transform_model.py +45 -0
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +35 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +101 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +101 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +310 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +295 -0
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +299 -0
- careamics/dataset/patching/random_patching.py +201 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/tiled_patching.py +164 -0
- careamics/dataset/zarr_dataset.py +151 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +12 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/read/zarr.py +60 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +17 -0
- careamics/lightning/callbacks/__init__.py +11 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/lightning_module.py +276 -0
- careamics/lightning/predict_data_module.py +333 -0
- careamics/lightning/train_data_module.py +680 -0
- careamics/losses/__init__.py +5 -0
- careamics/losses/loss_factory.py +49 -0
- careamics/losses/losses.py +98 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +121 -0
- careamics/model_io/bioimage/bioimage_utils.py +52 -0
- careamics/model_io/bioimage/model_description.py +327 -0
- careamics/model_io/bmz_io.py +233 -0
- careamics/model_io/model_io_utils.py +83 -0
- careamics/models/__init__.py +7 -0
- careamics/models/activation.py +37 -0
- careamics/models/layers.py +493 -0
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +52 -0
- careamics/models/unet.py +443 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/prediction_outputs.py +135 -0
- careamics/prediction_utils/stitch_prediction.py +98 -0
- careamics/transforms/__init__.py +20 -0
- careamics/transforms/compose.py +107 -0
- careamics/transforms/n2v_manipulate.py +146 -0
- careamics/transforms/normalize.py +243 -0
- careamics/transforms/pixel_manipulation.py +407 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +101 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +66 -0
- careamics/utils/logging.py +322 -0
- careamics/utils/metrics.py +115 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/torch_utils.py +127 -0
- careamics-0.0.2.dist-info/METADATA +78 -0
- careamics-0.0.2.dist-info/RECORD +140 -0
- {careamics-0.0.1.dist-info → careamics-0.0.2.dist-info}/WHEEL +1 -1
- {careamics-0.0.1.dist-info → careamics-0.0.2.dist-info}/licenses/LICENSE +1 -1
- careamics-0.0.1.dist-info/METADATA +0 -46
- careamics-0.0.1.dist-info/RECORD +0 -6
careamics/models/unet.py
ADDED
|
@@ -0,0 +1,443 @@
|
|
|
1
|
+
"""
|
|
2
|
+
UNet model.
|
|
3
|
+
|
|
4
|
+
A UNet encoder, decoder and complete model.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Any, List, Tuple, Union
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
|
|
12
|
+
from ..config.support import SupportedActivation
|
|
13
|
+
from .activation import get_activation
|
|
14
|
+
from .layers import Conv_Block, MaxBlurPool
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class UnetEncoder(nn.Module):
|
|
18
|
+
"""
|
|
19
|
+
Unet encoder pathway.
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
conv_dim : int
|
|
24
|
+
Number of dimension of the convolution layers, 2 for 2D or 3 for 3D.
|
|
25
|
+
in_channels : int, optional
|
|
26
|
+
Number of input channels, by default 1.
|
|
27
|
+
depth : int, optional
|
|
28
|
+
Number of encoder blocks, by default 3.
|
|
29
|
+
num_channels_init : int, optional
|
|
30
|
+
Number of channels in the first encoder block, by default 64.
|
|
31
|
+
use_batch_norm : bool, optional
|
|
32
|
+
Whether to use batch normalization, by default True.
|
|
33
|
+
dropout : float, optional
|
|
34
|
+
Dropout probability, by default 0.0.
|
|
35
|
+
pool_kernel : int, optional
|
|
36
|
+
Kernel size for the max pooling layers, by default 2.
|
|
37
|
+
n2v2 : bool, optional
|
|
38
|
+
Whether to use N2V2 architecture, by default False.
|
|
39
|
+
groups : int, optional
|
|
40
|
+
Number of blocked connections from input channels to output
|
|
41
|
+
channels, by default 1.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
conv_dim: int,
|
|
47
|
+
in_channels: int = 1,
|
|
48
|
+
depth: int = 3,
|
|
49
|
+
num_channels_init: int = 64,
|
|
50
|
+
use_batch_norm: bool = True,
|
|
51
|
+
dropout: float = 0.0,
|
|
52
|
+
pool_kernel: int = 2,
|
|
53
|
+
n2v2: bool = False,
|
|
54
|
+
groups: int = 1,
|
|
55
|
+
) -> None:
|
|
56
|
+
"""
|
|
57
|
+
Constructor.
|
|
58
|
+
|
|
59
|
+
Parameters
|
|
60
|
+
----------
|
|
61
|
+
conv_dim : int
|
|
62
|
+
Number of dimension of the convolution layers, 2 for 2D or 3 for 3D.
|
|
63
|
+
in_channels : int, optional
|
|
64
|
+
Number of input channels, by default 1.
|
|
65
|
+
depth : int, optional
|
|
66
|
+
Number of encoder blocks, by default 3.
|
|
67
|
+
num_channels_init : int, optional
|
|
68
|
+
Number of channels in the first encoder block, by default 64.
|
|
69
|
+
use_batch_norm : bool, optional
|
|
70
|
+
Whether to use batch normalization, by default True.
|
|
71
|
+
dropout : float, optional
|
|
72
|
+
Dropout probability, by default 0.0.
|
|
73
|
+
pool_kernel : int, optional
|
|
74
|
+
Kernel size for the max pooling layers, by default 2.
|
|
75
|
+
n2v2 : bool, optional
|
|
76
|
+
Whether to use N2V2 architecture, by default False.
|
|
77
|
+
groups : int, optional
|
|
78
|
+
Number of blocked connections from input channels to output
|
|
79
|
+
channels, by default 1.
|
|
80
|
+
"""
|
|
81
|
+
super().__init__()
|
|
82
|
+
|
|
83
|
+
self.pooling = (
|
|
84
|
+
getattr(nn, f"MaxPool{conv_dim}d")(kernel_size=pool_kernel)
|
|
85
|
+
if not n2v2
|
|
86
|
+
else MaxBlurPool(dim=conv_dim, kernel_size=3, max_pool_size=pool_kernel)
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
encoder_blocks = []
|
|
90
|
+
|
|
91
|
+
for n in range(depth):
|
|
92
|
+
out_channels = num_channels_init * (2**n) * groups
|
|
93
|
+
in_channels = in_channels if n == 0 else out_channels // 2
|
|
94
|
+
encoder_blocks.append(
|
|
95
|
+
Conv_Block(
|
|
96
|
+
conv_dim,
|
|
97
|
+
in_channels=in_channels,
|
|
98
|
+
out_channels=out_channels,
|
|
99
|
+
dropout_perc=dropout,
|
|
100
|
+
use_batch_norm=use_batch_norm,
|
|
101
|
+
groups=groups,
|
|
102
|
+
)
|
|
103
|
+
)
|
|
104
|
+
encoder_blocks.append(self.pooling)
|
|
105
|
+
self.encoder_blocks = nn.ModuleList(encoder_blocks)
|
|
106
|
+
|
|
107
|
+
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
|
108
|
+
"""
|
|
109
|
+
Forward pass.
|
|
110
|
+
|
|
111
|
+
Parameters
|
|
112
|
+
----------
|
|
113
|
+
x : torch.Tensor
|
|
114
|
+
Input tensor.
|
|
115
|
+
|
|
116
|
+
Returns
|
|
117
|
+
-------
|
|
118
|
+
List[torch.Tensor]
|
|
119
|
+
Output of each encoder block (skip connections) and final output of the
|
|
120
|
+
encoder.
|
|
121
|
+
"""
|
|
122
|
+
encoder_features = []
|
|
123
|
+
for module in self.encoder_blocks:
|
|
124
|
+
x = module(x)
|
|
125
|
+
if isinstance(module, Conv_Block):
|
|
126
|
+
encoder_features.append(x)
|
|
127
|
+
features = [x, *encoder_features]
|
|
128
|
+
return features
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class UnetDecoder(nn.Module):
|
|
132
|
+
"""
|
|
133
|
+
Unet decoder pathway.
|
|
134
|
+
|
|
135
|
+
Parameters
|
|
136
|
+
----------
|
|
137
|
+
conv_dim : int
|
|
138
|
+
Number of dimension of the convolution layers, 2 for 2D or 3 for 3D.
|
|
139
|
+
depth : int, optional
|
|
140
|
+
Number of decoder blocks, by default 3.
|
|
141
|
+
num_channels_init : int, optional
|
|
142
|
+
Number of channels in the first encoder block, by default 64.
|
|
143
|
+
use_batch_norm : bool, optional
|
|
144
|
+
Whether to use batch normalization, by default True.
|
|
145
|
+
dropout : float, optional
|
|
146
|
+
Dropout probability, by default 0.0.
|
|
147
|
+
n2v2 : bool, optional
|
|
148
|
+
Whether to use N2V2 architecture, by default False.
|
|
149
|
+
groups : int, optional
|
|
150
|
+
Number of blocked connections from input channels to output
|
|
151
|
+
channels, by default 1.
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
def __init__(
|
|
155
|
+
self,
|
|
156
|
+
conv_dim: int,
|
|
157
|
+
depth: int = 3,
|
|
158
|
+
num_channels_init: int = 64,
|
|
159
|
+
use_batch_norm: bool = True,
|
|
160
|
+
dropout: float = 0.0,
|
|
161
|
+
n2v2: bool = False,
|
|
162
|
+
groups: int = 1,
|
|
163
|
+
) -> None:
|
|
164
|
+
"""
|
|
165
|
+
Constructor.
|
|
166
|
+
|
|
167
|
+
Parameters
|
|
168
|
+
----------
|
|
169
|
+
conv_dim : int
|
|
170
|
+
Number of dimension of the convolution layers, 2 for 2D or 3 for 3D.
|
|
171
|
+
depth : int, optional
|
|
172
|
+
Number of decoder blocks, by default 3.
|
|
173
|
+
num_channels_init : int, optional
|
|
174
|
+
Number of channels in the first encoder block, by default 64.
|
|
175
|
+
use_batch_norm : bool, optional
|
|
176
|
+
Whether to use batch normalization, by default True.
|
|
177
|
+
dropout : float, optional
|
|
178
|
+
Dropout probability, by default 0.0.
|
|
179
|
+
n2v2 : bool, optional
|
|
180
|
+
Whether to use N2V2 architecture, by default False.
|
|
181
|
+
groups : int, optional
|
|
182
|
+
Number of blocked connections from input channels to output
|
|
183
|
+
channels, by default 1.
|
|
184
|
+
"""
|
|
185
|
+
super().__init__()
|
|
186
|
+
|
|
187
|
+
upsampling = nn.Upsample(
|
|
188
|
+
scale_factor=2, mode="bilinear" if conv_dim == 2 else "trilinear"
|
|
189
|
+
)
|
|
190
|
+
in_channels = out_channels = num_channels_init * groups * (2 ** (depth - 1))
|
|
191
|
+
|
|
192
|
+
self.n2v2 = n2v2
|
|
193
|
+
self.groups = groups
|
|
194
|
+
|
|
195
|
+
self.bottleneck = Conv_Block(
|
|
196
|
+
conv_dim,
|
|
197
|
+
in_channels=in_channels,
|
|
198
|
+
out_channels=out_channels,
|
|
199
|
+
intermediate_channel_multiplier=2,
|
|
200
|
+
use_batch_norm=use_batch_norm,
|
|
201
|
+
dropout_perc=dropout,
|
|
202
|
+
groups=self.groups,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
decoder_blocks: List[nn.Module] = []
|
|
206
|
+
for n in range(depth):
|
|
207
|
+
decoder_blocks.append(upsampling)
|
|
208
|
+
in_channels = (num_channels_init * 2 ** (depth - n)) * groups
|
|
209
|
+
out_channels = in_channels // 2
|
|
210
|
+
decoder_blocks.append(
|
|
211
|
+
Conv_Block(
|
|
212
|
+
conv_dim,
|
|
213
|
+
in_channels=(
|
|
214
|
+
in_channels + in_channels // 2 if n > 0 else in_channels
|
|
215
|
+
),
|
|
216
|
+
out_channels=out_channels,
|
|
217
|
+
intermediate_channel_multiplier=2,
|
|
218
|
+
dropout_perc=dropout,
|
|
219
|
+
activation="ReLU",
|
|
220
|
+
use_batch_norm=use_batch_norm,
|
|
221
|
+
groups=groups,
|
|
222
|
+
)
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
self.decoder_blocks = nn.ModuleList(decoder_blocks)
|
|
226
|
+
|
|
227
|
+
def forward(self, *features: torch.Tensor) -> torch.Tensor:
|
|
228
|
+
"""
|
|
229
|
+
Forward pass.
|
|
230
|
+
|
|
231
|
+
Parameters
|
|
232
|
+
----------
|
|
233
|
+
*features : List[torch.Tensor]
|
|
234
|
+
List containing the output of each encoder block(skip connections) and final
|
|
235
|
+
output of the encoder.
|
|
236
|
+
|
|
237
|
+
Returns
|
|
238
|
+
-------
|
|
239
|
+
torch.Tensor
|
|
240
|
+
Output of the decoder.
|
|
241
|
+
"""
|
|
242
|
+
x: torch.Tensor = features[0]
|
|
243
|
+
skip_connections: Tuple[torch.Tensor, ...] = features[-1:0:-1]
|
|
244
|
+
|
|
245
|
+
x = self.bottleneck(x)
|
|
246
|
+
|
|
247
|
+
for i, module in enumerate(self.decoder_blocks):
|
|
248
|
+
x = module(x)
|
|
249
|
+
if isinstance(module, nn.Upsample):
|
|
250
|
+
# divide index by 2 because of upsampling layers
|
|
251
|
+
skip_connection: torch.Tensor = skip_connections[i // 2]
|
|
252
|
+
if self.n2v2:
|
|
253
|
+
if x.shape != skip_connections[-1].shape:
|
|
254
|
+
x = self._interleave(x, skip_connection, self.groups)
|
|
255
|
+
else:
|
|
256
|
+
x = self._interleave(x, skip_connection, self.groups)
|
|
257
|
+
return x
|
|
258
|
+
|
|
259
|
+
@staticmethod
|
|
260
|
+
def _interleave(A: torch.Tensor, B: torch.Tensor, groups: int) -> torch.Tensor:
|
|
261
|
+
"""Interleave two tensors.
|
|
262
|
+
|
|
263
|
+
Splits the tensors `A` and `B` into equally sized groups along the channel
|
|
264
|
+
axis (axis=1); then concatenates the groups in alternating order along the
|
|
265
|
+
channel axis, starting with the first group from tensor A.
|
|
266
|
+
|
|
267
|
+
Parameters
|
|
268
|
+
----------
|
|
269
|
+
A : torch.Tensor
|
|
270
|
+
First tensor.
|
|
271
|
+
B : torch.Tensor
|
|
272
|
+
Second tensor.
|
|
273
|
+
groups : int
|
|
274
|
+
The number of groups.
|
|
275
|
+
|
|
276
|
+
Returns
|
|
277
|
+
-------
|
|
278
|
+
torch.Tensor
|
|
279
|
+
Interleaved tensor.
|
|
280
|
+
|
|
281
|
+
Raises
|
|
282
|
+
------
|
|
283
|
+
ValueError:
|
|
284
|
+
If either of `A` or `B`'s channel axis is not divisible by `groups`.
|
|
285
|
+
"""
|
|
286
|
+
if (A.shape[1] % groups != 0) or (B.shape[1] % groups != 0):
|
|
287
|
+
raise ValueError(f"Number of channels not divisible by {groups} groups.")
|
|
288
|
+
|
|
289
|
+
m = A.shape[1] // groups
|
|
290
|
+
n = B.shape[1] // groups
|
|
291
|
+
|
|
292
|
+
A_groups: List[torch.Tensor] = [
|
|
293
|
+
A[:, i * m : (i + 1) * m] for i in range(groups)
|
|
294
|
+
]
|
|
295
|
+
B_groups: List[torch.Tensor] = [
|
|
296
|
+
B[:, i * n : (i + 1) * n] for i in range(groups)
|
|
297
|
+
]
|
|
298
|
+
|
|
299
|
+
interleaved = torch.cat(
|
|
300
|
+
[
|
|
301
|
+
tensor_list[i]
|
|
302
|
+
for i in range(groups)
|
|
303
|
+
for tensor_list in [A_groups, B_groups]
|
|
304
|
+
],
|
|
305
|
+
dim=1,
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
return interleaved
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
class UNet(nn.Module):
|
|
312
|
+
"""
|
|
313
|
+
UNet model.
|
|
314
|
+
|
|
315
|
+
Adapted for PyTorch from:
|
|
316
|
+
https://github.com/juglab/n2v/blob/main/n2v/nets/unet_blocks.py.
|
|
317
|
+
|
|
318
|
+
Parameters
|
|
319
|
+
----------
|
|
320
|
+
conv_dims : int
|
|
321
|
+
Number of dimensions of the convolution layers (2 or 3).
|
|
322
|
+
num_classes : int, optional
|
|
323
|
+
Number of classes to predict, by default 1.
|
|
324
|
+
in_channels : int, optional
|
|
325
|
+
Number of input channels, by default 1.
|
|
326
|
+
depth : int, optional
|
|
327
|
+
Number of downsamplings, by default 3.
|
|
328
|
+
num_channels_init : int, optional
|
|
329
|
+
Number of filters in the first convolution layer, by default 64.
|
|
330
|
+
use_batch_norm : bool, optional
|
|
331
|
+
Whether to use batch normalization, by default True.
|
|
332
|
+
dropout : float, optional
|
|
333
|
+
Dropout probability, by default 0.0.
|
|
334
|
+
pool_kernel : int, optional
|
|
335
|
+
Kernel size of the pooling layers, by default 2.
|
|
336
|
+
final_activation : Optional[Callable], optional
|
|
337
|
+
Activation function to use for the last layer, by default None.
|
|
338
|
+
n2v2 : bool, optional
|
|
339
|
+
Whether to use N2V2 architecture, by default False.
|
|
340
|
+
independent_channels : bool
|
|
341
|
+
Whether to train the channels independently, by default True.
|
|
342
|
+
**kwargs : Any
|
|
343
|
+
Additional keyword arguments, unused.
|
|
344
|
+
"""
|
|
345
|
+
|
|
346
|
+
def __init__(
|
|
347
|
+
self,
|
|
348
|
+
conv_dims: int,
|
|
349
|
+
num_classes: int = 1,
|
|
350
|
+
in_channels: int = 1,
|
|
351
|
+
depth: int = 3,
|
|
352
|
+
num_channels_init: int = 64,
|
|
353
|
+
use_batch_norm: bool = True,
|
|
354
|
+
dropout: float = 0.0,
|
|
355
|
+
pool_kernel: int = 2,
|
|
356
|
+
final_activation: Union[SupportedActivation, str] = SupportedActivation.NONE,
|
|
357
|
+
n2v2: bool = False,
|
|
358
|
+
independent_channels: bool = True,
|
|
359
|
+
**kwargs: Any,
|
|
360
|
+
) -> None:
|
|
361
|
+
"""
|
|
362
|
+
Constructor.
|
|
363
|
+
|
|
364
|
+
Parameters
|
|
365
|
+
----------
|
|
366
|
+
conv_dims : int
|
|
367
|
+
Number of dimensions of the convolution layers (2 or 3).
|
|
368
|
+
num_classes : int, optional
|
|
369
|
+
Number of classes to predict, by default 1.
|
|
370
|
+
in_channels : int, optional
|
|
371
|
+
Number of input channels, by default 1.
|
|
372
|
+
depth : int, optional
|
|
373
|
+
Number of downsamplings, by default 3.
|
|
374
|
+
num_channels_init : int, optional
|
|
375
|
+
Number of filters in the first convolution layer, by default 64.
|
|
376
|
+
use_batch_norm : bool, optional
|
|
377
|
+
Whether to use batch normalization, by default True.
|
|
378
|
+
dropout : float, optional
|
|
379
|
+
Dropout probability, by default 0.0.
|
|
380
|
+
pool_kernel : int, optional
|
|
381
|
+
Kernel size of the pooling layers, by default 2.
|
|
382
|
+
final_activation : Optional[Callable], optional
|
|
383
|
+
Activation function to use for the last layer, by default None.
|
|
384
|
+
n2v2 : bool, optional
|
|
385
|
+
Whether to use N2V2 architecture, by default False.
|
|
386
|
+
independent_channels : bool
|
|
387
|
+
Whether to train parallel independent networks for each channel, by
|
|
388
|
+
default True.
|
|
389
|
+
**kwargs : Any
|
|
390
|
+
Additional keyword arguments, unused.
|
|
391
|
+
"""
|
|
392
|
+
super().__init__()
|
|
393
|
+
|
|
394
|
+
groups = in_channels if independent_channels else 1
|
|
395
|
+
|
|
396
|
+
self.encoder = UnetEncoder(
|
|
397
|
+
conv_dims,
|
|
398
|
+
in_channels=in_channels,
|
|
399
|
+
depth=depth,
|
|
400
|
+
num_channels_init=num_channels_init,
|
|
401
|
+
use_batch_norm=use_batch_norm,
|
|
402
|
+
dropout=dropout,
|
|
403
|
+
pool_kernel=pool_kernel,
|
|
404
|
+
n2v2=n2v2,
|
|
405
|
+
groups=groups,
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
self.decoder = UnetDecoder(
|
|
409
|
+
conv_dims,
|
|
410
|
+
depth=depth,
|
|
411
|
+
num_channels_init=num_channels_init,
|
|
412
|
+
use_batch_norm=use_batch_norm,
|
|
413
|
+
dropout=dropout,
|
|
414
|
+
n2v2=n2v2,
|
|
415
|
+
groups=groups,
|
|
416
|
+
)
|
|
417
|
+
self.final_conv = getattr(nn, f"Conv{conv_dims}d")(
|
|
418
|
+
in_channels=num_channels_init * groups,
|
|
419
|
+
out_channels=num_classes,
|
|
420
|
+
kernel_size=1,
|
|
421
|
+
groups=groups,
|
|
422
|
+
)
|
|
423
|
+
self.final_activation = get_activation(final_activation)
|
|
424
|
+
|
|
425
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
426
|
+
"""
|
|
427
|
+
Forward pass.
|
|
428
|
+
|
|
429
|
+
Parameters
|
|
430
|
+
----------
|
|
431
|
+
x : torch.Tensor
|
|
432
|
+
Input tensor.
|
|
433
|
+
|
|
434
|
+
Returns
|
|
435
|
+
-------
|
|
436
|
+
torch.Tensor
|
|
437
|
+
Output of the model.
|
|
438
|
+
"""
|
|
439
|
+
encoder_features = self.encoder(x)
|
|
440
|
+
x = self.decoder(*encoder_features)
|
|
441
|
+
x = self.final_conv(x)
|
|
442
|
+
x = self.final_activation(x)
|
|
443
|
+
return x
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
"""Package to house various prediction utilies."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"stitch_prediction",
|
|
5
|
+
"stitch_prediction_single",
|
|
6
|
+
"convert_outputs",
|
|
7
|
+
]
|
|
8
|
+
|
|
9
|
+
from .prediction_outputs import convert_outputs
|
|
10
|
+
from .stitch_prediction import stitch_prediction, stitch_prediction_single
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
"""Module containing functions to convert prediction outputs to desired form."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, List, Literal, Tuple, Union, overload
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from numpy.typing import NDArray
|
|
7
|
+
|
|
8
|
+
from ..config.tile_information import TileInformation
|
|
9
|
+
from .stitch_prediction import stitch_prediction
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def convert_outputs(predictions: List[Any], tiled: bool) -> list[NDArray]:
|
|
13
|
+
"""
|
|
14
|
+
Convert the Lightning trainer outputs to the desired form.
|
|
15
|
+
|
|
16
|
+
This method allows stitching back together tiled predictions.
|
|
17
|
+
|
|
18
|
+
Parameters
|
|
19
|
+
----------
|
|
20
|
+
predictions : list
|
|
21
|
+
Predictions that are output from `Trainer.predict`.
|
|
22
|
+
tiled : bool
|
|
23
|
+
Whether the predictions are tiled.
|
|
24
|
+
|
|
25
|
+
Returns
|
|
26
|
+
-------
|
|
27
|
+
list of numpy.ndarray or numpy.ndarray
|
|
28
|
+
List of arrays with the axes SC(Z)YX. If there is only 1 output it will not
|
|
29
|
+
be in a list.
|
|
30
|
+
"""
|
|
31
|
+
if len(predictions) == 0:
|
|
32
|
+
return predictions
|
|
33
|
+
|
|
34
|
+
# this layout is to stop mypy complaining
|
|
35
|
+
if tiled:
|
|
36
|
+
predictions_comb = combine_batches(predictions, tiled)
|
|
37
|
+
predictions_output = stitch_prediction(*predictions_comb)
|
|
38
|
+
else:
|
|
39
|
+
predictions_output = combine_batches(predictions, tiled)
|
|
40
|
+
|
|
41
|
+
return predictions_output
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
# for mypy
|
|
45
|
+
@overload
|
|
46
|
+
def combine_batches( # numpydoc ignore=GL08
|
|
47
|
+
predictions: List[Any], tiled: Literal[True]
|
|
48
|
+
) -> Tuple[List[NDArray], List[TileInformation]]: ...
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
# for mypy
|
|
52
|
+
@overload
|
|
53
|
+
def combine_batches( # numpydoc ignore=GL08
|
|
54
|
+
predictions: List[Any], tiled: Literal[False]
|
|
55
|
+
) -> List[NDArray]: ...
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
# for mypy
|
|
59
|
+
@overload
|
|
60
|
+
def combine_batches( # numpydoc ignore=GL08
|
|
61
|
+
predictions: List[Any], tiled: Union[bool, Literal[True], Literal[False]]
|
|
62
|
+
) -> Union[List[NDArray], Tuple[List[NDArray], List[TileInformation]]]: ...
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def combine_batches(
|
|
66
|
+
predictions: List[Any], tiled: bool
|
|
67
|
+
) -> Union[List[NDArray], Tuple[List[NDArray], List[TileInformation]]]:
|
|
68
|
+
"""
|
|
69
|
+
If predictions are in batches, they will be combined.
|
|
70
|
+
|
|
71
|
+
Parameters
|
|
72
|
+
----------
|
|
73
|
+
predictions : list
|
|
74
|
+
Predictions that are output from `Trainer.predict`.
|
|
75
|
+
tiled : bool
|
|
76
|
+
Whether the predictions are tiled.
|
|
77
|
+
|
|
78
|
+
Returns
|
|
79
|
+
-------
|
|
80
|
+
(list of numpy.ndarray) or tuple of (list of numpy.ndarray, list of TileInformation)
|
|
81
|
+
Combined batches.
|
|
82
|
+
"""
|
|
83
|
+
if tiled:
|
|
84
|
+
return _combine_tiled_batches(predictions)
|
|
85
|
+
else:
|
|
86
|
+
return _combine_array_batches(predictions)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _combine_tiled_batches(
|
|
90
|
+
predictions: List[Tuple[NDArray, List[TileInformation]]]
|
|
91
|
+
) -> Tuple[List[NDArray], List[TileInformation]]:
|
|
92
|
+
"""
|
|
93
|
+
Combine batches from tiled output.
|
|
94
|
+
|
|
95
|
+
Parameters
|
|
96
|
+
----------
|
|
97
|
+
predictions : list of (numpy.ndarray, list of TileInformation)
|
|
98
|
+
Predictions that are output from `Trainer.predict`. For tiled batches, this is
|
|
99
|
+
a list of tuples. The first element of the tuples is the prediction output of
|
|
100
|
+
tiles with dimension (B, C, (Z), Y, X), where B is batch size. The second
|
|
101
|
+
element of the tuples is a list of TileInformation objects of length B.
|
|
102
|
+
|
|
103
|
+
Returns
|
|
104
|
+
-------
|
|
105
|
+
tuple of (list of numpy.ndarray, list of TileInformation)
|
|
106
|
+
Combined batches.
|
|
107
|
+
"""
|
|
108
|
+
# turn list of lists into single list
|
|
109
|
+
tile_infos = [
|
|
110
|
+
tile_info for _, tile_info_list in predictions for tile_info in tile_info_list
|
|
111
|
+
]
|
|
112
|
+
prediction_tiles: List[NDArray] = _combine_array_batches(
|
|
113
|
+
[preds for preds, _ in predictions]
|
|
114
|
+
)
|
|
115
|
+
return prediction_tiles, tile_infos
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def _combine_array_batches(predictions: List[NDArray]) -> List[NDArray]:
|
|
119
|
+
"""
|
|
120
|
+
Combine batches of arrays.
|
|
121
|
+
|
|
122
|
+
Parameters
|
|
123
|
+
----------
|
|
124
|
+
predictions : list
|
|
125
|
+
Prediction arrays that are output from `Trainer.predict`. A list of arrays that
|
|
126
|
+
have dimensions (B, C, (Z), Y, X), where B is batch size.
|
|
127
|
+
|
|
128
|
+
Returns
|
|
129
|
+
-------
|
|
130
|
+
list of numpy.ndarray
|
|
131
|
+
A list of arrays with dimensions (1, C, (Z), Y, X).
|
|
132
|
+
"""
|
|
133
|
+
prediction_concat: NDArray = np.concatenate(predictions, axis=0)
|
|
134
|
+
prediction_split = np.split(prediction_concat, prediction_concat.shape[0], axis=0)
|
|
135
|
+
return prediction_split
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
"""Prediction utility functions."""
|
|
2
|
+
|
|
3
|
+
import builtins
|
|
4
|
+
from typing import List, Union
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from numpy.typing import NDArray
|
|
8
|
+
|
|
9
|
+
from careamics.config.tile_information import TileInformation
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# TODO: why not allow input and output of torch.tensor ?
|
|
13
|
+
def stitch_prediction(
|
|
14
|
+
tiles: List[np.ndarray],
|
|
15
|
+
tile_infos: List[TileInformation],
|
|
16
|
+
) -> List[np.ndarray]:
|
|
17
|
+
"""
|
|
18
|
+
Stitch tiles back together to form a full image(s).
|
|
19
|
+
|
|
20
|
+
Tiles are of dimensions SC(Z)YX, where C is the number of channels and can be a
|
|
21
|
+
singleton dimension.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
tiles : list of numpy.ndarray
|
|
26
|
+
Cropped tiles and their respective stitching coordinates. Can contain tiles
|
|
27
|
+
from multiple images.
|
|
28
|
+
tile_infos : list of TileInformation
|
|
29
|
+
List of information and coordinates obtained from
|
|
30
|
+
`dataset.tiled_patching.extract_tiles`.
|
|
31
|
+
|
|
32
|
+
Returns
|
|
33
|
+
-------
|
|
34
|
+
list of numpy.ndarray
|
|
35
|
+
Full image(s).
|
|
36
|
+
"""
|
|
37
|
+
# Find where to split the lists so that only info from one image is contained.
|
|
38
|
+
# Do this by locating the last tiles of each image.
|
|
39
|
+
last_tiles = [tile_info.last_tile for tile_info in tile_infos]
|
|
40
|
+
last_tile_position = np.where(last_tiles)[0]
|
|
41
|
+
image_slices = [
|
|
42
|
+
slice(
|
|
43
|
+
None if i == 0 else last_tile_position[i - 1] + 1, last_tile_position[i] + 1
|
|
44
|
+
)
|
|
45
|
+
for i in range(len(last_tile_position))
|
|
46
|
+
]
|
|
47
|
+
image_predictions = []
|
|
48
|
+
# slice the lists and apply stitch_prediction_single to each in turn.
|
|
49
|
+
for image_slice in image_slices:
|
|
50
|
+
image_predictions.append(
|
|
51
|
+
stitch_prediction_single(tiles[image_slice], tile_infos[image_slice])
|
|
52
|
+
)
|
|
53
|
+
return image_predictions
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def stitch_prediction_single(
|
|
57
|
+
tiles: List[NDArray],
|
|
58
|
+
tile_infos: List[TileInformation],
|
|
59
|
+
) -> NDArray:
|
|
60
|
+
"""
|
|
61
|
+
Stitch tiles back together to form a full image.
|
|
62
|
+
|
|
63
|
+
Tiles are of dimensions SC(Z)YX, where C is the number of channels and can be a
|
|
64
|
+
singleton dimension.
|
|
65
|
+
|
|
66
|
+
Parameters
|
|
67
|
+
----------
|
|
68
|
+
tiles : list of numpy.ndarray
|
|
69
|
+
Cropped tiles and their respective stitching coordinates.
|
|
70
|
+
tile_infos : list of TileInformation
|
|
71
|
+
List of information and coordinates obtained from
|
|
72
|
+
`dataset.tiled_patching.extract_tiles`.
|
|
73
|
+
|
|
74
|
+
Returns
|
|
75
|
+
-------
|
|
76
|
+
numpy.ndarray
|
|
77
|
+
Full image, with dimensions SC(Z)YX.
|
|
78
|
+
"""
|
|
79
|
+
# retrieve whole array size
|
|
80
|
+
input_shape = (1, *tile_infos[0].array_shape) # add S dim
|
|
81
|
+
predicted_image = np.zeros(input_shape, dtype=np.float32)
|
|
82
|
+
|
|
83
|
+
for tile, tile_info in zip(tiles, tile_infos):
|
|
84
|
+
|
|
85
|
+
# Compute coordinates for cropping predicted tile
|
|
86
|
+
crop_slices: tuple[Union[builtins.ellipsis, slice], ...] = (
|
|
87
|
+
...,
|
|
88
|
+
*[slice(c[0], c[1]) for c in tile_info.overlap_crop_coords],
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
# Crop predited tile according to overlap coordinates
|
|
92
|
+
cropped_tile = tile[crop_slices]
|
|
93
|
+
|
|
94
|
+
# Insert cropped tile into predicted image using stitch coordinates
|
|
95
|
+
image_slices = (..., *[slice(c[0], c[1]) for c in tile_info.stitch_coords])
|
|
96
|
+
predicted_image[image_slices] = cropped_tile.astype(np.float32)
|
|
97
|
+
|
|
98
|
+
return predicted_image
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""Transforms that are used to augment the data."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"get_all_transforms",
|
|
5
|
+
"N2VManipulate",
|
|
6
|
+
"XYFlip",
|
|
7
|
+
"XYRandomRotate90",
|
|
8
|
+
"ImageRestorationTTA",
|
|
9
|
+
"Denormalize",
|
|
10
|
+
"Normalize",
|
|
11
|
+
"Compose",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
from .compose import Compose, get_all_transforms
|
|
16
|
+
from .n2v_manipulate import N2VManipulate
|
|
17
|
+
from .normalize import Denormalize, Normalize
|
|
18
|
+
from .tta import ImageRestorationTTA
|
|
19
|
+
from .xy_flip import XYFlip
|
|
20
|
+
from .xy_random_rotate90 import XYRandomRotate90
|