careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +163 -266
- careamics/config/algorithm_model.py +0 -15
- careamics/config/architectures/custom_model.py +3 -3
- careamics/config/configuration_example.py +0 -3
- careamics/config/configuration_factory.py +23 -25
- careamics/config/configuration_model.py +11 -11
- careamics/config/data_model.py +80 -50
- careamics/config/inference_model.py +29 -17
- careamics/config/optimizer_models.py +7 -7
- careamics/config/support/supported_transforms.py +0 -1
- careamics/config/tile_information.py +26 -58
- careamics/config/transformations/normalize_model.py +32 -4
- careamics/config/validators/validator_utils.py +1 -1
- careamics/dataset/__init__.py +12 -1
- careamics/dataset/dataset_utils/__init__.py +8 -1
- careamics/dataset/dataset_utils/file_utils.py +1 -1
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/read_tiff.py +0 -9
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +66 -171
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +92 -249
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
- careamics/dataset/patching/patching.py +54 -25
- careamics/dataset/patching/random_patching.py +9 -4
- careamics/dataset/patching/validate_patch_dimension.py +5 -3
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
- careamics/lightning_datamodule.py +1 -6
- careamics/lightning_module.py +11 -7
- careamics/lightning_prediction_datamodule.py +52 -72
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/bioimage/model_description.py +40 -32
- careamics/model_io/bmz_io.py +1 -1
- careamics/model_io/model_io_utils.py +5 -2
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/prediction_utils/__init__.py +12 -0
- careamics/prediction_utils/create_pred_datamodule.py +185 -0
- careamics/prediction_utils/prediction_outputs.py +165 -0
- careamics/prediction_utils/stitch_prediction.py +100 -0
- careamics/transforms/n2v_manipulate.py +3 -1
- careamics/transforms/normalize.py +139 -68
- careamics/transforms/pixel_manipulation.py +33 -9
- careamics/transforms/tta.py +43 -29
- careamics/utils/ram.py +2 -2
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +7 -6
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/RECORD +65 -42
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
- careamics/lightning_prediction_loop.py +0 -118
- careamics/prediction/__init__.py +0 -7
- careamics/prediction/stitch_prediction.py +0 -70
- careamics/utils/running_stats.py +0 -43
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,985 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Ladder VAE (LVAE) Model
|
|
3
|
+
|
|
4
|
+
The current implementation is based on "Interpretable Unsupervised Diversity Denoising and Artefact Removal, Prakash et al."
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Dict, Iterable, List, Tuple, Union
|
|
8
|
+
|
|
9
|
+
import ml_collections
|
|
10
|
+
import numpy as np
|
|
11
|
+
import torch
|
|
12
|
+
import torch.nn as nn
|
|
13
|
+
|
|
14
|
+
from .layers import (
|
|
15
|
+
BottomUpDeterministicResBlock,
|
|
16
|
+
BottomUpLayer,
|
|
17
|
+
TopDownDeterministicResBlock,
|
|
18
|
+
TopDownLayer,
|
|
19
|
+
)
|
|
20
|
+
from .likelihoods import GaussianLikelihood, NoiseModelLikelihood
|
|
21
|
+
from .noise_models import get_noise_model
|
|
22
|
+
from .utils import Interpolate, LossType, ModelType, crop_img_tensor, pad_img_tensor
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class LadderVAE(nn.Module):
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
data_mean: Union[np.ndarray, Dict[str, torch.Tensor]],
|
|
30
|
+
data_std: Union[np.ndarray, Dict[str, torch.Tensor]],
|
|
31
|
+
config: ml_collections.ConfigDict,
|
|
32
|
+
use_uncond_mode_at: Iterable[int] = [],
|
|
33
|
+
target_ch: int = 2,
|
|
34
|
+
):
|
|
35
|
+
"""
|
|
36
|
+
Constructor.
|
|
37
|
+
|
|
38
|
+
Parameters
|
|
39
|
+
----------
|
|
40
|
+
data_mean: Union[np.ndarray, Dict[str, torch.Tensor]]
|
|
41
|
+
The mean of the data used for normalization.
|
|
42
|
+
data_std: Union[np.ndarray, Dict[str, torch.Tensor]]
|
|
43
|
+
The standard deviation of the data used for normalization.
|
|
44
|
+
config: ml_collections.ConfigDict
|
|
45
|
+
The configuration object of the model.
|
|
46
|
+
use_uncond_mode_at: Iterable[int], optional
|
|
47
|
+
A sequence of indexes associated to the layers in which sampling is disabled
|
|
48
|
+
and the mode (mean value) is used instead. Default is `[]`.
|
|
49
|
+
target_ch: int, optional
|
|
50
|
+
The number of target channels (e.g., 1 for super-resolution or 2 for splitting).
|
|
51
|
+
Default is `2`.
|
|
52
|
+
"""
|
|
53
|
+
super().__init__()
|
|
54
|
+
|
|
55
|
+
# -------------------------------------------------------
|
|
56
|
+
# Customizable attributes
|
|
57
|
+
self.image_size = config.data.image_size
|
|
58
|
+
self._multiscale_count = config.data.multiscale_lowres_count
|
|
59
|
+
self.z_dims = config.model.z_dims
|
|
60
|
+
self.encoder_n_filters = config.model.n_filters
|
|
61
|
+
self.decoder_n_filters = config.model.n_filters
|
|
62
|
+
self.encoder_dropout = config.model.dropout
|
|
63
|
+
self.decoder_dropout = config.model.dropout
|
|
64
|
+
self.nonlin = config.model.nonlin
|
|
65
|
+
self.predict_logvar = config.model.predict_logvar
|
|
66
|
+
self.enable_noise_model = config.model.enable_noise_model
|
|
67
|
+
self.noise_model_ch1_fpath = config.model.noise_model_ch1_fpath
|
|
68
|
+
self.noise_model_ch2_fpath = config.model.noise_model_ch2_fpath
|
|
69
|
+
self.analytical_kl = config.model.analytical_kl
|
|
70
|
+
# -------------------------------------------------------
|
|
71
|
+
|
|
72
|
+
# -------------------------------------------------------
|
|
73
|
+
# Model attributes -> Hardcoded
|
|
74
|
+
self.model_type = ModelType.LadderVae
|
|
75
|
+
self.encoder_blocks_per_layer = 1
|
|
76
|
+
self.decoder_blocks_per_layer = 1
|
|
77
|
+
self.bottomup_batchnorm = True
|
|
78
|
+
self.topdown_batchnorm = True
|
|
79
|
+
self.topdown_conv2d_bias = True
|
|
80
|
+
self.gated = True
|
|
81
|
+
self.encoder_res_block_kernel = 3
|
|
82
|
+
self.decoder_res_block_kernel = 3
|
|
83
|
+
self.encoder_res_block_skip_padding = False
|
|
84
|
+
self.decoder_res_block_skip_padding = False
|
|
85
|
+
self.merge_type = "residual"
|
|
86
|
+
self.no_initial_downscaling = True
|
|
87
|
+
self.skip_bottomk_buvalues = 0
|
|
88
|
+
self.non_stochastic_version = False
|
|
89
|
+
self.stochastic_skip = True
|
|
90
|
+
self.learn_top_prior = True
|
|
91
|
+
self.res_block_type = "bacdbacd"
|
|
92
|
+
self.mode_pred = False
|
|
93
|
+
self.logvar_lowerbound = -5
|
|
94
|
+
self._var_clip_max = 20
|
|
95
|
+
self._stochastic_use_naive_exponential = False
|
|
96
|
+
self._enable_topdown_normalize_factor = True
|
|
97
|
+
|
|
98
|
+
# Noise model attributes -> Hardcoded
|
|
99
|
+
self.noise_model_type = "gmm"
|
|
100
|
+
self.denoise_channel = (
|
|
101
|
+
"input" # 4 values for denoise_channel {'Ch1', 'Ch2', 'input','all'}
|
|
102
|
+
)
|
|
103
|
+
self.noise_model_learnable = False
|
|
104
|
+
|
|
105
|
+
# Attributes that handle LC -> Hardcoded
|
|
106
|
+
self.enable_multiscale = (
|
|
107
|
+
self._multiscale_count is not None and self._multiscale_count > 1
|
|
108
|
+
)
|
|
109
|
+
self.multiscale_retain_spatial_dims = True
|
|
110
|
+
self.multiscale_lowres_separate_branch = False
|
|
111
|
+
self.multiscale_decoder_retain_spatial_dims = (
|
|
112
|
+
self.multiscale_retain_spatial_dims and self.enable_multiscale
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
# Derived attributes
|
|
116
|
+
self.n_layers = len(self.z_dims)
|
|
117
|
+
self.encoder_no_padding_mode = (
|
|
118
|
+
self.encoder_res_block_skip_padding is True
|
|
119
|
+
and self.encoder_res_block_kernel > 1
|
|
120
|
+
)
|
|
121
|
+
self.decoder_no_padding_mode = (
|
|
122
|
+
self.decoder_res_block_skip_padding is True
|
|
123
|
+
and self.decoder_res_block_kernel > 1
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# Others...
|
|
127
|
+
self._tethered_to_input = False
|
|
128
|
+
self._tethered_ch1_scalar = self._tethered_ch2_scalar = None
|
|
129
|
+
if self._tethered_to_input:
|
|
130
|
+
target_ch = 1
|
|
131
|
+
requires_grad = False
|
|
132
|
+
self._tethered_ch1_scalar = nn.Parameter(
|
|
133
|
+
torch.ones(1) * 0.5, requires_grad=requires_grad
|
|
134
|
+
)
|
|
135
|
+
self._tethered_ch2_scalar = nn.Parameter(
|
|
136
|
+
torch.ones(1) * 2.0, requires_grad=requires_grad
|
|
137
|
+
)
|
|
138
|
+
# -------------------------------------------------------
|
|
139
|
+
|
|
140
|
+
# -------------------------------------------------------
|
|
141
|
+
# Data attributes
|
|
142
|
+
self.color_ch = 1
|
|
143
|
+
self.img_shape = (self.image_size, self.image_size)
|
|
144
|
+
self.normalized_input = True
|
|
145
|
+
# -------------------------------------------------------
|
|
146
|
+
|
|
147
|
+
# -------------------------------------------------------
|
|
148
|
+
# Loss attributes
|
|
149
|
+
self._restricted_kl = False # HC
|
|
150
|
+
# enabling reconstruction loss on mixed input
|
|
151
|
+
self.mixed_rec_w = 0
|
|
152
|
+
self.nbr_consistency_w = 0
|
|
153
|
+
|
|
154
|
+
# Setting the loss_type
|
|
155
|
+
self.loss_type = config.loss.get("loss_type", LossType.DenoiSplitMuSplit)
|
|
156
|
+
# -------------------------------------------------------
|
|
157
|
+
|
|
158
|
+
# -------------------------------------------------------
|
|
159
|
+
# # Training attributes
|
|
160
|
+
# # can be used to tile the validation predictions
|
|
161
|
+
# self._val_idx_manager = val_idx_manager
|
|
162
|
+
# self._val_frame_creator = None
|
|
163
|
+
# # initialize the learning rate scheduler params.
|
|
164
|
+
# self.lr_scheduler_monitor = self.lr_scheduler_mode = None
|
|
165
|
+
# self._init_lr_scheduler_params(config)
|
|
166
|
+
# self._global_step = 0
|
|
167
|
+
# -------------------------------------------------------
|
|
168
|
+
|
|
169
|
+
# -------------------------------------------------------
|
|
170
|
+
# Attributes from constructor arguments
|
|
171
|
+
self.target_ch = target_ch
|
|
172
|
+
self.use_uncond_mode_at = use_uncond_mode_at
|
|
173
|
+
|
|
174
|
+
# Data mean and std used for normalization
|
|
175
|
+
if isinstance(data_mean, np.ndarray):
|
|
176
|
+
self.data_mean = torch.Tensor(data_mean)
|
|
177
|
+
self.data_std = torch.Tensor(data_std)
|
|
178
|
+
elif isinstance(data_mean, dict):
|
|
179
|
+
for k in data_mean.keys():
|
|
180
|
+
data_mean[k] = (
|
|
181
|
+
torch.Tensor(data_mean[k])
|
|
182
|
+
if not isinstance(data_mean[k], dict)
|
|
183
|
+
else data_mean[k]
|
|
184
|
+
)
|
|
185
|
+
data_std[k] = (
|
|
186
|
+
torch.Tensor(data_std[k])
|
|
187
|
+
if not isinstance(data_std[k], dict)
|
|
188
|
+
else data_std[k]
|
|
189
|
+
)
|
|
190
|
+
self.data_mean = data_mean
|
|
191
|
+
self.data_std = data_std
|
|
192
|
+
else:
|
|
193
|
+
raise NotImplementedError(
|
|
194
|
+
"data_mean and data_std must be either a numpy array or a dictionary"
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
assert self.data_std is not None
|
|
198
|
+
assert self.data_mean is not None
|
|
199
|
+
|
|
200
|
+
# Initialize the Noise Model
|
|
201
|
+
self.likelihood_gm = self.likelihood_NM = None
|
|
202
|
+
self.noiseModel = get_noise_model(
|
|
203
|
+
enable_noise_model=self.enable_noise_model,
|
|
204
|
+
model_type=self.model_type,
|
|
205
|
+
noise_model_type=self.noise_model_type,
|
|
206
|
+
noise_model_ch1_fpath=self.noise_model_ch1_fpath,
|
|
207
|
+
noise_model_ch2_fpath=self.noise_model_ch2_fpath,
|
|
208
|
+
noise_model_learnable=self.noise_model_learnable,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
if self.noiseModel is None:
|
|
212
|
+
self.likelihood_form = "gaussian"
|
|
213
|
+
else:
|
|
214
|
+
self.likelihood_form = "noise_model"
|
|
215
|
+
|
|
216
|
+
# Calculate the downsampling happening in the network
|
|
217
|
+
self.downsample = [1] * self.n_layers
|
|
218
|
+
self.overall_downscale_factor = np.power(2, sum(self.downsample))
|
|
219
|
+
if not self.no_initial_downscaling: # by default do another downscaling
|
|
220
|
+
self.overall_downscale_factor *= 2
|
|
221
|
+
|
|
222
|
+
assert max(self.downsample) <= self.encoder_blocks_per_layer
|
|
223
|
+
assert len(self.downsample) == self.n_layers
|
|
224
|
+
# -------------------------------------------------------
|
|
225
|
+
|
|
226
|
+
# -------------------------------------------------------
|
|
227
|
+
### CREATE MODEL BLOCKS
|
|
228
|
+
# First bottom-up layer: change num channels + downsample by factor 2
|
|
229
|
+
# unless we want to prevent this
|
|
230
|
+
stride = 1 if self.no_initial_downscaling else 2
|
|
231
|
+
self.first_bottom_up = self.create_first_bottom_up(stride)
|
|
232
|
+
|
|
233
|
+
# Input Branches for Lateral Contextualization
|
|
234
|
+
self.lowres_first_bottom_ups = None
|
|
235
|
+
self._init_multires()
|
|
236
|
+
|
|
237
|
+
# Other bottom-up layers
|
|
238
|
+
self.bottom_up_layers = self.create_bottom_up_layers(
|
|
239
|
+
self.multiscale_lowres_separate_branch
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
# Top-down layers
|
|
243
|
+
self.top_down_layers = self.create_top_down_layers()
|
|
244
|
+
self.final_top_down = self.create_final_topdown_layer(
|
|
245
|
+
not self.no_initial_downscaling
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
# Likelihood module
|
|
249
|
+
self.likelihood = self.create_likelihood_module()
|
|
250
|
+
|
|
251
|
+
# Output layer --> Project to target_ch many channels
|
|
252
|
+
logvar_ch_needed = self.predict_logvar is not None
|
|
253
|
+
self.output_layer = self.parameter_net = nn.Conv2d(
|
|
254
|
+
self.decoder_n_filters,
|
|
255
|
+
self.target_ch * (1 + logvar_ch_needed),
|
|
256
|
+
kernel_size=3,
|
|
257
|
+
padding=1,
|
|
258
|
+
bias=self.topdown_conv2d_bias,
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
# # gradient norms. updated while training. this is also logged.
|
|
262
|
+
# self.grad_norm_bottom_up = 0.0
|
|
263
|
+
# self.grad_norm_top_down = 0.0
|
|
264
|
+
# PSNR computation on validation.
|
|
265
|
+
# self.label1_psnr = RunningPSNR()
|
|
266
|
+
# self.label2_psnr = RunningPSNR()
|
|
267
|
+
|
|
268
|
+
# msg =f'[{self.__class__.__name__}] Stoc:{not self.non_stochastic_version} RecMode:{self.reconstruction_mode} TethInput:{self._tethered_to_input}'
|
|
269
|
+
# msg += f' TargetCh: {self.target_ch}'
|
|
270
|
+
# print(msg)
|
|
271
|
+
|
|
272
|
+
### SET OF METHODS TO CREATE MODEL BLOCKS
|
|
273
|
+
def create_first_bottom_up(
|
|
274
|
+
self,
|
|
275
|
+
init_stride: int,
|
|
276
|
+
num_res_blocks: int = 1,
|
|
277
|
+
) -> nn.Sequential:
|
|
278
|
+
"""
|
|
279
|
+
This method creates the first bottom-up block of the Encoder.
|
|
280
|
+
Its role is to perform a first image compression step.
|
|
281
|
+
It is composed by a sequence of nn.Conv2d + non-linearity +
|
|
282
|
+
BottomUpDeterministicResBlock (1 or more, default is 1).
|
|
283
|
+
|
|
284
|
+
Parameters
|
|
285
|
+
----------
|
|
286
|
+
init_stride: int
|
|
287
|
+
The stride used by the intial Conv2d block.
|
|
288
|
+
num_res_blocks: int, optional
|
|
289
|
+
The number of BottomUpDeterministicResBlocks to include in the layer, default is 1.
|
|
290
|
+
"""
|
|
291
|
+
nonlin = self.get_nonlin()
|
|
292
|
+
modules = [
|
|
293
|
+
nn.Conv2d(
|
|
294
|
+
in_channels=self.color_ch,
|
|
295
|
+
out_channels=self.encoder_n_filters,
|
|
296
|
+
kernel_size=self.encoder_res_block_kernel,
|
|
297
|
+
padding=(
|
|
298
|
+
0
|
|
299
|
+
if self.encoder_res_block_skip_padding
|
|
300
|
+
else self.encoder_res_block_kernel // 2
|
|
301
|
+
),
|
|
302
|
+
stride=init_stride,
|
|
303
|
+
),
|
|
304
|
+
nonlin(),
|
|
305
|
+
]
|
|
306
|
+
|
|
307
|
+
for _ in range(num_res_blocks):
|
|
308
|
+
modules.append(
|
|
309
|
+
BottomUpDeterministicResBlock(
|
|
310
|
+
c_in=self.encoder_n_filters,
|
|
311
|
+
c_out=self.encoder_n_filters,
|
|
312
|
+
nonlin=nonlin,
|
|
313
|
+
downsample=False,
|
|
314
|
+
batchnorm=self.bottomup_batchnorm,
|
|
315
|
+
dropout=self.encoder_dropout,
|
|
316
|
+
res_block_type=self.res_block_type,
|
|
317
|
+
skip_padding=self.encoder_res_block_skip_padding,
|
|
318
|
+
res_block_kernel=self.encoder_res_block_kernel,
|
|
319
|
+
)
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
return nn.Sequential(*modules)
|
|
323
|
+
|
|
324
|
+
def create_bottom_up_layers(self, lowres_separate_branch: bool) -> nn.ModuleList:
|
|
325
|
+
"""
|
|
326
|
+
This method creates the stack of bottom-up layers of the Encoder
|
|
327
|
+
that are used to generate the so-called `bu_values`.
|
|
328
|
+
|
|
329
|
+
NOTE:
|
|
330
|
+
If `self._multiscale_count < self.n_layers`, then LC is done only in the first
|
|
331
|
+
`self._multiscale_count` bottom-up layers (starting from the bottom).
|
|
332
|
+
|
|
333
|
+
Parameters
|
|
334
|
+
----------
|
|
335
|
+
lowres_separate_branch: bool
|
|
336
|
+
Whether the residual block(s) used for encoding the low-res input are shared (`False`) or
|
|
337
|
+
not (`True`) with the "same-size" residual block(s) in the `BottomUpLayer`'s primary flow.
|
|
338
|
+
"""
|
|
339
|
+
multiscale_lowres_size_factor = 1
|
|
340
|
+
nonlin = self.get_nonlin()
|
|
341
|
+
|
|
342
|
+
bottom_up_layers = nn.ModuleList([])
|
|
343
|
+
for i in range(self.n_layers):
|
|
344
|
+
# Whether this is the top layer
|
|
345
|
+
is_top = i == self.n_layers - 1
|
|
346
|
+
|
|
347
|
+
# LC is applied only to the first (_multiscale_count - 1) bottom-up layers
|
|
348
|
+
layer_enable_multiscale = (
|
|
349
|
+
self.enable_multiscale and self._multiscale_count > i + 1
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
# This factor determines the factor by which the low-resolution tensor is larger
|
|
353
|
+
# N.B. Only used if layer_enable_multiscale == True, so we updated it only in that case
|
|
354
|
+
multiscale_lowres_size_factor *= 1 + int(layer_enable_multiscale)
|
|
355
|
+
|
|
356
|
+
output_expected_shape = (
|
|
357
|
+
(self.img_shape[0] // 2 ** (i + 1), self.img_shape[1] // 2 ** (i + 1))
|
|
358
|
+
if self._multiscale_count > 1
|
|
359
|
+
else None
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
# Add bottom-up deterministic layer at level i.
|
|
363
|
+
# It's a sequence of residual blocks (BottomUpDeterministicResBlock), possibly with downsampling between them.
|
|
364
|
+
bottom_up_layers.append(
|
|
365
|
+
BottomUpLayer(
|
|
366
|
+
n_res_blocks=self.encoder_blocks_per_layer,
|
|
367
|
+
n_filters=self.encoder_n_filters,
|
|
368
|
+
downsampling_steps=self.downsample[i],
|
|
369
|
+
nonlin=nonlin,
|
|
370
|
+
batchnorm=self.bottomup_batchnorm,
|
|
371
|
+
dropout=self.encoder_dropout,
|
|
372
|
+
res_block_type=self.res_block_type,
|
|
373
|
+
res_block_kernel=self.encoder_res_block_kernel,
|
|
374
|
+
res_block_skip_padding=self.encoder_res_block_skip_padding,
|
|
375
|
+
gated=self.gated,
|
|
376
|
+
lowres_separate_branch=lowres_separate_branch,
|
|
377
|
+
enable_multiscale=self.enable_multiscale, # shouldn't the arg be `layer_enable_multiscale` here?
|
|
378
|
+
multiscale_retain_spatial_dims=self.multiscale_retain_spatial_dims,
|
|
379
|
+
multiscale_lowres_size_factor=multiscale_lowres_size_factor,
|
|
380
|
+
decoder_retain_spatial_dims=self.multiscale_decoder_retain_spatial_dims,
|
|
381
|
+
output_expected_shape=output_expected_shape,
|
|
382
|
+
)
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
return bottom_up_layers
|
|
386
|
+
|
|
387
|
+
def create_top_down_layers(self) -> nn.ModuleList:
|
|
388
|
+
"""
|
|
389
|
+
This method creates the stack of top-down layers of the Decoder.
|
|
390
|
+
In these layer the `bu`_values` from the Encoder are merged with the `p_params` from the previous layer
|
|
391
|
+
of the Decoder to get `q_params`. Then, a stochastic layer generates a sample from the latent distribution
|
|
392
|
+
with parameters `q_params`. Finally, this sample is fed through a TopDownDeterministicResBlock to
|
|
393
|
+
compute the `p_params` for the layer below.
|
|
394
|
+
|
|
395
|
+
NOTE 1:
|
|
396
|
+
The algorithm for generative inference approximately works as follows:
|
|
397
|
+
- p_params = output of top-down layer above
|
|
398
|
+
- bu = inferred bottom-up value at this layer
|
|
399
|
+
- q_params = merge(bu, p_params)
|
|
400
|
+
- z = stochastic_layer(q_params)
|
|
401
|
+
- (optional) get and merge skip connection from prev top-down layer
|
|
402
|
+
- top-down deterministic ResNet
|
|
403
|
+
|
|
404
|
+
NOTE 2:
|
|
405
|
+
When doing unconditional generation, bu_value is not available. Hence the
|
|
406
|
+
merge layer is not used, and z is sampled directly from p_params.
|
|
407
|
+
|
|
408
|
+
Parameters
|
|
409
|
+
----------
|
|
410
|
+
"""
|
|
411
|
+
top_down_layers = nn.ModuleList([])
|
|
412
|
+
nonlin = self.get_nonlin()
|
|
413
|
+
# NOTE: top-down layers are created starting from the bottom-most
|
|
414
|
+
for i in range(self.n_layers):
|
|
415
|
+
# Check if this is the top layer
|
|
416
|
+
is_top = i == self.n_layers - 1
|
|
417
|
+
|
|
418
|
+
if self._enable_topdown_normalize_factor:
|
|
419
|
+
normalize_latent_factor = (
|
|
420
|
+
1 / np.sqrt(2 * (1 + i)) if len(self.z_dims) > 4 else 1.0
|
|
421
|
+
)
|
|
422
|
+
else:
|
|
423
|
+
normalize_latent_factor = 1.0
|
|
424
|
+
|
|
425
|
+
top_down_layers.append(
|
|
426
|
+
TopDownLayer(
|
|
427
|
+
z_dim=self.z_dims[i],
|
|
428
|
+
n_res_blocks=self.decoder_blocks_per_layer,
|
|
429
|
+
n_filters=self.decoder_n_filters,
|
|
430
|
+
is_top_layer=is_top,
|
|
431
|
+
downsampling_steps=self.downsample[i],
|
|
432
|
+
nonlin=nonlin,
|
|
433
|
+
merge_type=self.merge_type,
|
|
434
|
+
batchnorm=self.topdown_batchnorm,
|
|
435
|
+
dropout=self.decoder_dropout,
|
|
436
|
+
stochastic_skip=self.stochastic_skip,
|
|
437
|
+
learn_top_prior=self.learn_top_prior,
|
|
438
|
+
top_prior_param_shape=self.get_top_prior_param_shape(),
|
|
439
|
+
res_block_type=self.res_block_type,
|
|
440
|
+
res_block_kernel=self.decoder_res_block_kernel,
|
|
441
|
+
res_block_skip_padding=self.decoder_res_block_skip_padding,
|
|
442
|
+
gated=self.gated,
|
|
443
|
+
analytical_kl=self.analytical_kl,
|
|
444
|
+
restricted_kl=self._restricted_kl,
|
|
445
|
+
vanilla_latent_hw=self.get_latent_spatial_size(i),
|
|
446
|
+
# in no_padding_mode, what gets passed from the encoder are not multiples of 2 and so merging operation does not work natively.
|
|
447
|
+
bottomup_no_padding_mode=self.encoder_no_padding_mode,
|
|
448
|
+
topdown_no_padding_mode=self.decoder_no_padding_mode,
|
|
449
|
+
retain_spatial_dims=self.multiscale_decoder_retain_spatial_dims,
|
|
450
|
+
non_stochastic_version=self.non_stochastic_version,
|
|
451
|
+
input_image_shape=self.img_shape,
|
|
452
|
+
normalize_latent_factor=normalize_latent_factor,
|
|
453
|
+
conv2d_bias=self.topdown_conv2d_bias,
|
|
454
|
+
stochastic_use_naive_exponential=self._stochastic_use_naive_exponential,
|
|
455
|
+
)
|
|
456
|
+
)
|
|
457
|
+
return top_down_layers
|
|
458
|
+
|
|
459
|
+
def create_final_topdown_layer(self, upsample: bool) -> nn.Sequential:
|
|
460
|
+
"""
|
|
461
|
+
This method creates the final top-down layer of the Decoder.
|
|
462
|
+
|
|
463
|
+
Parameters
|
|
464
|
+
----------
|
|
465
|
+
upsample: bool
|
|
466
|
+
Whether to upsample the input of the final top-down layer
|
|
467
|
+
by bilinear interpolation with `scale_factor=2`.
|
|
468
|
+
"""
|
|
469
|
+
# Final top-down layer
|
|
470
|
+
modules = list()
|
|
471
|
+
|
|
472
|
+
if upsample:
|
|
473
|
+
modules.append(Interpolate(scale=2))
|
|
474
|
+
|
|
475
|
+
for i in range(self.decoder_blocks_per_layer):
|
|
476
|
+
modules.append(
|
|
477
|
+
TopDownDeterministicResBlock(
|
|
478
|
+
c_in=self.decoder_n_filters,
|
|
479
|
+
c_out=self.decoder_n_filters,
|
|
480
|
+
nonlin=self.get_nonlin(),
|
|
481
|
+
batchnorm=self.topdown_batchnorm,
|
|
482
|
+
dropout=self.decoder_dropout,
|
|
483
|
+
res_block_type=self.res_block_type,
|
|
484
|
+
res_block_kernel=self.decoder_res_block_kernel,
|
|
485
|
+
skip_padding=self.decoder_res_block_skip_padding,
|
|
486
|
+
gated=self.gated,
|
|
487
|
+
conv2d_bias=self.topdown_conv2d_bias,
|
|
488
|
+
)
|
|
489
|
+
)
|
|
490
|
+
return nn.Sequential(*modules)
|
|
491
|
+
|
|
492
|
+
def create_likelihood_module(self):
|
|
493
|
+
"""
|
|
494
|
+
This method defines the likelihood module for the current LVAE model.
|
|
495
|
+
The existing likelihood modules are `GaussianLikelihood` and `NoiseModelLikelihood`.
|
|
496
|
+
"""
|
|
497
|
+
self.likelihood_gm = GaussianLikelihood(
|
|
498
|
+
self.decoder_n_filters,
|
|
499
|
+
self.target_ch,
|
|
500
|
+
predict_logvar=self.predict_logvar,
|
|
501
|
+
logvar_lowerbound=self.logvar_lowerbound,
|
|
502
|
+
conv2d_bias=self.topdown_conv2d_bias,
|
|
503
|
+
)
|
|
504
|
+
|
|
505
|
+
self.likelihood_NM = None
|
|
506
|
+
if self.enable_noise_model:
|
|
507
|
+
self.likelihood_NM = NoiseModelLikelihood(
|
|
508
|
+
self.decoder_n_filters,
|
|
509
|
+
self.target_ch,
|
|
510
|
+
self.data_mean,
|
|
511
|
+
self.data_std,
|
|
512
|
+
self.noiseModel,
|
|
513
|
+
)
|
|
514
|
+
if self.loss_type == LossType.DenoiSplitMuSplit or self.likelihood_NM is None:
|
|
515
|
+
return self.likelihood_gm
|
|
516
|
+
|
|
517
|
+
return self.likelihood_NM
|
|
518
|
+
|
|
519
|
+
def _init_multires(self, config: ml_collections.ConfigDict = None) -> nn.ModuleList:
|
|
520
|
+
"""
|
|
521
|
+
This method defines the input block/branch to encode/compress low-res lateral inputs at different hierarchical levels
|
|
522
|
+
in the multiresolution approach (LC). The role of the input branches is similar to the one of the first bottom-up layer
|
|
523
|
+
in the primary flow of the Encoder, namely to compress the lateral input image to a degree that is compatible with the
|
|
524
|
+
one of the primary flow.
|
|
525
|
+
|
|
526
|
+
NOTE 1: Each input branch consists of a sequence of Conv2d + non-linearity + BottomUpDeterministicResBlock.
|
|
527
|
+
It is meaningful to observe that the `BottomUpDeterministicResBlock` shares the same model attributes with the blocks
|
|
528
|
+
in the primary flow of the Encoder (e.g., c_in, c_out, dropout, etc. etc.). Moreover, it does not perform downsampling.
|
|
529
|
+
|
|
530
|
+
NOTE 2: `_multiscale_count` attribute defines the total number of inputs to the bottom-up pass.
|
|
531
|
+
In other terms if we have the input patch and n_LC additional lateral inputs, we will have a total of (n_LC + 1) inputs.
|
|
532
|
+
"""
|
|
533
|
+
stride = 1 if self.no_initial_downscaling else 2
|
|
534
|
+
nonlin = self.get_nonlin()
|
|
535
|
+
if self._multiscale_count is None:
|
|
536
|
+
self._multiscale_count = 1
|
|
537
|
+
|
|
538
|
+
msg = "Multiscale count({}) should not exceed the number of bottom up layers ({}) by more than 1"
|
|
539
|
+
msg = msg.format(self._multiscale_count, self.n_layers)
|
|
540
|
+
assert (
|
|
541
|
+
self._multiscale_count <= 1 or self._multiscale_count <= 1 + self.n_layers
|
|
542
|
+
), msg
|
|
543
|
+
|
|
544
|
+
msg = (
|
|
545
|
+
"if multiscale is enabled, then we are just working with monocrome images."
|
|
546
|
+
)
|
|
547
|
+
assert self._multiscale_count == 1 or self.color_ch == 1, msg
|
|
548
|
+
|
|
549
|
+
lowres_first_bottom_ups = []
|
|
550
|
+
for _ in range(1, self._multiscale_count):
|
|
551
|
+
first_bottom_up = nn.Sequential(
|
|
552
|
+
nn.Conv2d(
|
|
553
|
+
in_channels=self.color_ch,
|
|
554
|
+
out_channels=self.encoder_n_filters,
|
|
555
|
+
kernel_size=5,
|
|
556
|
+
padding=2,
|
|
557
|
+
stride=stride,
|
|
558
|
+
),
|
|
559
|
+
nonlin(),
|
|
560
|
+
BottomUpDeterministicResBlock(
|
|
561
|
+
c_in=self.encoder_n_filters,
|
|
562
|
+
c_out=self.encoder_n_filters,
|
|
563
|
+
nonlin=nonlin,
|
|
564
|
+
downsample=False,
|
|
565
|
+
batchnorm=self.bottomup_batchnorm,
|
|
566
|
+
dropout=self.encoder_dropout,
|
|
567
|
+
res_block_type=self.res_block_type,
|
|
568
|
+
skip_padding=self.encoder_res_block_skip_padding,
|
|
569
|
+
),
|
|
570
|
+
)
|
|
571
|
+
lowres_first_bottom_ups.append(first_bottom_up)
|
|
572
|
+
|
|
573
|
+
self.lowres_first_bottom_ups = (
|
|
574
|
+
nn.ModuleList(lowres_first_bottom_ups)
|
|
575
|
+
if len(lowres_first_bottom_ups)
|
|
576
|
+
else None
|
|
577
|
+
)
|
|
578
|
+
|
|
579
|
+
### SET OF FORWARD-LIKE METHODS
|
|
580
|
+
def bottomup_pass(self, inp: torch.Tensor) -> List[torch.Tensor]:
|
|
581
|
+
"""
|
|
582
|
+
Wrapper of _bottomup_pass().
|
|
583
|
+
"""
|
|
584
|
+
return self._bottomup_pass(
|
|
585
|
+
inp,
|
|
586
|
+
self.first_bottom_up,
|
|
587
|
+
self.lowres_first_bottom_ups,
|
|
588
|
+
self.bottom_up_layers,
|
|
589
|
+
)
|
|
590
|
+
|
|
591
|
+
def _bottomup_pass(
|
|
592
|
+
self,
|
|
593
|
+
inp: torch.Tensor,
|
|
594
|
+
first_bottom_up: nn.Sequential,
|
|
595
|
+
lowres_first_bottom_ups: nn.ModuleList,
|
|
596
|
+
bottom_up_layers: nn.ModuleList,
|
|
597
|
+
) -> List[torch.Tensor]:
|
|
598
|
+
"""
|
|
599
|
+
This method defines the forward pass throught the LVAE Encoder, the so-called
|
|
600
|
+
Bottom-Up pass.
|
|
601
|
+
|
|
602
|
+
Parameters
|
|
603
|
+
----------
|
|
604
|
+
inp: torch.Tensor
|
|
605
|
+
The input tensor to the bottom-up pass of shape (B, 1+n_LC, H, W), where n_LC
|
|
606
|
+
is the number of lateral low-res inputs used in the LC approach.
|
|
607
|
+
In particular, the first channel corresponds to the input patch, while the
|
|
608
|
+
remaining ones are associated to the lateral low-res inputs.
|
|
609
|
+
first_bottom_up: nn.Sequential
|
|
610
|
+
The module defining the first bottom-up layer of the Encoder.
|
|
611
|
+
lowres_first_bottom_ups: nn.ModuleList
|
|
612
|
+
The list of modules defining Lateral Contextualization.
|
|
613
|
+
bottom_up_layers: nn.ModuleList
|
|
614
|
+
The list of modules defining the stack of bottom-up layers of the Encoder.
|
|
615
|
+
"""
|
|
616
|
+
if self._multiscale_count > 1:
|
|
617
|
+
x = first_bottom_up(inp[:, :1])
|
|
618
|
+
else:
|
|
619
|
+
x = first_bottom_up(inp)
|
|
620
|
+
|
|
621
|
+
# Loop from bottom to top layer, store all deterministic nodes we
|
|
622
|
+
# need for the top-down pass in bu_values list
|
|
623
|
+
bu_values = []
|
|
624
|
+
for i in range(self.n_layers):
|
|
625
|
+
lowres_x = None
|
|
626
|
+
if self._multiscale_count > 1 and i + 1 < inp.shape[1]:
|
|
627
|
+
lowres_x = lowres_first_bottom_ups[i](inp[:, i + 1 : i + 2])
|
|
628
|
+
|
|
629
|
+
x, bu_value = bottom_up_layers[i](x, lowres_x=lowres_x)
|
|
630
|
+
bu_values.append(bu_value)
|
|
631
|
+
|
|
632
|
+
return bu_values
|
|
633
|
+
|
|
634
|
+
def topdown_pass(
|
|
635
|
+
self,
|
|
636
|
+
bu_values: torch.Tensor = None,
|
|
637
|
+
n_img_prior: torch.Tensor = None,
|
|
638
|
+
mode_layers: Iterable[int] = None,
|
|
639
|
+
constant_layers: Iterable[int] = None,
|
|
640
|
+
forced_latent: List[torch.Tensor] = None,
|
|
641
|
+
top_down_layers: nn.ModuleList = None,
|
|
642
|
+
final_top_down_layer: nn.Sequential = None,
|
|
643
|
+
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
|
644
|
+
"""
|
|
645
|
+
This method defines the forward pass throught the LVAE Decoder, the so-called
|
|
646
|
+
Top-Down pass.
|
|
647
|
+
|
|
648
|
+
Parameters
|
|
649
|
+
----------
|
|
650
|
+
bu_values: torch.Tensor, optional
|
|
651
|
+
Output of the bottom-up pass. It will have values from multiple layers of the ladder.
|
|
652
|
+
n_img_prior: optional
|
|
653
|
+
When `bu_values` is `None`, `n_img_prior` indicates the number of images to generate
|
|
654
|
+
from the prior (so bottom-up pass is not used at all here).
|
|
655
|
+
mode_layers: Iterable[int], optional
|
|
656
|
+
A sequence of indexes associated to the layers in which sampling is disabled and
|
|
657
|
+
the mode (mean value) is used instead. Set to `None` to avoid this behaviour.
|
|
658
|
+
constant_layers: Iterable[int], optional
|
|
659
|
+
A sequence of indexes associated to the layers in which a single instance's z is
|
|
660
|
+
copied over the entire batch (bottom-up path is not used, so only prior is used here).
|
|
661
|
+
Set to `None` to avoid this behaviour.
|
|
662
|
+
forced_latent: List[torch.Tensor], optional
|
|
663
|
+
A list of tensors that are used as fixed latent variables (hence, sampling doesn't take
|
|
664
|
+
place in this case).
|
|
665
|
+
top_down_layers: nn.ModuleList, optional
|
|
666
|
+
A list of top-down layers to use in the top-down pass. If `None`, the method uses the
|
|
667
|
+
default layers defined in the contructor.
|
|
668
|
+
final_top_down_layer: nn.Sequential, optional
|
|
669
|
+
The last top-down layer of the top-down pass. If `None`, the method uses the default
|
|
670
|
+
layers defined in the contructor.
|
|
671
|
+
"""
|
|
672
|
+
if top_down_layers is None:
|
|
673
|
+
top_down_layers = self.top_down_layers
|
|
674
|
+
if final_top_down_layer is None:
|
|
675
|
+
final_top_down_layer = self.final_top_down
|
|
676
|
+
|
|
677
|
+
# Default: no layer is sampled from the distribution's mode
|
|
678
|
+
if mode_layers is None:
|
|
679
|
+
mode_layers = []
|
|
680
|
+
if constant_layers is None:
|
|
681
|
+
constant_layers = []
|
|
682
|
+
prior_experiment = len(mode_layers) > 0 or len(constant_layers) > 0
|
|
683
|
+
|
|
684
|
+
# If the bottom-up inference values are not given, don't do
|
|
685
|
+
# inference, sample from prior instead
|
|
686
|
+
inference_mode = bu_values is not None
|
|
687
|
+
|
|
688
|
+
# Check consistency of arguments
|
|
689
|
+
if inference_mode != (n_img_prior is None):
|
|
690
|
+
msg = (
|
|
691
|
+
"Number of images for top-down generation has to be given "
|
|
692
|
+
"if and only if we're not doing inference"
|
|
693
|
+
)
|
|
694
|
+
raise RuntimeError(msg)
|
|
695
|
+
if (
|
|
696
|
+
inference_mode
|
|
697
|
+
and prior_experiment
|
|
698
|
+
and (self.non_stochastic_version is False)
|
|
699
|
+
):
|
|
700
|
+
msg = (
|
|
701
|
+
"Prior experiments (e.g. sampling from mode) are not"
|
|
702
|
+
" compatible with inference mode"
|
|
703
|
+
)
|
|
704
|
+
raise RuntimeError(msg)
|
|
705
|
+
|
|
706
|
+
# Sampled latent variables at each layer
|
|
707
|
+
z = [None] * self.n_layers
|
|
708
|
+
|
|
709
|
+
# KL divergence of each layer
|
|
710
|
+
kl = [None] * self.n_layers
|
|
711
|
+
# Kl divergence restricted, only for the LC enabled setup denoiSplit.
|
|
712
|
+
kl_restricted = [None] * self.n_layers
|
|
713
|
+
|
|
714
|
+
# mean from which z is sampled.
|
|
715
|
+
q_mu = [None] * self.n_layers
|
|
716
|
+
# log(var) from which z is sampled.
|
|
717
|
+
q_lv = [None] * self.n_layers
|
|
718
|
+
|
|
719
|
+
# Spatial map of KL divergence for each layer
|
|
720
|
+
kl_spatial = [None] * self.n_layers
|
|
721
|
+
|
|
722
|
+
debug_qvar_max = [None] * self.n_layers
|
|
723
|
+
|
|
724
|
+
kl_channelwise = [None] * self.n_layers
|
|
725
|
+
|
|
726
|
+
if forced_latent is None:
|
|
727
|
+
forced_latent = [None] * self.n_layers
|
|
728
|
+
|
|
729
|
+
# log p(z) where z is the sample in the topdown pass
|
|
730
|
+
# logprob_p = 0.
|
|
731
|
+
|
|
732
|
+
# Top-down inference/generation loop
|
|
733
|
+
out = out_pre_residual = None
|
|
734
|
+
for i in reversed(range(self.n_layers)):
|
|
735
|
+
|
|
736
|
+
# If available, get deterministic node from bottom-up inference
|
|
737
|
+
try:
|
|
738
|
+
bu_value = bu_values[i]
|
|
739
|
+
except TypeError:
|
|
740
|
+
bu_value = None
|
|
741
|
+
|
|
742
|
+
# Whether the current layer should be sampled from the mode
|
|
743
|
+
use_mode = i in mode_layers
|
|
744
|
+
constant_out = i in constant_layers
|
|
745
|
+
use_uncond_mode = i in self.use_uncond_mode_at
|
|
746
|
+
|
|
747
|
+
# Input for skip connection
|
|
748
|
+
skip_input = out # TODO or n? or both?
|
|
749
|
+
|
|
750
|
+
# Full top-down layer, including sampling and deterministic part
|
|
751
|
+
out, out_pre_residual, aux = top_down_layers[i](
|
|
752
|
+
input_=out,
|
|
753
|
+
skip_connection_input=skip_input,
|
|
754
|
+
inference_mode=inference_mode,
|
|
755
|
+
bu_value=bu_value,
|
|
756
|
+
n_img_prior=n_img_prior,
|
|
757
|
+
use_mode=use_mode,
|
|
758
|
+
force_constant_output=constant_out,
|
|
759
|
+
forced_latent=forced_latent[i],
|
|
760
|
+
mode_pred=self.mode_pred,
|
|
761
|
+
use_uncond_mode=use_uncond_mode,
|
|
762
|
+
var_clip_max=self._var_clip_max,
|
|
763
|
+
)
|
|
764
|
+
|
|
765
|
+
# Save useful variables
|
|
766
|
+
z[i] = aux["z"] # sampled variable at this layer (batch, ch, h, w)
|
|
767
|
+
kl[i] = aux["kl_samplewise"] # (batch, )
|
|
768
|
+
kl_restricted[i] = aux["kl_samplewise_restricted"]
|
|
769
|
+
kl_spatial[i] = aux["kl_spatial"] # (batch, h, w)
|
|
770
|
+
q_mu[i] = aux["q_mu"]
|
|
771
|
+
q_lv[i] = aux["q_lv"]
|
|
772
|
+
|
|
773
|
+
kl_channelwise[i] = aux["kl_channelwise"]
|
|
774
|
+
debug_qvar_max[i] = aux["qvar_max"]
|
|
775
|
+
# if self.mode_pred is False:
|
|
776
|
+
# logprob_p += aux['logprob_p'].mean() # mean over batch
|
|
777
|
+
# else:
|
|
778
|
+
# logprob_p = None
|
|
779
|
+
|
|
780
|
+
# Final top-down layer
|
|
781
|
+
out = final_top_down_layer(out)
|
|
782
|
+
|
|
783
|
+
# Store useful variables in a dict to return them
|
|
784
|
+
data = {
|
|
785
|
+
"z": z, # list of tensors with shape (batch, ch[i], h[i], w[i])
|
|
786
|
+
"kl": kl, # list of tensors with shape (batch, )
|
|
787
|
+
"kl_restricted": kl_restricted, # list of tensors with shape (batch, )
|
|
788
|
+
"kl_spatial": kl_spatial, # list of tensors w shape (batch, h[i], w[i])
|
|
789
|
+
"kl_channelwise": kl_channelwise, # list of tensors with shape (batch, ch[i])
|
|
790
|
+
# 'logprob_p': logprob_p, # scalar, mean over batch
|
|
791
|
+
"q_mu": q_mu,
|
|
792
|
+
"q_lv": q_lv,
|
|
793
|
+
"debug_qvar_max": debug_qvar_max,
|
|
794
|
+
}
|
|
795
|
+
return out, data
|
|
796
|
+
|
|
797
|
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
|
798
|
+
"""
|
|
799
|
+
Parameters
|
|
800
|
+
----------
|
|
801
|
+
x: torch.Tensor
|
|
802
|
+
The input tensor of shape (B, C, H, W).
|
|
803
|
+
"""
|
|
804
|
+
img_size = x.size()[2:]
|
|
805
|
+
|
|
806
|
+
# Pad input to size equal to the closest power of 2
|
|
807
|
+
x_pad = self.pad_input(x)
|
|
808
|
+
|
|
809
|
+
# Bottom-up inference: return list of length n_layers (bottom to top)
|
|
810
|
+
bu_values = self.bottomup_pass(x_pad)
|
|
811
|
+
for i in range(0, self.skip_bottomk_buvalues):
|
|
812
|
+
bu_values[i] = None
|
|
813
|
+
|
|
814
|
+
mode_layers = range(self.n_layers) if self.non_stochastic_version else None
|
|
815
|
+
|
|
816
|
+
# Top-down inference/generation
|
|
817
|
+
out, td_data = self.topdown_pass(bu_values, mode_layers=mode_layers)
|
|
818
|
+
|
|
819
|
+
if out.shape[-1] > img_size[-1]:
|
|
820
|
+
# Restore original image size
|
|
821
|
+
out = crop_img_tensor(out, img_size)
|
|
822
|
+
|
|
823
|
+
out = self.output_layer(out)
|
|
824
|
+
if self._tethered_to_input:
|
|
825
|
+
assert out.shape[1] == 1
|
|
826
|
+
ch2 = self.get_other_channel(out, x_pad)
|
|
827
|
+
out = torch.cat([out, ch2], dim=1)
|
|
828
|
+
|
|
829
|
+
return out, td_data
|
|
830
|
+
|
|
831
|
+
### SET OF UTILS METHODS
|
|
832
|
+
# def sample_prior(
|
|
833
|
+
# self,
|
|
834
|
+
# n_imgs,
|
|
835
|
+
# mode_layers=None,
|
|
836
|
+
# constant_layers=None
|
|
837
|
+
# ):
|
|
838
|
+
|
|
839
|
+
# # Generate from prior
|
|
840
|
+
# out, _ = self.topdown_pass(n_img_prior=n_imgs, mode_layers=mode_layers, constant_layers=constant_layers)
|
|
841
|
+
# out = crop_img_tensor(out, self.img_shape)
|
|
842
|
+
|
|
843
|
+
# # Log likelihood and other info (per data point)
|
|
844
|
+
# _, likelihood_data = self.likelihood(out, None)
|
|
845
|
+
|
|
846
|
+
# return likelihood_data['sample']
|
|
847
|
+
|
|
848
|
+
# ### ???
|
|
849
|
+
# def sample_from_q(self, x, masks=None):
|
|
850
|
+
# """
|
|
851
|
+
# This method performs the bottomup_pass() and samples from the
|
|
852
|
+
# obtained distribution.
|
|
853
|
+
# """
|
|
854
|
+
# img_size = x.size()[2:]
|
|
855
|
+
|
|
856
|
+
# # Pad input to make everything easier with conv strides
|
|
857
|
+
# x_pad = self.pad_input(x)
|
|
858
|
+
|
|
859
|
+
# # Bottom-up inference: return list of length n_layers (bottom to top)
|
|
860
|
+
# bu_values = self.bottomup_pass(x_pad)
|
|
861
|
+
# return self._sample_from_q(bu_values, masks=masks)
|
|
862
|
+
# ### ???
|
|
863
|
+
|
|
864
|
+
# def _sample_from_q(self, bu_values, top_down_layers=None, final_top_down_layer=None, masks=None):
|
|
865
|
+
# if top_down_layers is None:
|
|
866
|
+
# top_down_layers = self.top_down_layers
|
|
867
|
+
# if final_top_down_layer is None:
|
|
868
|
+
# final_top_down_layer = self.final_top_down
|
|
869
|
+
# if masks is None:
|
|
870
|
+
# masks = [None] * len(bu_values)
|
|
871
|
+
|
|
872
|
+
# msg = "Multiscale is not supported as of now. You need the output from the previous layers to do this."
|
|
873
|
+
# assert self.n_layers == 1, msg
|
|
874
|
+
# samples = []
|
|
875
|
+
# for i in reversed(range(self.n_layers)):
|
|
876
|
+
# bu_value = bu_values[i]
|
|
877
|
+
|
|
878
|
+
# # Note that the first argument can be set to None since we are just dealing with one level
|
|
879
|
+
# sample = top_down_layers[i].sample_from_q(None, bu_value, var_clip_max=self._var_clip_max, mask=masks[i])
|
|
880
|
+
# samples.append(sample)
|
|
881
|
+
|
|
882
|
+
# return samples
|
|
883
|
+
|
|
884
|
+
# def reset_for_different_output_size(self, output_size):
|
|
885
|
+
# for i in range(self.n_layers):
|
|
886
|
+
# sz = output_size // 2**(1 + i)
|
|
887
|
+
# self.bottom_up_layers[i].output_expected_shape = (sz, sz)
|
|
888
|
+
# self.top_down_layers[i].latent_shape = (output_size, output_size)
|
|
889
|
+
|
|
890
|
+
def pad_input(self, x):
|
|
891
|
+
"""
|
|
892
|
+
Pads input x so that its sizes are powers of 2
|
|
893
|
+
:param x:
|
|
894
|
+
:return: Padded tensor
|
|
895
|
+
"""
|
|
896
|
+
size = self.get_padded_size(x.size())
|
|
897
|
+
x = pad_img_tensor(x, size)
|
|
898
|
+
return x
|
|
899
|
+
|
|
900
|
+
### SET OF GETTERS
|
|
901
|
+
def get_nonlin(self):
|
|
902
|
+
nonlin = {
|
|
903
|
+
"relu": nn.ReLU,
|
|
904
|
+
"leakyrelu": nn.LeakyReLU,
|
|
905
|
+
"elu": nn.ELU,
|
|
906
|
+
"selu": nn.SELU,
|
|
907
|
+
}
|
|
908
|
+
return nonlin[self.nonlin]
|
|
909
|
+
|
|
910
|
+
def get_padded_size(self, size):
|
|
911
|
+
"""
|
|
912
|
+
Returns the smallest size (H, W) of the image with actual size given
|
|
913
|
+
as input, such that H and W are powers of 2.
|
|
914
|
+
:param size: input size, tuple either (N, C, H, w) or (H, W)
|
|
915
|
+
:return: 2-tuple (H, W)
|
|
916
|
+
"""
|
|
917
|
+
# Make size argument into (heigth, width)
|
|
918
|
+
if len(size) == 4:
|
|
919
|
+
size = size[2:]
|
|
920
|
+
if len(size) != 2:
|
|
921
|
+
msg = (
|
|
922
|
+
"input size must be either (N, C, H, W) or (H, W), but it "
|
|
923
|
+
f"has length {len(size)} (size={size})"
|
|
924
|
+
)
|
|
925
|
+
raise RuntimeError(msg)
|
|
926
|
+
|
|
927
|
+
if self.multiscale_decoder_retain_spatial_dims is True:
|
|
928
|
+
# In this case, we can go much more deeper and so this is not required
|
|
929
|
+
# (in the way it is. ;). More work would be needed if this was to be correctly implemented )
|
|
930
|
+
return list(size)
|
|
931
|
+
|
|
932
|
+
# Overall downscale factor from input to top layer (power of 2)
|
|
933
|
+
dwnsc = self.overall_downscale_factor
|
|
934
|
+
|
|
935
|
+
# Output smallest powers of 2 that are larger than current sizes
|
|
936
|
+
padded_size = list(((s - 1) // dwnsc + 1) * dwnsc for s in size)
|
|
937
|
+
|
|
938
|
+
return padded_size
|
|
939
|
+
|
|
940
|
+
def get_latent_spatial_size(self, level_idx: int):
|
|
941
|
+
"""
|
|
942
|
+
level_idx: 0 is the bottommost layer, the highest resolution one.
|
|
943
|
+
"""
|
|
944
|
+
actual_downsampling = level_idx + 1
|
|
945
|
+
dwnsc = 2**actual_downsampling
|
|
946
|
+
sz = self.get_padded_size(self.img_shape)
|
|
947
|
+
h = sz[0] // dwnsc
|
|
948
|
+
w = sz[1] // dwnsc
|
|
949
|
+
assert h == w
|
|
950
|
+
return h
|
|
951
|
+
|
|
952
|
+
def get_top_prior_param_shape(self, n_imgs: int = 1):
|
|
953
|
+
# TODO num channels depends on random variable we're using
|
|
954
|
+
|
|
955
|
+
# Compute the total downscaling performed in the Encoder
|
|
956
|
+
if self.multiscale_decoder_retain_spatial_dims is False:
|
|
957
|
+
dwnsc = self.overall_downscale_factor
|
|
958
|
+
else:
|
|
959
|
+
# LC allow the encoder latents to keep the same (H, W) size at different levels
|
|
960
|
+
actual_downsampling = self.n_layers + 1 - self._multiscale_count
|
|
961
|
+
dwnsc = 2**actual_downsampling
|
|
962
|
+
|
|
963
|
+
sz = self.get_padded_size(self.img_shape)
|
|
964
|
+
h = sz[0] // dwnsc
|
|
965
|
+
w = sz[1] // dwnsc
|
|
966
|
+
c = self.z_dims[-1] * 2 # mu and logvar
|
|
967
|
+
top_layer_shape = (n_imgs, c, h, w)
|
|
968
|
+
return top_layer_shape
|
|
969
|
+
|
|
970
|
+
def get_other_channel(self, ch1, input):
|
|
971
|
+
assert self.data_std["target"].squeeze().shape == (2,)
|
|
972
|
+
assert self.data_mean["target"].squeeze().shape == (2,)
|
|
973
|
+
assert self.target_ch == 2
|
|
974
|
+
ch1_un = (
|
|
975
|
+
ch1[:, :1] * self.data_std["target"][:, :1]
|
|
976
|
+
+ self.data_mean["target"][:, :1]
|
|
977
|
+
)
|
|
978
|
+
input_un = input * self.data_std["input"] + self.data_mean["input"]
|
|
979
|
+
ch2_un = self._tethered_ch2_scalar * (
|
|
980
|
+
input_un - ch1_un * self._tethered_ch1_scalar
|
|
981
|
+
)
|
|
982
|
+
ch2 = (ch2_un - self.data_mean["target"][:, -1:]) / self.data_std["target"][
|
|
983
|
+
:, -1:
|
|
984
|
+
]
|
|
985
|
+
return ch2
|