careamics 0.0.4.2__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +17 -2
- careamics/careamist.py +239 -28
- careamics/cli/conf.py +19 -31
- careamics/cli/main.py +112 -12
- careamics/cli/utils.py +29 -0
- careamics/config/__init__.py +48 -24
- careamics/config/algorithms/__init__.py +15 -0
- careamics/config/algorithms/care_algorithm_model.py +50 -0
- careamics/config/algorithms/n2n_algorithm_model.py +42 -0
- careamics/config/algorithms/n2v_algorithm_model.py +35 -0
- careamics/config/algorithms/unet_algorithm_model.py +88 -0
- careamics/config/{vae_algorithm_model.py → algorithms/vae_algorithm_model.py} +26 -23
- careamics/config/architectures/__init__.py +1 -11
- careamics/config/architectures/architecture_model.py +3 -3
- careamics/config/architectures/lvae_model.py +109 -21
- careamics/config/architectures/unet_model.py +1 -0
- careamics/config/care_configuration.py +100 -0
- careamics/config/configuration.py +354 -0
- careamics/config/{configuration_factory.py → configuration_factories.py} +152 -81
- careamics/config/configuration_io.py +85 -0
- careamics/config/data/__init__.py +10 -0
- careamics/config/{data_model.py → data/data_model.py} +58 -198
- careamics/config/data/n2v_data_model.py +193 -0
- careamics/config/likelihood_model.py +8 -8
- careamics/config/loss_model.py +56 -0
- careamics/config/n2n_configuration.py +101 -0
- careamics/config/n2v_configuration.py +266 -0
- careamics/config/nm_model.py +24 -25
- careamics/config/support/__init__.py +7 -7
- careamics/config/support/supported_algorithms.py +0 -3
- careamics/config/support/supported_architectures.py +0 -4
- careamics/config/transformations/__init__.py +10 -4
- careamics/config/transformations/transform_model.py +3 -3
- careamics/config/transformations/transform_unions.py +42 -0
- careamics/config/validators/validator_utils.py +3 -3
- careamics/dataset/__init__.py +2 -2
- careamics/dataset/dataset_utils/__init__.py +3 -3
- careamics/dataset/dataset_utils/dataset_utils.py +4 -6
- careamics/dataset/dataset_utils/file_utils.py +9 -9
- careamics/dataset/dataset_utils/iterate_over_files.py +4 -3
- careamics/dataset/dataset_utils/running_stats.py +22 -23
- careamics/dataset/in_memory_dataset.py +11 -12
- careamics/dataset/iterable_dataset.py +4 -4
- careamics/dataset/iterable_pred_dataset.py +2 -1
- careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
- careamics/dataset/patching/random_patching.py +11 -10
- careamics/dataset/patching/sequential_patching.py +26 -26
- careamics/dataset/patching/validate_patch_dimension.py +3 -3
- careamics/dataset/tiling/__init__.py +2 -2
- careamics/dataset/tiling/collate_tiles.py +3 -3
- careamics/dataset/tiling/lvae_tiled_patching.py +2 -1
- careamics/dataset/tiling/tiled_patching.py +11 -10
- careamics/file_io/__init__.py +5 -5
- careamics/file_io/read/__init__.py +1 -1
- careamics/file_io/read/get_func.py +2 -2
- careamics/file_io/write/__init__.py +2 -2
- careamics/lightning/__init__.py +5 -5
- careamics/lightning/callbacks/__init__.py +1 -1
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +3 -3
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +2 -1
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +2 -1
- careamics/lightning/callbacks/progress_bar_callback.py +2 -2
- careamics/lightning/lightning_module.py +69 -34
- careamics/lightning/train_data_module.py +41 -27
- careamics/losses/__init__.py +3 -3
- careamics/losses/loss_factory.py +1 -85
- careamics/losses/lvae/losses.py +223 -164
- careamics/lvae_training/calibration.py +184 -0
- careamics/lvae_training/dataset/config.py +2 -2
- careamics/lvae_training/dataset/multich_dataset.py +11 -19
- careamics/lvae_training/dataset/multifile_dataset.py +3 -2
- careamics/lvae_training/dataset/types.py +15 -26
- careamics/lvae_training/dataset/utils/index_manager.py +4 -4
- careamics/lvae_training/eval_utils.py +125 -213
- careamics/model_io/__init__.py +1 -1
- careamics/model_io/bioimage/__init__.py +1 -1
- careamics/model_io/bioimage/_readme_factory.py +26 -34
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +56 -34
- careamics/model_io/bmz_io.py +42 -42
- careamics/model_io/model_io_utils.py +9 -9
- careamics/models/layers.py +22 -20
- careamics/models/lvae/layers.py +348 -975
- careamics/models/lvae/likelihoods.py +10 -8
- careamics/models/lvae/lvae.py +214 -275
- careamics/models/lvae/noise_models.py +179 -112
- careamics/models/lvae/stochastic.py +393 -0
- careamics/models/lvae/utils.py +82 -73
- careamics/models/model_factory.py +2 -15
- careamics/models/unet.py +8 -8
- careamics/prediction_utils/__init__.py +1 -1
- careamics/prediction_utils/prediction_outputs.py +15 -15
- careamics/prediction_utils/stitch_prediction.py +6 -6
- careamics/transforms/__init__.py +5 -5
- careamics/transforms/compose.py +13 -13
- careamics/transforms/n2v_manipulate.py +3 -3
- careamics/transforms/pixel_manipulation.py +9 -9
- careamics/transforms/xy_random_rotate90.py +4 -4
- careamics/utils/__init__.py +5 -5
- careamics/utils/context.py +2 -1
- careamics/utils/lightning_utils.py +57 -0
- careamics/utils/logging.py +11 -10
- careamics/utils/serializers.py +2 -0
- careamics/utils/torch_utils.py +8 -8
- {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/METADATA +16 -13
- careamics-0.0.6.dist-info/RECORD +176 -0
- {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/WHEEL +1 -1
- careamics/config/architectures/custom_model.py +0 -162
- careamics/config/architectures/register_model.py +0 -103
- careamics/config/configuration_model.py +0 -603
- careamics/config/fcn_algorithm_model.py +0 -152
- careamics/config/references/__init__.py +0 -45
- careamics/config/references/algorithm_descriptions.py +0 -132
- careamics/config/references/references.py +0 -39
- careamics/config/transformations/transform_union.py +0 -20
- careamics-0.0.4.2.dist-info/RECORD +0 -165
- {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/licenses/LICENSE +0 -0
careamics/models/lvae/layers.py
CHANGED
|
@@ -1,27 +1,23 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Script containing the common basic blocks (nn.Module) reused by the LadderVAE architecture.
|
|
3
|
-
|
|
4
|
-
Hierarchy in the model blocks:
|
|
5
|
-
|
|
6
|
-
"""
|
|
1
|
+
"""Script containing the common basic blocks (nn.Module) reused by the LadderVAE."""
|
|
7
2
|
|
|
3
|
+
from collections.abc import Iterable
|
|
8
4
|
from copy import deepcopy
|
|
9
|
-
from typing import Callable,
|
|
5
|
+
from typing import Callable, Literal, Optional, Union
|
|
10
6
|
|
|
7
|
+
import numpy as np
|
|
11
8
|
import torch
|
|
12
9
|
import torch.nn as nn
|
|
13
|
-
import torchvision.transforms.functional as F
|
|
14
|
-
from torch.distributions import kl_divergence
|
|
15
|
-
from torch.distributions.normal import Normal
|
|
16
10
|
|
|
11
|
+
from .stochastic import NormalStochasticBlock
|
|
17
12
|
from .utils import (
|
|
18
|
-
StableLogVar,
|
|
19
|
-
StableMean,
|
|
20
13
|
crop_img_tensor,
|
|
21
|
-
kl_normal_mc,
|
|
22
14
|
pad_img_tensor,
|
|
23
15
|
)
|
|
24
16
|
|
|
17
|
+
ConvType = Union[nn.Conv2d, nn.Conv3d]
|
|
18
|
+
NormType = Union[nn.BatchNorm2d, nn.BatchNorm3d]
|
|
19
|
+
DropoutType = Union[nn.Dropout2d, nn.Dropout3d]
|
|
20
|
+
|
|
25
21
|
|
|
26
22
|
class ResidualBlock(nn.Module):
|
|
27
23
|
"""
|
|
@@ -51,13 +47,13 @@ class ResidualBlock(nn.Module):
|
|
|
51
47
|
self,
|
|
52
48
|
channels: int,
|
|
53
49
|
nonlin: Callable,
|
|
54
|
-
|
|
50
|
+
conv_strides: tuple[int] = (2, 2),
|
|
51
|
+
kernel: Union[int, Iterable[int], None] = None,
|
|
55
52
|
groups: int = 1,
|
|
56
53
|
batchnorm: bool = True,
|
|
57
54
|
block_type: str = None,
|
|
58
55
|
dropout: float = None,
|
|
59
56
|
gated: bool = None,
|
|
60
|
-
skip_padding: bool = False,
|
|
61
57
|
conv2d_bias: bool = True,
|
|
62
58
|
):
|
|
63
59
|
"""
|
|
@@ -85,8 +81,6 @@ class ResidualBlock(nn.Module):
|
|
|
85
81
|
Default is `None`.
|
|
86
82
|
gated: bool, optional
|
|
87
83
|
Whether to use gated layer. Default is `None`.
|
|
88
|
-
skip_padding: bool, optional
|
|
89
|
-
Whether to skip padding in convolutions. Default is `False`.
|
|
90
84
|
conv2d_bias: bool, optional
|
|
91
85
|
Whether to use bias term in convolutions. Default is `True`.
|
|
92
86
|
"""
|
|
@@ -99,99 +93,142 @@ class ResidualBlock(nn.Module):
|
|
|
99
93
|
kernel = (kernel, kernel)
|
|
100
94
|
elif len(kernel) != 2:
|
|
101
95
|
raise ValueError("kernel has to be None, int, or an iterable of length 2")
|
|
102
|
-
assert all(
|
|
96
|
+
assert all(k % 2 == 1 for k in kernel), "kernel sizes have to be odd"
|
|
103
97
|
kernel = list(kernel)
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
98
|
+
|
|
99
|
+
# Define modules
|
|
100
|
+
conv_layer: ConvType = getattr(nn, f"Conv{len(conv_strides)}d")
|
|
101
|
+
norm_layer: NormType = getattr(nn, f"BatchNorm{len(conv_strides)}d")
|
|
102
|
+
dropout_layer: DropoutType = getattr(nn, f"Dropout{len(conv_strides)}d")
|
|
103
|
+
# TODO: same comment as in lvae.py, would be more readable to have `conv_dims`
|
|
107
104
|
|
|
108
105
|
modules = []
|
|
109
106
|
if block_type == "cabdcabd":
|
|
110
107
|
for i in range(2):
|
|
111
|
-
conv =
|
|
108
|
+
conv = conv_layer(
|
|
112
109
|
channels,
|
|
113
110
|
channels,
|
|
114
111
|
kernel[i],
|
|
115
|
-
padding=
|
|
112
|
+
padding="same",
|
|
116
113
|
groups=groups,
|
|
117
114
|
bias=conv2d_bias,
|
|
118
115
|
)
|
|
119
116
|
modules.append(conv)
|
|
120
117
|
modules.append(nonlin)
|
|
121
118
|
if batchnorm:
|
|
122
|
-
modules.append(
|
|
119
|
+
modules.append(norm_layer(channels))
|
|
123
120
|
if dropout is not None:
|
|
124
|
-
modules.append(
|
|
121
|
+
modules.append(dropout_layer(dropout))
|
|
125
122
|
elif block_type == "bacdbac":
|
|
126
123
|
for i in range(2):
|
|
127
124
|
if batchnorm:
|
|
128
|
-
modules.append(
|
|
125
|
+
modules.append(norm_layer(channels))
|
|
129
126
|
modules.append(nonlin)
|
|
130
|
-
conv =
|
|
127
|
+
conv = conv_layer(
|
|
131
128
|
channels,
|
|
132
129
|
channels,
|
|
133
130
|
kernel[i],
|
|
134
|
-
padding=
|
|
131
|
+
padding="same",
|
|
135
132
|
groups=groups,
|
|
136
133
|
bias=conv2d_bias,
|
|
137
134
|
)
|
|
138
135
|
modules.append(conv)
|
|
139
136
|
if dropout is not None and i == 0:
|
|
140
|
-
modules.append(
|
|
137
|
+
modules.append(dropout_layer(dropout))
|
|
141
138
|
elif block_type == "bacdbacd":
|
|
142
139
|
for i in range(2):
|
|
143
140
|
if batchnorm:
|
|
144
|
-
modules.append(
|
|
141
|
+
modules.append(norm_layer(channels))
|
|
145
142
|
modules.append(nonlin)
|
|
146
|
-
conv =
|
|
143
|
+
conv = conv_layer(
|
|
147
144
|
channels,
|
|
148
145
|
channels,
|
|
149
146
|
kernel[i],
|
|
150
|
-
padding=
|
|
147
|
+
padding="same",
|
|
151
148
|
groups=groups,
|
|
152
149
|
bias=conv2d_bias,
|
|
153
150
|
)
|
|
154
151
|
modules.append(conv)
|
|
155
|
-
modules.append(
|
|
152
|
+
modules.append(dropout_layer(dropout))
|
|
156
153
|
|
|
157
154
|
else:
|
|
158
155
|
raise ValueError(f"unrecognized block type '{block_type}'")
|
|
159
156
|
|
|
160
157
|
self.gated = gated
|
|
161
158
|
if gated:
|
|
162
|
-
modules.append(
|
|
159
|
+
modules.append(
|
|
160
|
+
GateLayer(
|
|
161
|
+
channels=channels,
|
|
162
|
+
conv_strides=conv_strides,
|
|
163
|
+
kernel_size=1,
|
|
164
|
+
nonlin=nonlin,
|
|
165
|
+
)
|
|
166
|
+
)
|
|
163
167
|
|
|
164
168
|
self.block = nn.Sequential(*modules)
|
|
165
169
|
|
|
166
|
-
def forward(self, x):
|
|
170
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
171
|
+
"""Forward pass.
|
|
172
|
+
|
|
173
|
+
Parameters
|
|
174
|
+
----------
|
|
175
|
+
x : torch.Tensor
|
|
176
|
+
input tensor # TODO add shape
|
|
167
177
|
|
|
178
|
+
Returns
|
|
179
|
+
-------
|
|
180
|
+
torch.Tensor
|
|
181
|
+
output tensor # TODO add shape
|
|
182
|
+
"""
|
|
168
183
|
out = self.block(x)
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
184
|
+
assert (
|
|
185
|
+
out.shape == x.shape
|
|
186
|
+
), f"output shape: {out.shape} != input shape: {x.shape}"
|
|
187
|
+
return out + x
|
|
173
188
|
|
|
174
189
|
|
|
175
190
|
class ResidualGatedBlock(ResidualBlock):
|
|
191
|
+
"""Layer class that implements a residual block with a gating mechanism."""
|
|
176
192
|
|
|
177
193
|
def __init__(self, *args, **kwargs):
|
|
178
194
|
super().__init__(*args, **kwargs, gated=True)
|
|
179
195
|
|
|
180
196
|
|
|
181
|
-
class
|
|
197
|
+
class GateLayer(nn.Module):
|
|
182
198
|
"""
|
|
199
|
+
Layer class that implements a gating mechanism.
|
|
200
|
+
|
|
183
201
|
Double the number of channels through a convolutional layer, then use
|
|
184
202
|
half the channels as gate for the other half.
|
|
185
203
|
"""
|
|
186
204
|
|
|
187
|
-
def __init__(
|
|
205
|
+
def __init__(
|
|
206
|
+
self,
|
|
207
|
+
channels: int,
|
|
208
|
+
conv_strides: tuple[int] = (2, 2),
|
|
209
|
+
kernel_size: int = 3,
|
|
210
|
+
nonlin: Callable = nn.LeakyReLU(),
|
|
211
|
+
):
|
|
188
212
|
super().__init__()
|
|
189
213
|
assert kernel_size % 2 == 1
|
|
190
214
|
pad = kernel_size // 2
|
|
191
|
-
|
|
215
|
+
conv_layer: ConvType = getattr(nn, f"Conv{len(conv_strides)}d")
|
|
216
|
+
self.conv = conv_layer(channels, 2 * channels, kernel_size, padding=pad)
|
|
192
217
|
self.nonlin = nonlin
|
|
193
218
|
|
|
194
|
-
def forward(self, x):
|
|
219
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
220
|
+
"""Forward pass.
|
|
221
|
+
|
|
222
|
+
Parameters
|
|
223
|
+
----------
|
|
224
|
+
x : torch.Tensor
|
|
225
|
+
input # TODO add shape
|
|
226
|
+
|
|
227
|
+
Returns
|
|
228
|
+
-------
|
|
229
|
+
torch.Tensor
|
|
230
|
+
output # TODO add shape
|
|
231
|
+
"""
|
|
195
232
|
x = self.conv(x)
|
|
196
233
|
x, gate = torch.chunk(x, 2, dim=1)
|
|
197
234
|
x = self.nonlin(x) # TODO remove this?
|
|
@@ -201,6 +238,8 @@ class GateLayer2d(nn.Module):
|
|
|
201
238
|
|
|
202
239
|
class ResBlockWithResampling(nn.Module):
|
|
203
240
|
"""
|
|
241
|
+
Residual block with resampling.
|
|
242
|
+
|
|
204
243
|
Residual block that takes care of resampling (i.e. downsampling or upsampling) steps (by a factor 2).
|
|
205
244
|
It is structured as follows:
|
|
206
245
|
1. `pre_conv`: a downsampling or upsampling strided convolutional layer in case of resampling, or
|
|
@@ -210,7 +249,7 @@ class ResBlockWithResampling(nn.Module):
|
|
|
210
249
|
|
|
211
250
|
Some implementation notes:
|
|
212
251
|
- Resampling is performed through a strided convolution layer at the beginning of the block.
|
|
213
|
-
- The strided convolution block has fixed kernel size of 3x3 and 1 layer of
|
|
252
|
+
- The strided convolution block has fixed kernel size of 3x3 and 1 layer of padding with zeros.
|
|
214
253
|
- The number of channels is adjusted at the beginning and end of the block through 1x1 convolutional layers.
|
|
215
254
|
- The number of internal channels is by default the same as the number of output channels, but
|
|
216
255
|
min_inner_channels can override the behaviour.
|
|
@@ -221,16 +260,16 @@ class ResBlockWithResampling(nn.Module):
|
|
|
221
260
|
mode: Literal["top-down", "bottom-up"],
|
|
222
261
|
c_in: int,
|
|
223
262
|
c_out: int,
|
|
224
|
-
|
|
225
|
-
|
|
263
|
+
conv_strides: tuple[int],
|
|
264
|
+
min_inner_channels: Union[int, None] = None,
|
|
265
|
+
nonlin: Callable = nn.LeakyReLU(),
|
|
226
266
|
resample: bool = False,
|
|
227
|
-
res_block_kernel: Union[int, Iterable[int]] = None,
|
|
267
|
+
res_block_kernel: Optional[Union[int, Iterable[int]]] = None,
|
|
228
268
|
groups: int = 1,
|
|
229
269
|
batchnorm: bool = True,
|
|
230
|
-
res_block_type: str = None,
|
|
231
|
-
dropout: float = None,
|
|
232
|
-
gated: bool = None,
|
|
233
|
-
skip_padding: bool = False,
|
|
270
|
+
res_block_type: Union[str, None] = None,
|
|
271
|
+
dropout: Union[float, None] = None,
|
|
272
|
+
gated: Union[bool, None] = None,
|
|
234
273
|
conv2d_bias: bool = True,
|
|
235
274
|
# lowres_input: bool = False,
|
|
236
275
|
):
|
|
@@ -273,14 +312,15 @@ class ResBlockWithResampling(nn.Module):
|
|
|
273
312
|
Default is `None`.
|
|
274
313
|
gated: bool, optional
|
|
275
314
|
Whether to use gated layer. Default is `None`.
|
|
276
|
-
skip_padding: bool, optional
|
|
277
|
-
Whether to skip padding in convolutions. Default is `False`.
|
|
278
315
|
conv2d_bias: bool, optional
|
|
279
316
|
Whether to use bias term in convolutions. Default is `True`.
|
|
280
317
|
"""
|
|
281
318
|
super().__init__()
|
|
282
319
|
assert mode in ["top-down", "bottom-up"]
|
|
283
320
|
|
|
321
|
+
conv_layer: ConvType = getattr(nn, f"Conv{len(conv_strides)}d")
|
|
322
|
+
transp_conv_layer: ConvType = getattr(nn, f"ConvTranspose{len(conv_strides)}d")
|
|
323
|
+
|
|
284
324
|
if min_inner_channels is None:
|
|
285
325
|
min_inner_channels = 0
|
|
286
326
|
# inner_channels is the number of channels used in the inner layers
|
|
@@ -290,28 +330,28 @@ class ResBlockWithResampling(nn.Module):
|
|
|
290
330
|
# Define first conv layer to change num channels and/or up/downsample
|
|
291
331
|
if resample:
|
|
292
332
|
if mode == "bottom-up": # downsample
|
|
293
|
-
self.pre_conv =
|
|
333
|
+
self.pre_conv = conv_layer(
|
|
294
334
|
in_channels=c_in,
|
|
295
335
|
out_channels=inner_channels,
|
|
296
336
|
kernel_size=3,
|
|
297
337
|
padding=1,
|
|
298
|
-
stride=
|
|
338
|
+
stride=conv_strides,
|
|
299
339
|
groups=groups,
|
|
300
340
|
bias=conv2d_bias,
|
|
301
341
|
)
|
|
302
342
|
elif mode == "top-down": # upsample
|
|
303
|
-
self.pre_conv =
|
|
343
|
+
self.pre_conv = transp_conv_layer(
|
|
304
344
|
in_channels=c_in,
|
|
305
345
|
kernel_size=3,
|
|
306
346
|
out_channels=inner_channels,
|
|
307
|
-
padding=1,
|
|
308
|
-
stride=
|
|
347
|
+
padding=1, # TODO maybe don't hardcode this?
|
|
348
|
+
stride=conv_strides,
|
|
309
349
|
groups=groups,
|
|
310
|
-
output_padding=1,
|
|
350
|
+
output_padding=1 if len(conv_strides) == 2 else (0, 1, 1),
|
|
311
351
|
bias=conv2d_bias,
|
|
312
352
|
)
|
|
313
353
|
elif c_in != inner_channels:
|
|
314
|
-
self.pre_conv =
|
|
354
|
+
self.pre_conv = conv_layer(
|
|
315
355
|
c_in, inner_channels, 1, groups=groups, bias=conv2d_bias
|
|
316
356
|
)
|
|
317
357
|
else:
|
|
@@ -320,6 +360,7 @@ class ResBlockWithResampling(nn.Module):
|
|
|
320
360
|
# Residual block
|
|
321
361
|
self.res = ResidualBlock(
|
|
322
362
|
channels=inner_channels,
|
|
363
|
+
conv_strides=conv_strides,
|
|
323
364
|
nonlin=nonlin,
|
|
324
365
|
kernel=res_block_kernel,
|
|
325
366
|
groups=groups,
|
|
@@ -327,19 +368,30 @@ class ResBlockWithResampling(nn.Module):
|
|
|
327
368
|
dropout=dropout,
|
|
328
369
|
gated=gated,
|
|
329
370
|
block_type=res_block_type,
|
|
330
|
-
skip_padding=skip_padding,
|
|
331
371
|
conv2d_bias=conv2d_bias,
|
|
332
372
|
)
|
|
333
373
|
|
|
334
374
|
# Define last conv layer to get correct num output channels
|
|
335
375
|
if inner_channels != c_out:
|
|
336
|
-
self.post_conv =
|
|
376
|
+
self.post_conv = conv_layer(
|
|
337
377
|
inner_channels, c_out, 1, groups=groups, bias=conv2d_bias
|
|
338
378
|
)
|
|
339
379
|
else:
|
|
340
380
|
self.post_conv = None
|
|
341
381
|
|
|
342
|
-
def forward(self, x):
|
|
382
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
383
|
+
"""Forward pass.
|
|
384
|
+
|
|
385
|
+
Parameters
|
|
386
|
+
----------
|
|
387
|
+
x : torch.Tensor
|
|
388
|
+
input # TODO add shape
|
|
389
|
+
|
|
390
|
+
Returns
|
|
391
|
+
-------
|
|
392
|
+
torch.Tensor
|
|
393
|
+
output # TODO add shape
|
|
394
|
+
"""
|
|
343
395
|
if self.pre_conv is not None:
|
|
344
396
|
x = self.pre_conv(x)
|
|
345
397
|
|
|
@@ -351,6 +403,7 @@ class ResBlockWithResampling(nn.Module):
|
|
|
351
403
|
|
|
352
404
|
|
|
353
405
|
class TopDownDeterministicResBlock(ResBlockWithResampling):
|
|
406
|
+
"""Resnet block for top-down deterministic layers."""
|
|
354
407
|
|
|
355
408
|
def __init__(self, *args, upsample: bool = False, **kwargs):
|
|
356
409
|
kwargs["resample"] = upsample
|
|
@@ -358,6 +411,7 @@ class TopDownDeterministicResBlock(ResBlockWithResampling):
|
|
|
358
411
|
|
|
359
412
|
|
|
360
413
|
class BottomUpDeterministicResBlock(ResBlockWithResampling):
|
|
414
|
+
"""Resnet block for bottom-up deterministic layers."""
|
|
361
415
|
|
|
362
416
|
def __init__(self, *args, downsample: bool = False, **kwargs):
|
|
363
417
|
kwargs["resample"] = downsample
|
|
@@ -367,6 +421,7 @@ class BottomUpDeterministicResBlock(ResBlockWithResampling):
|
|
|
367
421
|
class BottomUpLayer(nn.Module):
|
|
368
422
|
"""
|
|
369
423
|
Bottom-up deterministic layer.
|
|
424
|
+
|
|
370
425
|
It consists of one or a stack of `BottomUpDeterministicResBlock`'s.
|
|
371
426
|
The outputs are the so-called `bu_values` that are later used in the Decoder to update the
|
|
372
427
|
generative distributions.
|
|
@@ -385,20 +440,20 @@ class BottomUpLayer(nn.Module):
|
|
|
385
440
|
self,
|
|
386
441
|
n_res_blocks: int,
|
|
387
442
|
n_filters: int,
|
|
443
|
+
conv_strides: tuple[int] = (2, 2),
|
|
388
444
|
downsampling_steps: int = 0,
|
|
389
|
-
nonlin: Callable = None,
|
|
445
|
+
nonlin: Optional[Callable] = None,
|
|
390
446
|
batchnorm: bool = True,
|
|
391
|
-
dropout: float = None,
|
|
392
|
-
res_block_type: str = None,
|
|
393
|
-
res_block_kernel: int = None,
|
|
394
|
-
|
|
395
|
-
gated: bool = None,
|
|
447
|
+
dropout: Optional[float] = None,
|
|
448
|
+
res_block_type: Optional[str] = None,
|
|
449
|
+
res_block_kernel: Optional[int] = None,
|
|
450
|
+
gated: Optional[bool] = None,
|
|
396
451
|
enable_multiscale: bool = False,
|
|
397
|
-
multiscale_lowres_size_factor: int = None,
|
|
452
|
+
multiscale_lowres_size_factor: Optional[int] = None,
|
|
398
453
|
lowres_separate_branch: bool = False,
|
|
399
454
|
multiscale_retain_spatial_dims: bool = False,
|
|
400
455
|
decoder_retain_spatial_dims: bool = False,
|
|
401
|
-
output_expected_shape: Iterable[int] = None,
|
|
456
|
+
output_expected_shape: Optional[Iterable[int]] = None,
|
|
402
457
|
):
|
|
403
458
|
"""
|
|
404
459
|
Constructor.
|
|
@@ -427,8 +482,6 @@ class BottomUpLayer(nn.Module):
|
|
|
427
482
|
The kernel size used in the convolutions of the residual block.
|
|
428
483
|
It can be either a single integer or a pair of integers defining the squared kernel.
|
|
429
484
|
Default is `None`.
|
|
430
|
-
res_block_skip_padding: bool, optional
|
|
431
|
-
Whether to skip padding in convolutions in the Residual block. Default is `False`.
|
|
432
485
|
gated: bool, optional
|
|
433
486
|
Whether to use gated layer. Default is `None`.
|
|
434
487
|
enable_multiscale: bool, optional
|
|
@@ -443,7 +496,8 @@ class BottomUpLayer(nn.Module):
|
|
|
443
496
|
Whether to pad the latent tensor resulting from the bottom-up layer's primary flow
|
|
444
497
|
to match the size of the low-res input. Default is `False`.
|
|
445
498
|
decoder_retain_spatial_dims: bool, optional
|
|
446
|
-
|
|
499
|
+
Whether in the corresponding top-down layer the shape of tensor is retained between
|
|
500
|
+
input and output. Default is `False`.
|
|
447
501
|
output_expected_shape: Iterable[int], optional
|
|
448
502
|
The expected shape of the layer output (only used if `enable_multiscale == True`).
|
|
449
503
|
Default is `None`.
|
|
@@ -467,6 +521,7 @@ class BottomUpLayer(nn.Module):
|
|
|
467
521
|
do_resample = True
|
|
468
522
|
downsampling_steps -= 1
|
|
469
523
|
block = BottomUpDeterministicResBlock(
|
|
524
|
+
conv_strides=conv_strides,
|
|
470
525
|
c_in=n_filters,
|
|
471
526
|
c_out=n_filters,
|
|
472
527
|
nonlin=nonlin,
|
|
@@ -475,7 +530,6 @@ class BottomUpLayer(nn.Module):
|
|
|
475
530
|
dropout=dropout,
|
|
476
531
|
res_block_type=res_block_type,
|
|
477
532
|
res_block_kernel=res_block_kernel,
|
|
478
|
-
skip_padding=res_block_skip_padding,
|
|
479
533
|
gated=gated,
|
|
480
534
|
)
|
|
481
535
|
if do_resample:
|
|
@@ -491,6 +545,7 @@ class BottomUpLayer(nn.Module):
|
|
|
491
545
|
if self.enable_multiscale:
|
|
492
546
|
self._init_multiscale(
|
|
493
547
|
n_filters=n_filters,
|
|
548
|
+
conv_strides=conv_strides,
|
|
494
549
|
nonlin=nonlin,
|
|
495
550
|
batchnorm=batchnorm,
|
|
496
551
|
dropout=dropout,
|
|
@@ -506,20 +561,25 @@ class BottomUpLayer(nn.Module):
|
|
|
506
561
|
self,
|
|
507
562
|
nonlin: Callable = None,
|
|
508
563
|
n_filters: int = None,
|
|
564
|
+
conv_strides: tuple[int] = (2, 2),
|
|
509
565
|
batchnorm: bool = None,
|
|
510
566
|
dropout: float = None,
|
|
511
567
|
res_block_type: str = None,
|
|
512
568
|
) -> None:
|
|
513
569
|
"""
|
|
514
|
-
|
|
515
|
-
of the primary flow at different hierarchical levels in the multiresolution approach (LC).
|
|
570
|
+
Bottom-up layer's method that initializes the LC modules.
|
|
516
571
|
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
572
|
+
Defines the modules responsible of merging compressed lateral inputs to the
|
|
573
|
+
outputs of the primary flow at different hierarchical levels in the
|
|
574
|
+
multiresolution approach (LC). Specifically, the method initializes `lowres_net`
|
|
575
|
+
, which is a stack of `BottomUpDeterministicBlock`'s (w/out downsampling) that
|
|
576
|
+
takes care of additionally processing the low-res input, and `lowres_merge`,
|
|
577
|
+
which is the module responsible of merging the compressed lateral input to the
|
|
578
|
+
main flow.
|
|
520
579
|
|
|
521
|
-
NOTE: The merge modality is set by default to "residual", meaning that the
|
|
522
|
-
performs concatenation on dim=1, followed by 1x1 convolution and
|
|
580
|
+
NOTE: The merge modality is set by default to "residual", meaning that the
|
|
581
|
+
merge layer performs concatenation on dim=1, followed by 1x1 convolution and
|
|
582
|
+
a Residual Gated block.
|
|
523
583
|
|
|
524
584
|
Parameters
|
|
525
585
|
----------
|
|
@@ -543,6 +603,7 @@ class BottomUpLayer(nn.Module):
|
|
|
543
603
|
|
|
544
604
|
self.lowres_merge = MergeLowRes(
|
|
545
605
|
channels=n_filters,
|
|
606
|
+
conv_strides=conv_strides,
|
|
546
607
|
merge_type="residual",
|
|
547
608
|
nonlin=nonlin,
|
|
548
609
|
batchnorm=batchnorm,
|
|
@@ -553,9 +614,10 @@ class BottomUpLayer(nn.Module):
|
|
|
553
614
|
)
|
|
554
615
|
|
|
555
616
|
def forward(
|
|
556
|
-
self, x: torch.Tensor, lowres_x: torch.Tensor = None
|
|
557
|
-
) ->
|
|
558
|
-
"""
|
|
617
|
+
self, x: torch.Tensor, lowres_x: Union[torch.Tensor, None] = None
|
|
618
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
619
|
+
"""Forward pass.
|
|
620
|
+
|
|
559
621
|
Parameters
|
|
560
622
|
----------
|
|
561
623
|
x: torch.Tensor
|
|
@@ -563,6 +625,9 @@ class BottomUpLayer(nn.Module):
|
|
|
563
625
|
previous layer.
|
|
564
626
|
lowres_x: torch.Tensor, optional
|
|
565
627
|
The low-res input used for Lateral Contextualization (LC). Default is `None`.
|
|
628
|
+
|
|
629
|
+
NOTE: first returned tensor is used as input for the next BU layer, while the second
|
|
630
|
+
tensor is the bu_value passed to the top-down layer.
|
|
566
631
|
"""
|
|
567
632
|
# The input is fed through the residual downsampling block(s)
|
|
568
633
|
primary_flow = self.net_downsized(x)
|
|
@@ -582,12 +647,25 @@ class BottomUpLayer(nn.Module):
|
|
|
582
647
|
else:
|
|
583
648
|
merged = primary_flow
|
|
584
649
|
|
|
650
|
+
# NOTE: Explanation of possible cases for the conditionals:
|
|
651
|
+
# - if both are `True` -> `merged` has the same spatial dims as the input (`x`) since
|
|
652
|
+
# spatial dims are retained by padding `primary_flow` in `MergeLowRes`. This is
|
|
653
|
+
# OK for the corresp TopDown layer, as it also retains spatial dims.
|
|
654
|
+
# - if both are `False` -> `merged`'s spatial dims are equal to `self.net_downsized(x)`,
|
|
655
|
+
# since no padding is done in `MergeLowRes` and, instead, the lowres input is cropped.
|
|
656
|
+
# This is OK for the corresp TopDown layer, as it also halves the spatial dims.
|
|
657
|
+
# - if 1st is `False` and 2nd is `True` -> not a concern, it cannot happen
|
|
658
|
+
# (see lvae.py, line 111, intialization of `multiscale_decoder_retain_spatial_dims`).
|
|
585
659
|
if (
|
|
586
660
|
self.multiscale_retain_spatial_dims is False
|
|
587
661
|
or self.decoder_retain_spatial_dims is True
|
|
588
662
|
):
|
|
589
663
|
return merged, merged
|
|
590
664
|
|
|
665
|
+
# NOTE: if we reach here, it means that `multiscale_retain_spatial_dims` is `True`,
|
|
666
|
+
# but `decoder_retain_spatial_dims` is `False`, meaning that merging LC preserves
|
|
667
|
+
# the spatial dimensions, but at the same time we don't want to retain the spatial
|
|
668
|
+
# dims in the corresponding top-down layer. Therefore, we need to crop the tensor.
|
|
591
669
|
if self.output_expected_shape is not None:
|
|
592
670
|
expected_shape = self.output_expected_shape
|
|
593
671
|
else:
|
|
@@ -602,7 +680,10 @@ class BottomUpLayer(nn.Module):
|
|
|
602
680
|
|
|
603
681
|
class MergeLayer(nn.Module):
|
|
604
682
|
"""
|
|
605
|
-
|
|
683
|
+
Layer class that merges two or more input tensors.
|
|
684
|
+
|
|
685
|
+
Merges two or more (B, C, [Z], Y, X) input tensors by concatenating
|
|
686
|
+
them along dim=1 and passes the result through:
|
|
606
687
|
a) a convolutional 1x1 layer (`merge_type == "linear"`), or
|
|
607
688
|
b) a convolutional 1x1 layer and then a gated residual block (`merge_type == "residual"`), or
|
|
608
689
|
c) a convolutional 1x1 layer and then an ungated residual block (`merge_type == "residual_ungated"`).
|
|
@@ -612,13 +693,13 @@ class MergeLayer(nn.Module):
|
|
|
612
693
|
self,
|
|
613
694
|
merge_type: Literal["linear", "residual", "residual_ungated"],
|
|
614
695
|
channels: Union[int, Iterable[int]],
|
|
615
|
-
|
|
696
|
+
conv_strides: tuple[int] = (2, 2),
|
|
697
|
+
nonlin: Callable = nn.LeakyReLU(),
|
|
616
698
|
batchnorm: bool = True,
|
|
617
|
-
dropout: float = None,
|
|
618
|
-
res_block_type: str = None,
|
|
619
|
-
res_block_kernel: int = None,
|
|
620
|
-
|
|
621
|
-
conv2d_bias: bool = True,
|
|
699
|
+
dropout: Optional[float] = None,
|
|
700
|
+
res_block_type: Optional[str] = None,
|
|
701
|
+
res_block_kernel: Optional[int] = None,
|
|
702
|
+
conv2d_bias: Optional[bool] = True,
|
|
622
703
|
):
|
|
623
704
|
"""
|
|
624
705
|
Constructor.
|
|
@@ -626,16 +707,21 @@ class MergeLayer(nn.Module):
|
|
|
626
707
|
Parameters
|
|
627
708
|
----------
|
|
628
709
|
merge_type: Literal["linear", "residual", "residual_ungated"]
|
|
629
|
-
The type of merge done in the layer. It can be chosen between "linear",
|
|
630
|
-
Check the class docstring for more
|
|
710
|
+
The type of merge done in the layer. It can be chosen between "linear",
|
|
711
|
+
"residual", and "residual_ungated". Check the class docstring for more
|
|
712
|
+
information about the behaviour of different merge modalities.
|
|
631
713
|
channels: Union[int, Iterable[int]]
|
|
632
714
|
The number of channels used in the convolutional blocks of this layer.
|
|
633
715
|
If it is an `int`:
|
|
634
716
|
- 1st 1x1 Conv2d: in_channels=2*channels, out_channels=channels
|
|
635
717
|
- (Optional) ResBlock: in_channels=channels, out_channels=channels
|
|
636
718
|
If it is an Iterable (must have `len(channels)==3`):
|
|
637
|
-
- 1st 1x1 Conv2d: in_channels=sum(channels[:-1]),
|
|
638
|
-
|
|
719
|
+
- 1st 1x1 Conv2d: in_channels=sum(channels[:-1]),
|
|
720
|
+
out_channels=channels[-1]
|
|
721
|
+
- (Optional) ResBlock: in_channels=channels[-1],
|
|
722
|
+
out_channels=channels[-1]
|
|
723
|
+
conv_strides: tuple, optional
|
|
724
|
+
The strides used in the convolutions. Default is `(2, 2)`.
|
|
639
725
|
nonlin: Callable, optional
|
|
640
726
|
The non-linearity function used in the block. Default is `nn.LeakyReLU`.
|
|
641
727
|
batchnorm: bool, optional
|
|
@@ -649,10 +735,9 @@ class MergeLayer(nn.Module):
|
|
|
649
735
|
Default is `None`.
|
|
650
736
|
res_block_kernel: Union[int, Iterable[int]], optional
|
|
651
737
|
The kernel size used in the convolutions of the residual block.
|
|
652
|
-
It can be either a single integer or a pair of integers defining the squared
|
|
738
|
+
It can be either a single integer or a pair of integers defining the squared
|
|
739
|
+
kernel.
|
|
653
740
|
Default is `None`.
|
|
654
|
-
res_block_skip_padding: bool, optional
|
|
655
|
-
Whether to skip padding in convolutions in the Residual block. Default is `False`.
|
|
656
741
|
conv2d_bias: bool, optional
|
|
657
742
|
Whether to use bias term in convolutions. Default is `True`.
|
|
658
743
|
"""
|
|
@@ -665,42 +750,42 @@ class MergeLayer(nn.Module):
|
|
|
665
750
|
if len(channels) == 1:
|
|
666
751
|
channels = [channels[0]] * 3
|
|
667
752
|
|
|
668
|
-
|
|
753
|
+
self.conv_layer: ConvType = getattr(nn, f"Conv{len(conv_strides)}d")
|
|
669
754
|
|
|
670
755
|
if merge_type == "linear":
|
|
671
|
-
self.layer =
|
|
756
|
+
self.layer = self.conv_layer(
|
|
672
757
|
sum(channels[:-1]), channels[-1], 1, bias=conv2d_bias
|
|
673
758
|
)
|
|
674
759
|
elif merge_type == "residual":
|
|
675
760
|
self.layer = nn.Sequential(
|
|
676
|
-
|
|
761
|
+
self.conv_layer(
|
|
677
762
|
sum(channels[:-1]), channels[-1], 1, padding=0, bias=conv2d_bias
|
|
678
763
|
),
|
|
679
764
|
ResidualGatedBlock(
|
|
680
|
-
|
|
681
|
-
|
|
765
|
+
conv_strides=conv_strides,
|
|
766
|
+
channels=channels[-1],
|
|
767
|
+
nonlin=nonlin,
|
|
682
768
|
batchnorm=batchnorm,
|
|
683
769
|
dropout=dropout,
|
|
684
770
|
block_type=res_block_type,
|
|
685
771
|
kernel=res_block_kernel,
|
|
686
772
|
conv2d_bias=conv2d_bias,
|
|
687
|
-
skip_padding=res_block_skip_padding,
|
|
688
773
|
),
|
|
689
774
|
)
|
|
690
775
|
elif merge_type == "residual_ungated":
|
|
691
776
|
self.layer = nn.Sequential(
|
|
692
|
-
|
|
777
|
+
self.conv_layer(
|
|
693
778
|
sum(channels[:-1]), channels[-1], 1, padding=0, bias=conv2d_bias
|
|
694
779
|
),
|
|
695
780
|
ResidualBlock(
|
|
696
|
-
|
|
697
|
-
|
|
781
|
+
conv_strides=conv_strides,
|
|
782
|
+
channels=channels[-1],
|
|
783
|
+
nonlin=nonlin,
|
|
698
784
|
batchnorm=batchnorm,
|
|
699
785
|
dropout=dropout,
|
|
700
786
|
block_type=res_block_type,
|
|
701
787
|
kernel=res_block_kernel,
|
|
702
788
|
conv2d_bias=conv2d_bias,
|
|
703
|
-
skip_padding=res_block_skip_padding,
|
|
704
789
|
),
|
|
705
790
|
)
|
|
706
791
|
|
|
@@ -717,7 +802,9 @@ class MergeLayer(nn.Module):
|
|
|
717
802
|
|
|
718
803
|
class MergeLowRes(MergeLayer):
|
|
719
804
|
"""
|
|
720
|
-
Child class of `MergeLayer
|
|
805
|
+
Child class of `MergeLayer`.
|
|
806
|
+
|
|
807
|
+
Specifically designed to merge the low-resolution patches
|
|
721
808
|
that are used in Lateral Contextualization approach.
|
|
722
809
|
"""
|
|
723
810
|
|
|
@@ -727,7 +814,8 @@ class MergeLowRes(MergeLayer):
|
|
|
727
814
|
super().__init__(*args, **kwargs)
|
|
728
815
|
|
|
729
816
|
def forward(self, latent: torch.Tensor, lowres: torch.Tensor) -> torch.Tensor:
|
|
730
|
-
"""
|
|
817
|
+
"""Forward pass.
|
|
818
|
+
|
|
731
819
|
Parameters
|
|
732
820
|
----------
|
|
733
821
|
latent: torch.Tensor
|
|
@@ -735,25 +823,28 @@ class MergeLowRes(MergeLayer):
|
|
|
735
823
|
lowres: torch.Tensor
|
|
736
824
|
The low-res patch image to be merged to increase the context.
|
|
737
825
|
"""
|
|
826
|
+
# TODO: treat (X, Y) and Z differently (e.g., line 762)
|
|
738
827
|
if self.retain_spatial_dims:
|
|
739
828
|
# Pad latent tensor to match lowres tensor's shape
|
|
829
|
+
# Output.shape == Lowres.shape (== Input.shape),
|
|
830
|
+
# where Input is the input to the BU layer
|
|
740
831
|
latent = pad_img_tensor(latent, lowres.shape[2:])
|
|
741
832
|
else:
|
|
742
833
|
# Crop lowres tensor to match latent tensor's shape
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
834
|
+
lz, ly, lx = lowres.shape[2:]
|
|
835
|
+
z = lz // self.multiscale_lowres_size_factor
|
|
836
|
+
y = ly // self.multiscale_lowres_size_factor
|
|
837
|
+
x = lx // self.multiscale_lowres_size_factor
|
|
838
|
+
z_pad = (lz - z) // 2
|
|
839
|
+
y_pad = (ly - y) // 2
|
|
840
|
+
x_pad = (lx - x) // 2
|
|
841
|
+
lowres = lowres[:, :, z_pad:-z_pad, y_pad:-y_pad, x_pad:-x_pad]
|
|
749
842
|
|
|
750
843
|
return super().forward(latent, lowres)
|
|
751
844
|
|
|
752
845
|
|
|
753
846
|
class SkipConnectionMerger(MergeLayer):
|
|
754
|
-
"""
|
|
755
|
-
A specialized `MergeLayer` module, designed to handle skip connections in the model.
|
|
756
|
-
"""
|
|
847
|
+
"""Specialized `MergeLayer` module, handles skip connections in the model."""
|
|
757
848
|
|
|
758
849
|
def __init__(
|
|
759
850
|
self,
|
|
@@ -762,10 +853,10 @@ class SkipConnectionMerger(MergeLayer):
|
|
|
762
853
|
batchnorm: bool,
|
|
763
854
|
dropout: float,
|
|
764
855
|
res_block_type: str,
|
|
856
|
+
conv_strides: tuple[int] = (2, 2),
|
|
765
857
|
merge_type: Literal["linear", "residual", "residual_ungated"] = "residual",
|
|
766
858
|
conv2d_bias: bool = True,
|
|
767
|
-
res_block_kernel: int = None,
|
|
768
|
-
res_block_skip_padding: bool = False,
|
|
859
|
+
res_block_kernel: Optional[int] = None,
|
|
769
860
|
):
|
|
770
861
|
"""
|
|
771
862
|
Constructor.
|
|
@@ -780,15 +871,15 @@ class SkipConnectionMerger(MergeLayer):
|
|
|
780
871
|
If it is an Iterable (must have `len(channels)==3`):
|
|
781
872
|
- 1st 1x1 Conv2d: in_channels=sum(channels[:-1]), out_channels=channels[-1]
|
|
782
873
|
- (Optional) ResBlock: in_channels=channels[-1], out_channels=channels[-1]
|
|
783
|
-
batchnorm: bool
|
|
784
|
-
Whether to use batchnorm layers.
|
|
785
|
-
dropout: float
|
|
874
|
+
batchnorm: bool
|
|
875
|
+
Whether to use batchnorm layers.
|
|
876
|
+
dropout: float
|
|
786
877
|
The dropout probability in dropout layers. If `None` dropout is not used.
|
|
787
|
-
|
|
788
|
-
res_block_type: str, optional
|
|
878
|
+
res_block_type: str
|
|
789
879
|
A string specifying the structure of residual block.
|
|
790
880
|
Check `ResidualBlock` doscstring for more information.
|
|
791
|
-
|
|
881
|
+
conv_strides: tuple, optional
|
|
882
|
+
The strides used in the convolutions. Default is `(2, 2)`.
|
|
792
883
|
merge_type: Literal["linear", "residual", "residual_ungated"]
|
|
793
884
|
The type of merge done in the layer. It can be chosen between "linear", "residual", and "residual_ungated".
|
|
794
885
|
Check the class docstring for more information about the behaviour of different merge modalities.
|
|
@@ -798,10 +889,9 @@ class SkipConnectionMerger(MergeLayer):
|
|
|
798
889
|
The kernel size used in the convolutions of the residual block.
|
|
799
890
|
It can be either a single integer or a pair of integers defining the squared kernel.
|
|
800
891
|
Default is `None`.
|
|
801
|
-
res_block_skip_padding: bool, optional
|
|
802
|
-
Whether to skip padding in convolutions in the Residual block. Default is `False`.
|
|
803
892
|
"""
|
|
804
893
|
super().__init__(
|
|
894
|
+
conv_strides=conv_strides,
|
|
805
895
|
channels=channels,
|
|
806
896
|
nonlin=nonlin,
|
|
807
897
|
merge_type=merge_type,
|
|
@@ -810,26 +900,25 @@ class SkipConnectionMerger(MergeLayer):
|
|
|
810
900
|
res_block_type=res_block_type,
|
|
811
901
|
res_block_kernel=res_block_kernel,
|
|
812
902
|
conv2d_bias=conv2d_bias,
|
|
813
|
-
res_block_skip_padding=res_block_skip_padding,
|
|
814
903
|
)
|
|
815
904
|
|
|
816
905
|
|
|
817
906
|
class TopDownLayer(nn.Module):
|
|
818
|
-
"""
|
|
819
|
-
|
|
907
|
+
"""Top-down inference layer.
|
|
908
|
+
|
|
820
909
|
It includes:
|
|
821
910
|
- Stochastic sampling,
|
|
822
911
|
- Computation of KL divergence,
|
|
823
912
|
- A small deterministic ResNet that performs upsampling.
|
|
824
913
|
|
|
825
914
|
NOTE 1:
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
915
|
+
The algorithm for generative inference approximately works as follows:
|
|
916
|
+
- p_params = output of top-down layer above
|
|
917
|
+
- bu = inferred bottom-up value at this layer
|
|
918
|
+
- q_params = merge(bu, p_params)
|
|
919
|
+
- z = stochastic_layer(q_params)
|
|
920
|
+
- (optional) get and merge skip connection from prev top-down layer
|
|
921
|
+
- top-down deterministic ResNet
|
|
833
922
|
|
|
834
923
|
NOTE 2:
|
|
835
924
|
The Top-Down layer can work in two modes: inference and prediction/generative.
|
|
@@ -856,28 +945,26 @@ class TopDownLayer(nn.Module):
|
|
|
856
945
|
z_dim: int,
|
|
857
946
|
n_res_blocks: int,
|
|
858
947
|
n_filters: int,
|
|
948
|
+
conv_strides: tuple[int],
|
|
859
949
|
is_top_layer: bool = False,
|
|
860
|
-
|
|
861
|
-
nonlin: Callable = None,
|
|
862
|
-
merge_type:
|
|
950
|
+
upsampling_steps: Union[int, None] = None,
|
|
951
|
+
nonlin: Union[Callable, None] = None,
|
|
952
|
+
merge_type: Union[
|
|
953
|
+
Literal["linear", "residual", "residual_ungated"], None
|
|
954
|
+
] = None,
|
|
863
955
|
batchnorm: bool = True,
|
|
864
|
-
dropout: float = None,
|
|
956
|
+
dropout: Union[float, None] = None,
|
|
865
957
|
stochastic_skip: bool = False,
|
|
866
|
-
res_block_type: str = None,
|
|
867
|
-
res_block_kernel: int = None,
|
|
868
|
-
res_block_skip_padding: bool = None,
|
|
958
|
+
res_block_type: Union[str, None] = None,
|
|
959
|
+
res_block_kernel: Union[int, None] = None,
|
|
869
960
|
groups: int = 1,
|
|
870
|
-
gated: bool = None,
|
|
961
|
+
gated: Union[bool, None] = None,
|
|
871
962
|
learn_top_prior: bool = False,
|
|
872
|
-
top_prior_param_shape: Iterable[int] = None,
|
|
963
|
+
top_prior_param_shape: Union[Iterable[int], None] = None,
|
|
873
964
|
analytical_kl: bool = False,
|
|
874
|
-
bottomup_no_padding_mode: bool = False,
|
|
875
|
-
topdown_no_padding_mode: bool = False,
|
|
876
965
|
retain_spatial_dims: bool = False,
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
non_stochastic_version: bool = False,
|
|
880
|
-
input_image_shape: Union[None, Tuple[int, int]] = None,
|
|
966
|
+
vanilla_latent_hw: Union[Iterable[int], None] = None,
|
|
967
|
+
input_image_shape: Union[tuple[int, int], None] = None,
|
|
881
968
|
normalize_latent_factor: float = 1.0,
|
|
882
969
|
conv2d_bias: bool = True,
|
|
883
970
|
stochastic_use_naive_exponential: bool = False,
|
|
@@ -893,11 +980,13 @@ class TopDownLayer(nn.Module):
|
|
|
893
980
|
The number of TopDownDeterministicResBlock blocks
|
|
894
981
|
n_filters: int
|
|
895
982
|
The number of channels present through out the layers of this block.
|
|
983
|
+
conv_strides: tuple, optional
|
|
984
|
+
The strides used in the convolutions. Default is `(2, 2)`.
|
|
896
985
|
is_top_layer: bool, optional
|
|
897
986
|
Whether the current layer is at the top of the Decoder hierarchy. Default is `False`.
|
|
898
|
-
|
|
899
|
-
The number of
|
|
900
|
-
Default is `
|
|
987
|
+
upsampling_steps: int, optional
|
|
988
|
+
The number of upsampling steps that has to be done in this layer (typically 1).
|
|
989
|
+
Default is `None`.
|
|
901
990
|
nonlin: Callable, optional
|
|
902
991
|
The non-linearity function used in the block (e.g., `nn.ReLU`). Default is `None`.
|
|
903
992
|
merge_type: Literal["linear", "residual", "residual_ungated"], optional
|
|
@@ -921,8 +1010,6 @@ class TopDownLayer(nn.Module):
|
|
|
921
1010
|
The kernel size used in the convolutions of the residual block.
|
|
922
1011
|
It can be either a single integer or a pair of integers defining the squared kernel.
|
|
923
1012
|
Default is `None`.
|
|
924
|
-
res_block_skip_padding: bool, optional
|
|
925
|
-
Whether to skip padding in convolutions in the Residual block. Default is `None`.
|
|
926
1013
|
groups: int, optional
|
|
927
1014
|
The number of groups to consider in the convolutions. Default is 1.
|
|
928
1015
|
gated: bool, optional
|
|
@@ -939,33 +1026,14 @@ class TopDownLayer(nn.Module):
|
|
|
939
1026
|
If True, KL divergence is calculated according to the analytical formula.
|
|
940
1027
|
Otherwise, an MC approximation using sampled latents is calculated.
|
|
941
1028
|
Default is `False`.
|
|
942
|
-
bottomup_no_padding_mode: bool, optional
|
|
943
|
-
Whether padding is used in the different layers of the bottom-up pass.
|
|
944
|
-
It is meaningful to know this in advance in order to assess whether before
|
|
945
|
-
merging `bu_values` and `p_params` tensors any alignment is needed.
|
|
946
|
-
Default is `False`.
|
|
947
|
-
topdown_no_padding_mode: bool, optional
|
|
948
|
-
Whether padding is used in the different layers of the top-down pass.
|
|
949
|
-
It is meaningful to know this in advance in order to assess whether before
|
|
950
|
-
merging `bu_values` and `p_params` tensors any alignment is needed.
|
|
951
|
-
The same information is also needed in handling the skip connections between
|
|
952
|
-
top-down layers. Default is `False`.
|
|
953
1029
|
retain_spatial_dims: bool, optional
|
|
954
1030
|
If `True`, the size of Encoder's latent space is kept to `input_image_shape` within the topdown layer.
|
|
955
1031
|
This implies that the oput spatial size equals the input spatial size.
|
|
956
1032
|
To achieve this, we centercrop the intermediate representation.
|
|
957
1033
|
Default is `False`.
|
|
958
|
-
restricted_kl: bool, optional
|
|
959
|
-
Whether to compute the restricted version of KL Divergence.
|
|
960
|
-
See `NormalStochasticBlock2d` module for more information about its computation.
|
|
961
|
-
Default is `False`.
|
|
962
1034
|
vanilla_latent_hw: Iterable[int], optional
|
|
963
1035
|
The shape of the latent tensor used for prediction (i.e., it influences the computation of restricted KL).
|
|
964
1036
|
Default is `None`.
|
|
965
|
-
non_stochastic_version: bool, optional
|
|
966
|
-
Whether to replace the stochastic layer that samples a latent variable from the latent distribiution with
|
|
967
|
-
a non-stochastic layer that simply drwas a sample as the mode of the latent distribution.
|
|
968
|
-
Default is `False`.
|
|
969
1037
|
input_image_shape: Tuple[int, int], optionalut
|
|
970
1038
|
The shape of the input image tensor.
|
|
971
1039
|
When `retain_spatial_dims` is set to `True`, this is used to ensure that the shape of this layer
|
|
@@ -990,13 +1058,13 @@ class TopDownLayer(nn.Module):
|
|
|
990
1058
|
self.stochastic_skip = stochastic_skip
|
|
991
1059
|
self.learn_top_prior = learn_top_prior
|
|
992
1060
|
self.analytical_kl = analytical_kl
|
|
993
|
-
self.bottomup_no_padding_mode = bottomup_no_padding_mode
|
|
994
|
-
self.topdown_no_padding_mode = topdown_no_padding_mode
|
|
995
1061
|
self.retain_spatial_dims = retain_spatial_dims
|
|
996
|
-
self.
|
|
997
|
-
|
|
1062
|
+
self.input_image_shape = (
|
|
1063
|
+
input_image_shape if len(conv_strides) == 3 else input_image_shape[1:]
|
|
1064
|
+
)
|
|
1065
|
+
self.latent_shape = self.input_image_shape if self.retain_spatial_dims else None
|
|
998
1066
|
self.normalize_latent_factor = normalize_latent_factor
|
|
999
|
-
self._vanilla_latent_hw = vanilla_latent_hw
|
|
1067
|
+
self._vanilla_latent_hw = vanilla_latent_hw # TODO: check this, it is not used
|
|
1000
1068
|
|
|
1001
1069
|
# Define top layer prior parameters, possibly learnable
|
|
1002
1070
|
if is_top_layer:
|
|
@@ -1004,28 +1072,28 @@ class TopDownLayer(nn.Module):
|
|
|
1004
1072
|
torch.zeros(top_prior_param_shape), requires_grad=learn_top_prior
|
|
1005
1073
|
)
|
|
1006
1074
|
|
|
1007
|
-
#
|
|
1008
|
-
|
|
1075
|
+
# Upsampling steps left to do in this layer
|
|
1076
|
+
ups_left = upsampling_steps
|
|
1009
1077
|
|
|
1010
1078
|
# Define deterministic top-down block, which is a sequence of deterministic
|
|
1011
|
-
# residual blocks with (optional)
|
|
1079
|
+
# residual blocks with (optional) upsampling.
|
|
1012
1080
|
block_list = []
|
|
1013
1081
|
for _ in range(n_res_blocks):
|
|
1014
1082
|
do_resample = False
|
|
1015
|
-
if
|
|
1083
|
+
if ups_left > 0:
|
|
1016
1084
|
do_resample = True
|
|
1017
|
-
|
|
1085
|
+
ups_left -= 1
|
|
1018
1086
|
block_list.append(
|
|
1019
1087
|
TopDownDeterministicResBlock(
|
|
1020
1088
|
c_in=n_filters,
|
|
1021
1089
|
c_out=n_filters,
|
|
1090
|
+
conv_strides=conv_strides,
|
|
1022
1091
|
nonlin=nonlin,
|
|
1023
1092
|
upsample=do_resample,
|
|
1024
1093
|
batchnorm=batchnorm,
|
|
1025
1094
|
dropout=dropout,
|
|
1026
1095
|
res_block_type=res_block_type,
|
|
1027
1096
|
res_block_kernel=res_block_kernel,
|
|
1028
|
-
skip_padding=res_block_skip_padding,
|
|
1029
1097
|
gated=gated,
|
|
1030
1098
|
conv2d_bias=conv2d_bias,
|
|
1031
1099
|
groups=groups,
|
|
@@ -1033,32 +1101,24 @@ class TopDownLayer(nn.Module):
|
|
|
1033
1101
|
)
|
|
1034
1102
|
self.deterministic_block = nn.Sequential(*block_list)
|
|
1035
1103
|
|
|
1036
|
-
# Define stochastic block with
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
|
|
1046
|
-
|
|
1047
|
-
self.stochastic = NormalStochasticBlock2d(
|
|
1048
|
-
c_in=n_filters,
|
|
1049
|
-
c_vars=z_dim,
|
|
1050
|
-
c_out=n_filters,
|
|
1051
|
-
transform_p_params=(not is_top_layer),
|
|
1052
|
-
vanilla_latent_hw=vanilla_latent_hw,
|
|
1053
|
-
restricted_kl=restricted_kl,
|
|
1054
|
-
use_naive_exponential=stochastic_use_naive_exponential,
|
|
1055
|
-
)
|
|
1104
|
+
# Define stochastic block with convolutions
|
|
1105
|
+
|
|
1106
|
+
self.stochastic = NormalStochasticBlock(
|
|
1107
|
+
c_in=n_filters,
|
|
1108
|
+
c_vars=z_dim,
|
|
1109
|
+
c_out=n_filters,
|
|
1110
|
+
conv_dims=len(conv_strides),
|
|
1111
|
+
transform_p_params=(not is_top_layer),
|
|
1112
|
+
vanilla_latent_hw=vanilla_latent_hw,
|
|
1113
|
+
use_naive_exponential=stochastic_use_naive_exponential,
|
|
1114
|
+
)
|
|
1056
1115
|
|
|
1057
1116
|
if not is_top_layer:
|
|
1058
1117
|
# Merge layer: it combines bottom-up inference and top-down
|
|
1059
1118
|
# generative outcomes to give posterior parameters
|
|
1060
1119
|
self.merge = MergeLayer(
|
|
1061
1120
|
channels=n_filters,
|
|
1121
|
+
conv_strides=conv_strides,
|
|
1062
1122
|
merge_type=merge_type,
|
|
1063
1123
|
nonlin=nonlin,
|
|
1064
1124
|
batchnorm=batchnorm,
|
|
@@ -1072,6 +1132,7 @@ class TopDownLayer(nn.Module):
|
|
|
1072
1132
|
if stochastic_skip:
|
|
1073
1133
|
self.skip_connection_merger = SkipConnectionMerger(
|
|
1074
1134
|
channels=n_filters,
|
|
1135
|
+
conv_strides=conv_strides,
|
|
1075
1136
|
nonlin=nonlin,
|
|
1076
1137
|
batchnorm=batchnorm,
|
|
1077
1138
|
dropout=dropout,
|
|
@@ -1079,28 +1140,27 @@ class TopDownLayer(nn.Module):
|
|
|
1079
1140
|
merge_type=merge_type,
|
|
1080
1141
|
conv2d_bias=conv2d_bias,
|
|
1081
1142
|
res_block_kernel=res_block_kernel,
|
|
1082
|
-
res_block_skip_padding=res_block_skip_padding,
|
|
1083
1143
|
)
|
|
1084
1144
|
|
|
1085
|
-
# print(f'[{self.__class__.__name__}] normalize_latent_factor:{self.normalize_latent_factor}')
|
|
1086
|
-
|
|
1087
1145
|
def sample_from_q(
|
|
1088
1146
|
self,
|
|
1089
1147
|
input_: torch.Tensor,
|
|
1090
1148
|
bu_value: torch.Tensor,
|
|
1091
|
-
var_clip_max: float = None,
|
|
1149
|
+
var_clip_max: Optional[float] = None,
|
|
1092
1150
|
mask: torch.Tensor = None,
|
|
1093
1151
|
) -> torch.Tensor:
|
|
1094
1152
|
"""
|
|
1095
|
-
|
|
1153
|
+
Method computes the latent inference distribution q(z_i|z_{i+1}).
|
|
1154
|
+
|
|
1155
|
+
Used for sampling a latent tensor from it.
|
|
1096
1156
|
|
|
1097
1157
|
Parameters
|
|
1098
1158
|
----------
|
|
1099
1159
|
input_: torch.Tensor
|
|
1100
|
-
The input tensor to the layer, which is the output of the top-down layer
|
|
1160
|
+
The input tensor to the layer, which is the output of the top-down layer.
|
|
1101
1161
|
bu_value: torch.Tensor
|
|
1102
|
-
The tensor defining the parameters /mu_q and /sigma_q computed during the
|
|
1103
|
-
at the correspondent hierarchical layer.
|
|
1162
|
+
The tensor defining the parameters /mu_q and /sigma_q computed during the
|
|
1163
|
+
bottom-up deterministic pass at the correspondent hierarchical layer.
|
|
1104
1164
|
var_clip_max: float, optional
|
|
1105
1165
|
The maximum value reachable by the log-variance of the latent distribution.
|
|
1106
1166
|
Values exceeding this threshold are clipped. Default is `None`.
|
|
@@ -1127,9 +1187,11 @@ class TopDownLayer(nn.Module):
|
|
|
1127
1187
|
input_: torch.Tensor,
|
|
1128
1188
|
n_img_prior: int,
|
|
1129
1189
|
) -> torch.Tensor:
|
|
1130
|
-
"""
|
|
1131
|
-
|
|
1132
|
-
|
|
1190
|
+
"""Return the parameters of the prior distribution p(z_i|z_{i+1}).
|
|
1191
|
+
|
|
1192
|
+
The parameters depend on the hierarchical level of the layer:
|
|
1193
|
+
- if it is the topmost level, parameters are the ones of the prior.
|
|
1194
|
+
- else, the input from the layer above is the parameters itself.
|
|
1133
1195
|
|
|
1134
1196
|
Parameters
|
|
1135
1197
|
----------
|
|
@@ -1154,81 +1216,56 @@ class TopDownLayer(nn.Module):
|
|
|
1154
1216
|
|
|
1155
1217
|
return p_params
|
|
1156
1218
|
|
|
1157
|
-
def align_pparams_buvalue(
|
|
1158
|
-
self, p_params: torch.Tensor, bu_value: torch.Tensor
|
|
1159
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
1160
|
-
"""
|
|
1161
|
-
In case the padding is not used either (or both) in encoder and decoder, we could have a shape mismatch
|
|
1162
|
-
in the spatial dimensions (usually, dim=2 & dim=3).
|
|
1163
|
-
This method performs a centercrop to ensure that both remain aligned.
|
|
1164
|
-
|
|
1165
|
-
Parameters
|
|
1166
|
-
----------
|
|
1167
|
-
p_params: torch.Tensor
|
|
1168
|
-
The tensor defining the parameters /mu_p and /sigma_p for the latent distribution p(z_i|z_{i+1}).
|
|
1169
|
-
bu_value: torch.Tensor
|
|
1170
|
-
The tensor defining the parameters /mu_q and /sigma_q computed during the bottom-up deterministic pass
|
|
1171
|
-
at the correspondent hierarchical layer.
|
|
1172
|
-
"""
|
|
1173
|
-
if bu_value.shape[-2:] != p_params.shape[-2:]:
|
|
1174
|
-
assert self.bottomup_no_padding_mode is True # TODO WTF ?
|
|
1175
|
-
if self.topdown_no_padding_mode is False:
|
|
1176
|
-
assert bu_value.shape[-1] > p_params.shape[-1]
|
|
1177
|
-
bu_value = F.center_crop(bu_value, p_params.shape[-2:])
|
|
1178
|
-
else:
|
|
1179
|
-
if bu_value.shape[-1] > p_params.shape[-1]:
|
|
1180
|
-
bu_value = F.center_crop(bu_value, p_params.shape[-2:])
|
|
1181
|
-
else:
|
|
1182
|
-
p_params = F.center_crop(p_params, bu_value.shape[-2:])
|
|
1183
|
-
return p_params, bu_value
|
|
1184
|
-
|
|
1185
1219
|
def forward(
|
|
1186
1220
|
self,
|
|
1187
|
-
input_: torch.Tensor = None,
|
|
1188
|
-
skip_connection_input: torch.Tensor = None,
|
|
1221
|
+
input_: Union[torch.Tensor, None] = None,
|
|
1222
|
+
skip_connection_input: Union[torch.Tensor, None] = None,
|
|
1189
1223
|
inference_mode: bool = False,
|
|
1190
|
-
bu_value: torch.Tensor = None,
|
|
1191
|
-
n_img_prior: int = None,
|
|
1192
|
-
forced_latent: torch.Tensor = None,
|
|
1193
|
-
use_mode: bool = False,
|
|
1224
|
+
bu_value: Union[torch.Tensor, None] = None,
|
|
1225
|
+
n_img_prior: Union[int, None] = None,
|
|
1226
|
+
forced_latent: Union[torch.Tensor, None] = None,
|
|
1194
1227
|
force_constant_output: bool = False,
|
|
1195
1228
|
mode_pred: bool = False,
|
|
1196
1229
|
use_uncond_mode: bool = False,
|
|
1197
|
-
var_clip_max: float = None,
|
|
1198
|
-
) ->
|
|
1199
|
-
"""
|
|
1230
|
+
var_clip_max: Union[float, None] = None,
|
|
1231
|
+
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
|
1232
|
+
"""Forward pass.
|
|
1233
|
+
|
|
1200
1234
|
Parameters
|
|
1201
1235
|
----------
|
|
1202
1236
|
input_: torch.Tensor, optional
|
|
1203
|
-
The input tensor to the layer, which is the output of the top-down layer
|
|
1237
|
+
The input tensor to the layer, which is the output of the top-down layer.
|
|
1204
1238
|
Default is `None`.
|
|
1205
1239
|
skip_connection_input: torch.Tensor, optional
|
|
1206
|
-
The tensor brought by the skip connection between the current and the
|
|
1240
|
+
The tensor brought by the skip connection between the current and the
|
|
1241
|
+
previous top-down layer.
|
|
1207
1242
|
Default is `None`.
|
|
1208
1243
|
inference_mode: bool, optional
|
|
1209
|
-
Whether the layer is in inference mode. See NOTE 2 in class description
|
|
1244
|
+
Whether the layer is in inference mode. See NOTE 2 in class description
|
|
1245
|
+
for more info.
|
|
1210
1246
|
Default is `False`.
|
|
1211
1247
|
bu_value: torch.Tensor, optional
|
|
1212
|
-
The tensor defining the parameters /mu_q and /sigma_q computed during the
|
|
1248
|
+
The tensor defining the parameters /mu_q and /sigma_q computed during the
|
|
1249
|
+
bottom-up deterministic pass
|
|
1213
1250
|
at the correspondent hierarchical layer. Default is `None`.
|
|
1214
1251
|
n_img_prior: int, optional
|
|
1215
|
-
The number of images to be generated from the unconditional prior
|
|
1252
|
+
The number of images to be generated from the unconditional prior
|
|
1253
|
+
distribution p(z_L).
|
|
1216
1254
|
Default is `None`.
|
|
1217
1255
|
forced_latent: torch.Tensor, optional
|
|
1218
|
-
A pre-defined latent tensor. If it is not `None`, than it is used as the
|
|
1256
|
+
A pre-defined latent tensor. If it is not `None`, than it is used as the
|
|
1257
|
+
actual latent tensor and,
|
|
1219
1258
|
hence, sampling does not happen. Default is `None`.
|
|
1220
|
-
use_mode: bool, optional
|
|
1221
|
-
Whether the latent tensor should be set as the latent distribution mode.
|
|
1222
|
-
In the case of Gaussian, the mode coincides with the mean of the distribution.
|
|
1223
|
-
Default is `False`.
|
|
1224
1259
|
force_constant_output: bool, optional
|
|
1225
|
-
Whether to copy the first sample (and rel. distrib parameters) over the
|
|
1260
|
+
Whether to copy the first sample (and rel. distrib parameters) over the
|
|
1261
|
+
whole batch.
|
|
1226
1262
|
This is used when doing experiment from the prior - q is not used.
|
|
1227
1263
|
Default is `False`.
|
|
1228
1264
|
mode_pred: bool, optional
|
|
1229
1265
|
Whether the model is in prediction mode. Default is `False`.
|
|
1230
1266
|
use_uncond_mode: bool, optional
|
|
1231
|
-
Whether to use the uncoditional distribution p(z) to sample latents in
|
|
1267
|
+
Whether to use the uncoditional distribution p(z) to sample latents in
|
|
1268
|
+
prediction mode.
|
|
1232
1269
|
var_clip_max: float
|
|
1233
1270
|
The maximum value reachable by the log-variance of the latent distribution.
|
|
1234
1271
|
Values exceeding this threshold are clipped.
|
|
@@ -1241,26 +1278,33 @@ class TopDownLayer(nn.Module):
|
|
|
1241
1278
|
p_params = self.get_p_params(input_, n_img_prior)
|
|
1242
1279
|
|
|
1243
1280
|
# Get the parameters for the latent distribution to sample from
|
|
1244
|
-
if inference_mode: # TODO What's this ?
|
|
1281
|
+
if inference_mode: # TODO What's this ? reuse Fede's code?
|
|
1245
1282
|
if self.is_top_layer:
|
|
1246
1283
|
q_params = bu_value
|
|
1247
1284
|
if mode_pred is False:
|
|
1248
|
-
p_params
|
|
1285
|
+
assert p_params.shape[2:] == bu_value.shape[2:], (
|
|
1286
|
+
"Spatial dimensions of p_params and bu_value should match. "
|
|
1287
|
+
f"Instead, we got p_params={p_params.shape[2:]} and "
|
|
1288
|
+
f"bu_value={bu_value.shape[2:]}."
|
|
1289
|
+
)
|
|
1249
1290
|
else:
|
|
1250
1291
|
if use_uncond_mode:
|
|
1251
1292
|
q_params = p_params
|
|
1252
1293
|
else:
|
|
1253
|
-
p_params
|
|
1294
|
+
assert p_params.shape[2:] == bu_value.shape[2:], (
|
|
1295
|
+
"Spatial dimensions of p_params and bu_value should match. "
|
|
1296
|
+
f"Instead, we got p_params={p_params.shape[2:]} and "
|
|
1297
|
+
f"bu_value={bu_value.shape[2:]}."
|
|
1298
|
+
)
|
|
1254
1299
|
q_params = self.merge(bu_value, p_params)
|
|
1255
|
-
#
|
|
1256
|
-
else:
|
|
1300
|
+
else: # generative mode, q is not used, we sample from p(z_i | z_{i+1})
|
|
1257
1301
|
q_params = None
|
|
1258
1302
|
|
|
1259
1303
|
# NOTE: Sampling is done either from q(z_i | z_{i+1}, x) or p(z_i | z_{i+1})
|
|
1260
1304
|
# depending on the mode (hence, in practice, by checking whether q_params is None).
|
|
1261
1305
|
|
|
1262
|
-
# Normalization of latent space parameters
|
|
1263
|
-
#
|
|
1306
|
+
# Normalization of latent space parameters for stablity.
|
|
1307
|
+
# See Very deep VAEs generalize autoregressive models.
|
|
1264
1308
|
if self.normalize_latent_factor:
|
|
1265
1309
|
q_params = q_params / self.normalize_latent_factor
|
|
1266
1310
|
|
|
@@ -1269,52 +1313,44 @@ class TopDownLayer(nn.Module):
|
|
|
1269
1313
|
p_params=p_params,
|
|
1270
1314
|
q_params=q_params,
|
|
1271
1315
|
forced_latent=forced_latent,
|
|
1272
|
-
use_mode=use_mode,
|
|
1273
1316
|
force_constant_output=force_constant_output,
|
|
1274
1317
|
analytical_kl=self.analytical_kl,
|
|
1275
1318
|
mode_pred=mode_pred,
|
|
1276
1319
|
use_uncond_mode=use_uncond_mode,
|
|
1277
1320
|
var_clip_max=var_clip_max,
|
|
1278
1321
|
)
|
|
1279
|
-
|
|
1280
1322
|
# Merge skip connection from previous layer
|
|
1281
1323
|
if self.stochastic_skip and not self.is_top_layer:
|
|
1282
|
-
if self.topdown_no_padding_mode is True:
|
|
1283
|
-
# If no padding is done in the current top-down pass, there may be a shape mismatch between current tensor and skip connection input.
|
|
1284
|
-
# As an example, if the output of last TopDownLayer was of size 64*64, due to lack of padding in the current layer, the current tensor
|
|
1285
|
-
# might become different in shape, say 60*60.
|
|
1286
|
-
# In order to avoid shape mismatch, we do central crop of the skip connection input.
|
|
1287
|
-
skip_connection_input = F.center_crop(
|
|
1288
|
-
skip_connection_input, x.shape[-2:]
|
|
1289
|
-
)
|
|
1290
|
-
|
|
1291
1324
|
x = self.skip_connection_merger(x, skip_connection_input)
|
|
1292
|
-
|
|
1293
|
-
# Save activation before residual block as it can be the skip connection input in the next layer
|
|
1294
|
-
x_pre_residual = x
|
|
1295
|
-
|
|
1296
1325
|
if self.retain_spatial_dims:
|
|
1297
|
-
#
|
|
1298
|
-
|
|
1299
|
-
|
|
1300
|
-
#
|
|
1301
|
-
#
|
|
1302
|
-
#
|
|
1303
|
-
|
|
1304
|
-
|
|
1305
|
-
|
|
1306
|
-
|
|
1307
|
-
|
|
1308
|
-
#
|
|
1326
|
+
# NOTE: we assume that one topdown layer will have exactly one upscaling layer.
|
|
1327
|
+
|
|
1328
|
+
# NOTE: in case, in the Bottom-Up layer, LC retains spatial dimensions,
|
|
1329
|
+
# we have the following (see `MergeLowRes`):
|
|
1330
|
+
# - the "primary-flow" tensor is padded to match the low-res patch size
|
|
1331
|
+
# (e.g., from 32x32 to 64x64)
|
|
1332
|
+
# - padded tensor is then merged with the low-res patch (concatenation
|
|
1333
|
+
# along dim=1 + convolution)
|
|
1334
|
+
# Therefore, we need to do the symmetric operation here, that is to
|
|
1335
|
+
# crop `x` for the same amount we padded it in the correspondent BU layer.
|
|
1336
|
+
|
|
1337
|
+
# NOTE: cropping is done to retain the shape of the input in the output.
|
|
1338
|
+
# Therefore we need it only in the case `x` is the same shape of the input,
|
|
1339
|
+
# because that's the only case in which we need to retain the shape.
|
|
1340
|
+
# Here, it must be strictly greater than half the input shape, which is
|
|
1341
|
+
# the case if and only if `x.shape == self.latent_shape`.
|
|
1342
|
+
rescale = (
|
|
1343
|
+
np.array((1, 2, 2)) if len(self.latent_shape) == 3 else np.array((2, 2))
|
|
1344
|
+
) # TODO better way?
|
|
1345
|
+
new_latent_shape = tuple(np.array(self.latent_shape) // rescale)
|
|
1309
1346
|
if x.shape[-1] > new_latent_shape[-1]:
|
|
1310
|
-
x =
|
|
1311
|
-
|
|
1312
|
-
#
|
|
1347
|
+
x = crop_img_tensor(x, new_latent_shape)
|
|
1348
|
+
# TODO: `retain_spatial_dims` is the same for all the TD layers.
|
|
1349
|
+
# How to handle the case in which we do not have LC for all layers?
|
|
1350
|
+
# The answer is in `self.latent_shape`, which is equal to `input_image_shape`
|
|
1351
|
+
# (e.g., (64, 64)) if `retain_spatial_dims` is `True`, else it is `None`.
|
|
1352
|
+
# Last top-down block (sequence of residual blocks w\ upsampling)
|
|
1313
1353
|
x = self.deterministic_block(x)
|
|
1314
|
-
|
|
1315
|
-
if self.topdown_no_padding_mode:
|
|
1316
|
-
x = F.center_crop(x, self.latent_shape)
|
|
1317
|
-
|
|
1318
1354
|
# Save some metrics that will be used in the loss computation
|
|
1319
1355
|
keys = [
|
|
1320
1356
|
"z",
|
|
@@ -1322,7 +1358,6 @@ class TopDownLayer(nn.Module):
|
|
|
1322
1358
|
"kl_samplewise_restricted",
|
|
1323
1359
|
"kl_spatial",
|
|
1324
1360
|
"kl_channelwise",
|
|
1325
|
-
# 'logprob_p',
|
|
1326
1361
|
"logprob_q",
|
|
1327
1362
|
"qvar_max",
|
|
1328
1363
|
]
|
|
@@ -1333,666 +1368,4 @@ class TopDownLayer(nn.Module):
|
|
|
1333
1368
|
q_mu, q_lv = data_stoch["q_params"]
|
|
1334
1369
|
data["q_mu"] = q_mu
|
|
1335
1370
|
data["q_lv"] = q_lv
|
|
1336
|
-
|
|
1337
|
-
return x, x_pre_residual, data
|
|
1338
|
-
|
|
1339
|
-
|
|
1340
|
-
class NormalStochasticBlock2d(nn.Module):
|
|
1341
|
-
"""
|
|
1342
|
-
Stochastic block used in the Top-Down inference pass.
|
|
1343
|
-
|
|
1344
|
-
Algorithm:
|
|
1345
|
-
- map input parameters to q(z) and (optionally) p(z) via convolution
|
|
1346
|
-
- sample a latent tensor z ~ q(z)
|
|
1347
|
-
- feed z to convolution and return.
|
|
1348
|
-
|
|
1349
|
-
NOTE 1:
|
|
1350
|
-
If parameters for q are not given, sampling is done from p(z).
|
|
1351
|
-
|
|
1352
|
-
NOTE 2:
|
|
1353
|
-
The restricted KL divergence is obtained by first computing the element-wise KL divergence
|
|
1354
|
-
(i.e., the KL computed for each element of the latent tensors). Then, the restricted version
|
|
1355
|
-
is computed by summing over the channels and the spatial dimensions associated only to the
|
|
1356
|
-
portion of the latent tensor that is used for prediction.
|
|
1357
|
-
"""
|
|
1358
|
-
|
|
1359
|
-
def __init__(
|
|
1360
|
-
self,
|
|
1361
|
-
c_in: int,
|
|
1362
|
-
c_vars: int,
|
|
1363
|
-
c_out: int,
|
|
1364
|
-
kernel: int = 3,
|
|
1365
|
-
transform_p_params: bool = True,
|
|
1366
|
-
vanilla_latent_hw: int = None,
|
|
1367
|
-
restricted_kl: bool = False,
|
|
1368
|
-
use_naive_exponential: bool = False,
|
|
1369
|
-
):
|
|
1370
|
-
"""
|
|
1371
|
-
Parameters
|
|
1372
|
-
----------
|
|
1373
|
-
c_in: int
|
|
1374
|
-
The number of channels of the input tensor.
|
|
1375
|
-
c_vars: int
|
|
1376
|
-
The number of channels of the latent space tensor.
|
|
1377
|
-
c_out: int
|
|
1378
|
-
The output of the stochastic layer.
|
|
1379
|
-
Note that this is different from the sampled latent z.
|
|
1380
|
-
kernel: int, optional
|
|
1381
|
-
The size of the kernel used in convolutional layers.
|
|
1382
|
-
Default is 3.
|
|
1383
|
-
transform_p_params: bool, optional
|
|
1384
|
-
Whether a transformation should be applied to the `p_params` tensor.
|
|
1385
|
-
The transformation consists in a 2D convolution ()`conv_in_p()`) that
|
|
1386
|
-
maps the input to a larger number of channels.
|
|
1387
|
-
Default is `True`.
|
|
1388
|
-
vanilla_latent_hw: int, optional
|
|
1389
|
-
The shape of the latent tensor used for prediction (i.e., it influences the computation of restricted KL).
|
|
1390
|
-
Default is `None`.
|
|
1391
|
-
restricted_kl: bool, optional
|
|
1392
|
-
Whether to compute the restricted version of KL Divergence.
|
|
1393
|
-
See NOTE 2 for more information about its computation.
|
|
1394
|
-
Default is `False`.
|
|
1395
|
-
use_naive_exponential: bool, optional
|
|
1396
|
-
If `False`, exponentials are computed according to the alternative definition
|
|
1397
|
-
provided by `StableExponential` class. This should improve numerical stability
|
|
1398
|
-
in the training process. Default is `False`.
|
|
1399
|
-
"""
|
|
1400
|
-
super().__init__()
|
|
1401
|
-
assert kernel % 2 == 1
|
|
1402
|
-
pad = kernel // 2
|
|
1403
|
-
self.transform_p_params = transform_p_params
|
|
1404
|
-
self.c_in = c_in
|
|
1405
|
-
self.c_out = c_out
|
|
1406
|
-
self.c_vars = c_vars
|
|
1407
|
-
self._use_naive_exponential = use_naive_exponential
|
|
1408
|
-
self._vanilla_latent_hw = vanilla_latent_hw
|
|
1409
|
-
self._restricted_kl = restricted_kl
|
|
1410
|
-
|
|
1411
|
-
if transform_p_params:
|
|
1412
|
-
self.conv_in_p = nn.Conv2d(c_in, 2 * c_vars, kernel, padding=pad)
|
|
1413
|
-
self.conv_in_q = nn.Conv2d(c_in, 2 * c_vars, kernel, padding=pad)
|
|
1414
|
-
self.conv_out = nn.Conv2d(c_vars, c_out, kernel, padding=pad)
|
|
1415
|
-
|
|
1416
|
-
# def forward_swapped(self, p_params, q_mu, q_lv):
|
|
1417
|
-
#
|
|
1418
|
-
# if self.transform_p_params:
|
|
1419
|
-
# p_params = self.conv_in_p(p_params)
|
|
1420
|
-
# else:
|
|
1421
|
-
# assert p_params.size(1) == 2 * self.c_vars
|
|
1422
|
-
#
|
|
1423
|
-
# # Define p(z)
|
|
1424
|
-
# p_mu, p_lv = p_params.chunk(2, dim=1)
|
|
1425
|
-
# p = Normal(p_mu, (p_lv / 2).exp())
|
|
1426
|
-
#
|
|
1427
|
-
# # Define q(z)
|
|
1428
|
-
# q = Normal(q_mu, (q_lv / 2).exp())
|
|
1429
|
-
# # Sample from q(z)
|
|
1430
|
-
# sampling_distrib = q
|
|
1431
|
-
#
|
|
1432
|
-
# # Generate latent variable (typically by sampling)
|
|
1433
|
-
# z = sampling_distrib.rsample()
|
|
1434
|
-
#
|
|
1435
|
-
# # Output of stochastic layer
|
|
1436
|
-
# out = self.conv_out(z)
|
|
1437
|
-
#
|
|
1438
|
-
# data = {
|
|
1439
|
-
# 'z': z, # sampled variable at this layer (batch, ch, h, w)
|
|
1440
|
-
# 'p_params': p_params, # (b, ch, h, w) where b is 1 or batch size
|
|
1441
|
-
# }
|
|
1442
|
-
# return out, data
|
|
1443
|
-
|
|
1444
|
-
def get_z(
|
|
1445
|
-
self,
|
|
1446
|
-
sampling_distrib: torch.distributions.normal.Normal,
|
|
1447
|
-
forced_latent: torch.Tensor,
|
|
1448
|
-
use_mode: bool,
|
|
1449
|
-
mode_pred: bool,
|
|
1450
|
-
use_uncond_mode: bool,
|
|
1451
|
-
) -> torch.Tensor:
|
|
1452
|
-
"""
|
|
1453
|
-
This method enables to sample a latent tensor given the distribution to sample from.
|
|
1454
|
-
|
|
1455
|
-
Latent variable can be obtained is several ways:
|
|
1456
|
-
- Sampled from the (Gaussian) latent distribution.
|
|
1457
|
-
- Taken as a pre-defined forced latent.
|
|
1458
|
-
- Taken as the mode (mean) of the latent distribution.
|
|
1459
|
-
- In prediction mode (`mode_pred==True`), can be either sample or taken as the distribution mode.
|
|
1460
|
-
|
|
1461
|
-
Parameters
|
|
1462
|
-
----------
|
|
1463
|
-
sampling_distrib: torch.distributions.normal.Normal
|
|
1464
|
-
The Gaussian distribution from which latent tensor is sampled.
|
|
1465
|
-
forced_latent: torch.Tensor
|
|
1466
|
-
A pre-defined latent tensor. If it is not `None`, than it is used as the actual latent tensor and,
|
|
1467
|
-
hence, sampling does not happen.
|
|
1468
|
-
use_mode: bool
|
|
1469
|
-
Whether the latent tensor should be set as the latent distribution mode.
|
|
1470
|
-
In the case of Gaussian, the mode coincides with the mean of the distribution.
|
|
1471
|
-
mode_pred: bool
|
|
1472
|
-
Whether the model is prediction mode.
|
|
1473
|
-
use_uncond_mode: bool
|
|
1474
|
-
Whether to use the uncoditional distribution p(z) to sample latents in prediction mode.
|
|
1475
|
-
"""
|
|
1476
|
-
if forced_latent is None:
|
|
1477
|
-
if use_mode:
|
|
1478
|
-
z = sampling_distrib.mean
|
|
1479
|
-
else:
|
|
1480
|
-
if mode_pred:
|
|
1481
|
-
if use_uncond_mode:
|
|
1482
|
-
z = sampling_distrib.mean
|
|
1483
|
-
else:
|
|
1484
|
-
z = sampling_distrib.rsample()
|
|
1485
|
-
else:
|
|
1486
|
-
z = sampling_distrib.rsample()
|
|
1487
|
-
else:
|
|
1488
|
-
z = forced_latent
|
|
1489
|
-
return z
|
|
1490
|
-
|
|
1491
|
-
def sample_from_q(
|
|
1492
|
-
self, q_params: torch.Tensor, var_clip_max: float
|
|
1493
|
-
) -> torch.Tensor:
|
|
1494
|
-
"""
|
|
1495
|
-
Given an input parameter tensor defining q(z),
|
|
1496
|
-
it processes it by calling `process_q_params()` method and
|
|
1497
|
-
sample a latent tensor from the resulting distribution.
|
|
1498
|
-
|
|
1499
|
-
Parameters
|
|
1500
|
-
----------
|
|
1501
|
-
q_params: torch.Tensor
|
|
1502
|
-
The input tensor to be processed.
|
|
1503
|
-
var_clip_max: float
|
|
1504
|
-
The maximum value reachable by the log-variance of the latent distribution.
|
|
1505
|
-
Values exceeding this threshold are clipped.
|
|
1506
|
-
"""
|
|
1507
|
-
_, _, q = self.process_q_params(q_params, var_clip_max)
|
|
1508
|
-
return q.rsample()
|
|
1509
|
-
|
|
1510
|
-
def compute_kl_metrics(
|
|
1511
|
-
self,
|
|
1512
|
-
p: torch.distributions.normal.Normal,
|
|
1513
|
-
p_params: torch.Tensor,
|
|
1514
|
-
q: torch.distributions.normal.Normal,
|
|
1515
|
-
q_params: torch.Tensor,
|
|
1516
|
-
mode_pred: bool,
|
|
1517
|
-
analytical_kl: bool,
|
|
1518
|
-
z: torch.Tensor,
|
|
1519
|
-
) -> Dict[str, torch.Tensor]:
|
|
1520
|
-
"""
|
|
1521
|
-
Compute KL (analytical or MC estimate) and then process it, extracting composed versions of the metric.
|
|
1522
|
-
Specifically, the different versions of the KL loss terms are:
|
|
1523
|
-
- `kl_elementwise`: KL term for each single element of the latent tensor [Shape: (batch, ch, h, w)].
|
|
1524
|
-
- `kl_samplewise`: KL term associated to each sample in the batch [Shape: (batch, )].
|
|
1525
|
-
- `kl_samplewise_restricted`: KL term only associated to the portion of the latent tensor that is
|
|
1526
|
-
used for prediction and summed over channel and spatial dimensions [Shape: (batch, )].
|
|
1527
|
-
- `kl_channelwise`: KL term associated to each sample and each channel [Shape: (batch, ch, )].
|
|
1528
|
-
- `kl_spatial`: KL term summed over the channels, i.e., retaining the spatial dimensions [Shape: (batch, h, w)]
|
|
1529
|
-
|
|
1530
|
-
Parameters
|
|
1531
|
-
----------
|
|
1532
|
-
p: torch.distributions.normal.Normal
|
|
1533
|
-
The prior generative distribution p(z_i|z_{i+1}) (or p(z_L)).
|
|
1534
|
-
p_params: torch.Tensor
|
|
1535
|
-
The parameters of the prior generative distribution.
|
|
1536
|
-
q: torch.distributions.normal.Normal
|
|
1537
|
-
The inference distribution q(z_i|z_{i+1}) (or q(z_L|x)).
|
|
1538
|
-
q_params: torch.Tensor
|
|
1539
|
-
The parameters of the inference distribution.
|
|
1540
|
-
mode_pred: bool
|
|
1541
|
-
Whether the model is in prediction mode.
|
|
1542
|
-
analytical_kl: bool
|
|
1543
|
-
Whether to compute the KL divergence analytically or using Monte Carlo estimation.
|
|
1544
|
-
z: torch.Tensor
|
|
1545
|
-
The sampled latent tensor.
|
|
1546
|
-
"""
|
|
1547
|
-
kl_samplewise_restricted = None
|
|
1548
|
-
|
|
1549
|
-
if mode_pred is False: # if not in prediction mode
|
|
1550
|
-
# KL term for each single element of the latent tensor [Shape: (batch, ch, h, w)]
|
|
1551
|
-
if analytical_kl:
|
|
1552
|
-
kl_elementwise = kl_divergence(q, p)
|
|
1553
|
-
else:
|
|
1554
|
-
kl_elementwise = kl_normal_mc(z, p_params, q_params)
|
|
1555
|
-
|
|
1556
|
-
# KL term only associated to the portion of the latent tensor that is used for prediction and
|
|
1557
|
-
# summed over channel and spatial dimensions. [Shape: (batch, )]
|
|
1558
|
-
# NOTE: vanilla_latent_hw is the shape of the latent tensor used for prediction, hence
|
|
1559
|
-
# the restriction has shape [Shape: (batch, ch, vanilla_latent_hw[0], vanilla_latent_hw[1])]
|
|
1560
|
-
if self._restricted_kl:
|
|
1561
|
-
pad = (kl_elementwise.shape[-1] - self._vanilla_latent_hw) // 2
|
|
1562
|
-
assert pad > 0, "Disable restricted kl since there is no restriction."
|
|
1563
|
-
tmp = kl_elementwise[..., pad:-pad, pad:-pad]
|
|
1564
|
-
kl_samplewise_restricted = tmp.sum((1, 2, 3))
|
|
1565
|
-
|
|
1566
|
-
# KL term associated to each sample in the batch [Shape: (batch, )]
|
|
1567
|
-
kl_samplewise = kl_elementwise.sum((1, 2, 3))
|
|
1568
|
-
|
|
1569
|
-
# KL term associated to each sample and each channel [Shape: (batch, ch, )]
|
|
1570
|
-
kl_channelwise = kl_elementwise.sum((2, 3))
|
|
1571
|
-
|
|
1572
|
-
# KL term summed over the channels, i.e., retaining the spatial dimensions [Shape: (batch, h, w)]
|
|
1573
|
-
kl_spatial = kl_elementwise.sum(1)
|
|
1574
|
-
else: # if predicting, no need to compute KL
|
|
1575
|
-
kl_elementwise = kl_samplewise = kl_spatial = kl_channelwise = None
|
|
1576
|
-
|
|
1577
|
-
kl_dict = {
|
|
1578
|
-
"kl_elementwise": kl_elementwise, # (batch, ch, h, w)
|
|
1579
|
-
"kl_samplewise": kl_samplewise, # (batch, )
|
|
1580
|
-
"kl_samplewise_restricted": kl_samplewise_restricted, # (batch, )
|
|
1581
|
-
"kl_channelwise": kl_channelwise, # (batch, ch)
|
|
1582
|
-
"kl_spatial": kl_spatial, # (batch, h, w)
|
|
1583
|
-
}
|
|
1584
|
-
return kl_dict
|
|
1585
|
-
|
|
1586
|
-
def process_p_params(
|
|
1587
|
-
self, p_params: torch.Tensor, var_clip_max: float
|
|
1588
|
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.distributions.normal.Normal]:
|
|
1589
|
-
"""
|
|
1590
|
-
Process the input parameters to get the prior distribution p(z_i|z_{i+1}) (or p(z_L)).
|
|
1591
|
-
|
|
1592
|
-
Processing consists in:
|
|
1593
|
-
- (optionally) 2D convolution on the input tensor to increase number of channels.
|
|
1594
|
-
- split the resulting tensor into two chunks, the mean and the log-variance.
|
|
1595
|
-
- (optionally) clip the log-variance to an upper threshold.
|
|
1596
|
-
- define the normal distribution p(z) given the parameter tensors above.
|
|
1597
|
-
|
|
1598
|
-
Parameters
|
|
1599
|
-
----------
|
|
1600
|
-
p_params: torch.Tensor
|
|
1601
|
-
The input tensor to be processed.
|
|
1602
|
-
var_clip_max: float
|
|
1603
|
-
The maximum value reachable by the log-variance of the latent distribution.
|
|
1604
|
-
Values exceeding this threshold are clipped.
|
|
1605
|
-
"""
|
|
1606
|
-
if self.transform_p_params:
|
|
1607
|
-
p_params = self.conv_in_p(p_params)
|
|
1608
|
-
else:
|
|
1609
|
-
assert p_params.size(1) == 2 * self.c_vars
|
|
1610
|
-
|
|
1611
|
-
# Define p(z)
|
|
1612
|
-
p_mu, p_lv = p_params.chunk(2, dim=1)
|
|
1613
|
-
if var_clip_max is not None:
|
|
1614
|
-
p_lv = torch.clip(p_lv, max=var_clip_max)
|
|
1615
|
-
|
|
1616
|
-
p_mu = StableMean(p_mu)
|
|
1617
|
-
p_lv = StableLogVar(p_lv, enable_stable=not self._use_naive_exponential)
|
|
1618
|
-
p = Normal(p_mu.get(), p_lv.get_std())
|
|
1619
|
-
return p_mu, p_lv, p
|
|
1620
|
-
|
|
1621
|
-
def process_q_params(
|
|
1622
|
-
self, q_params: torch.Tensor, var_clip_max: float, allow_oddsizes: bool = False
|
|
1623
|
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.distributions.normal.Normal]:
|
|
1624
|
-
"""
|
|
1625
|
-
Process the input parameters to get the inference distribution q(z_i|z_{i+1}) (or q(z|x)).
|
|
1626
|
-
|
|
1627
|
-
Processing consists in:
|
|
1628
|
-
- 2D convolution on the input tensor to increase number of channels.
|
|
1629
|
-
- split the resulting tensor into two chunks, the mean and the log-variance.
|
|
1630
|
-
- (optionally) clip the log-variance to an upper threshold.
|
|
1631
|
-
- (optionally) crop the resulting tensors to ensure that the last spatial dimension is even.
|
|
1632
|
-
- define the normal distribution q(z) given the parameter tensors above.
|
|
1633
|
-
|
|
1634
|
-
Parameters
|
|
1635
|
-
----------
|
|
1636
|
-
p_params: torch.Tensor
|
|
1637
|
-
The input tensor to be processed.
|
|
1638
|
-
var_clip_max: float
|
|
1639
|
-
The maximum value reachable by the log-variance of the latent distribution.
|
|
1640
|
-
Values exceeding this threshold are clipped.
|
|
1641
|
-
"""
|
|
1642
|
-
q_params = self.conv_in_q(q_params)
|
|
1643
|
-
|
|
1644
|
-
q_mu, q_lv = q_params.chunk(2, dim=1)
|
|
1645
|
-
if var_clip_max is not None:
|
|
1646
|
-
q_lv = torch.clip(q_lv, max=var_clip_max)
|
|
1647
|
-
|
|
1648
|
-
if q_mu.shape[-1] % 2 == 1 and allow_oddsizes is False:
|
|
1649
|
-
q_mu = F.center_crop(q_mu, q_mu.shape[-1] - 1)
|
|
1650
|
-
q_lv = F.center_crop(q_lv, q_lv.shape[-1] - 1)
|
|
1651
|
-
# clip_start = np.random.rand() > 0.5
|
|
1652
|
-
# q_mu = q_mu[:, :, 1:, 1:] if clip_start else q_mu[:, :, :-1, :-1]
|
|
1653
|
-
# q_lv = q_lv[:, :, 1:, 1:] if clip_start else q_lv[:, :, :-1, :-1]
|
|
1654
|
-
|
|
1655
|
-
q_mu = StableMean(q_mu)
|
|
1656
|
-
q_lv = StableLogVar(q_lv, enable_stable=not self._use_naive_exponential)
|
|
1657
|
-
q = Normal(q_mu.get(), q_lv.get_std())
|
|
1658
|
-
return q_mu, q_lv, q
|
|
1659
|
-
|
|
1660
|
-
def forward(
|
|
1661
|
-
self,
|
|
1662
|
-
p_params: torch.Tensor,
|
|
1663
|
-
q_params: torch.Tensor = None,
|
|
1664
|
-
forced_latent: torch.Tensor = None,
|
|
1665
|
-
use_mode: bool = False,
|
|
1666
|
-
force_constant_output: bool = False,
|
|
1667
|
-
analytical_kl: bool = False,
|
|
1668
|
-
mode_pred: bool = False,
|
|
1669
|
-
use_uncond_mode: bool = False,
|
|
1670
|
-
var_clip_max: float = None,
|
|
1671
|
-
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
|
1672
|
-
"""
|
|
1673
|
-
Parameters
|
|
1674
|
-
----------
|
|
1675
|
-
p_params: torch.Tensor
|
|
1676
|
-
The output tensor of the top-down layer above (i.e., mu_{p,i+1}, sigma_{p,i+1}).
|
|
1677
|
-
q_params: torch.Tensor, optional
|
|
1678
|
-
The tensor resulting from merging the bu_value tensor at the same hierarchical level
|
|
1679
|
-
from the bottom-up pass and the `p_params` tensor. Default is `None`.
|
|
1680
|
-
forced_latent: torch.Tensor, optional
|
|
1681
|
-
A pre-defined latent tensor. If it is not `None`, than it is used as the actual latent
|
|
1682
|
-
tensor and, hence, sampling does not happen. Default is `None`.
|
|
1683
|
-
use_mode: bool, optional
|
|
1684
|
-
Whether the latent tensor should be set as the latent distribution mode.
|
|
1685
|
-
In the case of Gaussian, the mode coincides with the mean of the distribution.
|
|
1686
|
-
Default is `False`.
|
|
1687
|
-
force_constant_output: bool, optional
|
|
1688
|
-
Whether to copy the first sample (and rel. distrib parameters) over the whole batch.
|
|
1689
|
-
This is used when doing experiment from the prior - q is not used.
|
|
1690
|
-
Default is `False`.
|
|
1691
|
-
analytical_kl: bool, optional
|
|
1692
|
-
Whether to compute the KL divergence analytically or using Monte Carlo estimation.
|
|
1693
|
-
Default is `False`.
|
|
1694
|
-
mode_pred: bool, optional
|
|
1695
|
-
Whether the model is in prediction mode. Default is `False`.
|
|
1696
|
-
use_uncond_mode: bool, optional
|
|
1697
|
-
Whether to use the uncoditional distribution p(z) to sample latents in prediction mode.
|
|
1698
|
-
Default is `False`.
|
|
1699
|
-
var_clip_max: float, optional
|
|
1700
|
-
The maximum value reachable by the log-variance of the latent distribution.
|
|
1701
|
-
Values exceeding this threshold are clipped. Default is `None`.
|
|
1702
|
-
"""
|
|
1703
|
-
debug_qvar_max = 0
|
|
1704
|
-
|
|
1705
|
-
# Check sampling options consistency
|
|
1706
|
-
assert (forced_latent is None) or (not use_mode)
|
|
1707
|
-
|
|
1708
|
-
# Get generative distribution p(z_i|z_{i+1})
|
|
1709
|
-
p_mu, p_lv, p = self.process_p_params(p_params, var_clip_max)
|
|
1710
|
-
p_params = (p_mu, p_lv)
|
|
1711
|
-
|
|
1712
|
-
if q_params is not None:
|
|
1713
|
-
# Get inference distribution q(z_i|z_{i+1})
|
|
1714
|
-
# NOTE: At inference time, don't centercrop the q_params even if they are odd in size.
|
|
1715
|
-
q_mu, q_lv, q = self.process_q_params(
|
|
1716
|
-
q_params, var_clip_max, allow_oddsizes=mode_pred is True
|
|
1717
|
-
)
|
|
1718
|
-
q_params = (q_mu, q_lv)
|
|
1719
|
-
sampling_distrib = q
|
|
1720
|
-
debug_qvar_max = torch.max(q_lv.get())
|
|
1721
|
-
|
|
1722
|
-
# Centercrop p_params so that their size matches the one of q_params
|
|
1723
|
-
q_size = q_mu.get().shape[-1]
|
|
1724
|
-
if p_mu.get().shape[-1] != q_size and mode_pred is False:
|
|
1725
|
-
p_mu.centercrop_to_size(q_size)
|
|
1726
|
-
p_lv.centercrop_to_size(q_size)
|
|
1727
|
-
else:
|
|
1728
|
-
sampling_distrib = p
|
|
1729
|
-
|
|
1730
|
-
# Sample latent variable
|
|
1731
|
-
z = self.get_z(
|
|
1732
|
-
sampling_distrib, forced_latent, use_mode, mode_pred, use_uncond_mode
|
|
1733
|
-
)
|
|
1734
|
-
|
|
1735
|
-
# Copy one sample (and distrib parameters) over the whole batch.
|
|
1736
|
-
# This is used when doing experiment from the prior - q is not used.
|
|
1737
|
-
if force_constant_output:
|
|
1738
|
-
z = z[0:1].expand_as(z).clone()
|
|
1739
|
-
p_params = (
|
|
1740
|
-
p_params[0][0:1].expand_as(p_params[0]).clone(),
|
|
1741
|
-
p_params[1][0:1].expand_as(p_params[1]).clone(),
|
|
1742
|
-
)
|
|
1743
|
-
|
|
1744
|
-
# Pass the sampled latent througn the output convolutional layer of stochastic block
|
|
1745
|
-
out = self.conv_out(z)
|
|
1746
|
-
|
|
1747
|
-
# Compute log p(z)# NOTE: disabling its computation.
|
|
1748
|
-
# if mode_pred is False:
|
|
1749
|
-
# logprob_p = p.log_prob(z).sum((1, 2, 3))
|
|
1750
|
-
# else:
|
|
1751
|
-
# logprob_p = None
|
|
1752
|
-
|
|
1753
|
-
if q_params is not None:
|
|
1754
|
-
# Compute log q(z)
|
|
1755
|
-
logprob_q = q.log_prob(z).sum((1, 2, 3))
|
|
1756
|
-
# Compute KL divergence metrics
|
|
1757
|
-
kl_dict = self.compute_kl_metrics(
|
|
1758
|
-
p, p_params, q, q_params, mode_pred, analytical_kl, z
|
|
1759
|
-
)
|
|
1760
|
-
else:
|
|
1761
|
-
kl_dict = {}
|
|
1762
|
-
logprob_q = None
|
|
1763
|
-
|
|
1764
|
-
# Store meaningful quantities to use them in following layers
|
|
1765
|
-
data = kl_dict
|
|
1766
|
-
data["z"] = z # sampled variable at this layer (batch, ch, h, w)
|
|
1767
|
-
data["p_params"] = p_params # (b, ch, h, w) where b is 1 or batch size
|
|
1768
|
-
data["q_params"] = q_params # (batch, ch, h, w)
|
|
1769
|
-
# data['logprob_p'] = logprob_p # (batch, )
|
|
1770
|
-
data["logprob_q"] = logprob_q # (batch, )
|
|
1771
|
-
data["qvar_max"] = debug_qvar_max
|
|
1772
|
-
|
|
1773
|
-
return out, data
|
|
1774
|
-
|
|
1775
|
-
|
|
1776
|
-
class NonStochasticBlock2d(nn.Module):
|
|
1777
|
-
"""
|
|
1778
|
-
Non-stochastic version of the NormalStochasticBlock2d.
|
|
1779
|
-
"""
|
|
1780
|
-
|
|
1781
|
-
def __init__(
|
|
1782
|
-
self,
|
|
1783
|
-
c_vars: int,
|
|
1784
|
-
c_in: int,
|
|
1785
|
-
c_out: int,
|
|
1786
|
-
kernel: int = 3,
|
|
1787
|
-
groups: int = 1,
|
|
1788
|
-
conv2d_bias: bool = True,
|
|
1789
|
-
transform_p_params: bool = True,
|
|
1790
|
-
):
|
|
1791
|
-
"""
|
|
1792
|
-
Constructor.
|
|
1793
|
-
|
|
1794
|
-
Parameters
|
|
1795
|
-
----------
|
|
1796
|
-
c_vars: int
|
|
1797
|
-
The number of channels of the latent space tensor.
|
|
1798
|
-
c_in: int
|
|
1799
|
-
The number of channels of the input tensor.
|
|
1800
|
-
c_out: int
|
|
1801
|
-
The output of the stochastic layer.
|
|
1802
|
-
Note that this is different from the sampled latent z.
|
|
1803
|
-
kernel: int, optional
|
|
1804
|
-
The size of the kernel used in convolutional layers.
|
|
1805
|
-
Default is 3.
|
|
1806
|
-
groups: int, optional
|
|
1807
|
-
The number of groups to consider in the convolutions of this layer.
|
|
1808
|
-
Default is 1.
|
|
1809
|
-
conv2d_bias: bool, optional
|
|
1810
|
-
Whether to use bias term is the convolutional blocks of this layer.
|
|
1811
|
-
Default is `True`.
|
|
1812
|
-
transform_p_params: bool, optional
|
|
1813
|
-
Whether a transformation should be applied to the `p_params` tensor.
|
|
1814
|
-
The transformation consists in a 2D convolution ()`conv_in_p()`) that
|
|
1815
|
-
maps the input to a larger number of channels.
|
|
1816
|
-
Default is `True`.
|
|
1817
|
-
"""
|
|
1818
|
-
super().__init__()
|
|
1819
|
-
assert kernel % 2 == 1
|
|
1820
|
-
pad = kernel // 2
|
|
1821
|
-
self.transform_p_params = transform_p_params
|
|
1822
|
-
self.c_in = c_in
|
|
1823
|
-
self.c_out = c_out
|
|
1824
|
-
self.c_vars = c_vars
|
|
1825
|
-
|
|
1826
|
-
if transform_p_params:
|
|
1827
|
-
self.conv_in_p = nn.Conv2d(
|
|
1828
|
-
c_in, 2 * c_vars, kernel, padding=pad, bias=conv2d_bias, groups=groups
|
|
1829
|
-
)
|
|
1830
|
-
self.conv_in_q = nn.Conv2d(
|
|
1831
|
-
c_in, 2 * c_vars, kernel, padding=pad, bias=conv2d_bias, groups=groups
|
|
1832
|
-
)
|
|
1833
|
-
self.conv_out = nn.Conv2d(
|
|
1834
|
-
c_vars, c_out, kernel, padding=pad, bias=conv2d_bias, groups=groups
|
|
1835
|
-
)
|
|
1836
|
-
|
|
1837
|
-
def compute_kl_metrics(
|
|
1838
|
-
self,
|
|
1839
|
-
p: torch.distributions.normal.Normal,
|
|
1840
|
-
p_params: torch.Tensor,
|
|
1841
|
-
q: torch.distributions.normal.Normal,
|
|
1842
|
-
q_params: torch.Tensor,
|
|
1843
|
-
mode_pred: bool,
|
|
1844
|
-
analytical_kl: bool,
|
|
1845
|
-
z: torch.Tensor,
|
|
1846
|
-
) -> Dict[str, None]:
|
|
1847
|
-
"""
|
|
1848
|
-
Compute KL (analytical or MC estimate) and then process it, extracting composed versions of the metric.
|
|
1849
|
-
Specifically, the different versions of the KL loss terms are:
|
|
1850
|
-
- `kl_elementwise`: KL term for each single element of the latent tensor [Shape: (batch, ch, h, w)].
|
|
1851
|
-
- `kl_samplewise`: KL term associated to each sample in the batch [Shape: (batch, )].
|
|
1852
|
-
- `kl_samplewise_restricted`: KL term only associated to the portion of the latent tensor that is
|
|
1853
|
-
used for prediction and summed over channel and spatial dimensions [Shape: (batch, )].
|
|
1854
|
-
- `kl_channelwise`: KL term associated to each sample and each channel [Shape: (batch, ch, )].
|
|
1855
|
-
- `kl_spatial`: # KL term summed over the channels, i.e., retaining the spatial dimensions [Shape: (batch, h, w)]
|
|
1856
|
-
|
|
1857
|
-
NOTE: in this class all the KL metrics are set to `None`.
|
|
1858
|
-
|
|
1859
|
-
Parameters
|
|
1860
|
-
----------
|
|
1861
|
-
p: torch.distributions.normal.Normal
|
|
1862
|
-
The prior generative distribution p(z_i|z_{i+1}) (or p(z_L)).
|
|
1863
|
-
p_params: torch.Tensor
|
|
1864
|
-
The parameters of the prior generative distribution.
|
|
1865
|
-
q: torch.distributions.normal.Normal
|
|
1866
|
-
The inference distribution q(z_i|z_{i+1}) (or q(z_L|x)).
|
|
1867
|
-
q_params: torch.Tensor
|
|
1868
|
-
The parameters of the inference distribution.
|
|
1869
|
-
mode_pred: bool
|
|
1870
|
-
Whether the model is in prediction mode.
|
|
1871
|
-
analytical_kl: bool
|
|
1872
|
-
Whether to compute the KL divergence analytically or using Monte Carlo estimation.
|
|
1873
|
-
z: torch.Tensor
|
|
1874
|
-
The sampled latent tensor.
|
|
1875
|
-
"""
|
|
1876
|
-
kl_dict = {
|
|
1877
|
-
"kl_elementwise": None, # (batch, ch, h, w)
|
|
1878
|
-
"kl_samplewise": None, # (batch, )
|
|
1879
|
-
"kl_spatial": None, # (batch, h, w)
|
|
1880
|
-
"kl_channelwise": None, # (batch, ch)
|
|
1881
|
-
}
|
|
1882
|
-
return kl_dict
|
|
1883
|
-
|
|
1884
|
-
def process_p_params(self, p_params, var_clip_max):
|
|
1885
|
-
if self.transform_p_params:
|
|
1886
|
-
p_params = self.conv_in_p(p_params)
|
|
1887
|
-
else:
|
|
1888
|
-
|
|
1889
|
-
assert (
|
|
1890
|
-
p_params.size(1) == 2 * self.c_vars
|
|
1891
|
-
), f"{p_params.shape} {self.c_vars}"
|
|
1892
|
-
|
|
1893
|
-
# Define p(z)
|
|
1894
|
-
p_mu, p_lv = p_params.chunk(2, dim=1)
|
|
1895
|
-
return p_mu, None
|
|
1896
|
-
|
|
1897
|
-
def process_q_params(self, q_params, var_clip_max, allow_oddsizes=False):
|
|
1898
|
-
# Define q(z)
|
|
1899
|
-
q_params = self.conv_in_q(q_params)
|
|
1900
|
-
q_mu, q_lv = q_params.chunk(2, dim=1)
|
|
1901
|
-
|
|
1902
|
-
if q_mu.shape[-1] % 2 == 1 and allow_oddsizes is False:
|
|
1903
|
-
q_mu = F.center_crop(q_mu, q_mu.shape[-1] - 1)
|
|
1904
|
-
|
|
1905
|
-
return q_mu, None
|
|
1906
|
-
|
|
1907
|
-
def forward(
|
|
1908
|
-
self,
|
|
1909
|
-
p_params: torch.Tensor,
|
|
1910
|
-
q_params: torch.Tensor = None,
|
|
1911
|
-
forced_latent: Union[None, torch.Tensor] = None,
|
|
1912
|
-
use_mode: bool = False,
|
|
1913
|
-
force_constant_output: bool = False,
|
|
1914
|
-
analytical_kl: bool = False,
|
|
1915
|
-
mode_pred: bool = False,
|
|
1916
|
-
use_uncond_mode: bool = False,
|
|
1917
|
-
var_clip_max: float = None,
|
|
1918
|
-
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
|
1919
|
-
"""
|
|
1920
|
-
Parameters
|
|
1921
|
-
----------
|
|
1922
|
-
p_params: torch.Tensor
|
|
1923
|
-
The output tensor of the top-down layer above (i.e., mu_{p,i+1}, sigma_{p,i+1}).
|
|
1924
|
-
q_params: torch.Tensor, optional
|
|
1925
|
-
The tensor resulting from merging the bu_value tensor at the same hierarchical level
|
|
1926
|
-
from the bottom-up pass and the `p_params` tensor. Default is `None`.
|
|
1927
|
-
forced_latent: torch.Tensor, optional
|
|
1928
|
-
A pre-defined latent tensor. If it is not `None`, than it is used as the actual latent
|
|
1929
|
-
tensor and, hence, sampling does not happen. Default is `None`.
|
|
1930
|
-
use_mode: bool, optional
|
|
1931
|
-
Whether the latent tensor should be set as the latent distribution mode.
|
|
1932
|
-
In the case of Gaussian, the mode coincides with the mean of the distribution.
|
|
1933
|
-
Default is `False`.
|
|
1934
|
-
force_constant_output: bool, optional
|
|
1935
|
-
Whether to copy the first sample (and rel. distrib parameters) over the whole batch.
|
|
1936
|
-
This is used when doing experiment from the prior - q is not used.
|
|
1937
|
-
Default is `False`.
|
|
1938
|
-
analytical_kl: bool, optional
|
|
1939
|
-
Whether to compute the KL divergence analytically or using Monte Carlo estimation.
|
|
1940
|
-
Default is `False`.
|
|
1941
|
-
mode_pred: bool, optional
|
|
1942
|
-
Whether the model is in prediction mode. Default is `False`.
|
|
1943
|
-
use_uncond_mode: bool, optional
|
|
1944
|
-
Whether to use the uncoditional distribution p(z) to sample latents in prediction mode.
|
|
1945
|
-
Default is `False`.
|
|
1946
|
-
var_clip_max: float, optional
|
|
1947
|
-
The maximum value reachable by the log-variance of the latent distribution.
|
|
1948
|
-
Values exceeding this threshold are clipped. Default is `None`.
|
|
1949
|
-
"""
|
|
1950
|
-
debug_qvar_max = 0
|
|
1951
|
-
assert (forced_latent is None) or (not use_mode)
|
|
1952
|
-
|
|
1953
|
-
p_mu, _ = self.process_p_params(p_params, var_clip_max)
|
|
1954
|
-
|
|
1955
|
-
p_params = (p_mu, None)
|
|
1956
|
-
|
|
1957
|
-
if q_params is not None:
|
|
1958
|
-
# At inference time, just don't centercrop the q_params even if they are odd in size.
|
|
1959
|
-
q_mu, _ = self.process_q_params(
|
|
1960
|
-
q_params, var_clip_max, allow_oddsizes=mode_pred is True
|
|
1961
|
-
)
|
|
1962
|
-
q_params = (q_mu, None)
|
|
1963
|
-
debug_qvar_max = torch.Tensor([1]).to(q_mu.device)
|
|
1964
|
-
# Sample from q(z)
|
|
1965
|
-
sampling_distrib = q_mu
|
|
1966
|
-
q_size = q_mu.shape[-1]
|
|
1967
|
-
if p_mu.shape[-1] != q_size and mode_pred is False:
|
|
1968
|
-
p_mu.centercrop_to_size(q_size)
|
|
1969
|
-
else:
|
|
1970
|
-
# Sample from p(z)
|
|
1971
|
-
sampling_distrib = p_mu
|
|
1972
|
-
|
|
1973
|
-
# Generate latent variable (typically by sampling)
|
|
1974
|
-
z = sampling_distrib
|
|
1975
|
-
|
|
1976
|
-
# Copy one sample (and distrib parameters) over the whole batch.
|
|
1977
|
-
# This is used when doing experiment from the prior - q is not used.
|
|
1978
|
-
if force_constant_output:
|
|
1979
|
-
z = z[0:1].expand_as(z).clone()
|
|
1980
|
-
p_params = (
|
|
1981
|
-
p_params[0][0:1].expand_as(p_params[0]).clone(),
|
|
1982
|
-
p_params[1][0:1].expand_as(p_params[1]).clone(),
|
|
1983
|
-
)
|
|
1984
|
-
|
|
1985
|
-
# Output of stochastic layer
|
|
1986
|
-
out = self.conv_out(z)
|
|
1987
|
-
|
|
1988
|
-
kl_dict = {}
|
|
1989
|
-
logprob_q = None
|
|
1990
|
-
|
|
1991
|
-
data = kl_dict
|
|
1992
|
-
data["z"] = z # sampled variable at this layer (batch, ch, h, w)
|
|
1993
|
-
data["p_params"] = p_params # (b, ch, h, w) where b is 1 or batch size
|
|
1994
|
-
data["q_params"] = q_params # (batch, ch, h, w)
|
|
1995
|
-
data["logprob_q"] = logprob_q # (batch, )
|
|
1996
|
-
data["qvar_max"] = debug_qvar_max
|
|
1997
|
-
|
|
1998
|
-
return out, data
|
|
1371
|
+
return x, data
|