careamics 0.0.4.2__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +17 -2
- careamics/careamist.py +239 -28
- careamics/cli/conf.py +19 -31
- careamics/cli/main.py +112 -12
- careamics/cli/utils.py +29 -0
- careamics/config/__init__.py +48 -24
- careamics/config/algorithms/__init__.py +15 -0
- careamics/config/algorithms/care_algorithm_model.py +50 -0
- careamics/config/algorithms/n2n_algorithm_model.py +42 -0
- careamics/config/algorithms/n2v_algorithm_model.py +35 -0
- careamics/config/algorithms/unet_algorithm_model.py +88 -0
- careamics/config/{vae_algorithm_model.py → algorithms/vae_algorithm_model.py} +26 -23
- careamics/config/architectures/__init__.py +1 -11
- careamics/config/architectures/architecture_model.py +3 -3
- careamics/config/architectures/lvae_model.py +109 -21
- careamics/config/architectures/unet_model.py +1 -0
- careamics/config/care_configuration.py +100 -0
- careamics/config/configuration.py +354 -0
- careamics/config/{configuration_factory.py → configuration_factories.py} +152 -81
- careamics/config/configuration_io.py +85 -0
- careamics/config/data/__init__.py +10 -0
- careamics/config/{data_model.py → data/data_model.py} +58 -198
- careamics/config/data/n2v_data_model.py +193 -0
- careamics/config/likelihood_model.py +8 -8
- careamics/config/loss_model.py +56 -0
- careamics/config/n2n_configuration.py +101 -0
- careamics/config/n2v_configuration.py +266 -0
- careamics/config/nm_model.py +24 -25
- careamics/config/support/__init__.py +7 -7
- careamics/config/support/supported_algorithms.py +0 -3
- careamics/config/support/supported_architectures.py +0 -4
- careamics/config/transformations/__init__.py +10 -4
- careamics/config/transformations/transform_model.py +3 -3
- careamics/config/transformations/transform_unions.py +42 -0
- careamics/config/validators/validator_utils.py +3 -3
- careamics/dataset/__init__.py +2 -2
- careamics/dataset/dataset_utils/__init__.py +3 -3
- careamics/dataset/dataset_utils/dataset_utils.py +4 -6
- careamics/dataset/dataset_utils/file_utils.py +9 -9
- careamics/dataset/dataset_utils/iterate_over_files.py +4 -3
- careamics/dataset/dataset_utils/running_stats.py +22 -23
- careamics/dataset/in_memory_dataset.py +11 -12
- careamics/dataset/iterable_dataset.py +4 -4
- careamics/dataset/iterable_pred_dataset.py +2 -1
- careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
- careamics/dataset/patching/random_patching.py +11 -10
- careamics/dataset/patching/sequential_patching.py +26 -26
- careamics/dataset/patching/validate_patch_dimension.py +3 -3
- careamics/dataset/tiling/__init__.py +2 -2
- careamics/dataset/tiling/collate_tiles.py +3 -3
- careamics/dataset/tiling/lvae_tiled_patching.py +2 -1
- careamics/dataset/tiling/tiled_patching.py +11 -10
- careamics/file_io/__init__.py +5 -5
- careamics/file_io/read/__init__.py +1 -1
- careamics/file_io/read/get_func.py +2 -2
- careamics/file_io/write/__init__.py +2 -2
- careamics/lightning/__init__.py +5 -5
- careamics/lightning/callbacks/__init__.py +1 -1
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +3 -3
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +2 -1
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +2 -1
- careamics/lightning/callbacks/progress_bar_callback.py +2 -2
- careamics/lightning/lightning_module.py +69 -34
- careamics/lightning/train_data_module.py +41 -27
- careamics/losses/__init__.py +3 -3
- careamics/losses/loss_factory.py +1 -85
- careamics/losses/lvae/losses.py +223 -164
- careamics/lvae_training/calibration.py +184 -0
- careamics/lvae_training/dataset/config.py +2 -2
- careamics/lvae_training/dataset/multich_dataset.py +11 -19
- careamics/lvae_training/dataset/multifile_dataset.py +3 -2
- careamics/lvae_training/dataset/types.py +15 -26
- careamics/lvae_training/dataset/utils/index_manager.py +4 -4
- careamics/lvae_training/eval_utils.py +125 -213
- careamics/model_io/__init__.py +1 -1
- careamics/model_io/bioimage/__init__.py +1 -1
- careamics/model_io/bioimage/_readme_factory.py +26 -34
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +56 -34
- careamics/model_io/bmz_io.py +42 -42
- careamics/model_io/model_io_utils.py +9 -9
- careamics/models/layers.py +22 -20
- careamics/models/lvae/layers.py +348 -975
- careamics/models/lvae/likelihoods.py +10 -8
- careamics/models/lvae/lvae.py +214 -275
- careamics/models/lvae/noise_models.py +179 -112
- careamics/models/lvae/stochastic.py +393 -0
- careamics/models/lvae/utils.py +82 -73
- careamics/models/model_factory.py +2 -15
- careamics/models/unet.py +8 -8
- careamics/prediction_utils/__init__.py +1 -1
- careamics/prediction_utils/prediction_outputs.py +15 -15
- careamics/prediction_utils/stitch_prediction.py +6 -6
- careamics/transforms/__init__.py +5 -5
- careamics/transforms/compose.py +13 -13
- careamics/transforms/n2v_manipulate.py +3 -3
- careamics/transforms/pixel_manipulation.py +9 -9
- careamics/transforms/xy_random_rotate90.py +4 -4
- careamics/utils/__init__.py +5 -5
- careamics/utils/context.py +2 -1
- careamics/utils/lightning_utils.py +57 -0
- careamics/utils/logging.py +11 -10
- careamics/utils/serializers.py +2 -0
- careamics/utils/torch_utils.py +8 -8
- {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/METADATA +16 -13
- careamics-0.0.6.dist-info/RECORD +176 -0
- {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/WHEEL +1 -1
- careamics/config/architectures/custom_model.py +0 -162
- careamics/config/architectures/register_model.py +0 -103
- careamics/config/configuration_model.py +0 -603
- careamics/config/fcn_algorithm_model.py +0 -152
- careamics/config/references/__init__.py +0 -45
- careamics/config/references/algorithm_descriptions.py +0 -132
- careamics/config/references/references.py +0 -39
- careamics/config/transformations/transform_union.py +0 -20
- careamics-0.0.4.2.dist-info/RECORD +0 -165
- {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/licenses/LICENSE +0 -0
careamics/models/lvae/lvae.py
CHANGED
|
@@ -1,58 +1,94 @@
|
|
|
1
1
|
"""
|
|
2
|
-
Ladder VAE (LVAE) Model
|
|
2
|
+
Ladder VAE (LVAE) Model.
|
|
3
3
|
|
|
4
|
-
The current implementation is based on "Interpretable Unsupervised Diversity Denoising
|
|
4
|
+
The current implementation is based on "Interpretable Unsupervised Diversity Denoising
|
|
5
|
+
and Artefact Removal, Prakash et al."
|
|
5
6
|
"""
|
|
6
7
|
|
|
7
8
|
from collections.abc import Iterable
|
|
8
|
-
from typing import
|
|
9
|
+
from typing import Union
|
|
9
10
|
|
|
10
11
|
import numpy as np
|
|
11
12
|
import torch
|
|
12
13
|
import torch.nn as nn
|
|
13
14
|
|
|
14
|
-
from careamics.config.architectures import register_model
|
|
15
|
-
|
|
16
15
|
from ..activation import get_activation
|
|
17
16
|
from .layers import (
|
|
18
17
|
BottomUpDeterministicResBlock,
|
|
19
18
|
BottomUpLayer,
|
|
19
|
+
GateLayer,
|
|
20
20
|
TopDownDeterministicResBlock,
|
|
21
21
|
TopDownLayer,
|
|
22
22
|
)
|
|
23
|
-
from .utils import Interpolate, ModelType, crop_img_tensor
|
|
23
|
+
from .utils import Interpolate, ModelType, crop_img_tensor
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
@register_model("LVAE")
|
|
27
26
|
class LadderVAE(nn.Module):
|
|
27
|
+
"""
|
|
28
|
+
Constructor.
|
|
29
|
+
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
input_shape : int
|
|
33
|
+
The size of the input image.
|
|
34
|
+
output_channels : int
|
|
35
|
+
The number of output channels.
|
|
36
|
+
multiscale_count : int
|
|
37
|
+
The number of scales for multiscale processing.
|
|
38
|
+
z_dims : list[int]
|
|
39
|
+
The dimensions of the latent space for each layer.
|
|
40
|
+
encoder_n_filters : int
|
|
41
|
+
The number of filters in the encoder.
|
|
42
|
+
decoder_n_filters : int
|
|
43
|
+
The number of filters in the decoder.
|
|
44
|
+
encoder_conv_strides : list[int]
|
|
45
|
+
The strides for the conv layers encoder.
|
|
46
|
+
decoder_conv_strides : list[int]
|
|
47
|
+
The strides for the conv layers decoder.
|
|
48
|
+
encoder_dropout : float
|
|
49
|
+
The dropout rate for the encoder.
|
|
50
|
+
decoder_dropout : float
|
|
51
|
+
The dropout rate for the decoder.
|
|
52
|
+
nonlinearity : str
|
|
53
|
+
The nonlinearity function to use.
|
|
54
|
+
predict_logvar : bool
|
|
55
|
+
Whether to predict the log variance.
|
|
56
|
+
analytical_kl : bool
|
|
57
|
+
Whether to use analytical KL divergence.
|
|
58
|
+
|
|
59
|
+
Raises
|
|
60
|
+
------
|
|
61
|
+
NotImplementedError
|
|
62
|
+
If only 2D convolutions are supported.
|
|
63
|
+
"""
|
|
28
64
|
|
|
29
65
|
def __init__(
|
|
30
66
|
self,
|
|
31
67
|
input_shape: int,
|
|
32
68
|
output_channels: int,
|
|
33
69
|
multiscale_count: int,
|
|
34
|
-
z_dims:
|
|
70
|
+
z_dims: list[int],
|
|
35
71
|
encoder_n_filters: int,
|
|
36
72
|
decoder_n_filters: int,
|
|
73
|
+
encoder_conv_strides: list[int],
|
|
74
|
+
decoder_conv_strides: list[int],
|
|
37
75
|
encoder_dropout: float,
|
|
38
76
|
decoder_dropout: float,
|
|
39
77
|
nonlinearity: str,
|
|
40
78
|
predict_logvar: bool,
|
|
41
79
|
analytical_kl: bool,
|
|
42
80
|
):
|
|
43
|
-
"""
|
|
44
|
-
Constructor.
|
|
45
|
-
|
|
46
|
-
Parameters
|
|
47
|
-
----------
|
|
48
|
-
|
|
49
|
-
"""
|
|
50
81
|
super().__init__()
|
|
51
82
|
|
|
52
83
|
# -------------------------------------------------------
|
|
53
84
|
# Customizable attributes
|
|
54
85
|
self.image_size = input_shape
|
|
86
|
+
"""Input image size. (Z, Y, X) or (Y, X) if the data is 2D."""
|
|
87
|
+
# TODO: we need to be careful with this since used to be an int.
|
|
88
|
+
# the tuple of shapes used to be `self.input_shape`.
|
|
55
89
|
self.target_ch = output_channels
|
|
90
|
+
self.encoder_conv_strides = encoder_conv_strides
|
|
91
|
+
self.decoder_conv_strides = decoder_conv_strides
|
|
56
92
|
self._multiscale_count = multiscale_count
|
|
57
93
|
self.z_dims = z_dims
|
|
58
94
|
self.encoder_n_filters = encoder_n_filters
|
|
@@ -80,7 +116,6 @@ class LadderVAE(nn.Module):
|
|
|
80
116
|
self.merge_type = "residual"
|
|
81
117
|
self.no_initial_downscaling = True
|
|
82
118
|
self.skip_bottomk_buvalues = 0
|
|
83
|
-
self.non_stochastic_version = False
|
|
84
119
|
self.stochastic_skip = True
|
|
85
120
|
self.learn_top_prior = True
|
|
86
121
|
self.res_block_type = "bacdbacd" # TODO remove !
|
|
@@ -91,9 +126,7 @@ class LadderVAE(nn.Module):
|
|
|
91
126
|
self._enable_topdown_normalize_factor = True
|
|
92
127
|
|
|
93
128
|
# Attributes that handle LC -> Hardcoded
|
|
94
|
-
self.enable_multiscale =
|
|
95
|
-
self._multiscale_count is not None and self._multiscale_count > 1
|
|
96
|
-
)
|
|
129
|
+
self.enable_multiscale = self._multiscale_count > 1
|
|
97
130
|
self.multiscale_retain_spatial_dims = True
|
|
98
131
|
self.multiscale_lowres_separate_branch = False
|
|
99
132
|
self.multiscale_decoder_retain_spatial_dims = (
|
|
@@ -102,14 +135,6 @@ class LadderVAE(nn.Module):
|
|
|
102
135
|
|
|
103
136
|
# Derived attributes
|
|
104
137
|
self.n_layers = len(self.z_dims)
|
|
105
|
-
self.encoder_no_padding_mode = (
|
|
106
|
-
self.encoder_res_block_skip_padding is True
|
|
107
|
-
and self.encoder_res_block_kernel > 1
|
|
108
|
-
)
|
|
109
|
-
self.decoder_no_padding_mode = (
|
|
110
|
-
self.decoder_res_block_skip_padding is True
|
|
111
|
-
and self.decoder_res_block_kernel > 1
|
|
112
|
-
)
|
|
113
138
|
|
|
114
139
|
# Others...
|
|
115
140
|
self._tethered_to_input = False
|
|
@@ -127,19 +152,41 @@ class LadderVAE(nn.Module):
|
|
|
127
152
|
|
|
128
153
|
# -------------------------------------------------------
|
|
129
154
|
# Data attributes
|
|
130
|
-
self.color_ch = 1
|
|
131
|
-
self.img_shape = (self.image_size, self.image_size)
|
|
155
|
+
self.color_ch = 1 # TODO for now we only support 1 channel
|
|
132
156
|
self.normalized_input = True
|
|
133
157
|
# -------------------------------------------------------
|
|
134
158
|
|
|
135
159
|
# -------------------------------------------------------
|
|
136
160
|
# Loss attributes
|
|
137
|
-
self._restricted_kl = False # HC
|
|
138
161
|
# enabling reconstruction loss on mixed input
|
|
139
162
|
self.mixed_rec_w = 0
|
|
140
163
|
self.nbr_consistency_w = 0
|
|
141
164
|
|
|
142
165
|
# -------------------------------------------------------
|
|
166
|
+
# 3D related stuff
|
|
167
|
+
self._mode_3D = len(self.image_size) == 3 # TODO refac
|
|
168
|
+
self._model_3D_depth = self.image_size[0] if self._mode_3D else 1
|
|
169
|
+
self._decoder_mode_3D = len(self.decoder_conv_strides) == 3
|
|
170
|
+
if self._mode_3D and not self._decoder_mode_3D:
|
|
171
|
+
assert self._model_3D_depth % 2 == 1, "3D model depth should be odd"
|
|
172
|
+
assert (
|
|
173
|
+
self._mode_3D is True or self._decoder_mode_3D is False
|
|
174
|
+
), "Decoder cannot be 3D when encoder is 2D"
|
|
175
|
+
self._squish3d = self._mode_3D and not self._decoder_mode_3D
|
|
176
|
+
self._3D_squisher = (
|
|
177
|
+
None
|
|
178
|
+
if not self._squish3d
|
|
179
|
+
else nn.ModuleList(
|
|
180
|
+
[
|
|
181
|
+
GateLayer(
|
|
182
|
+
channels=self.encoder_n_filters,
|
|
183
|
+
conv_strides=self.encoder_conv_strides,
|
|
184
|
+
)
|
|
185
|
+
for k in range(len(self.z_dims))
|
|
186
|
+
]
|
|
187
|
+
)
|
|
188
|
+
)
|
|
189
|
+
# TODO: this bit is in the Ashesh's confusing-hacky style... Can we do better?
|
|
143
190
|
|
|
144
191
|
# -------------------------------------------------------
|
|
145
192
|
# # Training attributes
|
|
@@ -168,6 +215,11 @@ class LadderVAE(nn.Module):
|
|
|
168
215
|
### CREATE MODEL BLOCKS
|
|
169
216
|
# First bottom-up layer: change num channels + downsample by factor 2
|
|
170
217
|
# unless we want to prevent this
|
|
218
|
+
self.encoder_conv_op = getattr(nn, f"Conv{len(self.encoder_conv_strides)}d")
|
|
219
|
+
# TODO these should be defined for all layers here ?
|
|
220
|
+
self.decoder_conv_op = getattr(nn, f"Conv{len(self.decoder_conv_strides)}d")
|
|
221
|
+
# TODO: would be more readable to have a derived parameters to use like
|
|
222
|
+
# `conv_dims = len(self.encoder_conv_strides)` and then use `Conv{conv_dims}d`
|
|
171
223
|
stride = 1 if self.no_initial_downscaling else 2
|
|
172
224
|
self.first_bottom_up = self.create_first_bottom_up(stride)
|
|
173
225
|
|
|
@@ -191,7 +243,7 @@ class LadderVAE(nn.Module):
|
|
|
191
243
|
|
|
192
244
|
# Output layer --> Project to target_ch many channels
|
|
193
245
|
logvar_ch_needed = self.predict_logvar is not None
|
|
194
|
-
self.output_layer = self.parameter_net =
|
|
246
|
+
self.output_layer = self.parameter_net = self.decoder_conv_op(
|
|
195
247
|
self.decoder_n_filters,
|
|
196
248
|
self.target_ch * (1 + logvar_ch_needed),
|
|
197
249
|
kernel_size=3,
|
|
@@ -205,6 +257,7 @@ class LadderVAE(nn.Module):
|
|
|
205
257
|
# PSNR computation on validation.
|
|
206
258
|
# self.label1_psnr = RunningPSNR()
|
|
207
259
|
# self.label2_psnr = RunningPSNR()
|
|
260
|
+
# TODO: did you add this?
|
|
208
261
|
|
|
209
262
|
# msg =f'[{self.__class__.__name__}] Stoc:{not self.non_stochastic_version} RecMode:{self.reconstruction_mode} TethInput:{self._tethered_to_input}'
|
|
210
263
|
# msg += f' TargetCh: {self.target_ch}'
|
|
@@ -217,7 +270,8 @@ class LadderVAE(nn.Module):
|
|
|
217
270
|
num_res_blocks: int = 1,
|
|
218
271
|
) -> nn.Sequential:
|
|
219
272
|
"""
|
|
220
|
-
|
|
273
|
+
Method creates the first bottom-up block of the Encoder.
|
|
274
|
+
|
|
221
275
|
Its role is to perform a first image compression step.
|
|
222
276
|
It is composed by a sequence of nn.Conv2d + non-linearity +
|
|
223
277
|
BottomUpDeterministicResBlock (1 or more, default is 1).
|
|
@@ -225,29 +279,30 @@ class LadderVAE(nn.Module):
|
|
|
225
279
|
Parameters
|
|
226
280
|
----------
|
|
227
281
|
init_stride: int
|
|
228
|
-
The stride used by the
|
|
282
|
+
The stride used by the intial Conv2d block.
|
|
229
283
|
num_res_blocks: int, optional
|
|
230
|
-
The number of BottomUpDeterministicResBlocks
|
|
284
|
+
The number of BottomUpDeterministicResBlocks, default is 1.
|
|
231
285
|
"""
|
|
286
|
+
# From what I got from Ashesh, Z should not be touched in any case.
|
|
232
287
|
nonlin = get_activation(self.nonlin)
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
else self.encoder_res_block_kernel // 2
|
|
242
|
-
),
|
|
243
|
-
stride=init_stride,
|
|
288
|
+
conv_block = self.encoder_conv_op(
|
|
289
|
+
in_channels=self.color_ch,
|
|
290
|
+
out_channels=self.encoder_n_filters,
|
|
291
|
+
kernel_size=self.encoder_res_block_kernel,
|
|
292
|
+
padding=(
|
|
293
|
+
0
|
|
294
|
+
if self.encoder_res_block_skip_padding
|
|
295
|
+
else self.encoder_res_block_kernel // 2
|
|
244
296
|
),
|
|
245
|
-
|
|
246
|
-
|
|
297
|
+
stride=init_stride,
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
modules = [conv_block, nonlin]
|
|
247
301
|
|
|
248
302
|
for _ in range(num_res_blocks):
|
|
249
303
|
modules.append(
|
|
250
304
|
BottomUpDeterministicResBlock(
|
|
305
|
+
conv_strides=self.encoder_conv_strides,
|
|
251
306
|
c_in=self.encoder_n_filters,
|
|
252
307
|
c_out=self.encoder_n_filters,
|
|
253
308
|
nonlin=nonlin,
|
|
@@ -255,7 +310,6 @@ class LadderVAE(nn.Module):
|
|
|
255
310
|
batchnorm=self.bottomup_batchnorm,
|
|
256
311
|
dropout=self.encoder_dropout,
|
|
257
312
|
res_block_type=self.res_block_type,
|
|
258
|
-
skip_padding=self.encoder_res_block_skip_padding,
|
|
259
313
|
res_block_kernel=self.encoder_res_block_kernel,
|
|
260
314
|
)
|
|
261
315
|
)
|
|
@@ -264,7 +318,8 @@ class LadderVAE(nn.Module):
|
|
|
264
318
|
|
|
265
319
|
def create_bottom_up_layers(self, lowres_separate_branch: bool) -> nn.ModuleList:
|
|
266
320
|
"""
|
|
267
|
-
|
|
321
|
+
Method creates the stack of bottom-up layers of the Encoder.
|
|
322
|
+
|
|
268
323
|
that are used to generate the so-called `bu_values`.
|
|
269
324
|
|
|
270
325
|
NOTE:
|
|
@@ -274,8 +329,9 @@ class LadderVAE(nn.Module):
|
|
|
274
329
|
Parameters
|
|
275
330
|
----------
|
|
276
331
|
lowres_separate_branch: bool
|
|
277
|
-
Whether the residual block(s) used for encoding the low-res input are shared
|
|
278
|
-
not (`True`) with the "same-size" residual block(s) in the
|
|
332
|
+
Whether the residual block(s) used for encoding the low-res input are shared
|
|
333
|
+
(`False`) or not (`True`) with the "same-size" residual block(s) in the
|
|
334
|
+
`BottomUpLayer`'s primary flow.
|
|
279
335
|
"""
|
|
280
336
|
multiscale_lowres_size_factor = 1
|
|
281
337
|
nonlin = get_activation(self.nonlin)
|
|
@@ -294,11 +350,11 @@ class LadderVAE(nn.Module):
|
|
|
294
350
|
# N.B. Only used if layer_enable_multiscale == True, so we updated it only in that case
|
|
295
351
|
multiscale_lowres_size_factor *= 1 + int(layer_enable_multiscale)
|
|
296
352
|
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
353
|
+
# TODO: check correctness of this
|
|
354
|
+
if self._multiscale_count > 1:
|
|
355
|
+
output_expected_shape = (dim // 2 ** (i + 1) for dim in self.image_size)
|
|
356
|
+
else:
|
|
357
|
+
output_expected_shape = None
|
|
302
358
|
|
|
303
359
|
# Add bottom-up deterministic layer at level i.
|
|
304
360
|
# It's a sequence of residual blocks (BottomUpDeterministicResBlock), possibly with downsampling between them.
|
|
@@ -308,14 +364,14 @@ class LadderVAE(nn.Module):
|
|
|
308
364
|
n_filters=self.encoder_n_filters,
|
|
309
365
|
downsampling_steps=self.downsample[i],
|
|
310
366
|
nonlin=nonlin,
|
|
367
|
+
conv_strides=self.encoder_conv_strides,
|
|
311
368
|
batchnorm=self.bottomup_batchnorm,
|
|
312
369
|
dropout=self.encoder_dropout,
|
|
313
370
|
res_block_type=self.res_block_type,
|
|
314
371
|
res_block_kernel=self.encoder_res_block_kernel,
|
|
315
|
-
res_block_skip_padding=self.encoder_res_block_skip_padding,
|
|
316
372
|
gated=self.gated,
|
|
317
373
|
lowres_separate_branch=lowres_separate_branch,
|
|
318
|
-
enable_multiscale=self.enable_multiscale, # shouldn't the arg be `layer_enable_multiscale` here?
|
|
374
|
+
enable_multiscale=self.enable_multiscale, # TODO: shouldn't the arg be `layer_enable_multiscale` here?
|
|
319
375
|
multiscale_retain_spatial_dims=self.multiscale_retain_spatial_dims,
|
|
320
376
|
multiscale_lowres_size_factor=multiscale_lowres_size_factor,
|
|
321
377
|
decoder_retain_spatial_dims=self.multiscale_decoder_retain_spatial_dims,
|
|
@@ -327,7 +383,8 @@ class LadderVAE(nn.Module):
|
|
|
327
383
|
|
|
328
384
|
def create_top_down_layers(self) -> nn.ModuleList:
|
|
329
385
|
"""
|
|
330
|
-
|
|
386
|
+
Method creates the stack of top-down layers of the Decoder.
|
|
387
|
+
|
|
331
388
|
In these layer the `bu`_values` from the Encoder are merged with the `p_params` from the previous layer
|
|
332
389
|
of the Decoder to get `q_params`. Then, a stochastic layer generates a sample from the latent distribution
|
|
333
390
|
with parameters `q_params`. Finally, this sample is fed through a TopDownDeterministicResBlock to
|
|
@@ -346,8 +403,6 @@ class LadderVAE(nn.Module):
|
|
|
346
403
|
When doing unconditional generation, bu_value is not available. Hence the
|
|
347
404
|
merge layer is not used, and z is sampled directly from p_params.
|
|
348
405
|
|
|
349
|
-
Parameters
|
|
350
|
-
----------
|
|
351
406
|
"""
|
|
352
407
|
top_down_layers = nn.ModuleList([])
|
|
353
408
|
nonlin = get_activation(self.nonlin)
|
|
@@ -356,7 +411,7 @@ class LadderVAE(nn.Module):
|
|
|
356
411
|
# Check if this is the top layer
|
|
357
412
|
is_top = i == self.n_layers - 1
|
|
358
413
|
|
|
359
|
-
if self._enable_topdown_normalize_factor:
|
|
414
|
+
if self._enable_topdown_normalize_factor: # TODO: What is this?
|
|
360
415
|
normalize_latent_factor = (
|
|
361
416
|
1 / np.sqrt(2 * (1 + i)) if len(self.z_dims) > 4 else 1.0
|
|
362
417
|
)
|
|
@@ -369,7 +424,8 @@ class LadderVAE(nn.Module):
|
|
|
369
424
|
n_res_blocks=self.decoder_blocks_per_layer,
|
|
370
425
|
n_filters=self.decoder_n_filters,
|
|
371
426
|
is_top_layer=is_top,
|
|
372
|
-
|
|
427
|
+
conv_strides=self.decoder_conv_strides,
|
|
428
|
+
upsampling_steps=self.downsample[i],
|
|
373
429
|
nonlin=nonlin,
|
|
374
430
|
merge_type=self.merge_type,
|
|
375
431
|
batchnorm=self.topdown_batchnorm,
|
|
@@ -379,17 +435,11 @@ class LadderVAE(nn.Module):
|
|
|
379
435
|
top_prior_param_shape=self.get_top_prior_param_shape(),
|
|
380
436
|
res_block_type=self.res_block_type,
|
|
381
437
|
res_block_kernel=self.decoder_res_block_kernel,
|
|
382
|
-
res_block_skip_padding=self.decoder_res_block_skip_padding,
|
|
383
438
|
gated=self.gated,
|
|
384
439
|
analytical_kl=self.analytical_kl,
|
|
385
|
-
restricted_kl=self._restricted_kl,
|
|
386
440
|
vanilla_latent_hw=self.get_latent_spatial_size(i),
|
|
387
|
-
# in no_padding_mode, what gets passed from the encoder are not multiples of 2 and so merging operation does not work natively.
|
|
388
|
-
bottomup_no_padding_mode=self.encoder_no_padding_mode,
|
|
389
|
-
topdown_no_padding_mode=self.decoder_no_padding_mode,
|
|
390
441
|
retain_spatial_dims=self.multiscale_decoder_retain_spatial_dims,
|
|
391
|
-
|
|
392
|
-
input_image_shape=self.img_shape,
|
|
442
|
+
input_image_shape=self.image_size,
|
|
393
443
|
normalize_latent_factor=normalize_latent_factor,
|
|
394
444
|
conv2d_bias=self.topdown_conv2d_bias,
|
|
395
445
|
stochastic_use_naive_exponential=self._stochastic_use_naive_exponential,
|
|
@@ -398,8 +448,10 @@ class LadderVAE(nn.Module):
|
|
|
398
448
|
return top_down_layers
|
|
399
449
|
|
|
400
450
|
def create_final_topdown_layer(self, upsample: bool) -> nn.Sequential:
|
|
401
|
-
"""
|
|
402
|
-
|
|
451
|
+
"""Create the final top-down layer of the Decoder.
|
|
452
|
+
|
|
453
|
+
NOTE: In this layer, (optional) upsampling is performed by bilinear interpolation
|
|
454
|
+
instead of transposed convolution (like in other TD layers).
|
|
403
455
|
|
|
404
456
|
Parameters
|
|
405
457
|
----------
|
|
@@ -419,69 +471,76 @@ class LadderVAE(nn.Module):
|
|
|
419
471
|
c_in=self.decoder_n_filters,
|
|
420
472
|
c_out=self.decoder_n_filters,
|
|
421
473
|
nonlin=get_activation(self.nonlin),
|
|
474
|
+
conv_strides=self.decoder_conv_strides,
|
|
422
475
|
batchnorm=self.topdown_batchnorm,
|
|
423
476
|
dropout=self.decoder_dropout,
|
|
424
477
|
res_block_type=self.res_block_type,
|
|
425
478
|
res_block_kernel=self.decoder_res_block_kernel,
|
|
426
|
-
skip_padding=self.decoder_res_block_skip_padding,
|
|
427
479
|
gated=self.gated,
|
|
428
480
|
conv2d_bias=self.topdown_conv2d_bias,
|
|
429
481
|
)
|
|
430
482
|
)
|
|
431
483
|
return nn.Sequential(*modules)
|
|
432
484
|
|
|
433
|
-
def _init_multires(
|
|
434
|
-
self, config=None
|
|
435
|
-
) -> nn.ModuleList: # TODO config: ml_collections.ConfigDict refactor
|
|
485
|
+
def _init_multires(self, config=None) -> nn.ModuleList:
|
|
436
486
|
"""
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
487
|
+
Method defines the input block/branch to encode/compress low-res lateral inputs.
|
|
488
|
+
|
|
489
|
+
at different hierarchical levels
|
|
490
|
+
in the multiresolution approach (LC). The role of the input branches is similar
|
|
491
|
+
to the one of the first bottom-up layer in the primary flow of the Encoder,
|
|
492
|
+
namely to compress the lateral input image to a degree that is compatible with
|
|
493
|
+
the one of the primary flow.
|
|
494
|
+
|
|
495
|
+
NOTE 1: Each input branch consists of a sequence of Conv2d + non-linearity
|
|
496
|
+
+ BottomUpDeterministicResBlock. It is meaningful to observe that the
|
|
497
|
+
`BottomUpDeterministicResBlock` shares the same model attributes with the blocks
|
|
498
|
+
in the primary flow of the Encoder (e.g., c_in, c_out, dropout, etc. etc.).
|
|
499
|
+
Moreover, it does not perform downsampling.
|
|
500
|
+
|
|
501
|
+
NOTE 2: `_multiscale_count` attribute defines the total number of inputs to the
|
|
502
|
+
bottom-up pass. In other terms if we have the input patch and n_LC additional
|
|
503
|
+
lateral inputs, we will have a total of (n_LC + 1) inputs.
|
|
448
504
|
"""
|
|
449
505
|
stride = 1 if self.no_initial_downscaling else 2
|
|
450
506
|
nonlin = get_activation(self.nonlin)
|
|
451
507
|
if self._multiscale_count is None:
|
|
452
508
|
self._multiscale_count = 1
|
|
453
509
|
|
|
454
|
-
msg =
|
|
455
|
-
|
|
510
|
+
msg = (
|
|
511
|
+
f"Multiscale count ({self._multiscale_count}) should not exceed the number"
|
|
512
|
+
f"of bottom up layers ({self.n_layers}) by more than 1.\n"
|
|
513
|
+
)
|
|
456
514
|
assert (
|
|
457
515
|
self._multiscale_count <= 1 or self._multiscale_count <= 1 + self.n_layers
|
|
458
|
-
), msg
|
|
516
|
+
), msg # TODO how ?
|
|
459
517
|
|
|
460
518
|
msg = (
|
|
461
|
-
"
|
|
519
|
+
"Multiscale approach only supports monocrome images. "
|
|
520
|
+
f"Found instead color_ch={self.color_ch}."
|
|
462
521
|
)
|
|
463
|
-
assert self._multiscale_count == 1 or self.color_ch == 1, msg
|
|
522
|
+
# assert self._multiscale_count == 1 or self.color_ch == 1, msg
|
|
464
523
|
|
|
465
524
|
lowres_first_bottom_ups = []
|
|
466
525
|
for _ in range(1, self._multiscale_count):
|
|
467
526
|
first_bottom_up = nn.Sequential(
|
|
468
|
-
|
|
527
|
+
self.encoder_conv_op(
|
|
469
528
|
in_channels=self.color_ch,
|
|
470
529
|
out_channels=self.encoder_n_filters,
|
|
471
530
|
kernel_size=5,
|
|
472
|
-
padding=
|
|
531
|
+
padding="same",
|
|
473
532
|
stride=stride,
|
|
474
533
|
),
|
|
475
534
|
nonlin,
|
|
476
535
|
BottomUpDeterministicResBlock(
|
|
477
536
|
c_in=self.encoder_n_filters,
|
|
478
537
|
c_out=self.encoder_n_filters,
|
|
538
|
+
conv_strides=self.encoder_conv_strides,
|
|
479
539
|
nonlin=nonlin,
|
|
480
540
|
downsample=False,
|
|
481
541
|
batchnorm=self.bottomup_batchnorm,
|
|
482
542
|
dropout=self.encoder_dropout,
|
|
483
543
|
res_block_type=self.res_block_type,
|
|
484
|
-
skip_padding=self.encoder_res_block_skip_padding,
|
|
485
544
|
),
|
|
486
545
|
)
|
|
487
546
|
lowres_first_bottom_ups.append(first_bottom_up)
|
|
@@ -493,10 +552,9 @@ class LadderVAE(nn.Module):
|
|
|
493
552
|
)
|
|
494
553
|
|
|
495
554
|
### SET OF FORWARD-LIKE METHODS
|
|
496
|
-
def bottomup_pass(self, inp: torch.Tensor) ->
|
|
497
|
-
"""
|
|
498
|
-
|
|
499
|
-
"""
|
|
555
|
+
def bottomup_pass(self, inp: torch.Tensor) -> list[torch.Tensor]:
|
|
556
|
+
"""Wrapper of _bottomup_pass()."""
|
|
557
|
+
# TODO Remove wrapper
|
|
500
558
|
return self._bottomup_pass(
|
|
501
559
|
inp,
|
|
502
560
|
self.first_bottom_up,
|
|
@@ -510,9 +568,10 @@ class LadderVAE(nn.Module):
|
|
|
510
568
|
first_bottom_up: nn.Sequential,
|
|
511
569
|
lowres_first_bottom_ups: nn.ModuleList,
|
|
512
570
|
bottom_up_layers: nn.ModuleList,
|
|
513
|
-
) ->
|
|
571
|
+
) -> list[torch.Tensor]:
|
|
514
572
|
"""
|
|
515
|
-
|
|
573
|
+
Method defines the forward pass through the LVAE Encoder, the so-called.
|
|
574
|
+
|
|
516
575
|
Bottom-Up pass.
|
|
517
576
|
|
|
518
577
|
Parameters
|
|
@@ -541,7 +600,6 @@ class LadderVAE(nn.Module):
|
|
|
541
600
|
lowres_x = None
|
|
542
601
|
if self._multiscale_count > 1 and i + 1 < inp.shape[1]:
|
|
543
602
|
lowres_x = lowres_first_bottom_ups[i](inp[:, i + 1 : i + 2])
|
|
544
|
-
|
|
545
603
|
x, bu_value = bottom_up_layers[i](x, lowres_x=lowres_x)
|
|
546
604
|
bu_values.append(bu_value)
|
|
547
605
|
|
|
@@ -549,41 +607,40 @@ class LadderVAE(nn.Module):
|
|
|
549
607
|
|
|
550
608
|
def topdown_pass(
|
|
551
609
|
self,
|
|
552
|
-
bu_values: torch.Tensor = None,
|
|
553
|
-
n_img_prior: torch.Tensor = None,
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
|
610
|
+
bu_values: Union[torch.Tensor, None] = None,
|
|
611
|
+
n_img_prior: Union[torch.Tensor, None] = None,
|
|
612
|
+
constant_layers: Union[Iterable[int], None] = None,
|
|
613
|
+
forced_latent: Union[list[torch.Tensor], None] = None,
|
|
614
|
+
top_down_layers: Union[nn.ModuleList, None] = None,
|
|
615
|
+
final_top_down_layer: Union[nn.Sequential, None] = None,
|
|
616
|
+
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
|
560
617
|
"""
|
|
561
|
-
|
|
618
|
+
Method defines the forward pass through the LVAE Decoder, the so-called.
|
|
619
|
+
|
|
562
620
|
Top-Down pass.
|
|
563
621
|
|
|
564
622
|
Parameters
|
|
565
623
|
----------
|
|
566
624
|
bu_values: torch.Tensor, optional
|
|
567
|
-
Output of the bottom-up pass. It will have values from multiple layers of
|
|
625
|
+
Output of the bottom-up pass. It will have values from multiple layers of
|
|
626
|
+
the ladder.
|
|
568
627
|
n_img_prior: optional
|
|
569
|
-
When `bu_values` is `None`, `n_img_prior` indicates the number of images to
|
|
628
|
+
When `bu_values` is `None`, `n_img_prior` indicates the number of images to
|
|
629
|
+
generate
|
|
570
630
|
from the prior (so bottom-up pass is not used at all here).
|
|
571
|
-
mode_layers: Iterable[int], optional
|
|
572
|
-
A sequence of indexes associated to the layers in which sampling is disabled and
|
|
573
|
-
the mode (mean value) is used instead. Set to `None` to avoid this behaviour.
|
|
574
631
|
constant_layers: Iterable[int], optional
|
|
575
|
-
A sequence of indexes associated to the layers in which a single instance's
|
|
576
|
-
copied over the entire batch (bottom-up path is not used, so only prior
|
|
577
|
-
Set to `None` to avoid this behaviour.
|
|
578
|
-
forced_latent:
|
|
579
|
-
A list of tensors that are used as fixed latent variables (hence, sampling
|
|
580
|
-
place in this case).
|
|
632
|
+
A sequence of indexes associated to the layers in which a single instance's
|
|
633
|
+
z is copied over the entire batch (bottom-up path is not used, so only prior
|
|
634
|
+
is used here). Set to `None` to avoid this behaviour.
|
|
635
|
+
forced_latent: list[torch.Tensor], optional
|
|
636
|
+
A list of tensors that are used as fixed latent variables (hence, sampling
|
|
637
|
+
doesn't take place in this case).
|
|
581
638
|
top_down_layers: nn.ModuleList, optional
|
|
582
|
-
A list of top-down layers to use in the top-down pass. If `None`, the method
|
|
583
|
-
default layers defined in the constructor.
|
|
639
|
+
A list of top-down layers to use in the top-down pass. If `None`, the method
|
|
640
|
+
uses the default layers defined in the constructor.
|
|
584
641
|
final_top_down_layer: nn.Sequential, optional
|
|
585
|
-
The last top-down layer of the top-down pass. If `None`, the method uses the
|
|
586
|
-
layers defined in the constructor.
|
|
642
|
+
The last top-down layer of the top-down pass. If `None`, the method uses the
|
|
643
|
+
default layers defined in the constructor.
|
|
587
644
|
"""
|
|
588
645
|
if top_down_layers is None:
|
|
589
646
|
top_down_layers = self.top_down_layers
|
|
@@ -591,11 +648,9 @@ class LadderVAE(nn.Module):
|
|
|
591
648
|
final_top_down_layer = self.final_top_down
|
|
592
649
|
|
|
593
650
|
# Default: no layer is sampled from the distribution's mode
|
|
594
|
-
if mode_layers is None:
|
|
595
|
-
mode_layers = []
|
|
596
651
|
if constant_layers is None:
|
|
597
652
|
constant_layers = []
|
|
598
|
-
prior_experiment = len(
|
|
653
|
+
prior_experiment = len(constant_layers) > 0
|
|
599
654
|
|
|
600
655
|
# If the bottom-up inference values are not given, don't do
|
|
601
656
|
# inference, sample from prior instead
|
|
@@ -608,11 +663,7 @@ class LadderVAE(nn.Module):
|
|
|
608
663
|
"if and only if we're not doing inference"
|
|
609
664
|
)
|
|
610
665
|
raise RuntimeError(msg)
|
|
611
|
-
if
|
|
612
|
-
inference_mode
|
|
613
|
-
and prior_experiment
|
|
614
|
-
and (self.non_stochastic_version is False)
|
|
615
|
-
):
|
|
666
|
+
if inference_mode and prior_experiment:
|
|
616
667
|
msg = (
|
|
617
668
|
"Prior experiments (e.g. sampling from mode) are not"
|
|
618
669
|
" compatible with inference mode"
|
|
@@ -621,34 +672,24 @@ class LadderVAE(nn.Module):
|
|
|
621
672
|
|
|
622
673
|
# Sampled latent variables at each layer
|
|
623
674
|
z = [None] * self.n_layers
|
|
624
|
-
|
|
625
675
|
# KL divergence of each layer
|
|
626
676
|
kl = [None] * self.n_layers
|
|
627
677
|
# Kl divergence restricted, only for the LC enabled setup denoiSplit.
|
|
628
678
|
kl_restricted = [None] * self.n_layers
|
|
629
|
-
|
|
630
679
|
# mean from which z is sampled.
|
|
631
680
|
q_mu = [None] * self.n_layers
|
|
632
681
|
# log(var) from which z is sampled.
|
|
633
682
|
q_lv = [None] * self.n_layers
|
|
634
|
-
|
|
635
683
|
# Spatial map of KL divergence for each layer
|
|
636
684
|
kl_spatial = [None] * self.n_layers
|
|
637
|
-
|
|
638
685
|
debug_qvar_max = [None] * self.n_layers
|
|
639
|
-
|
|
640
686
|
kl_channelwise = [None] * self.n_layers
|
|
641
|
-
|
|
642
687
|
if forced_latent is None:
|
|
643
688
|
forced_latent = [None] * self.n_layers
|
|
644
689
|
|
|
645
|
-
# log p(z) where z is the sample in the topdown pass
|
|
646
|
-
# logprob_p = 0.
|
|
647
|
-
|
|
648
690
|
# Top-down inference/generation loop
|
|
649
|
-
out =
|
|
691
|
+
out = None
|
|
650
692
|
for i in reversed(range(self.n_layers)):
|
|
651
|
-
|
|
652
693
|
# If available, get deterministic node from bottom-up inference
|
|
653
694
|
try:
|
|
654
695
|
bu_value = bu_values[i]
|
|
@@ -656,26 +697,23 @@ class LadderVAE(nn.Module):
|
|
|
656
697
|
bu_value = None
|
|
657
698
|
|
|
658
699
|
# Whether the current layer should be sampled from the mode
|
|
659
|
-
use_mode = i in mode_layers
|
|
660
700
|
constant_out = i in constant_layers
|
|
661
701
|
|
|
662
702
|
# Input for skip connection
|
|
663
|
-
skip_input = out
|
|
703
|
+
skip_input = out
|
|
664
704
|
|
|
665
705
|
# Full top-down layer, including sampling and deterministic part
|
|
666
|
-
out,
|
|
706
|
+
out, aux = top_down_layers[i](
|
|
667
707
|
input_=out,
|
|
668
708
|
skip_connection_input=skip_input,
|
|
669
709
|
inference_mode=inference_mode,
|
|
670
710
|
bu_value=bu_value,
|
|
671
711
|
n_img_prior=n_img_prior,
|
|
672
|
-
use_mode=use_mode,
|
|
673
712
|
force_constant_output=constant_out,
|
|
674
713
|
forced_latent=forced_latent[i],
|
|
675
714
|
mode_pred=self.mode_pred,
|
|
676
715
|
var_clip_max=self._var_clip_max,
|
|
677
716
|
)
|
|
678
|
-
|
|
679
717
|
# Save useful variables
|
|
680
718
|
z[i] = aux["z"] # sampled variable at this layer (batch, ch, h, w)
|
|
681
719
|
kl[i] = aux["kl_samplewise"] # (batch, )
|
|
@@ -708,8 +746,10 @@ class LadderVAE(nn.Module):
|
|
|
708
746
|
}
|
|
709
747
|
return out, data
|
|
710
748
|
|
|
711
|
-
def forward(self, x: torch.Tensor) ->
|
|
749
|
+
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
|
712
750
|
"""
|
|
751
|
+
Forward pass through the LVAE model.
|
|
752
|
+
|
|
713
753
|
Parameters
|
|
714
754
|
----------
|
|
715
755
|
x: torch.Tensor
|
|
@@ -717,124 +757,40 @@ class LadderVAE(nn.Module):
|
|
|
717
757
|
"""
|
|
718
758
|
img_size = x.size()[2:]
|
|
719
759
|
|
|
720
|
-
# Pad input to size equal to the closest power of 2
|
|
721
|
-
x_pad = self.pad_input(x)
|
|
722
|
-
|
|
723
760
|
# Bottom-up inference: return list of length n_layers (bottom to top)
|
|
724
|
-
bu_values = self.bottomup_pass(
|
|
761
|
+
bu_values = self.bottomup_pass(x)
|
|
725
762
|
for i in range(0, self.skip_bottomk_buvalues):
|
|
726
763
|
bu_values[i] = None
|
|
727
764
|
|
|
728
|
-
|
|
765
|
+
if self._squish3d:
|
|
766
|
+
bu_values = [
|
|
767
|
+
torch.mean(self._3D_squisher[k](bu_value), dim=2)
|
|
768
|
+
for k, bu_value in enumerate(bu_values)
|
|
769
|
+
]
|
|
729
770
|
|
|
730
771
|
# Top-down inference/generation
|
|
731
|
-
out, td_data = self.topdown_pass(bu_values
|
|
772
|
+
out, td_data = self.topdown_pass(bu_values)
|
|
732
773
|
|
|
733
774
|
if out.shape[-1] > img_size[-1]:
|
|
734
775
|
# Restore original image size
|
|
735
776
|
out = crop_img_tensor(out, img_size)
|
|
736
777
|
|
|
737
778
|
out = self.output_layer(out)
|
|
738
|
-
if self._tethered_to_input:
|
|
739
|
-
assert out.shape[1] == 1
|
|
740
|
-
ch2 = self.get_other_channel(out, x_pad)
|
|
741
|
-
out = torch.cat([out, ch2], dim=1)
|
|
742
779
|
|
|
743
780
|
return out, td_data
|
|
744
781
|
|
|
745
|
-
### SET OF UTILS METHODS
|
|
746
|
-
# def sample_prior(
|
|
747
|
-
# self,
|
|
748
|
-
# n_imgs,
|
|
749
|
-
# mode_layers=None,
|
|
750
|
-
# constant_layers=None
|
|
751
|
-
# ):
|
|
752
|
-
|
|
753
|
-
# # Generate from prior
|
|
754
|
-
# out, _ = self.topdown_pass(n_img_prior=n_imgs, mode_layers=mode_layers, constant_layers=constant_layers)
|
|
755
|
-
# out = crop_img_tensor(out, self.img_shape)
|
|
756
|
-
|
|
757
|
-
# # Log likelihood and other info (per data point)
|
|
758
|
-
# _, likelihood_data = self.likelihood(out, None)
|
|
759
|
-
|
|
760
|
-
# return likelihood_data['sample']
|
|
761
|
-
|
|
762
|
-
# ### ???
|
|
763
|
-
# def sample_from_q(self, x, masks=None):
|
|
764
|
-
# """
|
|
765
|
-
# This method performs the bottomup_pass() and samples from the
|
|
766
|
-
# obtained distribution.
|
|
767
|
-
# """
|
|
768
|
-
# img_size = x.size()[2:]
|
|
769
|
-
|
|
770
|
-
# # Pad input to make everything easier with conv strides
|
|
771
|
-
# x_pad = self.pad_input(x)
|
|
772
|
-
|
|
773
|
-
# # Bottom-up inference: return list of length n_layers (bottom to top)
|
|
774
|
-
# bu_values = self.bottomup_pass(x_pad)
|
|
775
|
-
# return self._sample_from_q(bu_values, masks=masks)
|
|
776
|
-
# ### ???
|
|
777
|
-
|
|
778
|
-
# def _sample_from_q(self, bu_values, top_down_layers=None, final_top_down_layer=None, masks=None):
|
|
779
|
-
# if top_down_layers is None:
|
|
780
|
-
# top_down_layers = self.top_down_layers
|
|
781
|
-
# if final_top_down_layer is None:
|
|
782
|
-
# final_top_down_layer = self.final_top_down
|
|
783
|
-
# if masks is None:
|
|
784
|
-
# masks = [None] * len(bu_values)
|
|
785
|
-
|
|
786
|
-
# msg = "Multiscale is not supported as of now. You need the output from the previous layers to do this."
|
|
787
|
-
# assert self.n_layers == 1, msg
|
|
788
|
-
# samples = []
|
|
789
|
-
# for i in reversed(range(self.n_layers)):
|
|
790
|
-
# bu_value = bu_values[i]
|
|
791
|
-
|
|
792
|
-
# # Note that the first argument can be set to None since we are just dealing with one level
|
|
793
|
-
# sample = top_down_layers[i].sample_from_q(None, bu_value, var_clip_max=self._var_clip_max, mask=masks[i])
|
|
794
|
-
# samples.append(sample)
|
|
795
|
-
|
|
796
|
-
# return samples
|
|
797
|
-
|
|
798
|
-
def reset_for_different_output_size(self, output_size: int) -> None:
|
|
799
|
-
"""Reset shape of output and latent tensors for different output size.
|
|
800
|
-
|
|
801
|
-
Used during evaluation to reset expected shapes of tensors when
|
|
802
|
-
input/output shape changes.
|
|
803
|
-
For instance, it is needed when the model was trained on, say, 64x64 sized
|
|
804
|
-
patches, but prediction is done on 128x128 patches.
|
|
805
|
-
"""
|
|
806
|
-
for i in range(self.n_layers):
|
|
807
|
-
sz = output_size // 2 ** (1 + i)
|
|
808
|
-
self.bottom_up_layers[i].output_expected_shape = (sz, sz)
|
|
809
|
-
self.top_down_layers[i].latent_shape = (output_size, output_size)
|
|
810
|
-
|
|
811
|
-
def pad_input(self, x):
|
|
812
|
-
"""
|
|
813
|
-
Pads input x so that its sizes are powers of 2
|
|
814
|
-
:param x:
|
|
815
|
-
:return: Padded tensor
|
|
816
|
-
"""
|
|
817
|
-
size = self.get_padded_size(x.size())
|
|
818
|
-
x = pad_img_tensor(x, size)
|
|
819
|
-
return x
|
|
820
|
-
|
|
821
782
|
### SET OF GETTERS
|
|
822
783
|
def get_padded_size(self, size):
|
|
823
784
|
"""
|
|
824
785
|
Returns the smallest size (H, W) of the image with actual size given
|
|
825
786
|
as input, such that H and W are powers of 2.
|
|
826
|
-
:param size: input size, tuple either (N, C, H,
|
|
787
|
+
:param size: input size, tuple either (N, C, H, W) or (H, W)
|
|
827
788
|
:return: 2-tuple (H, W)
|
|
828
789
|
"""
|
|
829
790
|
# Make size argument into (heigth, width)
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
msg = (
|
|
834
|
-
"input size must be either (N, C, H, W) or (H, W), but it "
|
|
835
|
-
f"has length {len(size)} (size={size})"
|
|
836
|
-
)
|
|
837
|
-
raise RuntimeError(msg)
|
|
791
|
+
# assert len(size) in [2, 4, 5] # TODO commented out cuz it's weird
|
|
792
|
+
# We're only interested in the Y,X dimensions
|
|
793
|
+
size = size[-2:]
|
|
838
794
|
|
|
839
795
|
if self.multiscale_decoder_retain_spatial_dims is True:
|
|
840
796
|
# In this case, we can go much more deeper and so this is not required
|
|
@@ -845,24 +801,21 @@ class LadderVAE(nn.Module):
|
|
|
845
801
|
dwnsc = self.overall_downscale_factor
|
|
846
802
|
|
|
847
803
|
# Output smallest powers of 2 that are larger than current sizes
|
|
848
|
-
padded_size =
|
|
849
|
-
|
|
804
|
+
padded_size = [((s - 1) // dwnsc + 1) * dwnsc for s in size]
|
|
805
|
+
# TODO Needed for pad/crop odd sizes. Move to dataset?
|
|
850
806
|
return padded_size
|
|
851
807
|
|
|
852
808
|
def get_latent_spatial_size(self, level_idx: int):
|
|
853
|
-
"""
|
|
854
|
-
level_idx: 0 is the bottommost layer, the highest resolution one.
|
|
855
|
-
"""
|
|
809
|
+
"""Level_idx: 0 is the bottommost layer, the highest resolution one."""
|
|
856
810
|
actual_downsampling = level_idx + 1
|
|
857
811
|
dwnsc = 2**actual_downsampling
|
|
858
|
-
sz = self.get_padded_size(self.
|
|
812
|
+
sz = self.get_padded_size(self.image_size)
|
|
859
813
|
h = sz[0] // dwnsc
|
|
860
814
|
w = sz[1] // dwnsc
|
|
861
815
|
assert h == w
|
|
862
816
|
return h
|
|
863
817
|
|
|
864
818
|
def get_top_prior_param_shape(self, n_imgs: int = 1):
|
|
865
|
-
# TODO num channels depends on random variable we're using
|
|
866
819
|
|
|
867
820
|
# Compute the total downscaling performed in the Encoder
|
|
868
821
|
if self.multiscale_decoder_retain_spatial_dims is False:
|
|
@@ -872,26 +825,12 @@ class LadderVAE(nn.Module):
|
|
|
872
825
|
actual_downsampling = self.n_layers + 1 - self._multiscale_count
|
|
873
826
|
dwnsc = 2**actual_downsampling
|
|
874
827
|
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
828
|
+
h = self.image_size[-2] // dwnsc
|
|
829
|
+
w = self.image_size[-1] // dwnsc
|
|
830
|
+
mu_logvar = self.z_dims[-1] * 2 # mu and logvar
|
|
831
|
+
top_layer_shape = (n_imgs, mu_logvar, h, w)
|
|
832
|
+
# TODO refactor!
|
|
833
|
+
if self._model_3D_depth > 1 and self._decoder_mode_3D is True:
|
|
834
|
+
# TODO check if model_3D_depth is needed ?
|
|
835
|
+
top_layer_shape = (n_imgs, mu_logvar, self._model_3D_depth, h, w)
|
|
880
836
|
return top_layer_shape
|
|
881
|
-
|
|
882
|
-
def get_other_channel(self, ch1, input):
|
|
883
|
-
assert self.data_std["target"].squeeze().shape == (2,)
|
|
884
|
-
assert self.data_mean["target"].squeeze().shape == (2,)
|
|
885
|
-
assert self.target_ch == 2
|
|
886
|
-
ch1_un = (
|
|
887
|
-
ch1[:, :1] * self.data_std["target"][:, :1]
|
|
888
|
-
+ self.data_mean["target"][:, :1]
|
|
889
|
-
)
|
|
890
|
-
input_un = input * self.data_std["input"] + self.data_mean["input"]
|
|
891
|
-
ch2_un = self._tethered_ch2_scalar * (
|
|
892
|
-
input_un - ch1_un * self._tethered_ch1_scalar
|
|
893
|
-
)
|
|
894
|
-
ch2 = (ch2_un - self.data_mean["target"][:, -1:]) / self.data_std["target"][
|
|
895
|
-
:, -1:
|
|
896
|
-
]
|
|
897
|
-
return ch2
|