careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +6 -1
- careamics/careamist.py +729 -0
- careamics/config/__init__.py +39 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +37 -0
- careamics/config/architectures/custom_model.py +162 -0
- careamics/config/architectures/lvae_model.py +174 -0
- careamics/config/architectures/register_model.py +103 -0
- careamics/config/architectures/unet_model.py +118 -0
- careamics/config/callback_model.py +123 -0
- careamics/config/configuration_factory.py +583 -0
- careamics/config/configuration_model.py +604 -0
- careamics/config/data_model.py +527 -0
- careamics/config/fcn_algorithm_model.py +147 -0
- careamics/config/inference_model.py +239 -0
- careamics/config/likelihood_model.py +43 -0
- careamics/config/nm_model.py +101 -0
- careamics/config/optimizer_models.py +187 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +132 -0
- careamics/config/references/references.py +39 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +33 -0
- careamics/config/support/supported_architectures.py +17 -0
- careamics/config/support/supported_data.py +109 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +29 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +11 -0
- careamics/config/tile_information.py +65 -0
- careamics/config/training_model.py +72 -0
- careamics/config/transformations/__init__.py +15 -0
- careamics/config/transformations/n2v_manipulate_model.py +64 -0
- careamics/config/transformations/normalize_model.py +60 -0
- careamics/config/transformations/transform_model.py +45 -0
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +35 -0
- careamics/config/vae_algorithm_model.py +171 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +101 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +101 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +310 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +295 -0
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +299 -0
- careamics/dataset/patching/random_patching.py +201 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
- careamics/dataset/tiling/tiled_patching.py +164 -0
- careamics/dataset/zarr_dataset.py +151 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +12 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/read/zarr.py +60 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +18 -0
- careamics/lightning/callbacks/__init__.py +11 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/lightning_module.py +632 -0
- careamics/lightning/predict_data_module.py +333 -0
- careamics/lightning/train_data_module.py +680 -0
- careamics/losses/__init__.py +15 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +98 -0
- careamics/losses/loss_factory.py +155 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +445 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/dataset/__init__.py +0 -0
- careamics/lvae_training/dataset/data_utils.py +701 -0
- careamics/lvae_training/dataset/lc_dataset.py +259 -0
- careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
- careamics/lvae_training/dataset/vae_data_config.py +179 -0
- careamics/lvae_training/dataset/vae_dataset.py +1054 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +121 -0
- careamics/model_io/bioimage/bioimage_utils.py +52 -0
- careamics/model_io/bioimage/model_description.py +327 -0
- careamics/model_io/bmz_io.py +246 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +39 -0
- careamics/models/layers.py +493 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +364 -0
- careamics/models/lvae/lvae.py +901 -0
- careamics/models/lvae/noise_models.py +541 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +67 -0
- careamics/models/unet.py +443 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +135 -0
- careamics/prediction_utils/stitch_prediction.py +112 -0
- careamics/transforms/__init__.py +20 -0
- careamics/transforms/compose.py +107 -0
- careamics/transforms/n2v_manipulate.py +146 -0
- careamics/transforms/normalize.py +243 -0
- careamics/transforms/pixel_manipulation.py +407 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +101 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +66 -0
- careamics/utils/logging.py +322 -0
- careamics/utils/metrics.py +188 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/torch_utils.py +127 -0
- careamics-0.0.3.dist-info/METADATA +78 -0
- careamics-0.0.3.dist-info/RECORD +154 -0
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
- careamics-0.0.1.dist-info/METADATA +0 -46
- careamics-0.0.1.dist-info/RECORD +0 -6
careamics/models/unet.py
ADDED
|
@@ -0,0 +1,443 @@
|
|
|
1
|
+
"""
|
|
2
|
+
UNet model.
|
|
3
|
+
|
|
4
|
+
A UNet encoder, decoder and complete model.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Any, List, Tuple, Union
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
|
|
12
|
+
from ..config.support import SupportedActivation
|
|
13
|
+
from .activation import get_activation
|
|
14
|
+
from .layers import Conv_Block, MaxBlurPool
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class UnetEncoder(nn.Module):
|
|
18
|
+
"""
|
|
19
|
+
Unet encoder pathway.
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
conv_dim : int
|
|
24
|
+
Number of dimension of the convolution layers, 2 for 2D or 3 for 3D.
|
|
25
|
+
in_channels : int, optional
|
|
26
|
+
Number of input channels, by default 1.
|
|
27
|
+
depth : int, optional
|
|
28
|
+
Number of encoder blocks, by default 3.
|
|
29
|
+
num_channels_init : int, optional
|
|
30
|
+
Number of channels in the first encoder block, by default 64.
|
|
31
|
+
use_batch_norm : bool, optional
|
|
32
|
+
Whether to use batch normalization, by default True.
|
|
33
|
+
dropout : float, optional
|
|
34
|
+
Dropout probability, by default 0.0.
|
|
35
|
+
pool_kernel : int, optional
|
|
36
|
+
Kernel size for the max pooling layers, by default 2.
|
|
37
|
+
n2v2 : bool, optional
|
|
38
|
+
Whether to use N2V2 architecture, by default False.
|
|
39
|
+
groups : int, optional
|
|
40
|
+
Number of blocked connections from input channels to output
|
|
41
|
+
channels, by default 1.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
conv_dim: int,
|
|
47
|
+
in_channels: int = 1,
|
|
48
|
+
depth: int = 3,
|
|
49
|
+
num_channels_init: int = 64,
|
|
50
|
+
use_batch_norm: bool = True,
|
|
51
|
+
dropout: float = 0.0,
|
|
52
|
+
pool_kernel: int = 2,
|
|
53
|
+
n2v2: bool = False,
|
|
54
|
+
groups: int = 1,
|
|
55
|
+
) -> None:
|
|
56
|
+
"""
|
|
57
|
+
Constructor.
|
|
58
|
+
|
|
59
|
+
Parameters
|
|
60
|
+
----------
|
|
61
|
+
conv_dim : int
|
|
62
|
+
Number of dimension of the convolution layers, 2 for 2D or 3 for 3D.
|
|
63
|
+
in_channels : int, optional
|
|
64
|
+
Number of input channels, by default 1.
|
|
65
|
+
depth : int, optional
|
|
66
|
+
Number of encoder blocks, by default 3.
|
|
67
|
+
num_channels_init : int, optional
|
|
68
|
+
Number of channels in the first encoder block, by default 64.
|
|
69
|
+
use_batch_norm : bool, optional
|
|
70
|
+
Whether to use batch normalization, by default True.
|
|
71
|
+
dropout : float, optional
|
|
72
|
+
Dropout probability, by default 0.0.
|
|
73
|
+
pool_kernel : int, optional
|
|
74
|
+
Kernel size for the max pooling layers, by default 2.
|
|
75
|
+
n2v2 : bool, optional
|
|
76
|
+
Whether to use N2V2 architecture, by default False.
|
|
77
|
+
groups : int, optional
|
|
78
|
+
Number of blocked connections from input channels to output
|
|
79
|
+
channels, by default 1.
|
|
80
|
+
"""
|
|
81
|
+
super().__init__()
|
|
82
|
+
|
|
83
|
+
self.pooling = (
|
|
84
|
+
getattr(nn, f"MaxPool{conv_dim}d")(kernel_size=pool_kernel)
|
|
85
|
+
if not n2v2
|
|
86
|
+
else MaxBlurPool(dim=conv_dim, kernel_size=3, max_pool_size=pool_kernel)
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
encoder_blocks = []
|
|
90
|
+
|
|
91
|
+
for n in range(depth):
|
|
92
|
+
out_channels = num_channels_init * (2**n) * groups
|
|
93
|
+
in_channels = in_channels if n == 0 else out_channels // 2
|
|
94
|
+
encoder_blocks.append(
|
|
95
|
+
Conv_Block(
|
|
96
|
+
conv_dim,
|
|
97
|
+
in_channels=in_channels,
|
|
98
|
+
out_channels=out_channels,
|
|
99
|
+
dropout_perc=dropout,
|
|
100
|
+
use_batch_norm=use_batch_norm,
|
|
101
|
+
groups=groups,
|
|
102
|
+
)
|
|
103
|
+
)
|
|
104
|
+
encoder_blocks.append(self.pooling)
|
|
105
|
+
self.encoder_blocks = nn.ModuleList(encoder_blocks)
|
|
106
|
+
|
|
107
|
+
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
|
108
|
+
"""
|
|
109
|
+
Forward pass.
|
|
110
|
+
|
|
111
|
+
Parameters
|
|
112
|
+
----------
|
|
113
|
+
x : torch.Tensor
|
|
114
|
+
Input tensor.
|
|
115
|
+
|
|
116
|
+
Returns
|
|
117
|
+
-------
|
|
118
|
+
List[torch.Tensor]
|
|
119
|
+
Output of each encoder block (skip connections) and final output of the
|
|
120
|
+
encoder.
|
|
121
|
+
"""
|
|
122
|
+
encoder_features = []
|
|
123
|
+
for module in self.encoder_blocks:
|
|
124
|
+
x = module(x)
|
|
125
|
+
if isinstance(module, Conv_Block):
|
|
126
|
+
encoder_features.append(x)
|
|
127
|
+
features = [x, *encoder_features]
|
|
128
|
+
return features
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class UnetDecoder(nn.Module):
|
|
132
|
+
"""
|
|
133
|
+
Unet decoder pathway.
|
|
134
|
+
|
|
135
|
+
Parameters
|
|
136
|
+
----------
|
|
137
|
+
conv_dim : int
|
|
138
|
+
Number of dimension of the convolution layers, 2 for 2D or 3 for 3D.
|
|
139
|
+
depth : int, optional
|
|
140
|
+
Number of decoder blocks, by default 3.
|
|
141
|
+
num_channels_init : int, optional
|
|
142
|
+
Number of channels in the first encoder block, by default 64.
|
|
143
|
+
use_batch_norm : bool, optional
|
|
144
|
+
Whether to use batch normalization, by default True.
|
|
145
|
+
dropout : float, optional
|
|
146
|
+
Dropout probability, by default 0.0.
|
|
147
|
+
n2v2 : bool, optional
|
|
148
|
+
Whether to use N2V2 architecture, by default False.
|
|
149
|
+
groups : int, optional
|
|
150
|
+
Number of blocked connections from input channels to output
|
|
151
|
+
channels, by default 1.
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
def __init__(
|
|
155
|
+
self,
|
|
156
|
+
conv_dim: int,
|
|
157
|
+
depth: int = 3,
|
|
158
|
+
num_channels_init: int = 64,
|
|
159
|
+
use_batch_norm: bool = True,
|
|
160
|
+
dropout: float = 0.0,
|
|
161
|
+
n2v2: bool = False,
|
|
162
|
+
groups: int = 1,
|
|
163
|
+
) -> None:
|
|
164
|
+
"""
|
|
165
|
+
Constructor.
|
|
166
|
+
|
|
167
|
+
Parameters
|
|
168
|
+
----------
|
|
169
|
+
conv_dim : int
|
|
170
|
+
Number of dimension of the convolution layers, 2 for 2D or 3 for 3D.
|
|
171
|
+
depth : int, optional
|
|
172
|
+
Number of decoder blocks, by default 3.
|
|
173
|
+
num_channels_init : int, optional
|
|
174
|
+
Number of channels in the first encoder block, by default 64.
|
|
175
|
+
use_batch_norm : bool, optional
|
|
176
|
+
Whether to use batch normalization, by default True.
|
|
177
|
+
dropout : float, optional
|
|
178
|
+
Dropout probability, by default 0.0.
|
|
179
|
+
n2v2 : bool, optional
|
|
180
|
+
Whether to use N2V2 architecture, by default False.
|
|
181
|
+
groups : int, optional
|
|
182
|
+
Number of blocked connections from input channels to output
|
|
183
|
+
channels, by default 1.
|
|
184
|
+
"""
|
|
185
|
+
super().__init__()
|
|
186
|
+
|
|
187
|
+
upsampling = nn.Upsample(
|
|
188
|
+
scale_factor=2, mode="bilinear" if conv_dim == 2 else "trilinear"
|
|
189
|
+
)
|
|
190
|
+
in_channels = out_channels = num_channels_init * groups * (2 ** (depth - 1))
|
|
191
|
+
|
|
192
|
+
self.n2v2 = n2v2
|
|
193
|
+
self.groups = groups
|
|
194
|
+
|
|
195
|
+
self.bottleneck = Conv_Block(
|
|
196
|
+
conv_dim,
|
|
197
|
+
in_channels=in_channels,
|
|
198
|
+
out_channels=out_channels,
|
|
199
|
+
intermediate_channel_multiplier=2,
|
|
200
|
+
use_batch_norm=use_batch_norm,
|
|
201
|
+
dropout_perc=dropout,
|
|
202
|
+
groups=self.groups,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
decoder_blocks: List[nn.Module] = []
|
|
206
|
+
for n in range(depth):
|
|
207
|
+
decoder_blocks.append(upsampling)
|
|
208
|
+
in_channels = (num_channels_init * 2 ** (depth - n)) * groups
|
|
209
|
+
out_channels = in_channels // 2
|
|
210
|
+
decoder_blocks.append(
|
|
211
|
+
Conv_Block(
|
|
212
|
+
conv_dim,
|
|
213
|
+
in_channels=(
|
|
214
|
+
in_channels + in_channels // 2 if n > 0 else in_channels
|
|
215
|
+
),
|
|
216
|
+
out_channels=out_channels,
|
|
217
|
+
intermediate_channel_multiplier=2,
|
|
218
|
+
dropout_perc=dropout,
|
|
219
|
+
activation="ReLU",
|
|
220
|
+
use_batch_norm=use_batch_norm,
|
|
221
|
+
groups=groups,
|
|
222
|
+
)
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
self.decoder_blocks = nn.ModuleList(decoder_blocks)
|
|
226
|
+
|
|
227
|
+
def forward(self, *features: torch.Tensor) -> torch.Tensor:
|
|
228
|
+
"""
|
|
229
|
+
Forward pass.
|
|
230
|
+
|
|
231
|
+
Parameters
|
|
232
|
+
----------
|
|
233
|
+
*features : List[torch.Tensor]
|
|
234
|
+
List containing the output of each encoder block(skip connections) and final
|
|
235
|
+
output of the encoder.
|
|
236
|
+
|
|
237
|
+
Returns
|
|
238
|
+
-------
|
|
239
|
+
torch.Tensor
|
|
240
|
+
Output of the decoder.
|
|
241
|
+
"""
|
|
242
|
+
x: torch.Tensor = features[0]
|
|
243
|
+
skip_connections: Tuple[torch.Tensor, ...] = features[-1:0:-1]
|
|
244
|
+
|
|
245
|
+
x = self.bottleneck(x)
|
|
246
|
+
|
|
247
|
+
for i, module in enumerate(self.decoder_blocks):
|
|
248
|
+
x = module(x)
|
|
249
|
+
if isinstance(module, nn.Upsample):
|
|
250
|
+
# divide index by 2 because of upsampling layers
|
|
251
|
+
skip_connection: torch.Tensor = skip_connections[i // 2]
|
|
252
|
+
if self.n2v2:
|
|
253
|
+
if x.shape != skip_connections[-1].shape:
|
|
254
|
+
x = self._interleave(x, skip_connection, self.groups)
|
|
255
|
+
else:
|
|
256
|
+
x = self._interleave(x, skip_connection, self.groups)
|
|
257
|
+
return x
|
|
258
|
+
|
|
259
|
+
@staticmethod
|
|
260
|
+
def _interleave(A: torch.Tensor, B: torch.Tensor, groups: int) -> torch.Tensor:
|
|
261
|
+
"""Interleave two tensors.
|
|
262
|
+
|
|
263
|
+
Splits the tensors `A` and `B` into equally sized groups along the channel
|
|
264
|
+
axis (axis=1); then concatenates the groups in alternating order along the
|
|
265
|
+
channel axis, starting with the first group from tensor A.
|
|
266
|
+
|
|
267
|
+
Parameters
|
|
268
|
+
----------
|
|
269
|
+
A : torch.Tensor
|
|
270
|
+
First tensor.
|
|
271
|
+
B : torch.Tensor
|
|
272
|
+
Second tensor.
|
|
273
|
+
groups : int
|
|
274
|
+
The number of groups.
|
|
275
|
+
|
|
276
|
+
Returns
|
|
277
|
+
-------
|
|
278
|
+
torch.Tensor
|
|
279
|
+
Interleaved tensor.
|
|
280
|
+
|
|
281
|
+
Raises
|
|
282
|
+
------
|
|
283
|
+
ValueError:
|
|
284
|
+
If either of `A` or `B`'s channel axis is not divisible by `groups`.
|
|
285
|
+
"""
|
|
286
|
+
if (A.shape[1] % groups != 0) or (B.shape[1] % groups != 0):
|
|
287
|
+
raise ValueError(f"Number of channels not divisible by {groups} groups.")
|
|
288
|
+
|
|
289
|
+
m = A.shape[1] // groups
|
|
290
|
+
n = B.shape[1] // groups
|
|
291
|
+
|
|
292
|
+
A_groups: List[torch.Tensor] = [
|
|
293
|
+
A[:, i * m : (i + 1) * m] for i in range(groups)
|
|
294
|
+
]
|
|
295
|
+
B_groups: List[torch.Tensor] = [
|
|
296
|
+
B[:, i * n : (i + 1) * n] for i in range(groups)
|
|
297
|
+
]
|
|
298
|
+
|
|
299
|
+
interleaved = torch.cat(
|
|
300
|
+
[
|
|
301
|
+
tensor_list[i]
|
|
302
|
+
for i in range(groups)
|
|
303
|
+
for tensor_list in [A_groups, B_groups]
|
|
304
|
+
],
|
|
305
|
+
dim=1,
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
return interleaved
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
class UNet(nn.Module):
|
|
312
|
+
"""
|
|
313
|
+
UNet model.
|
|
314
|
+
|
|
315
|
+
Adapted for PyTorch from:
|
|
316
|
+
https://github.com/juglab/n2v/blob/main/n2v/nets/unet_blocks.py.
|
|
317
|
+
|
|
318
|
+
Parameters
|
|
319
|
+
----------
|
|
320
|
+
conv_dims : int
|
|
321
|
+
Number of dimensions of the convolution layers (2 or 3).
|
|
322
|
+
num_classes : int, optional
|
|
323
|
+
Number of classes to predict, by default 1.
|
|
324
|
+
in_channels : int, optional
|
|
325
|
+
Number of input channels, by default 1.
|
|
326
|
+
depth : int, optional
|
|
327
|
+
Number of downsamplings, by default 3.
|
|
328
|
+
num_channels_init : int, optional
|
|
329
|
+
Number of filters in the first convolution layer, by default 64.
|
|
330
|
+
use_batch_norm : bool, optional
|
|
331
|
+
Whether to use batch normalization, by default True.
|
|
332
|
+
dropout : float, optional
|
|
333
|
+
Dropout probability, by default 0.0.
|
|
334
|
+
pool_kernel : int, optional
|
|
335
|
+
Kernel size of the pooling layers, by default 2.
|
|
336
|
+
final_activation : Optional[Callable], optional
|
|
337
|
+
Activation function to use for the last layer, by default None.
|
|
338
|
+
n2v2 : bool, optional
|
|
339
|
+
Whether to use N2V2 architecture, by default False.
|
|
340
|
+
independent_channels : bool
|
|
341
|
+
Whether to train the channels independently, by default True.
|
|
342
|
+
**kwargs : Any
|
|
343
|
+
Additional keyword arguments, unused.
|
|
344
|
+
"""
|
|
345
|
+
|
|
346
|
+
def __init__(
|
|
347
|
+
self,
|
|
348
|
+
conv_dims: int,
|
|
349
|
+
num_classes: int = 1,
|
|
350
|
+
in_channels: int = 1,
|
|
351
|
+
depth: int = 3,
|
|
352
|
+
num_channels_init: int = 64,
|
|
353
|
+
use_batch_norm: bool = True,
|
|
354
|
+
dropout: float = 0.0,
|
|
355
|
+
pool_kernel: int = 2,
|
|
356
|
+
final_activation: Union[SupportedActivation, str] = SupportedActivation.NONE,
|
|
357
|
+
n2v2: bool = False,
|
|
358
|
+
independent_channels: bool = True,
|
|
359
|
+
**kwargs: Any,
|
|
360
|
+
) -> None:
|
|
361
|
+
"""
|
|
362
|
+
Constructor.
|
|
363
|
+
|
|
364
|
+
Parameters
|
|
365
|
+
----------
|
|
366
|
+
conv_dims : int
|
|
367
|
+
Number of dimensions of the convolution layers (2 or 3).
|
|
368
|
+
num_classes : int, optional
|
|
369
|
+
Number of classes to predict, by default 1.
|
|
370
|
+
in_channels : int, optional
|
|
371
|
+
Number of input channels, by default 1.
|
|
372
|
+
depth : int, optional
|
|
373
|
+
Number of downsamplings, by default 3.
|
|
374
|
+
num_channels_init : int, optional
|
|
375
|
+
Number of filters in the first convolution layer, by default 64.
|
|
376
|
+
use_batch_norm : bool, optional
|
|
377
|
+
Whether to use batch normalization, by default True.
|
|
378
|
+
dropout : float, optional
|
|
379
|
+
Dropout probability, by default 0.0.
|
|
380
|
+
pool_kernel : int, optional
|
|
381
|
+
Kernel size of the pooling layers, by default 2.
|
|
382
|
+
final_activation : Optional[Callable], optional
|
|
383
|
+
Activation function to use for the last layer, by default None.
|
|
384
|
+
n2v2 : bool, optional
|
|
385
|
+
Whether to use N2V2 architecture, by default False.
|
|
386
|
+
independent_channels : bool
|
|
387
|
+
Whether to train parallel independent networks for each channel, by
|
|
388
|
+
default True.
|
|
389
|
+
**kwargs : Any
|
|
390
|
+
Additional keyword arguments, unused.
|
|
391
|
+
"""
|
|
392
|
+
super().__init__()
|
|
393
|
+
|
|
394
|
+
groups = in_channels if independent_channels else 1
|
|
395
|
+
|
|
396
|
+
self.encoder = UnetEncoder(
|
|
397
|
+
conv_dims,
|
|
398
|
+
in_channels=in_channels,
|
|
399
|
+
depth=depth,
|
|
400
|
+
num_channels_init=num_channels_init,
|
|
401
|
+
use_batch_norm=use_batch_norm,
|
|
402
|
+
dropout=dropout,
|
|
403
|
+
pool_kernel=pool_kernel,
|
|
404
|
+
n2v2=n2v2,
|
|
405
|
+
groups=groups,
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
self.decoder = UnetDecoder(
|
|
409
|
+
conv_dims,
|
|
410
|
+
depth=depth,
|
|
411
|
+
num_channels_init=num_channels_init,
|
|
412
|
+
use_batch_norm=use_batch_norm,
|
|
413
|
+
dropout=dropout,
|
|
414
|
+
n2v2=n2v2,
|
|
415
|
+
groups=groups,
|
|
416
|
+
)
|
|
417
|
+
self.final_conv = getattr(nn, f"Conv{conv_dims}d")(
|
|
418
|
+
in_channels=num_channels_init * groups,
|
|
419
|
+
out_channels=num_classes,
|
|
420
|
+
kernel_size=1,
|
|
421
|
+
groups=groups,
|
|
422
|
+
)
|
|
423
|
+
self.final_activation = get_activation(final_activation)
|
|
424
|
+
|
|
425
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
426
|
+
"""
|
|
427
|
+
Forward pass.
|
|
428
|
+
|
|
429
|
+
Parameters
|
|
430
|
+
----------
|
|
431
|
+
x : torch.Tensor
|
|
432
|
+
Input tensor.
|
|
433
|
+
|
|
434
|
+
Returns
|
|
435
|
+
-------
|
|
436
|
+
torch.Tensor
|
|
437
|
+
Output of the model.
|
|
438
|
+
"""
|
|
439
|
+
encoder_features = self.encoder(x)
|
|
440
|
+
x = self.decoder(*encoder_features)
|
|
441
|
+
x = self.final_conv(x)
|
|
442
|
+
x = self.final_activation(x)
|
|
443
|
+
return x
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
"""Package to house various prediction utilies."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"stitch_prediction",
|
|
5
|
+
"stitch_prediction_single",
|
|
6
|
+
"convert_outputs",
|
|
7
|
+
]
|
|
8
|
+
|
|
9
|
+
from .prediction_outputs import convert_outputs
|
|
10
|
+
from .stitch_prediction import stitch_prediction, stitch_prediction_single
|
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
"""Module containing pytorch implementations for obtaining predictions from an LVAE."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Optional
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from careamics.models.lvae import LadderVAE as LVAE
|
|
8
|
+
from careamics.models.lvae.likelihoods import LikelihoodModule
|
|
9
|
+
|
|
10
|
+
# TODO: convert these functions to lightning module `predict_step`
|
|
11
|
+
# -> mmse_count will have to be an instance attribute?
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
# This function is needed because the output of the datasets (input here) can include
|
|
15
|
+
# auxillary items, such as the TileInformation. This function allows for easier reuse
|
|
16
|
+
# between lvae_predict_single_sample and lvae_predict_mmse.
|
|
17
|
+
def lvae_predict_single_sample(
|
|
18
|
+
model: LVAE,
|
|
19
|
+
likelihood_obj: LikelihoodModule,
|
|
20
|
+
input: torch.Tensor,
|
|
21
|
+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
22
|
+
"""
|
|
23
|
+
Generate a single sample prediction from an LVAE model, for a given input.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
model : LVAE
|
|
28
|
+
Trained LVAE model.
|
|
29
|
+
likelihood_obj : LikelihoodModule
|
|
30
|
+
Instance of a likelihood class.
|
|
31
|
+
input : torch.tensor
|
|
32
|
+
Input to generate prediction for. Expected shape is (S, C, Y, X).
|
|
33
|
+
|
|
34
|
+
Returns
|
|
35
|
+
-------
|
|
36
|
+
tuple of (torch.tensor, optional torch.tensor)
|
|
37
|
+
The first element is the sample prediction, and the second element is the
|
|
38
|
+
log-variance. The log-variance will be None if `model.predict_logvar is None`.
|
|
39
|
+
"""
|
|
40
|
+
model.eval() # Not in original predict code: effects batch_norm and dropout layers
|
|
41
|
+
with torch.no_grad():
|
|
42
|
+
output: torch.Tensor
|
|
43
|
+
output, _ = model(input) # 2nd item is top-down data dict
|
|
44
|
+
|
|
45
|
+
# presently, get_mean_lv just splits the output in 2 if predict_logvar=True,
|
|
46
|
+
# optionally clips the logvavr if logvar_lowerbound is not None
|
|
47
|
+
# TODO: consider refactoring to remove use of the likelihood object
|
|
48
|
+
sample_prediction, log_var = likelihood_obj.get_mean_lv(output)
|
|
49
|
+
|
|
50
|
+
# TODO: output denormalization using target stats that will be saved in data config
|
|
51
|
+
# -> Don't think we need this, saw it in a random bit of code somewhere.
|
|
52
|
+
|
|
53
|
+
return sample_prediction, log_var
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def lvae_predict_tiled_batch(
|
|
57
|
+
model: LVAE,
|
|
58
|
+
likelihood_obj: LikelihoodModule,
|
|
59
|
+
input: tuple[Any],
|
|
60
|
+
) -> tuple[tuple[Any], Optional[tuple[Any]]]:
|
|
61
|
+
# TODO: fix docstring return types, ... too many output options
|
|
62
|
+
"""
|
|
63
|
+
Generate a single sample prediction from an LVAE model, for a given input.
|
|
64
|
+
|
|
65
|
+
Parameters
|
|
66
|
+
----------
|
|
67
|
+
model : LVAE
|
|
68
|
+
Trained LVAE model.
|
|
69
|
+
likelihood_obj : LikelihoodModule
|
|
70
|
+
Instance of a likelihood class.
|
|
71
|
+
input : torch.tensor | tuple of (torch.tensor, Any, ...)
|
|
72
|
+
Input to generate prediction for. This can include auxilary inputs such as
|
|
73
|
+
`TileInformation`, but the model input is always the first item of the tuple.
|
|
74
|
+
Expected shape of the model input is (S, C, Y, X).
|
|
75
|
+
|
|
76
|
+
Returns
|
|
77
|
+
-------
|
|
78
|
+
tuple of ((torch.tensor, Any, ...), optional tuple of (torch.tensor, Any, ...))
|
|
79
|
+
The first element is the sample prediction, and the second element is the
|
|
80
|
+
log-variance. The log-variance will be None if `model.predict_logvar is None`.
|
|
81
|
+
Any auxillary data included in the input will also be include with both the
|
|
82
|
+
sample prediction and the log-variance.
|
|
83
|
+
"""
|
|
84
|
+
x: torch.Tensor
|
|
85
|
+
aux: list[Any]
|
|
86
|
+
x, *aux = input
|
|
87
|
+
|
|
88
|
+
sample_prediction, log_var = lvae_predict_single_sample(
|
|
89
|
+
model=model, likelihood_obj=likelihood_obj, input=x
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
log_var_output = (log_var, *aux) if log_var is not None else None
|
|
93
|
+
return (sample_prediction, *aux), log_var_output
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def lvae_predict_mmse_tiled_batch(
|
|
97
|
+
model: LVAE,
|
|
98
|
+
likelihood_obj: LikelihoodModule,
|
|
99
|
+
input: tuple[Any],
|
|
100
|
+
mmse_count: int,
|
|
101
|
+
) -> tuple[tuple[Any], tuple[Any], Optional[tuple[Any]]]:
|
|
102
|
+
# TODO: fix docstring return types, ... hard to make readable
|
|
103
|
+
"""
|
|
104
|
+
Generate the MMSE (minimum mean squared error) prediction, for a given input.
|
|
105
|
+
|
|
106
|
+
This is calculated from the mean of multiple single sample predictions.
|
|
107
|
+
|
|
108
|
+
Parameters
|
|
109
|
+
----------
|
|
110
|
+
model : LVAE
|
|
111
|
+
Trained LVAE model.
|
|
112
|
+
likelihood_obj : LikelihoodModule
|
|
113
|
+
Instance of a likelihood class.
|
|
114
|
+
input : torch.tensor | tuple of (torch.tensor, Any, ...)
|
|
115
|
+
Input to generate prediction for. This can include auxilary inputs such as
|
|
116
|
+
`TileInformation`, but the model input is always the first item of the tuple.
|
|
117
|
+
Expected shape of the model input is (S, C, Y, X).
|
|
118
|
+
mmse_count : int
|
|
119
|
+
Number of samples to generate to calculate MMSE (minimum mean squared error).
|
|
120
|
+
|
|
121
|
+
Returns
|
|
122
|
+
-------
|
|
123
|
+
tuple of (tuple of (torch.Tensor[Any], Any, ...))
|
|
124
|
+
A tuple of 3 elements. The first element contains the MMSE prediction, the
|
|
125
|
+
second contains the standard deviation of the samples used to create the MMSE
|
|
126
|
+
prediction. Finally the last element contains the log-variance of the
|
|
127
|
+
likelihood, this will be `None` if `likelihood.predict_logvar` is `None`.
|
|
128
|
+
Any auxillary data included in the input will also be include with all of the
|
|
129
|
+
MMSE prediction, the standard deviation, and the log-variance.
|
|
130
|
+
"""
|
|
131
|
+
if mmse_count <= 0:
|
|
132
|
+
raise ValueError("MMSE count must be greater than zero.")
|
|
133
|
+
|
|
134
|
+
x: torch.Tensor
|
|
135
|
+
aux: list[Any]
|
|
136
|
+
x, *aux = input
|
|
137
|
+
|
|
138
|
+
input_shape = x.shape
|
|
139
|
+
output_shape = (input_shape[0], model.target_ch, *input_shape[2:])
|
|
140
|
+
log_var: Optional[torch.Tensor] = None
|
|
141
|
+
# pre-declare empty array to fill with individual sample predictions
|
|
142
|
+
sample_predictions = torch.zeros(size=(mmse_count, *output_shape))
|
|
143
|
+
for mmse_idx in range(mmse_count):
|
|
144
|
+
sample_prediction, lv = lvae_predict_single_sample(
|
|
145
|
+
model=model, likelihood_obj=likelihood_obj, input=x
|
|
146
|
+
)
|
|
147
|
+
# only keep the log variance of the first sample prediction
|
|
148
|
+
if mmse_idx == 0:
|
|
149
|
+
log_var = lv
|
|
150
|
+
|
|
151
|
+
# store sample predictions
|
|
152
|
+
sample_predictions[mmse_idx, ...] = sample_prediction
|
|
153
|
+
|
|
154
|
+
mmse_prediction = torch.mean(sample_predictions, dim=0)
|
|
155
|
+
mmse_prediction_std = torch.std(sample_predictions, dim=0)
|
|
156
|
+
|
|
157
|
+
log_var_output = (log_var, *aux) if log_var is not None else None
|
|
158
|
+
return (mmse_prediction, *aux), (mmse_prediction_std, *aux), log_var_output
|