careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +1 -14
- careamics/careamist.py +212 -294
- careamics/config/__init__.py +0 -3
- careamics/config/algorithm_model.py +8 -15
- careamics/config/architectures/architecture_model.py +1 -0
- careamics/config/architectures/custom_model.py +5 -3
- careamics/config/architectures/unet_model.py +19 -0
- careamics/config/architectures/vae_model.py +1 -0
- careamics/config/callback_model.py +76 -34
- careamics/config/configuration_factory.py +18 -98
- careamics/config/configuration_model.py +23 -18
- careamics/config/data_model.py +103 -54
- careamics/config/inference_model.py +41 -19
- careamics/config/optimizer_models.py +13 -7
- careamics/config/support/supported_data.py +29 -4
- careamics/config/support/supported_transforms.py +0 -1
- careamics/config/tile_information.py +36 -58
- careamics/config/training_model.py +5 -1
- careamics/config/transformations/normalize_model.py +32 -4
- careamics/config/validators/validator_utils.py +1 -1
- careamics/dataset/__init__.py +12 -1
- careamics/dataset/dataset_utils/__init__.py +8 -7
- careamics/dataset/dataset_utils/file_utils.py +2 -2
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +84 -173
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +97 -250
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/patching.py +97 -52
- careamics/dataset/patching/random_patching.py +9 -4
- careamics/dataset/patching/validate_patch_dimension.py +5 -3
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
- careamics/file_io/__init__.py +7 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/{dataset/dataset_utils/read_tiff.py → file_io/read/tiff.py} +3 -10
- careamics/file_io/write/__init__.py +9 -0
- careamics/file_io/write/get_func.py +59 -0
- careamics/file_io/write/tiff.py +39 -0
- careamics/lightning/__init__.py +17 -0
- careamics/{lightning_module.py → lightning/lightning_module.py} +69 -92
- careamics/{lightning_prediction_datamodule.py → lightning/predict_data_module.py} +120 -178
- careamics/{lightning_datamodule.py → lightning/train_data_module.py} +135 -220
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/bioimage/model_description.py +40 -32
- careamics/model_io/bmz_io.py +2 -2
- careamics/model_io/model_io_utils.py +6 -3
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/prediction_outputs.py +137 -0
- careamics/prediction_utils/stitch_prediction.py +103 -0
- careamics/transforms/n2v_manipulate.py +3 -1
- careamics/transforms/normalize.py +139 -68
- careamics/transforms/pixel_manipulation.py +33 -9
- careamics/transforms/tta.py +43 -29
- careamics/utils/__init__.py +2 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/ram.py +2 -2
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/METADATA +7 -6
- careamics-0.1.0rc8.dist-info/RECORD +135 -0
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/WHEEL +1 -1
- careamics/config/configuration_example.py +0 -89
- careamics/dataset/dataset_utils/read_utils.py +0 -27
- careamics/lightning_prediction_loop.py +0 -118
- careamics/prediction/__init__.py +0 -7
- careamics/prediction/stitch_prediction.py +0 -70
- careamics/utils/running_stats.py +0 -43
- careamics-0.1.0rc6.dist-info/RECORD +0 -107
- /careamics/{dataset/dataset_utils/read_zarr.py → file_io/read/zarr.py} +0 -0
- /careamics/{callbacks → lightning/callbacks}/__init__.py +0 -0
- /careamics/{callbacks → lightning/callbacks}/hyperparameters_callback.py +0 -0
- /careamics/{callbacks → lightning/callbacks}/progress_bar_callback.py +0 -0
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,1998 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Script containing the common basic blocks (nn.Module) reused by the LadderVAE architecture.
|
|
3
|
+
|
|
4
|
+
Hierarchy in the model blocks:
|
|
5
|
+
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from copy import deepcopy
|
|
9
|
+
from typing import Callable, Dict, Iterable, Literal, Tuple, Union
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
import torch.nn as nn
|
|
13
|
+
import torchvision.transforms.functional as F
|
|
14
|
+
from torch.distributions import kl_divergence
|
|
15
|
+
from torch.distributions.normal import Normal
|
|
16
|
+
|
|
17
|
+
from .utils import (
|
|
18
|
+
StableLogVar,
|
|
19
|
+
StableMean,
|
|
20
|
+
crop_img_tensor,
|
|
21
|
+
kl_normal_mc,
|
|
22
|
+
pad_img_tensor,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ResidualBlock(nn.Module):
|
|
27
|
+
"""
|
|
28
|
+
Residual block with 2 convolutional layers.
|
|
29
|
+
|
|
30
|
+
Some architectural notes:
|
|
31
|
+
- The number of input, intermediate, and output channels is the same,
|
|
32
|
+
- Padding is always 'same',
|
|
33
|
+
- The 2 convolutional layers have the same groups,
|
|
34
|
+
- No stride allowed,
|
|
35
|
+
- Kernel sizes must be odd.
|
|
36
|
+
|
|
37
|
+
The output isgiven by: `out = gate(f(x)) + x`.
|
|
38
|
+
The presence of the gating mechanism is optional, and f(x) has different
|
|
39
|
+
structures depending on the `block_type` argument.
|
|
40
|
+
Specifically, `block_type` is a string specifying the block's structure, with:
|
|
41
|
+
a = activation
|
|
42
|
+
b = batch norm
|
|
43
|
+
c = conv layer
|
|
44
|
+
d = dropout.
|
|
45
|
+
For example, "bacdbacd" defines a block with 2x[batchnorm, activation, conv, dropout].
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
default_kernel_size = (3, 3)
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
channels: int,
|
|
53
|
+
nonlin: Callable,
|
|
54
|
+
kernel: Union[int, Iterable[int]] = None,
|
|
55
|
+
groups: int = 1,
|
|
56
|
+
batchnorm: bool = True,
|
|
57
|
+
block_type: str = None,
|
|
58
|
+
dropout: float = None,
|
|
59
|
+
gated: bool = None,
|
|
60
|
+
skip_padding: bool = False,
|
|
61
|
+
conv2d_bias: bool = True,
|
|
62
|
+
):
|
|
63
|
+
"""
|
|
64
|
+
Constructor.
|
|
65
|
+
|
|
66
|
+
Parameters
|
|
67
|
+
----------
|
|
68
|
+
channels: int
|
|
69
|
+
The number of input and output channels (they are the same).
|
|
70
|
+
nonlin: Callable
|
|
71
|
+
The non-linearity function used in the block (e.g., `nn.ReLU`).
|
|
72
|
+
kernel: Union[int, Iterable[int]], optional
|
|
73
|
+
The kernel size used in the convolutions of the block.
|
|
74
|
+
It can be either a single integer or a pair of integers defining the squared kernel.
|
|
75
|
+
Default is `None`.
|
|
76
|
+
groups: int, optional
|
|
77
|
+
The number of groups to consider in the convolutions. Default is 1.
|
|
78
|
+
batchnorm: bool, optional
|
|
79
|
+
Whether to use batchnorm layers. Default is `True`.
|
|
80
|
+
block_type: str, optional
|
|
81
|
+
A string specifying the block structure, check class docstring for more info.
|
|
82
|
+
Default is `None`.
|
|
83
|
+
dropout: float, optional
|
|
84
|
+
The dropout probability in dropout layers. If `None` dropout is not used.
|
|
85
|
+
Default is `None`.
|
|
86
|
+
gated: bool, optional
|
|
87
|
+
Whether to use gated layer. Default is `None`.
|
|
88
|
+
skip_padding: bool, optional
|
|
89
|
+
Whether to skip padding in convolutions. Default is `False`.
|
|
90
|
+
conv2d_bias: bool, optional
|
|
91
|
+
Whether to use bias term in convolutions. Default is `True`.
|
|
92
|
+
"""
|
|
93
|
+
super().__init__()
|
|
94
|
+
|
|
95
|
+
# Set kernel size & padding
|
|
96
|
+
if kernel is None:
|
|
97
|
+
kernel = self.default_kernel_size
|
|
98
|
+
elif isinstance(kernel, int):
|
|
99
|
+
kernel = (kernel, kernel)
|
|
100
|
+
elif len(kernel) != 2:
|
|
101
|
+
raise ValueError("kernel has to be None, int, or an iterable of length 2")
|
|
102
|
+
assert all([k % 2 == 1 for k in kernel]), "kernel sizes have to be odd"
|
|
103
|
+
kernel = list(kernel)
|
|
104
|
+
self.skip_padding = skip_padding
|
|
105
|
+
pad = [0] * len(kernel) if self.skip_padding else [k // 2 for k in kernel]
|
|
106
|
+
# print(kernel, pad)
|
|
107
|
+
|
|
108
|
+
modules = []
|
|
109
|
+
if block_type == "cabdcabd":
|
|
110
|
+
for i in range(2):
|
|
111
|
+
conv = nn.Conv2d(
|
|
112
|
+
channels,
|
|
113
|
+
channels,
|
|
114
|
+
kernel[i],
|
|
115
|
+
padding=pad[i],
|
|
116
|
+
groups=groups,
|
|
117
|
+
bias=conv2d_bias,
|
|
118
|
+
)
|
|
119
|
+
modules.append(conv)
|
|
120
|
+
modules.append(nonlin())
|
|
121
|
+
if batchnorm:
|
|
122
|
+
modules.append(nn.BatchNorm2d(channels))
|
|
123
|
+
if dropout is not None:
|
|
124
|
+
modules.append(nn.Dropout2d(dropout))
|
|
125
|
+
elif block_type == "bacdbac":
|
|
126
|
+
for i in range(2):
|
|
127
|
+
if batchnorm:
|
|
128
|
+
modules.append(nn.BatchNorm2d(channels))
|
|
129
|
+
modules.append(nonlin())
|
|
130
|
+
conv = nn.Conv2d(
|
|
131
|
+
channels,
|
|
132
|
+
channels,
|
|
133
|
+
kernel[i],
|
|
134
|
+
padding=pad[i],
|
|
135
|
+
groups=groups,
|
|
136
|
+
bias=conv2d_bias,
|
|
137
|
+
)
|
|
138
|
+
modules.append(conv)
|
|
139
|
+
if dropout is not None and i == 0:
|
|
140
|
+
modules.append(nn.Dropout2d(dropout))
|
|
141
|
+
elif block_type == "bacdbacd":
|
|
142
|
+
for i in range(2):
|
|
143
|
+
if batchnorm:
|
|
144
|
+
modules.append(nn.BatchNorm2d(channels))
|
|
145
|
+
modules.append(nonlin())
|
|
146
|
+
conv = nn.Conv2d(
|
|
147
|
+
channels,
|
|
148
|
+
channels,
|
|
149
|
+
kernel[i],
|
|
150
|
+
padding=pad[i],
|
|
151
|
+
groups=groups,
|
|
152
|
+
bias=conv2d_bias,
|
|
153
|
+
)
|
|
154
|
+
modules.append(conv)
|
|
155
|
+
modules.append(nn.Dropout2d(dropout))
|
|
156
|
+
|
|
157
|
+
else:
|
|
158
|
+
raise ValueError(f"unrecognized block type '{block_type}'")
|
|
159
|
+
|
|
160
|
+
self.gated = gated
|
|
161
|
+
if gated:
|
|
162
|
+
modules.append(GateLayer2d(channels, 1, nonlin))
|
|
163
|
+
|
|
164
|
+
self.block = nn.Sequential(*modules)
|
|
165
|
+
|
|
166
|
+
def forward(self, x):
|
|
167
|
+
|
|
168
|
+
out = self.block(x)
|
|
169
|
+
if out.shape != x.shape:
|
|
170
|
+
return out + F.center_crop(x, out.shape[-2:])
|
|
171
|
+
else:
|
|
172
|
+
return out + x
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
class ResidualGatedBlock(ResidualBlock):
|
|
176
|
+
|
|
177
|
+
def __init__(self, *args, **kwargs):
|
|
178
|
+
super().__init__(*args, **kwargs, gated=True)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
class GateLayer2d(nn.Module):
|
|
182
|
+
"""
|
|
183
|
+
Double the number of channels through a convolutional layer, then use
|
|
184
|
+
half the channels as gate for the other half.
|
|
185
|
+
"""
|
|
186
|
+
|
|
187
|
+
def __init__(self, channels, kernel_size, nonlin=nn.LeakyReLU):
|
|
188
|
+
super().__init__()
|
|
189
|
+
assert kernel_size % 2 == 1
|
|
190
|
+
pad = kernel_size // 2
|
|
191
|
+
self.conv = nn.Conv2d(channels, 2 * channels, kernel_size, padding=pad)
|
|
192
|
+
self.nonlin = nonlin()
|
|
193
|
+
|
|
194
|
+
def forward(self, x):
|
|
195
|
+
x = self.conv(x)
|
|
196
|
+
x, gate = torch.chunk(x, 2, dim=1)
|
|
197
|
+
x = self.nonlin(x) # TODO remove this?
|
|
198
|
+
gate = torch.sigmoid(gate)
|
|
199
|
+
return x * gate
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
class ResBlockWithResampling(nn.Module):
|
|
203
|
+
"""
|
|
204
|
+
Residual block that takes care of resampling (i.e. downsampling or upsampling) steps (by a factor 2).
|
|
205
|
+
It is structured as follows:
|
|
206
|
+
1. `pre_conv`: a downsampling or upsampling strided convolutional layer in case of resampling, or
|
|
207
|
+
a 1x1 convolutional layer that maps the number of channels of the input to `inner_channels`.
|
|
208
|
+
2. `ResidualBlock`
|
|
209
|
+
3. `post_conv`: a 1x1 convolutional layer that maps the number of channels to `c_out`.
|
|
210
|
+
|
|
211
|
+
Some implementation notes:
|
|
212
|
+
- Resampling is performed through a strided convolution layer at the beginning of the block.
|
|
213
|
+
- The strided convolution block has fixed kernel size of 3x3 and 1 layer of zero-padding.
|
|
214
|
+
- The number of channels is adjusted at the beginning and end of the block through 1x1 convolutional layers.
|
|
215
|
+
- The number of internal channels is by default the same as the number of output channels, but
|
|
216
|
+
min_inner_channels can override the behaviour.
|
|
217
|
+
"""
|
|
218
|
+
|
|
219
|
+
def __init__(
|
|
220
|
+
self,
|
|
221
|
+
mode: Literal["top-down", "bottom-up"],
|
|
222
|
+
c_in: int,
|
|
223
|
+
c_out: int,
|
|
224
|
+
min_inner_channels: int = None,
|
|
225
|
+
nonlin: Callable = nn.LeakyReLU,
|
|
226
|
+
resample: bool = False,
|
|
227
|
+
res_block_kernel: Union[int, Iterable[int]] = None,
|
|
228
|
+
groups: int = 1,
|
|
229
|
+
batchnorm: bool = True,
|
|
230
|
+
res_block_type: str = None,
|
|
231
|
+
dropout: float = None,
|
|
232
|
+
gated: bool = None,
|
|
233
|
+
skip_padding: bool = False,
|
|
234
|
+
conv2d_bias: bool = True,
|
|
235
|
+
# lowres_input: bool = False,
|
|
236
|
+
):
|
|
237
|
+
"""
|
|
238
|
+
Constructor.
|
|
239
|
+
|
|
240
|
+
Parameters
|
|
241
|
+
----------
|
|
242
|
+
mode: Literal["top-down", "bottom-up"]
|
|
243
|
+
The type of resampling performed in the initial strided convolution of the block.
|
|
244
|
+
If "bottom-up" downsampling of a factor 2 is done.
|
|
245
|
+
If "top-down" upsampling of a factor 2 is done.
|
|
246
|
+
c_in: int
|
|
247
|
+
The number of input channels.
|
|
248
|
+
c_out: int
|
|
249
|
+
The number of output channels.
|
|
250
|
+
min_inner_channels: int, optional
|
|
251
|
+
The number of channels used in the inner layer of this module.
|
|
252
|
+
Default is `None`, meaning that the number of inner channels is set to `c_out`.
|
|
253
|
+
nonlin: Callable, optional
|
|
254
|
+
The non-linearity function used in the block. Default is `nn.LeakyReLU`.
|
|
255
|
+
resample: bool, optional
|
|
256
|
+
Whether to perform resampling in the first convolutional layer.
|
|
257
|
+
If `False`, the first convolutional layer just maps the input to a tensor with
|
|
258
|
+
`inner_channels` channels through 1x1 convolution. Deafult is `False`.
|
|
259
|
+
res_block_kernel: Union[int, Iterable[int]], optional
|
|
260
|
+
The kernel size used in the convolutions of the residual block.
|
|
261
|
+
It can be either a single integer or a pair of integers defining the squared kernel.
|
|
262
|
+
Default is `None`.
|
|
263
|
+
groups: int, optional
|
|
264
|
+
The number of groups to consider in the convolutions. Default is 1.
|
|
265
|
+
batchnorm: bool, optional
|
|
266
|
+
Whether to use batchnorm layers. Default is `True`.
|
|
267
|
+
res_block_type: str, optional
|
|
268
|
+
A string specifying the structure of residual block.
|
|
269
|
+
Check `ResidualBlock` doscstring for more information.
|
|
270
|
+
Default is `None`.
|
|
271
|
+
dropout: float, optional
|
|
272
|
+
The dropout probability in dropout layers. If `None` dropout is not used.
|
|
273
|
+
Default is `None`.
|
|
274
|
+
gated: bool, optional
|
|
275
|
+
Whether to use gated layer. Default is `None`.
|
|
276
|
+
skip_padding: bool, optional
|
|
277
|
+
Whether to skip padding in convolutions. Default is `False`.
|
|
278
|
+
conv2d_bias: bool, optional
|
|
279
|
+
Whether to use bias term in convolutions. Default is `True`.
|
|
280
|
+
"""
|
|
281
|
+
super().__init__()
|
|
282
|
+
assert mode in ["top-down", "bottom-up"]
|
|
283
|
+
|
|
284
|
+
if min_inner_channels is None:
|
|
285
|
+
min_inner_channels = 0
|
|
286
|
+
# inner_channels is the number of channels used in the inner layers
|
|
287
|
+
# of ResBlockWithResampling
|
|
288
|
+
inner_channels = max(c_out, min_inner_channels)
|
|
289
|
+
|
|
290
|
+
# Define first conv layer to change num channels and/or up/downsample
|
|
291
|
+
if resample:
|
|
292
|
+
if mode == "bottom-up": # downsample
|
|
293
|
+
self.pre_conv = nn.Conv2d(
|
|
294
|
+
in_channels=c_in,
|
|
295
|
+
out_channels=inner_channels,
|
|
296
|
+
kernel_size=3,
|
|
297
|
+
padding=1,
|
|
298
|
+
stride=2,
|
|
299
|
+
groups=groups,
|
|
300
|
+
bias=conv2d_bias,
|
|
301
|
+
)
|
|
302
|
+
elif mode == "top-down": # upsample
|
|
303
|
+
self.pre_conv = nn.ConvTranspose2d(
|
|
304
|
+
in_channels=c_in,
|
|
305
|
+
kernel_size=3,
|
|
306
|
+
out_channels=inner_channels,
|
|
307
|
+
padding=1,
|
|
308
|
+
stride=2,
|
|
309
|
+
groups=groups,
|
|
310
|
+
output_padding=1,
|
|
311
|
+
bias=conv2d_bias,
|
|
312
|
+
)
|
|
313
|
+
elif c_in != inner_channels:
|
|
314
|
+
self.pre_conv = nn.Conv2d(
|
|
315
|
+
c_in, inner_channels, 1, groups=groups, bias=conv2d_bias
|
|
316
|
+
)
|
|
317
|
+
else:
|
|
318
|
+
self.pre_conv = None
|
|
319
|
+
|
|
320
|
+
# Residual block
|
|
321
|
+
self.res = ResidualBlock(
|
|
322
|
+
channels=inner_channels,
|
|
323
|
+
nonlin=nonlin,
|
|
324
|
+
kernel=res_block_kernel,
|
|
325
|
+
groups=groups,
|
|
326
|
+
batchnorm=batchnorm,
|
|
327
|
+
dropout=dropout,
|
|
328
|
+
gated=gated,
|
|
329
|
+
block_type=res_block_type,
|
|
330
|
+
skip_padding=skip_padding,
|
|
331
|
+
conv2d_bias=conv2d_bias,
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
# Define last conv layer to get correct num output channels
|
|
335
|
+
if inner_channels != c_out:
|
|
336
|
+
self.post_conv = nn.Conv2d(
|
|
337
|
+
inner_channels, c_out, 1, groups=groups, bias=conv2d_bias
|
|
338
|
+
)
|
|
339
|
+
else:
|
|
340
|
+
self.post_conv = None
|
|
341
|
+
|
|
342
|
+
def forward(self, x):
|
|
343
|
+
if self.pre_conv is not None:
|
|
344
|
+
x = self.pre_conv(x)
|
|
345
|
+
|
|
346
|
+
x = self.res(x)
|
|
347
|
+
|
|
348
|
+
if self.post_conv is not None:
|
|
349
|
+
x = self.post_conv(x)
|
|
350
|
+
return x
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
class TopDownDeterministicResBlock(ResBlockWithResampling):
|
|
354
|
+
|
|
355
|
+
def __init__(self, *args, upsample: bool = False, **kwargs):
|
|
356
|
+
kwargs["resample"] = upsample
|
|
357
|
+
super().__init__("top-down", *args, **kwargs)
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
class BottomUpDeterministicResBlock(ResBlockWithResampling):
|
|
361
|
+
|
|
362
|
+
def __init__(self, *args, downsample: bool = False, **kwargs):
|
|
363
|
+
kwargs["resample"] = downsample
|
|
364
|
+
super().__init__("bottom-up", *args, **kwargs)
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
class BottomUpLayer(nn.Module):
|
|
368
|
+
"""
|
|
369
|
+
Bottom-up deterministic layer.
|
|
370
|
+
It consists of one or a stack of `BottomUpDeterministicResBlock`'s.
|
|
371
|
+
The outputs are the so-called `bu_values` that are later used in the Decoder to update the
|
|
372
|
+
generative distributions.
|
|
373
|
+
|
|
374
|
+
NOTE: When Lateral Contextualization is Enabled (i.e., `enable_multiscale=True`),
|
|
375
|
+
the low-res lateral input is first fed through a BottomUpDeterministicBlock (BUDB)
|
|
376
|
+
(without downsampling), and then merged to the latent tensor produced by the primary flow
|
|
377
|
+
of the `BottomUpLayer` through the `MergeLowRes` layer. It is meaningful to remark that
|
|
378
|
+
the BUDB that takes care of encoding the low-res input can be either shared with the
|
|
379
|
+
primary flow (and in that case it is the "same_size" BUDB (or stack of BUDBs) -> see `self.net`),
|
|
380
|
+
or can be a deep-copy of the primary flow's BUDB.
|
|
381
|
+
This behaviour is controlled by `lowres_separate_branch` parameter.
|
|
382
|
+
"""
|
|
383
|
+
|
|
384
|
+
def __init__(
|
|
385
|
+
self,
|
|
386
|
+
n_res_blocks: int,
|
|
387
|
+
n_filters: int,
|
|
388
|
+
downsampling_steps: int = 0,
|
|
389
|
+
nonlin: Callable = None,
|
|
390
|
+
batchnorm: bool = True,
|
|
391
|
+
dropout: float = None,
|
|
392
|
+
res_block_type: str = None,
|
|
393
|
+
res_block_kernel: int = None,
|
|
394
|
+
res_block_skip_padding: bool = False,
|
|
395
|
+
gated: bool = None,
|
|
396
|
+
enable_multiscale: bool = False,
|
|
397
|
+
multiscale_lowres_size_factor: int = None,
|
|
398
|
+
lowres_separate_branch: bool = False,
|
|
399
|
+
multiscale_retain_spatial_dims: bool = False,
|
|
400
|
+
decoder_retain_spatial_dims: bool = False,
|
|
401
|
+
output_expected_shape: Iterable[int] = None,
|
|
402
|
+
):
|
|
403
|
+
"""
|
|
404
|
+
Constructor.
|
|
405
|
+
|
|
406
|
+
Parameters
|
|
407
|
+
----------
|
|
408
|
+
n_res_blocks: int
|
|
409
|
+
Number of `BottomUpDeterministicResBlock` modules stacked in this layer.
|
|
410
|
+
n_filters: int
|
|
411
|
+
Number of channels present through out the layers of this block.
|
|
412
|
+
downsampling_steps: int, optional
|
|
413
|
+
Number of downsampling steps that has to be done in this layer (typically 1).
|
|
414
|
+
Default is 0.
|
|
415
|
+
nonlin: Callable, optional
|
|
416
|
+
The non-linearity function used in the block. Default is `None`.
|
|
417
|
+
batchnorm: bool, optional
|
|
418
|
+
Whether to use batchnorm layers. Default is `True`.
|
|
419
|
+
dropout: float, optional
|
|
420
|
+
The dropout probability in dropout layers. If `None` dropout is not used.
|
|
421
|
+
Default is `None`.
|
|
422
|
+
res_block_type: str, optional
|
|
423
|
+
A string specifying the structure of residual block.
|
|
424
|
+
Check `ResidualBlock` doscstring for more information.
|
|
425
|
+
Default is `None`.
|
|
426
|
+
res_block_kernel: Union[int, Iterable[int]], optional
|
|
427
|
+
The kernel size used in the convolutions of the residual block.
|
|
428
|
+
It can be either a single integer or a pair of integers defining the squared kernel.
|
|
429
|
+
Default is `None`.
|
|
430
|
+
res_block_skip_padding: bool, optional
|
|
431
|
+
Whether to skip padding in convolutions in the Residual block. Default is `False`.
|
|
432
|
+
gated: bool, optional
|
|
433
|
+
Whether to use gated layer. Default is `None`.
|
|
434
|
+
enable_multiscale: bool, optional
|
|
435
|
+
Whether to enable multiscale (Lateral Contextualization) or not. Default is `False`.
|
|
436
|
+
multiscale_lowres_size_factor: int, optional
|
|
437
|
+
A factor the expresses the relative size of the primary flow tensor with respect to the
|
|
438
|
+
lower-resolution lateral input tensor. Default in `None`.
|
|
439
|
+
lowres_separate_branch: bool, optional
|
|
440
|
+
Whether the residual block(s) encoding the low-res input should be shared (`False`) or
|
|
441
|
+
not (`True`) with the primary flow "same-size" residual block(s). Default is `False`.
|
|
442
|
+
multiscale_retain_spatial_dims: bool, optional
|
|
443
|
+
Whether to pad the latent tensor resulting from the bottom-up layer's primary flow
|
|
444
|
+
to match the size of the low-res input. Default is `False`.
|
|
445
|
+
decoder_retain_spatial_dims: bool, optional
|
|
446
|
+
Default is `False`.
|
|
447
|
+
output_expected_shape: Iterable[int], optional
|
|
448
|
+
The expected shape of the layer output (only used if `enable_multiscale == True`).
|
|
449
|
+
Default is `None`.
|
|
450
|
+
"""
|
|
451
|
+
super().__init__()
|
|
452
|
+
|
|
453
|
+
# Define attributes for Lateral Contextualization
|
|
454
|
+
self.enable_multiscale = enable_multiscale
|
|
455
|
+
self.lowres_separate_branch = lowres_separate_branch
|
|
456
|
+
self.multiscale_retain_spatial_dims = multiscale_retain_spatial_dims
|
|
457
|
+
self.multiscale_lowres_size_factor = multiscale_lowres_size_factor
|
|
458
|
+
self.decoder_retain_spatial_dims = decoder_retain_spatial_dims
|
|
459
|
+
self.output_expected_shape = output_expected_shape
|
|
460
|
+
assert self.output_expected_shape is None or self.enable_multiscale is True
|
|
461
|
+
|
|
462
|
+
bu_blocks_downsized = []
|
|
463
|
+
bu_blocks_samesize = []
|
|
464
|
+
for _ in range(n_res_blocks):
|
|
465
|
+
do_resample = False
|
|
466
|
+
if downsampling_steps > 0:
|
|
467
|
+
do_resample = True
|
|
468
|
+
downsampling_steps -= 1
|
|
469
|
+
block = BottomUpDeterministicResBlock(
|
|
470
|
+
c_in=n_filters,
|
|
471
|
+
c_out=n_filters,
|
|
472
|
+
nonlin=nonlin,
|
|
473
|
+
downsample=do_resample,
|
|
474
|
+
batchnorm=batchnorm,
|
|
475
|
+
dropout=dropout,
|
|
476
|
+
res_block_type=res_block_type,
|
|
477
|
+
res_block_kernel=res_block_kernel,
|
|
478
|
+
skip_padding=res_block_skip_padding,
|
|
479
|
+
gated=gated,
|
|
480
|
+
)
|
|
481
|
+
if do_resample:
|
|
482
|
+
bu_blocks_downsized.append(block)
|
|
483
|
+
else:
|
|
484
|
+
bu_blocks_samesize.append(block)
|
|
485
|
+
|
|
486
|
+
self.net_downsized = nn.Sequential(*bu_blocks_downsized)
|
|
487
|
+
self.net = nn.Sequential(*bu_blocks_samesize)
|
|
488
|
+
|
|
489
|
+
# Using the same net for the low resolution (and larger sized image)
|
|
490
|
+
self.lowres_net = self.lowres_merge = None
|
|
491
|
+
if self.enable_multiscale:
|
|
492
|
+
self._init_multiscale(
|
|
493
|
+
n_filters=n_filters,
|
|
494
|
+
nonlin=nonlin,
|
|
495
|
+
batchnorm=batchnorm,
|
|
496
|
+
dropout=dropout,
|
|
497
|
+
res_block_type=res_block_type,
|
|
498
|
+
)
|
|
499
|
+
|
|
500
|
+
# msg = f'[{self.__class__.__name__}] McEnabled:{int(enable_multiscale)} '
|
|
501
|
+
# if enable_multiscale:
|
|
502
|
+
# msg += f'McParallelBeam:{int(multiscale_retain_spatial_dims)} McFactor{multiscale_lowres_size_factor}'
|
|
503
|
+
# print(msg)
|
|
504
|
+
|
|
505
|
+
def _init_multiscale(
|
|
506
|
+
self,
|
|
507
|
+
nonlin: Callable = None,
|
|
508
|
+
n_filters: int = None,
|
|
509
|
+
batchnorm: bool = None,
|
|
510
|
+
dropout: float = None,
|
|
511
|
+
res_block_type: str = None,
|
|
512
|
+
) -> None:
|
|
513
|
+
"""
|
|
514
|
+
This method defines the modules responsible of merging compressed lateral inputs to the outputs
|
|
515
|
+
of the primary flow at different hierarchical levels in the multiresolution approach (LC).
|
|
516
|
+
|
|
517
|
+
Specifically, the method initializes `lowres_net`, which is a stack of `BottomUpDeterministicBlock`'s
|
|
518
|
+
(w/out downsampling) that takes care of additionally processing the low-res input, and `lowres_merge`,
|
|
519
|
+
which is the module responsible of merging the compressed lateral input to the main flow.
|
|
520
|
+
|
|
521
|
+
NOTE: The merge modality is set by default to "residual", meaning that the merge layer
|
|
522
|
+
performs concatenation on dim=1, followed by 1x1 convolution and a Residual Gated block.
|
|
523
|
+
|
|
524
|
+
Parameters
|
|
525
|
+
----------
|
|
526
|
+
nonlin: Callable, optional
|
|
527
|
+
The non-linearity function used in the block. Default is `None`.
|
|
528
|
+
n_filters: int
|
|
529
|
+
Number of channels present through out the layers of this block.
|
|
530
|
+
batchnorm: bool, optional
|
|
531
|
+
Whether to use batchnorm layers. Default is `True`.
|
|
532
|
+
dropout: float, optional
|
|
533
|
+
The dropout probability in dropout layers. If `None` dropout is not used.
|
|
534
|
+
Default is `None`.
|
|
535
|
+
res_block_type: str, optional
|
|
536
|
+
A string specifying the structure of residual block.
|
|
537
|
+
Check `ResidualBlock` doscstring for more information.
|
|
538
|
+
Default is `None`.
|
|
539
|
+
"""
|
|
540
|
+
self.lowres_net = self.net
|
|
541
|
+
if self.lowres_separate_branch:
|
|
542
|
+
self.lowres_net = deepcopy(self.net)
|
|
543
|
+
|
|
544
|
+
self.lowres_merge = MergeLowRes(
|
|
545
|
+
channels=n_filters,
|
|
546
|
+
merge_type="residual",
|
|
547
|
+
nonlin=nonlin,
|
|
548
|
+
batchnorm=batchnorm,
|
|
549
|
+
dropout=dropout,
|
|
550
|
+
res_block_type=res_block_type,
|
|
551
|
+
multiscale_retain_spatial_dims=self.multiscale_retain_spatial_dims,
|
|
552
|
+
multiscale_lowres_size_factor=self.multiscale_lowres_size_factor,
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
def forward(
|
|
556
|
+
self, x: torch.Tensor, lowres_x: torch.Tensor = None
|
|
557
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
558
|
+
"""
|
|
559
|
+
Parameters
|
|
560
|
+
----------
|
|
561
|
+
x: torch.Tensor
|
|
562
|
+
The input of the `BottomUpLayer`, i.e., the input image or the output of the
|
|
563
|
+
previous layer.
|
|
564
|
+
lowres_x: torch.Tensor, optional
|
|
565
|
+
The low-res input used for Lateral Contextualization (LC). Default is `None`.
|
|
566
|
+
"""
|
|
567
|
+
# The input is fed through the residual downsampling block(s)
|
|
568
|
+
primary_flow = self.net_downsized(x)
|
|
569
|
+
# The downsampling output is fed through additional residual block(s)
|
|
570
|
+
primary_flow = self.net(primary_flow)
|
|
571
|
+
|
|
572
|
+
# If LC is not used, simply return output of primary-flow
|
|
573
|
+
if self.enable_multiscale is False:
|
|
574
|
+
assert lowres_x is None
|
|
575
|
+
return primary_flow, primary_flow
|
|
576
|
+
|
|
577
|
+
if lowres_x is not None:
|
|
578
|
+
# First encode the low-res lateral input
|
|
579
|
+
lowres_flow = self.lowres_net(lowres_x)
|
|
580
|
+
# Then pass the result through the MergeLowRes layer
|
|
581
|
+
merged = self.lowres_merge(primary_flow, lowres_flow)
|
|
582
|
+
else:
|
|
583
|
+
merged = primary_flow
|
|
584
|
+
|
|
585
|
+
if (
|
|
586
|
+
self.multiscale_retain_spatial_dims is False
|
|
587
|
+
or self.decoder_retain_spatial_dims is True
|
|
588
|
+
):
|
|
589
|
+
return merged, merged
|
|
590
|
+
|
|
591
|
+
if self.output_expected_shape is not None:
|
|
592
|
+
expected_shape = self.output_expected_shape
|
|
593
|
+
else:
|
|
594
|
+
fac = self.multiscale_lowres_size_factor
|
|
595
|
+
expected_shape = (merged.shape[-2] // fac, merged.shape[-1] // fac)
|
|
596
|
+
assert merged.shape[-2:] != expected_shape
|
|
597
|
+
|
|
598
|
+
# Crop the resulting tensor so that it matches with the Decoder
|
|
599
|
+
value_to_use_in_topdown = crop_img_tensor(merged, expected_shape)
|
|
600
|
+
return merged, value_to_use_in_topdown
|
|
601
|
+
|
|
602
|
+
|
|
603
|
+
class MergeLayer(nn.Module):
|
|
604
|
+
"""
|
|
605
|
+
This layer merges two or more 4D input tensors by concatenating along dim=1 and passes the result through:
|
|
606
|
+
a) a convolutional 1x1 layer (`merge_type == "linear"`), or
|
|
607
|
+
b) a convolutional 1x1 layer and then a gated residual block (`merge_type == "residual"`), or
|
|
608
|
+
c) a convolutional 1x1 layer and then an ungated residual block (`merge_type == "residual_ungated"`).
|
|
609
|
+
"""
|
|
610
|
+
|
|
611
|
+
def __init__(
|
|
612
|
+
self,
|
|
613
|
+
merge_type: Literal["linear", "residual", "residual_ungated"],
|
|
614
|
+
channels: Union[int, Iterable[int]],
|
|
615
|
+
nonlin: Callable = nn.LeakyReLU,
|
|
616
|
+
batchnorm: bool = True,
|
|
617
|
+
dropout: float = None,
|
|
618
|
+
res_block_type: str = None,
|
|
619
|
+
res_block_kernel: int = None,
|
|
620
|
+
res_block_skip_padding: bool = False,
|
|
621
|
+
conv2d_bias: bool = True,
|
|
622
|
+
):
|
|
623
|
+
"""
|
|
624
|
+
Constructor.
|
|
625
|
+
|
|
626
|
+
Parameters
|
|
627
|
+
----------
|
|
628
|
+
merge_type: Literal["linear", "residual", "residual_ungated"]
|
|
629
|
+
The type of merge done in the layer. It can be chosen between "linear", "residual", and "residual_ungated".
|
|
630
|
+
Check the class docstring for more information about the behaviour of different merge modalities.
|
|
631
|
+
channels: Union[int, Iterable[int]]
|
|
632
|
+
The number of channels used in the convolutional blocks of this layer.
|
|
633
|
+
If it is an `int`:
|
|
634
|
+
- 1st 1x1 Conv2d: in_channels=2*channels, out_channels=channels
|
|
635
|
+
- (Optional) ResBlock: in_channels=channels, out_channels=channels
|
|
636
|
+
If it is an Iterable (must have `len(channels)==3`):
|
|
637
|
+
- 1st 1x1 Conv2d: in_channels=sum(channels[:-1]), out_channels=channels[-1]
|
|
638
|
+
- (Optional) ResBlock: in_channels=channels[-1], out_channels=channels[-1]
|
|
639
|
+
nonlin: Callable, optional
|
|
640
|
+
The non-linearity function used in the block. Default is `nn.LeakyReLU`.
|
|
641
|
+
batchnorm: bool, optional
|
|
642
|
+
Whether to use batchnorm layers. Default is `True`.
|
|
643
|
+
dropout: float, optional
|
|
644
|
+
The dropout probability in dropout layers. If `None` dropout is not used.
|
|
645
|
+
Default is `None`.
|
|
646
|
+
res_block_type: str, optional
|
|
647
|
+
A string specifying the structure of residual block.
|
|
648
|
+
Check `ResidualBlock` doscstring for more information.
|
|
649
|
+
Default is `None`.
|
|
650
|
+
res_block_kernel: Union[int, Iterable[int]], optional
|
|
651
|
+
The kernel size used in the convolutions of the residual block.
|
|
652
|
+
It can be either a single integer or a pair of integers defining the squared kernel.
|
|
653
|
+
Default is `None`.
|
|
654
|
+
res_block_skip_padding: bool, optional
|
|
655
|
+
Whether to skip padding in convolutions in the Residual block. Default is `False`.
|
|
656
|
+
conv2d_bias: bool, optional
|
|
657
|
+
Whether to use bias term in convolutions. Default is `True`.
|
|
658
|
+
"""
|
|
659
|
+
super().__init__()
|
|
660
|
+
try:
|
|
661
|
+
iter(channels)
|
|
662
|
+
except TypeError: # it is not iterable
|
|
663
|
+
channels = [channels] * 3
|
|
664
|
+
else: # it is iterable
|
|
665
|
+
if len(channels) == 1:
|
|
666
|
+
channels = [channels[0]] * 3
|
|
667
|
+
|
|
668
|
+
# assert len(channels) == 3
|
|
669
|
+
|
|
670
|
+
if merge_type == "linear":
|
|
671
|
+
self.layer = nn.Conv2d(
|
|
672
|
+
sum(channels[:-1]), channels[-1], 1, bias=conv2d_bias
|
|
673
|
+
)
|
|
674
|
+
elif merge_type == "residual":
|
|
675
|
+
self.layer = nn.Sequential(
|
|
676
|
+
nn.Conv2d(
|
|
677
|
+
sum(channels[:-1]), channels[-1], 1, padding=0, bias=conv2d_bias
|
|
678
|
+
),
|
|
679
|
+
ResidualGatedBlock(
|
|
680
|
+
channels[-1],
|
|
681
|
+
nonlin,
|
|
682
|
+
batchnorm=batchnorm,
|
|
683
|
+
dropout=dropout,
|
|
684
|
+
block_type=res_block_type,
|
|
685
|
+
kernel=res_block_kernel,
|
|
686
|
+
conv2d_bias=conv2d_bias,
|
|
687
|
+
skip_padding=res_block_skip_padding,
|
|
688
|
+
),
|
|
689
|
+
)
|
|
690
|
+
elif merge_type == "residual_ungated":
|
|
691
|
+
self.layer = nn.Sequential(
|
|
692
|
+
nn.Conv2d(
|
|
693
|
+
sum(channels[:-1]), channels[-1], 1, padding=0, bias=conv2d_bias
|
|
694
|
+
),
|
|
695
|
+
ResidualBlock(
|
|
696
|
+
channels[-1],
|
|
697
|
+
nonlin,
|
|
698
|
+
batchnorm=batchnorm,
|
|
699
|
+
dropout=dropout,
|
|
700
|
+
block_type=res_block_type,
|
|
701
|
+
kernel=res_block_kernel,
|
|
702
|
+
conv2d_bias=conv2d_bias,
|
|
703
|
+
skip_padding=res_block_skip_padding,
|
|
704
|
+
),
|
|
705
|
+
)
|
|
706
|
+
|
|
707
|
+
def forward(self, *args) -> torch.Tensor:
|
|
708
|
+
|
|
709
|
+
# Concatenate the input tensors along dim=1
|
|
710
|
+
x = torch.cat(args, dim=1)
|
|
711
|
+
|
|
712
|
+
# Pass the concatenated tensor through the conv layer
|
|
713
|
+
x = self.layer(x)
|
|
714
|
+
|
|
715
|
+
return x
|
|
716
|
+
|
|
717
|
+
|
|
718
|
+
class MergeLowRes(MergeLayer):
|
|
719
|
+
"""
|
|
720
|
+
Child class of `MergeLayer`, specifically designed to merge the low-resolution patches
|
|
721
|
+
that are used in Lateral Contextualization approach.
|
|
722
|
+
"""
|
|
723
|
+
|
|
724
|
+
def __init__(self, *args, **kwargs):
|
|
725
|
+
self.retain_spatial_dims = kwargs.pop("multiscale_retain_spatial_dims")
|
|
726
|
+
self.multiscale_lowres_size_factor = kwargs.pop("multiscale_lowres_size_factor")
|
|
727
|
+
super().__init__(*args, **kwargs)
|
|
728
|
+
|
|
729
|
+
def forward(self, latent: torch.Tensor, lowres: torch.Tensor) -> torch.Tensor:
|
|
730
|
+
"""
|
|
731
|
+
Parameters
|
|
732
|
+
----------
|
|
733
|
+
latent: torch.Tensor
|
|
734
|
+
The output latent tensor from previous layer in the LVAE hierarchy.
|
|
735
|
+
lowres: torch.Tensor
|
|
736
|
+
The low-res patch image to be merged to increase the context.
|
|
737
|
+
"""
|
|
738
|
+
if self.retain_spatial_dims:
|
|
739
|
+
# Pad latent tensor to match lowres tensor's shape
|
|
740
|
+
latent = pad_img_tensor(latent, lowres.shape[2:])
|
|
741
|
+
else:
|
|
742
|
+
# Crop lowres tensor to match latent tensor's shape
|
|
743
|
+
lh, lw = lowres.shape[-2:]
|
|
744
|
+
h = lh // self.multiscale_lowres_size_factor
|
|
745
|
+
w = lw // self.multiscale_lowres_size_factor
|
|
746
|
+
h_pad = (lh - h) // 2
|
|
747
|
+
w_pad = (lw - w) // 2
|
|
748
|
+
lowres = lowres[:, :, h_pad:-h_pad, w_pad:-w_pad]
|
|
749
|
+
|
|
750
|
+
return super().forward(latent, lowres)
|
|
751
|
+
|
|
752
|
+
|
|
753
|
+
class SkipConnectionMerger(MergeLayer):
|
|
754
|
+
"""
|
|
755
|
+
A specialized `MergeLayer` module, designed to handle skip connections in the model.
|
|
756
|
+
"""
|
|
757
|
+
|
|
758
|
+
def __init__(
|
|
759
|
+
self,
|
|
760
|
+
nonlin: Callable,
|
|
761
|
+
channels: Union[int, Iterable[int]],
|
|
762
|
+
batchnorm: bool,
|
|
763
|
+
dropout: float,
|
|
764
|
+
res_block_type: str,
|
|
765
|
+
merge_type: Literal["linear", "residual", "residual_ungated"] = "residual",
|
|
766
|
+
conv2d_bias: bool = True,
|
|
767
|
+
res_block_kernel: int = None,
|
|
768
|
+
res_block_skip_padding: bool = False,
|
|
769
|
+
):
|
|
770
|
+
"""
|
|
771
|
+
Constructor.
|
|
772
|
+
|
|
773
|
+
nonlin: Callable, optional
|
|
774
|
+
The non-linearity function used in the block. Default is `nn.LeakyReLU`.
|
|
775
|
+
channels: Union[int, Iterable[int]]
|
|
776
|
+
The number of channels used in the convolutional blocks of this layer.
|
|
777
|
+
If it is an `int`:
|
|
778
|
+
- 1st 1x1 Conv2d: in_channels=2*channels, out_channels=channels
|
|
779
|
+
- (Optional) ResBlock: in_channels=channels, out_channels=channels
|
|
780
|
+
If it is an Iterable (must have `len(channels)==3`):
|
|
781
|
+
- 1st 1x1 Conv2d: in_channels=sum(channels[:-1]), out_channels=channels[-1]
|
|
782
|
+
- (Optional) ResBlock: in_channels=channels[-1], out_channels=channels[-1]
|
|
783
|
+
batchnorm: bool, optional
|
|
784
|
+
Whether to use batchnorm layers. Default is `True`.
|
|
785
|
+
dropout: float, optional
|
|
786
|
+
The dropout probability in dropout layers. If `None` dropout is not used.
|
|
787
|
+
Default is `None`.
|
|
788
|
+
res_block_type: str, optional
|
|
789
|
+
A string specifying the structure of residual block.
|
|
790
|
+
Check `ResidualBlock` doscstring for more information.
|
|
791
|
+
Default is `None`.
|
|
792
|
+
merge_type: Literal["linear", "residual", "residual_ungated"]
|
|
793
|
+
The type of merge done in the layer. It can be chosen between "linear", "residual", and "residual_ungated".
|
|
794
|
+
Check the class docstring for more information about the behaviour of different merge modalities.
|
|
795
|
+
conv2d_bias: bool, optional
|
|
796
|
+
Whether to use bias term in convolutions. Default is `True`.
|
|
797
|
+
res_block_kernel: Union[int, Iterable[int]], optional
|
|
798
|
+
The kernel size used in the convolutions of the residual block.
|
|
799
|
+
It can be either a single integer or a pair of integers defining the squared kernel.
|
|
800
|
+
Default is `None`.
|
|
801
|
+
res_block_skip_padding: bool, optional
|
|
802
|
+
Whether to skip padding in convolutions in the Residual block. Default is `False`.
|
|
803
|
+
"""
|
|
804
|
+
super().__init__(
|
|
805
|
+
channels=channels,
|
|
806
|
+
nonlin=nonlin,
|
|
807
|
+
merge_type=merge_type,
|
|
808
|
+
batchnorm=batchnorm,
|
|
809
|
+
dropout=dropout,
|
|
810
|
+
res_block_type=res_block_type,
|
|
811
|
+
res_block_kernel=res_block_kernel,
|
|
812
|
+
conv2d_bias=conv2d_bias,
|
|
813
|
+
res_block_skip_padding=res_block_skip_padding,
|
|
814
|
+
)
|
|
815
|
+
|
|
816
|
+
|
|
817
|
+
class TopDownLayer(nn.Module):
|
|
818
|
+
"""
|
|
819
|
+
Top-down inference layer.
|
|
820
|
+
It includes:
|
|
821
|
+
- Stochastic sampling,
|
|
822
|
+
- Computation of KL divergence,
|
|
823
|
+
- A small deterministic ResNet that performs upsampling.
|
|
824
|
+
|
|
825
|
+
NOTE 1:
|
|
826
|
+
The algorithm for generative inference approximately works as follows:
|
|
827
|
+
- p_params = output of top-down layer above
|
|
828
|
+
- bu = inferred bottom-up value at this layer
|
|
829
|
+
- q_params = merge(bu, p_params)
|
|
830
|
+
- z = stochastic_layer(q_params)
|
|
831
|
+
- (optional) get and merge skip connection from prev top-down layer
|
|
832
|
+
- top-down deterministic ResNet
|
|
833
|
+
|
|
834
|
+
NOTE 2:
|
|
835
|
+
The Top-Down layer can work in two modes: inference and prediction/generative.
|
|
836
|
+
Depending on the particular mode, it follows distinct behaviours:
|
|
837
|
+
- In inference mode, parameters of q(z_i|z_i+1) are obtained from the inference path,
|
|
838
|
+
by merging outcomes of bottom-up and top-down passes. The exception is the top layer,
|
|
839
|
+
in which the parameters of q(z_L|x) are set as the output of the topmost bottom-up layer.
|
|
840
|
+
- On the contrary in prediciton/generative mode, parameters of q(z_i|z_i+1) can be obtained
|
|
841
|
+
once again by merging bottom-up and top-down outputs (CONDITIONAL GENERATION), or it is
|
|
842
|
+
possible to directly sample from the prior p(z_i|z_i+1) (UNCONDITIONAL GENERATION).
|
|
843
|
+
|
|
844
|
+
NOTE 3:
|
|
845
|
+
When doing unconditional generation, bu_value is not available. Hence the
|
|
846
|
+
merge layer is not used, and z is sampled directly from p_params.
|
|
847
|
+
|
|
848
|
+
NOTE 4:
|
|
849
|
+
If this is the top layer, at inference time, the uppermost bottom-up value
|
|
850
|
+
is used directly as q_params, and p_params are defined in this layer
|
|
851
|
+
(while they are usually taken from the previous layer), and can be learned.
|
|
852
|
+
"""
|
|
853
|
+
|
|
854
|
+
def __init__(
|
|
855
|
+
self,
|
|
856
|
+
z_dim: int,
|
|
857
|
+
n_res_blocks: int,
|
|
858
|
+
n_filters: int,
|
|
859
|
+
is_top_layer: bool = False,
|
|
860
|
+
downsampling_steps: int = None,
|
|
861
|
+
nonlin: Callable = None,
|
|
862
|
+
merge_type: Literal["linear", "residual", "residual_ungated"] = None,
|
|
863
|
+
batchnorm: bool = True,
|
|
864
|
+
dropout: float = None,
|
|
865
|
+
stochastic_skip: bool = False,
|
|
866
|
+
res_block_type: str = None,
|
|
867
|
+
res_block_kernel: int = None,
|
|
868
|
+
res_block_skip_padding: bool = None,
|
|
869
|
+
groups: int = 1,
|
|
870
|
+
gated: bool = None,
|
|
871
|
+
learn_top_prior: bool = False,
|
|
872
|
+
top_prior_param_shape: Iterable[int] = None,
|
|
873
|
+
analytical_kl: bool = False,
|
|
874
|
+
bottomup_no_padding_mode: bool = False,
|
|
875
|
+
topdown_no_padding_mode: bool = False,
|
|
876
|
+
retain_spatial_dims: bool = False,
|
|
877
|
+
restricted_kl: bool = False,
|
|
878
|
+
vanilla_latent_hw: Iterable[int] = None,
|
|
879
|
+
non_stochastic_version: bool = False,
|
|
880
|
+
input_image_shape: Union[None, Tuple[int, int]] = None,
|
|
881
|
+
normalize_latent_factor: float = 1.0,
|
|
882
|
+
conv2d_bias: bool = True,
|
|
883
|
+
stochastic_use_naive_exponential: bool = False,
|
|
884
|
+
):
|
|
885
|
+
"""
|
|
886
|
+
Constructor.
|
|
887
|
+
|
|
888
|
+
Parameters
|
|
889
|
+
----------
|
|
890
|
+
z_dim: int
|
|
891
|
+
The size of the latent space.
|
|
892
|
+
n_res_blocks: int
|
|
893
|
+
The number of TopDownDeterministicResBlock blocks
|
|
894
|
+
n_filters: int
|
|
895
|
+
The number of channels present through out the layers of this block.
|
|
896
|
+
is_top_layer: bool, optional
|
|
897
|
+
Whether the current layer is at the top of the Decoder hierarchy. Default is `False`.
|
|
898
|
+
downsampling_steps: int, optional
|
|
899
|
+
The number of downsampling steps that has to be done in this layer (typically 1).
|
|
900
|
+
Default is `False`.
|
|
901
|
+
nonlin: Callable, optional
|
|
902
|
+
The non-linearity function used in the block (e.g., `nn.ReLU`). Deafault is `None`.
|
|
903
|
+
merge_type: Literal["linear", "residual", "residual_ungated"], optional
|
|
904
|
+
The type of merge done in the layer. It can be chosen between "linear", "residual",
|
|
905
|
+
and "residual_ungated". Check the `MergeLayer` class docstring for more information
|
|
906
|
+
about the behaviour of different merging modalities. Default is `None`.
|
|
907
|
+
batchnorm: bool, optional
|
|
908
|
+
Whether to use batchnorm layers. Default is `True`.
|
|
909
|
+
dropout: float, optional
|
|
910
|
+
The dropout probability in dropout layers. If `None` dropout is not used.
|
|
911
|
+
Default is `None`.
|
|
912
|
+
stochastic_skip: bool, optional
|
|
913
|
+
Whether to use skip connections between previous top-down layer's output and this layer's stochastic output.
|
|
914
|
+
Stochastic skip connection allows the previous layer's output has a way to directly reach this hierarchical
|
|
915
|
+
level, hence facilitating the gradient flow during backpropagation. Default is `False`.
|
|
916
|
+
res_block_type: str, optional
|
|
917
|
+
A string specifying the structure of residual block.
|
|
918
|
+
Check `ResidualBlock` documentation for more information.
|
|
919
|
+
Default is `None`.
|
|
920
|
+
res_block_kernel: Union[int, Iterable[int]], optional
|
|
921
|
+
The kernel size used in the convolutions of the residual block.
|
|
922
|
+
It can be either a single integer or a pair of integers defining the squared kernel.
|
|
923
|
+
Default is `None`.
|
|
924
|
+
res_block_skip_padding: bool, optional
|
|
925
|
+
Whether to skip padding in convolutions in the Residual block. Default is `None`.
|
|
926
|
+
groups: int, optional
|
|
927
|
+
The number of groups to consider in the convolutions. Default is 1.
|
|
928
|
+
gated: bool, optional
|
|
929
|
+
Whether to use gated layer in `ResidualBlock`. Default is `None`.
|
|
930
|
+
learn_top_prior:
|
|
931
|
+
Whether to set the top prior as learnable.
|
|
932
|
+
If this is set to `False`, in the top-most layer the prior will be N(0,1).
|
|
933
|
+
Otherwise, we will still have a normal distribution whose parameters will be learnt.
|
|
934
|
+
Deafult is `False`.
|
|
935
|
+
top_prior_param_shape: Iterable[int], optional
|
|
936
|
+
The size of the tensor which expresses the mean and the variance
|
|
937
|
+
of the prior for the top most layer. Default is `None`.
|
|
938
|
+
analytical_kl: bool, optional
|
|
939
|
+
If True, KL divergence is calculated according to the analytical formula.
|
|
940
|
+
Otherwise, an MC approximation using sampled latents is calculated.
|
|
941
|
+
Default is `False`.
|
|
942
|
+
bottomup_no_padding_mode: bool, optional
|
|
943
|
+
Whether padding is used in the different layers of the bottom-up pass.
|
|
944
|
+
It is meaningful to know this in advance in order to assess whether before
|
|
945
|
+
merging `bu_values` and `p_params` tensors any alignment is needed.
|
|
946
|
+
Default is `False`.
|
|
947
|
+
topdown_no_padding_mode: bool, optional
|
|
948
|
+
Whether padding is used in the different layers of the top-down pass.
|
|
949
|
+
It is meaningful to know this in advance in order to assess whether before
|
|
950
|
+
merging `bu_values` and `p_params` tensors any alignment is needed.
|
|
951
|
+
The same information is also needed in handling the skip connections between
|
|
952
|
+
top-down layers. Default is `False`.
|
|
953
|
+
retain_spatial_dims: bool, optional
|
|
954
|
+
If `True`, the size of Encoder's latent space is kept to `input_image_shape` within the topdown layer.
|
|
955
|
+
This implies that the oput spatial size equals the input spatial size.
|
|
956
|
+
To achieve this, we centercrop the intermediate representation.
|
|
957
|
+
Default is `False`.
|
|
958
|
+
restricted_kl: bool, optional
|
|
959
|
+
Whether to compute the restricted version of KL Divergence.
|
|
960
|
+
See `NormalStochasticBlock2d` module for more information about its computation.
|
|
961
|
+
Default is `False`.
|
|
962
|
+
vanilla_latent_hw: Iterable[int], optional
|
|
963
|
+
The shape of the latent tensor used for prediction (i.e., it influences the computation of restricted KL).
|
|
964
|
+
Default is `None`.
|
|
965
|
+
non_stochastic_version: bool, optional
|
|
966
|
+
Whether to replace the stochastic layer that samples a latent variable from the latent distribiution with
|
|
967
|
+
a non-stochastic layer that simply drwas a sample as the mode of the latent distribution.
|
|
968
|
+
Default is `False`.
|
|
969
|
+
input_image_shape: Tuple[int, int], optionalut
|
|
970
|
+
The shape of the input image tensor.
|
|
971
|
+
When `retain_spatial_dims` is set to `True`, this is used to ensure that the shape of this layer
|
|
972
|
+
output has the same shape as the input. Default is `None`.
|
|
973
|
+
normalize_latent_factor: float, optional
|
|
974
|
+
A factor used to normalize the latent tensors `q_params`.
|
|
975
|
+
Specifically, normalization is done by dividing the latent tensor by this factor.
|
|
976
|
+
Default is 1.0.
|
|
977
|
+
conv2d_bias: bool, optional
|
|
978
|
+
Whether to use bias term is the convolutional blocks of this layer.
|
|
979
|
+
Default is `True`.
|
|
980
|
+
stochastic_use_naive_exponential: bool, optional
|
|
981
|
+
If `False`, in the NormalStochasticBlock2d exponentials are computed according
|
|
982
|
+
to the alternative definition provided by `StableExponential` class.
|
|
983
|
+
This should improve numerical stability in the training process.
|
|
984
|
+
Default is `False`.
|
|
985
|
+
"""
|
|
986
|
+
super().__init__()
|
|
987
|
+
|
|
988
|
+
self.is_top_layer = is_top_layer
|
|
989
|
+
self.z_dim = z_dim
|
|
990
|
+
self.stochastic_skip = stochastic_skip
|
|
991
|
+
self.learn_top_prior = learn_top_prior
|
|
992
|
+
self.analytical_kl = analytical_kl
|
|
993
|
+
self.bottomup_no_padding_mode = bottomup_no_padding_mode
|
|
994
|
+
self.topdown_no_padding_mode = topdown_no_padding_mode
|
|
995
|
+
self.retain_spatial_dims = retain_spatial_dims
|
|
996
|
+
self.latent_shape = input_image_shape if self.retain_spatial_dims else None
|
|
997
|
+
self.non_stochastic_version = non_stochastic_version
|
|
998
|
+
self.normalize_latent_factor = normalize_latent_factor
|
|
999
|
+
self._vanilla_latent_hw = vanilla_latent_hw
|
|
1000
|
+
|
|
1001
|
+
# Define top layer prior parameters, possibly learnable
|
|
1002
|
+
if is_top_layer:
|
|
1003
|
+
self.top_prior_params = nn.Parameter(
|
|
1004
|
+
torch.zeros(top_prior_param_shape), requires_grad=learn_top_prior
|
|
1005
|
+
)
|
|
1006
|
+
|
|
1007
|
+
# Downsampling steps left to do in this layer
|
|
1008
|
+
dws_left = downsampling_steps
|
|
1009
|
+
|
|
1010
|
+
# Define deterministic top-down block, which is a sequence of deterministic
|
|
1011
|
+
# residual blocks with (optional) downsampling.
|
|
1012
|
+
block_list = []
|
|
1013
|
+
for _ in range(n_res_blocks):
|
|
1014
|
+
do_resample = False
|
|
1015
|
+
if dws_left > 0:
|
|
1016
|
+
do_resample = True
|
|
1017
|
+
dws_left -= 1
|
|
1018
|
+
block_list.append(
|
|
1019
|
+
TopDownDeterministicResBlock(
|
|
1020
|
+
c_in=n_filters,
|
|
1021
|
+
c_out=n_filters,
|
|
1022
|
+
nonlin=nonlin,
|
|
1023
|
+
upsample=do_resample,
|
|
1024
|
+
batchnorm=batchnorm,
|
|
1025
|
+
dropout=dropout,
|
|
1026
|
+
res_block_type=res_block_type,
|
|
1027
|
+
res_block_kernel=res_block_kernel,
|
|
1028
|
+
skip_padding=res_block_skip_padding,
|
|
1029
|
+
gated=gated,
|
|
1030
|
+
conv2d_bias=conv2d_bias,
|
|
1031
|
+
groups=groups,
|
|
1032
|
+
)
|
|
1033
|
+
)
|
|
1034
|
+
self.deterministic_block = nn.Sequential(*block_list)
|
|
1035
|
+
|
|
1036
|
+
# Define stochastic block with 2D convolutions
|
|
1037
|
+
if self.non_stochastic_version:
|
|
1038
|
+
self.stochastic = NonStochasticBlock2d(
|
|
1039
|
+
c_in=n_filters,
|
|
1040
|
+
c_vars=z_dim,
|
|
1041
|
+
c_out=n_filters,
|
|
1042
|
+
transform_p_params=(not is_top_layer),
|
|
1043
|
+
groups=groups,
|
|
1044
|
+
conv2d_bias=conv2d_bias,
|
|
1045
|
+
)
|
|
1046
|
+
else:
|
|
1047
|
+
self.stochastic = NormalStochasticBlock2d(
|
|
1048
|
+
c_in=n_filters,
|
|
1049
|
+
c_vars=z_dim,
|
|
1050
|
+
c_out=n_filters,
|
|
1051
|
+
transform_p_params=(not is_top_layer),
|
|
1052
|
+
vanilla_latent_hw=vanilla_latent_hw,
|
|
1053
|
+
restricted_kl=restricted_kl,
|
|
1054
|
+
use_naive_exponential=stochastic_use_naive_exponential,
|
|
1055
|
+
)
|
|
1056
|
+
|
|
1057
|
+
if not is_top_layer:
|
|
1058
|
+
# Merge layer: it combines bottom-up inference and top-down
|
|
1059
|
+
# generative outcomes to give posterior parameters
|
|
1060
|
+
self.merge = MergeLayer(
|
|
1061
|
+
channels=n_filters,
|
|
1062
|
+
merge_type=merge_type,
|
|
1063
|
+
nonlin=nonlin,
|
|
1064
|
+
batchnorm=batchnorm,
|
|
1065
|
+
dropout=dropout,
|
|
1066
|
+
res_block_type=res_block_type,
|
|
1067
|
+
res_block_kernel=res_block_kernel,
|
|
1068
|
+
conv2d_bias=conv2d_bias,
|
|
1069
|
+
)
|
|
1070
|
+
|
|
1071
|
+
# Skip connection that goes around the stochastic top-down layer
|
|
1072
|
+
if stochastic_skip:
|
|
1073
|
+
self.skip_connection_merger = SkipConnectionMerger(
|
|
1074
|
+
channels=n_filters,
|
|
1075
|
+
nonlin=nonlin,
|
|
1076
|
+
batchnorm=batchnorm,
|
|
1077
|
+
dropout=dropout,
|
|
1078
|
+
res_block_type=res_block_type,
|
|
1079
|
+
merge_type=merge_type,
|
|
1080
|
+
conv2d_bias=conv2d_bias,
|
|
1081
|
+
res_block_kernel=res_block_kernel,
|
|
1082
|
+
res_block_skip_padding=res_block_skip_padding,
|
|
1083
|
+
)
|
|
1084
|
+
|
|
1085
|
+
# print(f'[{self.__class__.__name__}] normalize_latent_factor:{self.normalize_latent_factor}')
|
|
1086
|
+
|
|
1087
|
+
def sample_from_q(
|
|
1088
|
+
self,
|
|
1089
|
+
input_: torch.Tensor,
|
|
1090
|
+
bu_value: torch.Tensor,
|
|
1091
|
+
var_clip_max: float = None,
|
|
1092
|
+
mask: torch.Tensor = None,
|
|
1093
|
+
) -> torch.Tensor:
|
|
1094
|
+
"""
|
|
1095
|
+
This method computes the latent inference distribution q(z_i|z_{i+1}) amd samples a latent tensor from it.
|
|
1096
|
+
|
|
1097
|
+
Parameters
|
|
1098
|
+
----------
|
|
1099
|
+
input_: torch.Tensor
|
|
1100
|
+
The input tensor to the layer, which is the output of the top-down layer above.
|
|
1101
|
+
bu_value: torch.Tensor
|
|
1102
|
+
The tensor defining the parameters /mu_q and /sigma_q computed during the bottom-up deterministic pass
|
|
1103
|
+
at the correspondent hierarchical layer.
|
|
1104
|
+
var_clip_max: float, optional
|
|
1105
|
+
The maximum value reachable by the log-variance of the latent distribtion.
|
|
1106
|
+
Values exceeding this threshold are clipped. Default is `None`.
|
|
1107
|
+
mask: Union[None, torch.Tensor], optional
|
|
1108
|
+
A tensor that is used to mask the sampled latent tensor. Default is `None`.
|
|
1109
|
+
"""
|
|
1110
|
+
if self.is_top_layer: # In top layer, we don't merge bu_value with p_params
|
|
1111
|
+
q_params = bu_value
|
|
1112
|
+
else:
|
|
1113
|
+
# NOTE: Here the assumption is that the vampprior is only applied on the top layer.
|
|
1114
|
+
n_img_prior = None
|
|
1115
|
+
p_params = self.get_p_params(input_, n_img_prior)
|
|
1116
|
+
q_params = self.merge(bu_value, p_params)
|
|
1117
|
+
|
|
1118
|
+
sample = self.stochastic.sample_from_q(q_params, var_clip_max)
|
|
1119
|
+
|
|
1120
|
+
if mask:
|
|
1121
|
+
return sample[mask]
|
|
1122
|
+
|
|
1123
|
+
return sample
|
|
1124
|
+
|
|
1125
|
+
def get_p_params(
|
|
1126
|
+
self,
|
|
1127
|
+
input_: torch.Tensor,
|
|
1128
|
+
n_img_prior: int,
|
|
1129
|
+
) -> torch.Tensor:
|
|
1130
|
+
"""
|
|
1131
|
+
This method returns the parameters of the prior distribution p(z_i|z_{i+1}) for the latent tensor
|
|
1132
|
+
depending on the hierarchical level of the layer and other specific conditions.
|
|
1133
|
+
|
|
1134
|
+
Parameters
|
|
1135
|
+
----------
|
|
1136
|
+
input_: torch.Tensor
|
|
1137
|
+
The input tensor to the layer, which is the output of the top-down layer above.
|
|
1138
|
+
n_img_prior: int
|
|
1139
|
+
The number of images to be generated from the unconditional prior distribution p(z_L).
|
|
1140
|
+
"""
|
|
1141
|
+
p_params = None
|
|
1142
|
+
|
|
1143
|
+
# If top layer, define p_params as the ones of the prior p(z_L)
|
|
1144
|
+
if self.is_top_layer:
|
|
1145
|
+
p_params = self.top_prior_params
|
|
1146
|
+
|
|
1147
|
+
# Sample specific number of images by expanding the prior
|
|
1148
|
+
if n_img_prior is not None:
|
|
1149
|
+
p_params = p_params.expand(n_img_prior, -1, -1, -1)
|
|
1150
|
+
|
|
1151
|
+
# Else the input from the layer above is p_params itself
|
|
1152
|
+
else:
|
|
1153
|
+
p_params = input_
|
|
1154
|
+
|
|
1155
|
+
return p_params
|
|
1156
|
+
|
|
1157
|
+
def align_pparams_buvalue(
|
|
1158
|
+
self, p_params: torch.Tensor, bu_value: torch.Tensor
|
|
1159
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
1160
|
+
"""
|
|
1161
|
+
In case the padding is not used either (or both) in encoder and decoder, we could have a shape mismatch
|
|
1162
|
+
in the spatial dimensions (usually, dim=2 & dim=3).
|
|
1163
|
+
This method performs a centercrop to ensure that both remain aligned.
|
|
1164
|
+
|
|
1165
|
+
Parameters
|
|
1166
|
+
----------
|
|
1167
|
+
p_params: torch.Tensor
|
|
1168
|
+
The tensor defining the parameters /mu_p and /sigma_p for the latent distribution p(z_i|z_{i+1}).
|
|
1169
|
+
bu_value: torch.Tensor
|
|
1170
|
+
The tensor defining the parameters /mu_q and /sigma_q computed during the bottom-up deterministic pass
|
|
1171
|
+
at the correspondent hierarchical layer.
|
|
1172
|
+
"""
|
|
1173
|
+
if bu_value.shape[-2:] != p_params.shape[-2:]:
|
|
1174
|
+
assert self.bottomup_no_padding_mode is True
|
|
1175
|
+
if self.topdown_no_padding_mode is False:
|
|
1176
|
+
assert bu_value.shape[-1] > p_params.shape[-1]
|
|
1177
|
+
bu_value = F.center_crop(bu_value, p_params.shape[-2:])
|
|
1178
|
+
else:
|
|
1179
|
+
if bu_value.shape[-1] > p_params.shape[-1]:
|
|
1180
|
+
bu_value = F.center_crop(bu_value, p_params.shape[-2:])
|
|
1181
|
+
else:
|
|
1182
|
+
p_params = F.center_crop(p_params, bu_value.shape[-2:])
|
|
1183
|
+
return p_params, bu_value
|
|
1184
|
+
|
|
1185
|
+
def forward(
|
|
1186
|
+
self,
|
|
1187
|
+
input_: torch.Tensor = None,
|
|
1188
|
+
skip_connection_input: torch.Tensor = None,
|
|
1189
|
+
inference_mode: bool = False,
|
|
1190
|
+
bu_value: torch.Tensor = None,
|
|
1191
|
+
n_img_prior: int = None,
|
|
1192
|
+
forced_latent: torch.Tensor = None,
|
|
1193
|
+
use_mode: bool = False,
|
|
1194
|
+
force_constant_output: bool = False,
|
|
1195
|
+
mode_pred: bool = False,
|
|
1196
|
+
use_uncond_mode: bool = False,
|
|
1197
|
+
var_clip_max: float = None,
|
|
1198
|
+
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
|
|
1199
|
+
"""
|
|
1200
|
+
Parameters
|
|
1201
|
+
----------
|
|
1202
|
+
input_: torch.Tensor, optional
|
|
1203
|
+
The input tensor to the layer, which is the output of the top-down layer above.
|
|
1204
|
+
Default is `None`.
|
|
1205
|
+
skip_connection_input: torch.Tensor, optional
|
|
1206
|
+
The tensor brought by the skip connection between the current and the previous top-down layer.
|
|
1207
|
+
Default is `None`.
|
|
1208
|
+
inference_mode: bool, optional
|
|
1209
|
+
Whether the layer is in inference mode. See NOTE 2 in class description for more info.
|
|
1210
|
+
Default is `False`.
|
|
1211
|
+
bu_value: torch.Tensor, optional
|
|
1212
|
+
The tensor defining the parameters /mu_q and /sigma_q computed during the bottom-up deterministic pass
|
|
1213
|
+
at the correspondent hierarchical layer. Default is `None`.
|
|
1214
|
+
n_img_prior: int, optional
|
|
1215
|
+
The number of images to be generated from the unconditional prior distribution p(z_L).
|
|
1216
|
+
Default is `None`.
|
|
1217
|
+
forced_latent: torch.Tensor, optional
|
|
1218
|
+
A pre-defined latent tensor. If it is not `None`, than it is used as the actual latent tensor and,
|
|
1219
|
+
hence, sampling does not happen. Default is `None`.
|
|
1220
|
+
use_mode: bool, optional
|
|
1221
|
+
Wheteher the latent tensor should be set as the latent distribution mode.
|
|
1222
|
+
In the case of Gaussian, the mode coincides with the mean of the distribution.
|
|
1223
|
+
Default is `False`.
|
|
1224
|
+
force_constant_output: bool, optional
|
|
1225
|
+
Whether to copy the first sample (and rel. distrib parameters) over the whole batch.
|
|
1226
|
+
This is used when doing experiment from the prior - q is not used.
|
|
1227
|
+
Default is `False`.
|
|
1228
|
+
mode_pred: bool, optional
|
|
1229
|
+
Whether the model is in prediction mode. Default is `False`.
|
|
1230
|
+
use_uncond_mode: bool, optional
|
|
1231
|
+
Whether to use the uncoditional distribution p(z) to sample latents in prediction mode.
|
|
1232
|
+
var_clip_max: float
|
|
1233
|
+
The maximum value reachable by the log-variance of the latent distribtion.
|
|
1234
|
+
Values exceeding this threshold are clipped.
|
|
1235
|
+
"""
|
|
1236
|
+
# Check consistency of arguments
|
|
1237
|
+
inputs_none = input_ is None and skip_connection_input is None
|
|
1238
|
+
if self.is_top_layer and not inputs_none:
|
|
1239
|
+
raise ValueError("In top layer, inputs should be None")
|
|
1240
|
+
|
|
1241
|
+
p_params = self.get_p_params(input_, n_img_prior)
|
|
1242
|
+
|
|
1243
|
+
# Get the parameters for the latent distribution to sample from
|
|
1244
|
+
if inference_mode:
|
|
1245
|
+
if self.is_top_layer:
|
|
1246
|
+
q_params = bu_value
|
|
1247
|
+
if mode_pred is False:
|
|
1248
|
+
p_params, bu_value = self.align_pparams_buvalue(p_params, bu_value)
|
|
1249
|
+
else:
|
|
1250
|
+
if use_uncond_mode:
|
|
1251
|
+
q_params = p_params
|
|
1252
|
+
else:
|
|
1253
|
+
p_params, bu_value = self.align_pparams_buvalue(p_params, bu_value)
|
|
1254
|
+
q_params = self.merge(bu_value, p_params)
|
|
1255
|
+
# In generative mode, q is not used
|
|
1256
|
+
else:
|
|
1257
|
+
q_params = None
|
|
1258
|
+
|
|
1259
|
+
# NOTE: Sampling is done either from q(z_i | z_{i+1}, x) or p(z_i | z_{i+1})
|
|
1260
|
+
# depending on the mode (hence, in practice, by checking whether q_params is None).
|
|
1261
|
+
|
|
1262
|
+
# Normalization of latent space parameters:
|
|
1263
|
+
# it is done, purely for stablity. See Very deep VAEs generalize autoregressive models.
|
|
1264
|
+
if self.normalize_latent_factor:
|
|
1265
|
+
q_params = q_params / self.normalize_latent_factor
|
|
1266
|
+
|
|
1267
|
+
# Sample (and process) a latent tensor in the stochastic layer
|
|
1268
|
+
x, data_stoch = self.stochastic(
|
|
1269
|
+
p_params=p_params,
|
|
1270
|
+
q_params=q_params,
|
|
1271
|
+
forced_latent=forced_latent,
|
|
1272
|
+
use_mode=use_mode,
|
|
1273
|
+
force_constant_output=force_constant_output,
|
|
1274
|
+
analytical_kl=self.analytical_kl,
|
|
1275
|
+
mode_pred=mode_pred,
|
|
1276
|
+
use_uncond_mode=use_uncond_mode,
|
|
1277
|
+
var_clip_max=var_clip_max,
|
|
1278
|
+
)
|
|
1279
|
+
|
|
1280
|
+
# Merge skip connection from previous layer
|
|
1281
|
+
if self.stochastic_skip and not self.is_top_layer:
|
|
1282
|
+
if self.topdown_no_padding_mode is True:
|
|
1283
|
+
# If no padding is done in the current top-down pass, there may be a shape mismatch between current tensor and skip connection input.
|
|
1284
|
+
# As an example, if the output of last TopDownLayer was of size 64*64, due to lack of padding in the current layer, the current tensor
|
|
1285
|
+
# might become different in shape, say 60*60.
|
|
1286
|
+
# In order to avoid shape mismatch, we do central crop of the skip connection input.
|
|
1287
|
+
skip_connection_input = F.center_crop(
|
|
1288
|
+
skip_connection_input, x.shape[-2:]
|
|
1289
|
+
)
|
|
1290
|
+
|
|
1291
|
+
x = self.skip_connection_merger(x, skip_connection_input)
|
|
1292
|
+
|
|
1293
|
+
# Save activation before residual block as it can be the skip connection input in the next layer
|
|
1294
|
+
x_pre_residual = x
|
|
1295
|
+
|
|
1296
|
+
if self.retain_spatial_dims:
|
|
1297
|
+
# when we don't want to do padding in topdown as well, we need to spare some boundary pixels which would be used up.
|
|
1298
|
+
extra_len = (self.topdown_no_padding_mode is True) * 3
|
|
1299
|
+
|
|
1300
|
+
# this means that x should be of the same size as config.data.image_size. So, we have to centercrop by a factor of 2 at this point.
|
|
1301
|
+
# assert x.shape[-1] >= self.latent_shape[-1] // 2 + extra_len
|
|
1302
|
+
# we assume that one topdown layer will have exactly one upscaling layer.
|
|
1303
|
+
new_latent_shape = (
|
|
1304
|
+
self.latent_shape[0] // 2 + extra_len,
|
|
1305
|
+
self.latent_shape[1] // 2 + extra_len,
|
|
1306
|
+
)
|
|
1307
|
+
|
|
1308
|
+
# If the LC is not applied on all layers, then this can happen.
|
|
1309
|
+
if x.shape[-1] > new_latent_shape[-1]:
|
|
1310
|
+
x = F.center_crop(x, new_latent_shape)
|
|
1311
|
+
|
|
1312
|
+
# Last top-down block (sequence of residual blocks)
|
|
1313
|
+
x = self.deterministic_block(x)
|
|
1314
|
+
|
|
1315
|
+
if self.topdown_no_padding_mode:
|
|
1316
|
+
x = F.center_crop(x, self.latent_shape)
|
|
1317
|
+
|
|
1318
|
+
# Save some metrics that will be used in the loss computation
|
|
1319
|
+
keys = [
|
|
1320
|
+
"z",
|
|
1321
|
+
"kl_samplewise",
|
|
1322
|
+
"kl_samplewise_restricted",
|
|
1323
|
+
"kl_spatial",
|
|
1324
|
+
"kl_channelwise",
|
|
1325
|
+
# 'logprob_p',
|
|
1326
|
+
"logprob_q",
|
|
1327
|
+
"qvar_max",
|
|
1328
|
+
]
|
|
1329
|
+
data = {k: data_stoch.get(k, None) for k in keys}
|
|
1330
|
+
data["q_mu"] = None
|
|
1331
|
+
data["q_lv"] = None
|
|
1332
|
+
if data_stoch["q_params"] is not None:
|
|
1333
|
+
q_mu, q_lv = data_stoch["q_params"]
|
|
1334
|
+
data["q_mu"] = q_mu
|
|
1335
|
+
data["q_lv"] = q_lv
|
|
1336
|
+
|
|
1337
|
+
return x, x_pre_residual, data
|
|
1338
|
+
|
|
1339
|
+
|
|
1340
|
+
class NormalStochasticBlock2d(nn.Module):
|
|
1341
|
+
"""
|
|
1342
|
+
Stochastic block used in the Top-Down inference pass.
|
|
1343
|
+
|
|
1344
|
+
Algorithm:
|
|
1345
|
+
- map input parameters to q(z) and (optionally) p(z) via convolution
|
|
1346
|
+
- sample a latent tensor z ~ q(z)
|
|
1347
|
+
- feed z to convolution and return.
|
|
1348
|
+
|
|
1349
|
+
NOTE 1:
|
|
1350
|
+
If parameters for q are not given, sampling is done from p(z).
|
|
1351
|
+
|
|
1352
|
+
NOTE 2:
|
|
1353
|
+
The restricted KL divergence is obtained by first computing the element-wise KL divergence
|
|
1354
|
+
(i.e., the KL computed for each element of the latent tensors). Then, the restricted version
|
|
1355
|
+
is computed by summing over the channels and the spatial dimensions associated only to the
|
|
1356
|
+
portion of the latent tensor that is used for prediction.
|
|
1357
|
+
"""
|
|
1358
|
+
|
|
1359
|
+
def __init__(
|
|
1360
|
+
self,
|
|
1361
|
+
c_in: int,
|
|
1362
|
+
c_vars: int,
|
|
1363
|
+
c_out: int,
|
|
1364
|
+
kernel: int = 3,
|
|
1365
|
+
transform_p_params: bool = True,
|
|
1366
|
+
vanilla_latent_hw: int = None,
|
|
1367
|
+
restricted_kl: bool = False,
|
|
1368
|
+
use_naive_exponential: bool = False,
|
|
1369
|
+
):
|
|
1370
|
+
"""
|
|
1371
|
+
Parameters
|
|
1372
|
+
----------
|
|
1373
|
+
c_in: int
|
|
1374
|
+
The number of channels of the input tensor.
|
|
1375
|
+
c_vars: int
|
|
1376
|
+
The number of channels of the latent space tensor.
|
|
1377
|
+
c_out: int
|
|
1378
|
+
The output of the stochastic layer.
|
|
1379
|
+
Note that this is different from the sampled latent z.
|
|
1380
|
+
kernel: int, optional
|
|
1381
|
+
The size of the kernel used in convolutional layers.
|
|
1382
|
+
Default is 3.
|
|
1383
|
+
transform_p_params: bool, optional
|
|
1384
|
+
Whether a transformation should be applied to the `p_params` tensor.
|
|
1385
|
+
The transformation consists in a 2D convolution ()`conv_in_p()`) that
|
|
1386
|
+
maps the input to a larger number of channels.
|
|
1387
|
+
Default is `True`.
|
|
1388
|
+
vanilla_latent_hw: int, optional
|
|
1389
|
+
The shape of the latent tensor used for prediction (i.e., it influences the computation of restricted KL).
|
|
1390
|
+
Default is `None`.
|
|
1391
|
+
restricted_kl: bool, optional
|
|
1392
|
+
Whether to compute the restricted version of KL Divergence.
|
|
1393
|
+
See NOTE 2 for more information about its computation.
|
|
1394
|
+
Default is `False`.
|
|
1395
|
+
use_naive_exponential: bool, optional
|
|
1396
|
+
If `False`, exponentials are computed according to the alternative definition
|
|
1397
|
+
provided by `StableExponential` class. This should improve numerical stability
|
|
1398
|
+
in the training process. Default is `False`.
|
|
1399
|
+
"""
|
|
1400
|
+
super().__init__()
|
|
1401
|
+
assert kernel % 2 == 1
|
|
1402
|
+
pad = kernel // 2
|
|
1403
|
+
self.transform_p_params = transform_p_params
|
|
1404
|
+
self.c_in = c_in
|
|
1405
|
+
self.c_out = c_out
|
|
1406
|
+
self.c_vars = c_vars
|
|
1407
|
+
self._use_naive_exponential = use_naive_exponential
|
|
1408
|
+
self._vanilla_latent_hw = vanilla_latent_hw
|
|
1409
|
+
self._restricted_kl = restricted_kl
|
|
1410
|
+
|
|
1411
|
+
if transform_p_params:
|
|
1412
|
+
self.conv_in_p = nn.Conv2d(c_in, 2 * c_vars, kernel, padding=pad)
|
|
1413
|
+
self.conv_in_q = nn.Conv2d(c_in, 2 * c_vars, kernel, padding=pad)
|
|
1414
|
+
self.conv_out = nn.Conv2d(c_vars, c_out, kernel, padding=pad)
|
|
1415
|
+
|
|
1416
|
+
# def forward_swapped(self, p_params, q_mu, q_lv):
|
|
1417
|
+
#
|
|
1418
|
+
# if self.transform_p_params:
|
|
1419
|
+
# p_params = self.conv_in_p(p_params)
|
|
1420
|
+
# else:
|
|
1421
|
+
# assert p_params.size(1) == 2 * self.c_vars
|
|
1422
|
+
#
|
|
1423
|
+
# # Define p(z)
|
|
1424
|
+
# p_mu, p_lv = p_params.chunk(2, dim=1)
|
|
1425
|
+
# p = Normal(p_mu, (p_lv / 2).exp())
|
|
1426
|
+
#
|
|
1427
|
+
# # Define q(z)
|
|
1428
|
+
# q = Normal(q_mu, (q_lv / 2).exp())
|
|
1429
|
+
# # Sample from q(z)
|
|
1430
|
+
# sampling_distrib = q
|
|
1431
|
+
#
|
|
1432
|
+
# # Generate latent variable (typically by sampling)
|
|
1433
|
+
# z = sampling_distrib.rsample()
|
|
1434
|
+
#
|
|
1435
|
+
# # Output of stochastic layer
|
|
1436
|
+
# out = self.conv_out(z)
|
|
1437
|
+
#
|
|
1438
|
+
# data = {
|
|
1439
|
+
# 'z': z, # sampled variable at this layer (batch, ch, h, w)
|
|
1440
|
+
# 'p_params': p_params, # (b, ch, h, w) where b is 1 or batch size
|
|
1441
|
+
# }
|
|
1442
|
+
# return out, data
|
|
1443
|
+
|
|
1444
|
+
def get_z(
|
|
1445
|
+
self,
|
|
1446
|
+
sampling_distrib: torch.distributions.normal.Normal,
|
|
1447
|
+
forced_latent: torch.Tensor,
|
|
1448
|
+
use_mode: bool,
|
|
1449
|
+
mode_pred: bool,
|
|
1450
|
+
use_uncond_mode: bool,
|
|
1451
|
+
) -> torch.Tensor:
|
|
1452
|
+
"""
|
|
1453
|
+
This method enables to sample a latent tensor given the distribution to sample from.
|
|
1454
|
+
|
|
1455
|
+
Latent variable can be obtained is several ways:
|
|
1456
|
+
- Sampled from the (Gaussian) latent distribution.
|
|
1457
|
+
- Taken as a pre-defined forced latent.
|
|
1458
|
+
- Taken as the mode (mean) of the latent distribution.
|
|
1459
|
+
- In prediction mode (`mode_pred==True`), can be either sample or taken as the distribution mode.
|
|
1460
|
+
|
|
1461
|
+
Parameters
|
|
1462
|
+
----------
|
|
1463
|
+
sampling_distrib: torch.distributions.normal.Normal
|
|
1464
|
+
The Gaussian distribution from which latent tensor is sampled.
|
|
1465
|
+
forced_latent: torch.Tensor
|
|
1466
|
+
A pre-defined latent tensor. If it is not `None`, than it is used as the actual latent tensor and,
|
|
1467
|
+
hence, sampling does not happen.
|
|
1468
|
+
use_mode: bool
|
|
1469
|
+
Wheteher the latent tensor should be set as the latent distribution mode.
|
|
1470
|
+
In the case of Gaussian, the mode coincides with the mean of the distribution.
|
|
1471
|
+
mode_pred: bool
|
|
1472
|
+
Whether the model is prediction mode.
|
|
1473
|
+
use_uncond_mode: bool
|
|
1474
|
+
Whether to use the uncoditional distribution p(z) to sample latents in prediction mode.
|
|
1475
|
+
"""
|
|
1476
|
+
if forced_latent is None:
|
|
1477
|
+
if use_mode:
|
|
1478
|
+
z = sampling_distrib.mean
|
|
1479
|
+
else:
|
|
1480
|
+
if mode_pred:
|
|
1481
|
+
if use_uncond_mode:
|
|
1482
|
+
z = sampling_distrib.mean
|
|
1483
|
+
else:
|
|
1484
|
+
z = sampling_distrib.rsample()
|
|
1485
|
+
else:
|
|
1486
|
+
z = sampling_distrib.rsample()
|
|
1487
|
+
else:
|
|
1488
|
+
z = forced_latent
|
|
1489
|
+
return z
|
|
1490
|
+
|
|
1491
|
+
def sample_from_q(
|
|
1492
|
+
self, q_params: torch.Tensor, var_clip_max: float
|
|
1493
|
+
) -> torch.Tensor:
|
|
1494
|
+
"""
|
|
1495
|
+
Given an input parameter tensor defining q(z),
|
|
1496
|
+
it processes it by calling `process_q_params()` method and
|
|
1497
|
+
sample a latent tensor from the resulting distribution.
|
|
1498
|
+
|
|
1499
|
+
Parameters
|
|
1500
|
+
----------
|
|
1501
|
+
q_params: torch.Tensor
|
|
1502
|
+
The input tensor to be processed.
|
|
1503
|
+
var_clip_max: float
|
|
1504
|
+
The maximum value reachable by the log-variance of the latent distribtion.
|
|
1505
|
+
Values exceeding this threshold are clipped.
|
|
1506
|
+
"""
|
|
1507
|
+
_, _, q = self.process_q_params(q_params, var_clip_max)
|
|
1508
|
+
return q.rsample()
|
|
1509
|
+
|
|
1510
|
+
def compute_kl_metrics(
|
|
1511
|
+
self,
|
|
1512
|
+
p: torch.distributions.normal.Normal,
|
|
1513
|
+
p_params: torch.Tensor,
|
|
1514
|
+
q: torch.distributions.normal.Normal,
|
|
1515
|
+
q_params: torch.Tensor,
|
|
1516
|
+
mode_pred: bool,
|
|
1517
|
+
analytical_kl: bool,
|
|
1518
|
+
z: torch.Tensor,
|
|
1519
|
+
) -> Dict[str, torch.Tensor]:
|
|
1520
|
+
"""
|
|
1521
|
+
Compute KL (analytical or MC estimate) and then process it, extracting composed versions of the metric.
|
|
1522
|
+
Specifically, the different versions of the KL loss terms are:
|
|
1523
|
+
- `kl_elementwise`: KL term for each single element of the latent tensor [Shape: (batch, ch, h, w)].
|
|
1524
|
+
- `kl_samplewise`: KL term associated to each sample in the batch [Shape: (batch, )].
|
|
1525
|
+
- `kl_samplewise_restricted`: KL term only associated to the portion of the latent tensor that is
|
|
1526
|
+
used for prediction and summed over channel and spatial dimensions [Shape: (batch, )].
|
|
1527
|
+
- `kl_channelwise`: KL term associated to each sample and each channel [Shape: (batch, ch, )].
|
|
1528
|
+
- `kl_spatial`: KL term summed over the channels, i.e., retaining the spatial dimensions [Shape: (batch, h, w)]
|
|
1529
|
+
|
|
1530
|
+
Parameters
|
|
1531
|
+
----------
|
|
1532
|
+
p: torch.distributions.normal.Normal
|
|
1533
|
+
The prior generative distribution p(z_i|z_{i+1}) (or p(z_L)).
|
|
1534
|
+
p_params: torch.Tensor
|
|
1535
|
+
The parameters of the prior generative distribution.
|
|
1536
|
+
q: torch.distributions.normal.Normal
|
|
1537
|
+
The inference distribution q(z_i|z_{i+1}) (or q(z_L|x)).
|
|
1538
|
+
q_params: torch.Tensor
|
|
1539
|
+
The parameters of the inference distribution.
|
|
1540
|
+
mode_pred: bool
|
|
1541
|
+
Whether the model is in prediction mode.
|
|
1542
|
+
analytical_kl: bool
|
|
1543
|
+
Whether to compute the KL divergence analytically or using Monte Carlo estimation.
|
|
1544
|
+
z: torch.Tensor
|
|
1545
|
+
The sampled latent tensor.
|
|
1546
|
+
"""
|
|
1547
|
+
kl_samplewise_restricted = None
|
|
1548
|
+
|
|
1549
|
+
if mode_pred is False: # if not in prediction mode
|
|
1550
|
+
# KL term for each single element of the latent tensor [Shape: (batch, ch, h, w)]
|
|
1551
|
+
if analytical_kl:
|
|
1552
|
+
kl_elementwise = kl_divergence(q, p)
|
|
1553
|
+
else:
|
|
1554
|
+
kl_elementwise = kl_normal_mc(z, p_params, q_params)
|
|
1555
|
+
|
|
1556
|
+
# KL term only associated to the portion of the latent tensor that is used for prediction and
|
|
1557
|
+
# summed over channel and spatial dimensions. [Shape: (batch, )]
|
|
1558
|
+
# NOTE: vanilla_latent_hw is the shape of the latent tensor used for prediction, hence
|
|
1559
|
+
# the restriction has shape [Shape: (batch, ch, vanilla_latent_hw[0], vanilla_latent_hw[1])]
|
|
1560
|
+
if self._restricted_kl:
|
|
1561
|
+
pad = (kl_elementwise.shape[-1] - self._vanilla_latent_hw) // 2
|
|
1562
|
+
assert pad > 0, "Disable restricted kl since there is no restriction."
|
|
1563
|
+
tmp = kl_elementwise[..., pad:-pad, pad:-pad]
|
|
1564
|
+
kl_samplewise_restricted = tmp.sum((1, 2, 3))
|
|
1565
|
+
|
|
1566
|
+
# KL term associated to each sample in the batch [Shape: (batch, )]
|
|
1567
|
+
kl_samplewise = kl_elementwise.sum((1, 2, 3))
|
|
1568
|
+
|
|
1569
|
+
# KL term associated to each sample and each channel [Shape: (batch, ch, )]
|
|
1570
|
+
kl_channelwise = kl_elementwise.sum((2, 3))
|
|
1571
|
+
|
|
1572
|
+
# KL term summed over the channels, i.e., retaining the spatial dimensions [Shape: (batch, h, w)]
|
|
1573
|
+
kl_spatial = kl_elementwise.sum(1)
|
|
1574
|
+
else: # if predicting, no need to compute KL
|
|
1575
|
+
kl_elementwise = kl_samplewise = kl_spatial = kl_channelwise = None
|
|
1576
|
+
|
|
1577
|
+
kl_dict = {
|
|
1578
|
+
"kl_elementwise": kl_elementwise, # (batch, ch, h, w)
|
|
1579
|
+
"kl_samplewise": kl_samplewise, # (batch, )
|
|
1580
|
+
"kl_samplewise_restricted": kl_samplewise_restricted, # (batch, )
|
|
1581
|
+
"kl_channelwise": kl_channelwise, # (batch, ch)
|
|
1582
|
+
"kl_spatial": kl_spatial, # (batch, h, w)
|
|
1583
|
+
}
|
|
1584
|
+
return kl_dict
|
|
1585
|
+
|
|
1586
|
+
def process_p_params(
|
|
1587
|
+
self, p_params: torch.Tensor, var_clip_max: float
|
|
1588
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.distributions.normal.Normal]:
|
|
1589
|
+
"""
|
|
1590
|
+
Process the input parameters to get the prior distribution p(z_i|z_{i+1}) (or p(z_L)).
|
|
1591
|
+
|
|
1592
|
+
Processing consists in:
|
|
1593
|
+
- (optionally) 2D convolution on the input tensor to increase number of channels.
|
|
1594
|
+
- split the resulting tensor into two chunks, the mean and the log-variance.
|
|
1595
|
+
- (optionally) clip the log-variance to an upper threshold.
|
|
1596
|
+
- define the normal distribution p(z) given the parameter tensors above.
|
|
1597
|
+
|
|
1598
|
+
Parameters
|
|
1599
|
+
----------
|
|
1600
|
+
p_params: torch.Tensor
|
|
1601
|
+
The input tensor to be processed.
|
|
1602
|
+
var_clip_max: float
|
|
1603
|
+
The maximum value reachable by the log-variance of the latent distribtion.
|
|
1604
|
+
Values exceeding this threshold are clipped.
|
|
1605
|
+
"""
|
|
1606
|
+
if self.transform_p_params:
|
|
1607
|
+
p_params = self.conv_in_p(p_params)
|
|
1608
|
+
else:
|
|
1609
|
+
assert p_params.size(1) == 2 * self.c_vars
|
|
1610
|
+
|
|
1611
|
+
# Define p(z)
|
|
1612
|
+
p_mu, p_lv = p_params.chunk(2, dim=1)
|
|
1613
|
+
if var_clip_max is not None:
|
|
1614
|
+
p_lv = torch.clip(p_lv, max=var_clip_max)
|
|
1615
|
+
|
|
1616
|
+
p_mu = StableMean(p_mu)
|
|
1617
|
+
p_lv = StableLogVar(p_lv, enable_stable=not self._use_naive_exponential)
|
|
1618
|
+
p = Normal(p_mu.get(), p_lv.get_std())
|
|
1619
|
+
return p_mu, p_lv, p
|
|
1620
|
+
|
|
1621
|
+
def process_q_params(
|
|
1622
|
+
self, q_params: torch.Tensor, var_clip_max: float, allow_oddsizes: bool = False
|
|
1623
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.distributions.normal.Normal]:
|
|
1624
|
+
"""
|
|
1625
|
+
Process the input parameters to get the inference distribution q(z_i|z_{i+1}) (or q(z|x)).
|
|
1626
|
+
|
|
1627
|
+
Processing consists in:
|
|
1628
|
+
- 2D convolution on the input tensor to increase number of channels.
|
|
1629
|
+
- split the resulting tensor into two chunks, the mean and the log-variance.
|
|
1630
|
+
- (optionally) clip the log-variance to an upper threshold.
|
|
1631
|
+
- (optionally) crop the resulting tensors to ensure that the last spatial dimension is even.
|
|
1632
|
+
- define the normal distribution q(z) given the parameter tensors above.
|
|
1633
|
+
|
|
1634
|
+
Parameters
|
|
1635
|
+
----------
|
|
1636
|
+
p_params: torch.Tensor
|
|
1637
|
+
The input tensor to be processed.
|
|
1638
|
+
var_clip_max: float
|
|
1639
|
+
The maximum value reachable by the log-variance of the latent distribtion.
|
|
1640
|
+
Values exceeding this threshold are clipped.
|
|
1641
|
+
"""
|
|
1642
|
+
q_params = self.conv_in_q(q_params)
|
|
1643
|
+
|
|
1644
|
+
q_mu, q_lv = q_params.chunk(2, dim=1)
|
|
1645
|
+
if var_clip_max is not None:
|
|
1646
|
+
q_lv = torch.clip(q_lv, max=var_clip_max)
|
|
1647
|
+
|
|
1648
|
+
if q_mu.shape[-1] % 2 == 1 and allow_oddsizes is False:
|
|
1649
|
+
q_mu = F.center_crop(q_mu, q_mu.shape[-1] - 1)
|
|
1650
|
+
q_lv = F.center_crop(q_lv, q_lv.shape[-1] - 1)
|
|
1651
|
+
# clip_start = np.random.rand() > 0.5
|
|
1652
|
+
# q_mu = q_mu[:, :, 1:, 1:] if clip_start else q_mu[:, :, :-1, :-1]
|
|
1653
|
+
# q_lv = q_lv[:, :, 1:, 1:] if clip_start else q_lv[:, :, :-1, :-1]
|
|
1654
|
+
|
|
1655
|
+
q_mu = StableMean(q_mu)
|
|
1656
|
+
q_lv = StableLogVar(q_lv, enable_stable=not self._use_naive_exponential)
|
|
1657
|
+
q = Normal(q_mu.get(), q_lv.get_std())
|
|
1658
|
+
return q_mu, q_lv, q
|
|
1659
|
+
|
|
1660
|
+
def forward(
|
|
1661
|
+
self,
|
|
1662
|
+
p_params: torch.Tensor,
|
|
1663
|
+
q_params: torch.Tensor = None,
|
|
1664
|
+
forced_latent: torch.Tensor = None,
|
|
1665
|
+
use_mode: bool = False,
|
|
1666
|
+
force_constant_output: bool = False,
|
|
1667
|
+
analytical_kl: bool = False,
|
|
1668
|
+
mode_pred: bool = False,
|
|
1669
|
+
use_uncond_mode: bool = False,
|
|
1670
|
+
var_clip_max: float = None,
|
|
1671
|
+
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
|
1672
|
+
"""
|
|
1673
|
+
Parameters
|
|
1674
|
+
----------
|
|
1675
|
+
p_params: torch.Tensor
|
|
1676
|
+
The output tensor of the top-down layer above (i.e., mu_{p,i+1}, sigma_{p,i+1}).
|
|
1677
|
+
q_params: torch.Tensor, optional
|
|
1678
|
+
The tensor resulting from merging the bu_value tensor at the same hierarchical level
|
|
1679
|
+
from the bottom-up pass and the `p_params` tensor. Default is `None`.
|
|
1680
|
+
forced_latent: torch.Tensor, optional
|
|
1681
|
+
A pre-defined latent tensor. If it is not `None`, than it is used as the actual latent
|
|
1682
|
+
tensor and, hence, sampling does not happen. Default is `None`.
|
|
1683
|
+
use_mode: bool, optional
|
|
1684
|
+
Wheteher the latent tensor should be set as the latent distribution mode.
|
|
1685
|
+
In the case of Gaussian, the mode coincides with the mean of the distribution.
|
|
1686
|
+
Default is `False`.
|
|
1687
|
+
force_constant_output: bool, optional
|
|
1688
|
+
Whether to copy the first sample (and rel. distrib parameters) over the whole batch.
|
|
1689
|
+
This is used when doing experiment from the prior - q is not used.
|
|
1690
|
+
Default is `False`.
|
|
1691
|
+
analytical_kl: bool, optional
|
|
1692
|
+
Whether to compute the KL divergence analytically or using Monte Carlo estimation.
|
|
1693
|
+
Default is `False`.
|
|
1694
|
+
mode_pred: bool, optional
|
|
1695
|
+
Whether the model is in prediction mode. Default is `False`.
|
|
1696
|
+
use_uncond_mode: bool, optional
|
|
1697
|
+
Whether to use the uncoditional distribution p(z) to sample latents in prediction mode.
|
|
1698
|
+
Default is `False`.
|
|
1699
|
+
var_clip_max: float, optional
|
|
1700
|
+
The maximum value reachable by the log-variance of the latent distribtion.
|
|
1701
|
+
Values exceeding this threshold are clipped. Default is `None`.
|
|
1702
|
+
"""
|
|
1703
|
+
debug_qvar_max = 0
|
|
1704
|
+
|
|
1705
|
+
# Check sampling options consistency
|
|
1706
|
+
assert (forced_latent is None) or (not use_mode)
|
|
1707
|
+
|
|
1708
|
+
# Get generative distribution p(z_i|z_{i+1})
|
|
1709
|
+
p_mu, p_lv, p = self.process_p_params(p_params, var_clip_max)
|
|
1710
|
+
p_params = (p_mu, p_lv)
|
|
1711
|
+
|
|
1712
|
+
if q_params is not None:
|
|
1713
|
+
# Get inference distribution q(z_i|z_{i+1})
|
|
1714
|
+
# NOTE: At inference time, don't centercrop the q_params even if they are odd in size.
|
|
1715
|
+
q_mu, q_lv, q = self.process_q_params(
|
|
1716
|
+
q_params, var_clip_max, allow_oddsizes=mode_pred is True
|
|
1717
|
+
)
|
|
1718
|
+
q_params = (q_mu, q_lv)
|
|
1719
|
+
sampling_distrib = q
|
|
1720
|
+
debug_qvar_max = torch.max(q_lv.get())
|
|
1721
|
+
|
|
1722
|
+
# Centercrop p_params so that their size matches the one of q_params
|
|
1723
|
+
q_size = q_mu.get().shape[-1]
|
|
1724
|
+
if p_mu.get().shape[-1] != q_size and mode_pred is False:
|
|
1725
|
+
p_mu.centercrop_to_size(q_size)
|
|
1726
|
+
p_lv.centercrop_to_size(q_size)
|
|
1727
|
+
else:
|
|
1728
|
+
sampling_distrib = p
|
|
1729
|
+
|
|
1730
|
+
# Sample latent variable
|
|
1731
|
+
z = self.get_z(
|
|
1732
|
+
sampling_distrib, forced_latent, use_mode, mode_pred, use_uncond_mode
|
|
1733
|
+
)
|
|
1734
|
+
|
|
1735
|
+
# Copy one sample (and distrib parameters) over the whole batch.
|
|
1736
|
+
# This is used when doing experiment from the prior - q is not used.
|
|
1737
|
+
if force_constant_output:
|
|
1738
|
+
z = z[0:1].expand_as(z).clone()
|
|
1739
|
+
p_params = (
|
|
1740
|
+
p_params[0][0:1].expand_as(p_params[0]).clone(),
|
|
1741
|
+
p_params[1][0:1].expand_as(p_params[1]).clone(),
|
|
1742
|
+
)
|
|
1743
|
+
|
|
1744
|
+
# Pass the sampled latent througn the output convolutional layer of stochastic block
|
|
1745
|
+
out = self.conv_out(z)
|
|
1746
|
+
|
|
1747
|
+
# Compute log p(z)# NOTE: disabling its computation.
|
|
1748
|
+
# if mode_pred is False:
|
|
1749
|
+
# logprob_p = p.log_prob(z).sum((1, 2, 3))
|
|
1750
|
+
# else:
|
|
1751
|
+
# logprob_p = None
|
|
1752
|
+
|
|
1753
|
+
if q_params is not None:
|
|
1754
|
+
# Compute log q(z)
|
|
1755
|
+
logprob_q = q.log_prob(z).sum((1, 2, 3))
|
|
1756
|
+
# Compute KL divergence metrics
|
|
1757
|
+
kl_dict = self.compute_kl_metrics(
|
|
1758
|
+
p, p_params, q, q_params, mode_pred, analytical_kl, z
|
|
1759
|
+
)
|
|
1760
|
+
else:
|
|
1761
|
+
kl_dict = {}
|
|
1762
|
+
logprob_q = None
|
|
1763
|
+
|
|
1764
|
+
# Store meaningful quantities to use them in following layers
|
|
1765
|
+
data = kl_dict
|
|
1766
|
+
data["z"] = z # sampled variable at this layer (batch, ch, h, w)
|
|
1767
|
+
data["p_params"] = p_params # (b, ch, h, w) where b is 1 or batch size
|
|
1768
|
+
data["q_params"] = q_params # (batch, ch, h, w)
|
|
1769
|
+
# data['logprob_p'] = logprob_p # (batch, )
|
|
1770
|
+
data["logprob_q"] = logprob_q # (batch, )
|
|
1771
|
+
data["qvar_max"] = debug_qvar_max
|
|
1772
|
+
|
|
1773
|
+
return out, data
|
|
1774
|
+
|
|
1775
|
+
|
|
1776
|
+
class NonStochasticBlock2d(nn.Module):
|
|
1777
|
+
"""
|
|
1778
|
+
Non-stochastic version of the NormalStochasticBlock2d.
|
|
1779
|
+
"""
|
|
1780
|
+
|
|
1781
|
+
def __init__(
|
|
1782
|
+
self,
|
|
1783
|
+
c_vars: int,
|
|
1784
|
+
c_in: int,
|
|
1785
|
+
c_out: int,
|
|
1786
|
+
kernel: int = 3,
|
|
1787
|
+
groups: int = 1,
|
|
1788
|
+
conv2d_bias: bool = True,
|
|
1789
|
+
transform_p_params: bool = True,
|
|
1790
|
+
):
|
|
1791
|
+
"""
|
|
1792
|
+
Constructor.
|
|
1793
|
+
|
|
1794
|
+
Parameters
|
|
1795
|
+
----------
|
|
1796
|
+
c_vars: int
|
|
1797
|
+
The number of channels of the latent space tensor.
|
|
1798
|
+
c_in: int
|
|
1799
|
+
The number of channels of the input tensor.
|
|
1800
|
+
c_out: int
|
|
1801
|
+
The output of the stochastic layer.
|
|
1802
|
+
Note that this is different from the sampled latent z.
|
|
1803
|
+
kernel: int, optional
|
|
1804
|
+
The size of the kernel used in convolutional layers.
|
|
1805
|
+
Default is 3.
|
|
1806
|
+
groups: int, optional
|
|
1807
|
+
The number of groups to consider in the convolutions of this layer.
|
|
1808
|
+
Default is 1.
|
|
1809
|
+
conv2d_bias: bool, optional
|
|
1810
|
+
Whether to use bias term is the convolutional blocks of this layer.
|
|
1811
|
+
Default is `True`.
|
|
1812
|
+
transform_p_params: bool, optional
|
|
1813
|
+
Whether a transformation should be applied to the `p_params` tensor.
|
|
1814
|
+
The transformation consists in a 2D convolution ()`conv_in_p()`) that
|
|
1815
|
+
maps the input to a larger number of channels.
|
|
1816
|
+
Default is `True`.
|
|
1817
|
+
"""
|
|
1818
|
+
super().__init__()
|
|
1819
|
+
assert kernel % 2 == 1
|
|
1820
|
+
pad = kernel // 2
|
|
1821
|
+
self.transform_p_params = transform_p_params
|
|
1822
|
+
self.c_in = c_in
|
|
1823
|
+
self.c_out = c_out
|
|
1824
|
+
self.c_vars = c_vars
|
|
1825
|
+
|
|
1826
|
+
if transform_p_params:
|
|
1827
|
+
self.conv_in_p = nn.Conv2d(
|
|
1828
|
+
c_in, 2 * c_vars, kernel, padding=pad, bias=conv2d_bias, groups=groups
|
|
1829
|
+
)
|
|
1830
|
+
self.conv_in_q = nn.Conv2d(
|
|
1831
|
+
c_in, 2 * c_vars, kernel, padding=pad, bias=conv2d_bias, groups=groups
|
|
1832
|
+
)
|
|
1833
|
+
self.conv_out = nn.Conv2d(
|
|
1834
|
+
c_vars, c_out, kernel, padding=pad, bias=conv2d_bias, groups=groups
|
|
1835
|
+
)
|
|
1836
|
+
|
|
1837
|
+
def compute_kl_metrics(
|
|
1838
|
+
self,
|
|
1839
|
+
p: torch.distributions.normal.Normal,
|
|
1840
|
+
p_params: torch.Tensor,
|
|
1841
|
+
q: torch.distributions.normal.Normal,
|
|
1842
|
+
q_params: torch.Tensor,
|
|
1843
|
+
mode_pred: bool,
|
|
1844
|
+
analytical_kl: bool,
|
|
1845
|
+
z: torch.Tensor,
|
|
1846
|
+
) -> Dict[str, None]:
|
|
1847
|
+
"""
|
|
1848
|
+
Compute KL (analytical or MC estimate) and then process it, extracting composed versions of the metric.
|
|
1849
|
+
Specifically, the different versions of the KL loss terms are:
|
|
1850
|
+
- `kl_elementwise`: KL term for each single element of the latent tensor [Shape: (batch, ch, h, w)].
|
|
1851
|
+
- `kl_samplewise`: KL term associated to each sample in the batch [Shape: (batch, )].
|
|
1852
|
+
- `kl_samplewise_restricted`: KL term only associated to the portion of the latent tensor that is
|
|
1853
|
+
used for prediction and summed over channel and spatial dimensions [Shape: (batch, )].
|
|
1854
|
+
- `kl_channelwise`: KL term associated to each sample and each channel [Shape: (batch, ch, )].
|
|
1855
|
+
- `kl_spatial`: # KL term summed over the channels, i.e., retaining the spatial dimensions [Shape: (batch, h, w)]
|
|
1856
|
+
|
|
1857
|
+
NOTE: in this class all the KL metrics are set to `None`.
|
|
1858
|
+
|
|
1859
|
+
Parameters
|
|
1860
|
+
----------
|
|
1861
|
+
p: torch.distributions.normal.Normal
|
|
1862
|
+
The prior generative distribution p(z_i|z_{i+1}) (or p(z_L)).
|
|
1863
|
+
p_params: torch.Tensor
|
|
1864
|
+
The parameters of the prior generative distribution.
|
|
1865
|
+
q: torch.distributions.normal.Normal
|
|
1866
|
+
The inference distribution q(z_i|z_{i+1}) (or q(z_L|x)).
|
|
1867
|
+
q_params: torch.Tensor
|
|
1868
|
+
The parameters of the inference distribution.
|
|
1869
|
+
mode_pred: bool
|
|
1870
|
+
Whether the model is in prediction mode.
|
|
1871
|
+
analytical_kl: bool
|
|
1872
|
+
Whether to compute the KL divergence analytically or using Monte Carlo estimation.
|
|
1873
|
+
z: torch.Tensor
|
|
1874
|
+
The sampled latent tensor.
|
|
1875
|
+
"""
|
|
1876
|
+
kl_dict = {
|
|
1877
|
+
"kl_elementwise": None, # (batch, ch, h, w)
|
|
1878
|
+
"kl_samplewise": None, # (batch, )
|
|
1879
|
+
"kl_spatial": None, # (batch, h, w)
|
|
1880
|
+
"kl_channelwise": None, # (batch, ch)
|
|
1881
|
+
}
|
|
1882
|
+
return kl_dict
|
|
1883
|
+
|
|
1884
|
+
def process_p_params(self, p_params, var_clip_max):
|
|
1885
|
+
if self.transform_p_params:
|
|
1886
|
+
p_params = self.conv_in_p(p_params)
|
|
1887
|
+
else:
|
|
1888
|
+
|
|
1889
|
+
assert (
|
|
1890
|
+
p_params.size(1) == 2 * self.c_vars
|
|
1891
|
+
), f"{p_params.shape} {self.c_vars}"
|
|
1892
|
+
|
|
1893
|
+
# Define p(z)
|
|
1894
|
+
p_mu, p_lv = p_params.chunk(2, dim=1)
|
|
1895
|
+
return p_mu, None
|
|
1896
|
+
|
|
1897
|
+
def process_q_params(self, q_params, var_clip_max, allow_oddsizes=False):
|
|
1898
|
+
# Define q(z)
|
|
1899
|
+
q_params = self.conv_in_q(q_params)
|
|
1900
|
+
q_mu, q_lv = q_params.chunk(2, dim=1)
|
|
1901
|
+
|
|
1902
|
+
if q_mu.shape[-1] % 2 == 1 and allow_oddsizes is False:
|
|
1903
|
+
q_mu = F.center_crop(q_mu, q_mu.shape[-1] - 1)
|
|
1904
|
+
|
|
1905
|
+
return q_mu, None
|
|
1906
|
+
|
|
1907
|
+
def forward(
|
|
1908
|
+
self,
|
|
1909
|
+
p_params: torch.Tensor,
|
|
1910
|
+
q_params: torch.Tensor = None,
|
|
1911
|
+
forced_latent: Union[None, torch.Tensor] = None,
|
|
1912
|
+
use_mode: bool = False,
|
|
1913
|
+
force_constant_output: bool = False,
|
|
1914
|
+
analytical_kl: bool = False,
|
|
1915
|
+
mode_pred: bool = False,
|
|
1916
|
+
use_uncond_mode: bool = False,
|
|
1917
|
+
var_clip_max: float = None,
|
|
1918
|
+
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
|
1919
|
+
"""
|
|
1920
|
+
Parameters
|
|
1921
|
+
----------
|
|
1922
|
+
p_params: torch.Tensor
|
|
1923
|
+
The output tensor of the top-down layer above (i.e., mu_{p,i+1}, sigma_{p,i+1}).
|
|
1924
|
+
q_params: torch.Tensor, optional
|
|
1925
|
+
The tensor resulting from merging the bu_value tensor at the same hierarchical level
|
|
1926
|
+
from the bottom-up pass and the `p_params` tensor. Default is `None`.
|
|
1927
|
+
forced_latent: torch.Tensor, optional
|
|
1928
|
+
A pre-defined latent tensor. If it is not `None`, than it is used as the actual latent
|
|
1929
|
+
tensor and, hence, sampling does not happen. Default is `None`.
|
|
1930
|
+
use_mode: bool, optional
|
|
1931
|
+
Wheteher the latent tensor should be set as the latent distribution mode.
|
|
1932
|
+
In the case of Gaussian, the mode coincides with the mean of the distribution.
|
|
1933
|
+
Default is `False`.
|
|
1934
|
+
force_constant_output: bool, optional
|
|
1935
|
+
Whether to copy the first sample (and rel. distrib parameters) over the whole batch.
|
|
1936
|
+
This is used when doing experiment from the prior - q is not used.
|
|
1937
|
+
Default is `False`.
|
|
1938
|
+
analytical_kl: bool, optional
|
|
1939
|
+
Whether to compute the KL divergence analytically or using Monte Carlo estimation.
|
|
1940
|
+
Default is `False`.
|
|
1941
|
+
mode_pred: bool, optional
|
|
1942
|
+
Whether the model is in prediction mode. Default is `False`.
|
|
1943
|
+
use_uncond_mode: bool, optional
|
|
1944
|
+
Whether to use the uncoditional distribution p(z) to sample latents in prediction mode.
|
|
1945
|
+
Default is `False`.
|
|
1946
|
+
var_clip_max: float, optional
|
|
1947
|
+
The maximum value reachable by the log-variance of the latent distribtion.
|
|
1948
|
+
Values exceeding this threshold are clipped. Default is `None`.
|
|
1949
|
+
"""
|
|
1950
|
+
debug_qvar_max = 0
|
|
1951
|
+
assert (forced_latent is None) or (not use_mode)
|
|
1952
|
+
|
|
1953
|
+
p_mu, _ = self.process_p_params(p_params, var_clip_max)
|
|
1954
|
+
|
|
1955
|
+
p_params = (p_mu, None)
|
|
1956
|
+
|
|
1957
|
+
if q_params is not None:
|
|
1958
|
+
# At inference time, just don't centercrop the q_params even if they are odd in size.
|
|
1959
|
+
q_mu, _ = self.process_q_params(
|
|
1960
|
+
q_params, var_clip_max, allow_oddsizes=mode_pred is True
|
|
1961
|
+
)
|
|
1962
|
+
q_params = (q_mu, None)
|
|
1963
|
+
debug_qvar_max = torch.Tensor([1]).to(q_mu.device)
|
|
1964
|
+
# Sample from q(z)
|
|
1965
|
+
sampling_distrib = q_mu
|
|
1966
|
+
q_size = q_mu.shape[-1]
|
|
1967
|
+
if p_mu.shape[-1] != q_size and mode_pred is False:
|
|
1968
|
+
p_mu.centercrop_to_size(q_size)
|
|
1969
|
+
else:
|
|
1970
|
+
# Sample from p(z)
|
|
1971
|
+
sampling_distrib = p_mu
|
|
1972
|
+
|
|
1973
|
+
# Generate latent variable (typically by sampling)
|
|
1974
|
+
z = sampling_distrib
|
|
1975
|
+
|
|
1976
|
+
# Copy one sample (and distrib parameters) over the whole batch.
|
|
1977
|
+
# This is used when doing experiment from the prior - q is not used.
|
|
1978
|
+
if force_constant_output:
|
|
1979
|
+
z = z[0:1].expand_as(z).clone()
|
|
1980
|
+
p_params = (
|
|
1981
|
+
p_params[0][0:1].expand_as(p_params[0]).clone(),
|
|
1982
|
+
p_params[1][0:1].expand_as(p_params[1]).clone(),
|
|
1983
|
+
)
|
|
1984
|
+
|
|
1985
|
+
# Output of stochastic layer
|
|
1986
|
+
out = self.conv_out(z)
|
|
1987
|
+
|
|
1988
|
+
kl_dict = {}
|
|
1989
|
+
logprob_q = None
|
|
1990
|
+
|
|
1991
|
+
data = kl_dict
|
|
1992
|
+
data["z"] = z # sampled variable at this layer (batch, ch, h, w)
|
|
1993
|
+
data["p_params"] = p_params # (b, ch, h, w) where b is 1 or batch size
|
|
1994
|
+
data["q_params"] = q_params # (batch, ch, h, w)
|
|
1995
|
+
data["logprob_q"] = logprob_q # (batch, )
|
|
1996
|
+
data["qvar_max"] = debug_qvar_max
|
|
1997
|
+
|
|
1998
|
+
return out, data
|