careamics 0.1.0rc5__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (118) hide show
  1. careamics/callbacks/hyperparameters_callback.py +10 -3
  2. careamics/callbacks/progress_bar_callback.py +37 -4
  3. careamics/careamist.py +164 -231
  4. careamics/config/algorithm_model.py +5 -18
  5. careamics/config/architectures/architecture_model.py +7 -0
  6. careamics/config/architectures/custom_model.py +11 -4
  7. careamics/config/architectures/register_model.py +3 -1
  8. careamics/config/architectures/unet_model.py +2 -0
  9. careamics/config/architectures/vae_model.py +2 -0
  10. careamics/config/callback_model.py +3 -15
  11. careamics/config/configuration_example.py +4 -5
  12. careamics/config/configuration_factory.py +27 -41
  13. careamics/config/configuration_model.py +11 -11
  14. careamics/config/data_model.py +89 -63
  15. careamics/config/inference_model.py +28 -81
  16. careamics/config/optimizer_models.py +11 -11
  17. careamics/config/support/__init__.py +0 -2
  18. careamics/config/support/supported_activations.py +2 -0
  19. careamics/config/support/supported_algorithms.py +3 -1
  20. careamics/config/support/supported_architectures.py +2 -0
  21. careamics/config/support/supported_data.py +2 -0
  22. careamics/config/support/supported_loggers.py +2 -0
  23. careamics/config/support/supported_losses.py +2 -0
  24. careamics/config/support/supported_optimizers.py +2 -0
  25. careamics/config/support/supported_pixel_manipulations.py +3 -3
  26. careamics/config/support/supported_struct_axis.py +2 -0
  27. careamics/config/support/supported_transforms.py +4 -16
  28. careamics/config/tile_information.py +28 -58
  29. careamics/config/transformations/__init__.py +3 -2
  30. careamics/config/transformations/normalize_model.py +32 -4
  31. careamics/config/transformations/xy_flip_model.py +43 -0
  32. careamics/config/transformations/xy_random_rotate90_model.py +11 -3
  33. careamics/config/validators/validator_utils.py +1 -1
  34. careamics/conftest.py +12 -0
  35. careamics/dataset/__init__.py +12 -1
  36. careamics/dataset/dataset_utils/__init__.py +8 -1
  37. careamics/dataset/dataset_utils/dataset_utils.py +4 -4
  38. careamics/dataset/dataset_utils/file_utils.py +4 -3
  39. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  40. careamics/dataset/dataset_utils/read_tiff.py +6 -11
  41. careamics/dataset/dataset_utils/read_utils.py +2 -0
  42. careamics/dataset/dataset_utils/read_zarr.py +11 -7
  43. careamics/dataset/dataset_utils/running_stats.py +186 -0
  44. careamics/dataset/in_memory_dataset.py +88 -154
  45. careamics/dataset/in_memory_pred_dataset.py +88 -0
  46. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  47. careamics/dataset/iterable_dataset.py +121 -191
  48. careamics/dataset/iterable_pred_dataset.py +121 -0
  49. careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
  50. careamics/dataset/patching/patching.py +109 -39
  51. careamics/dataset/patching/random_patching.py +17 -6
  52. careamics/dataset/patching/sequential_patching.py +14 -8
  53. careamics/dataset/patching/validate_patch_dimension.py +7 -3
  54. careamics/dataset/tiling/__init__.py +10 -0
  55. careamics/dataset/tiling/collate_tiles.py +33 -0
  56. careamics/dataset/{patching → tiling}/tiled_patching.py +7 -5
  57. careamics/dataset/zarr_dataset.py +2 -0
  58. careamics/lightning_datamodule.py +46 -25
  59. careamics/lightning_module.py +19 -9
  60. careamics/lightning_prediction_datamodule.py +54 -84
  61. careamics/losses/__init__.py +2 -3
  62. careamics/losses/loss_factory.py +1 -1
  63. careamics/losses/losses.py +11 -7
  64. careamics/lvae_training/__init__.py +0 -0
  65. careamics/lvae_training/data_modules.py +1220 -0
  66. careamics/lvae_training/data_utils.py +618 -0
  67. careamics/lvae_training/eval_utils.py +905 -0
  68. careamics/lvae_training/get_config.py +84 -0
  69. careamics/lvae_training/lightning_module.py +701 -0
  70. careamics/lvae_training/metrics.py +214 -0
  71. careamics/lvae_training/train_lvae.py +339 -0
  72. careamics/lvae_training/train_utils.py +121 -0
  73. careamics/model_io/bioimage/model_description.py +40 -32
  74. careamics/model_io/bmz_io.py +3 -3
  75. careamics/model_io/model_io_utils.py +5 -2
  76. careamics/models/activation.py +2 -0
  77. careamics/models/layers.py +121 -25
  78. careamics/models/lvae/__init__.py +0 -0
  79. careamics/models/lvae/layers.py +1998 -0
  80. careamics/models/lvae/likelihoods.py +312 -0
  81. careamics/models/lvae/lvae.py +985 -0
  82. careamics/models/lvae/noise_models.py +409 -0
  83. careamics/models/lvae/utils.py +395 -0
  84. careamics/models/model_factory.py +1 -1
  85. careamics/models/unet.py +35 -14
  86. careamics/prediction_utils/__init__.py +12 -0
  87. careamics/prediction_utils/create_pred_datamodule.py +185 -0
  88. careamics/prediction_utils/prediction_outputs.py +165 -0
  89. careamics/prediction_utils/stitch_prediction.py +100 -0
  90. careamics/transforms/__init__.py +2 -2
  91. careamics/transforms/compose.py +33 -7
  92. careamics/transforms/n2v_manipulate.py +52 -14
  93. careamics/transforms/normalize.py +171 -48
  94. careamics/transforms/pixel_manipulation.py +35 -11
  95. careamics/transforms/struct_mask_parameters.py +3 -1
  96. careamics/transforms/transform.py +10 -19
  97. careamics/transforms/tta.py +43 -29
  98. careamics/transforms/xy_flip.py +123 -0
  99. careamics/transforms/xy_random_rotate90.py +38 -5
  100. careamics/utils/base_enum.py +28 -0
  101. careamics/utils/path_utils.py +2 -0
  102. careamics/utils/ram.py +4 -2
  103. careamics/utils/receptive_field.py +93 -87
  104. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +8 -6
  105. careamics-0.1.0rc7.dist-info/RECORD +130 -0
  106. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
  107. careamics/config/noise_models.py +0 -162
  108. careamics/config/support/supported_extraction_strategies.py +0 -25
  109. careamics/config/transformations/nd_flip_model.py +0 -27
  110. careamics/lightning_prediction_loop.py +0 -116
  111. careamics/losses/noise_model_factory.py +0 -40
  112. careamics/losses/noise_models.py +0 -524
  113. careamics/prediction/__init__.py +0 -7
  114. careamics/prediction/stitch_prediction.py +0 -74
  115. careamics/transforms/nd_flip.py +0 -67
  116. careamics/utils/running_stats.py +0 -43
  117. careamics-0.1.0rc5.dist-info/RECORD +0 -111
  118. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,1998 @@
1
+ """
2
+ Script containing the common basic blocks (nn.Module) reused by the LadderVAE architecture.
3
+
4
+ Hierarchy in the model blocks:
5
+
6
+ """
7
+
8
+ from copy import deepcopy
9
+ from typing import Callable, Dict, Iterable, Literal, Tuple, Union
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torchvision.transforms.functional as F
14
+ from torch.distributions import kl_divergence
15
+ from torch.distributions.normal import Normal
16
+
17
+ from .utils import (
18
+ StableLogVar,
19
+ StableMean,
20
+ crop_img_tensor,
21
+ kl_normal_mc,
22
+ pad_img_tensor,
23
+ )
24
+
25
+
26
+ class ResidualBlock(nn.Module):
27
+ """
28
+ Residual block with 2 convolutional layers.
29
+
30
+ Some architectural notes:
31
+ - The number of input, intermediate, and output channels is the same,
32
+ - Padding is always 'same',
33
+ - The 2 convolutional layers have the same groups,
34
+ - No stride allowed,
35
+ - Kernel sizes must be odd.
36
+
37
+ The output isgiven by: `out = gate(f(x)) + x`.
38
+ The presence of the gating mechanism is optional, and f(x) has different
39
+ structures depending on the `block_type` argument.
40
+ Specifically, `block_type` is a string specifying the block's structure, with:
41
+ a = activation
42
+ b = batch norm
43
+ c = conv layer
44
+ d = dropout.
45
+ For example, "bacdbacd" defines a block with 2x[batchnorm, activation, conv, dropout].
46
+ """
47
+
48
+ default_kernel_size = (3, 3)
49
+
50
+ def __init__(
51
+ self,
52
+ channels: int,
53
+ nonlin: Callable,
54
+ kernel: Union[int, Iterable[int]] = None,
55
+ groups: int = 1,
56
+ batchnorm: bool = True,
57
+ block_type: str = None,
58
+ dropout: float = None,
59
+ gated: bool = None,
60
+ skip_padding: bool = False,
61
+ conv2d_bias: bool = True,
62
+ ):
63
+ """
64
+ Constructor.
65
+
66
+ Parameters
67
+ ----------
68
+ channels: int
69
+ The number of input and output channels (they are the same).
70
+ nonlin: Callable
71
+ The non-linearity function used in the block (e.g., `nn.ReLU`).
72
+ kernel: Union[int, Iterable[int]], optional
73
+ The kernel size used in the convolutions of the block.
74
+ It can be either a single integer or a pair of integers defining the squared kernel.
75
+ Default is `None`.
76
+ groups: int, optional
77
+ The number of groups to consider in the convolutions. Default is 1.
78
+ batchnorm: bool, optional
79
+ Whether to use batchnorm layers. Default is `True`.
80
+ block_type: str, optional
81
+ A string specifying the block structure, check class docstring for more info.
82
+ Default is `None`.
83
+ dropout: float, optional
84
+ The dropout probability in dropout layers. If `None` dropout is not used.
85
+ Default is `None`.
86
+ gated: bool, optional
87
+ Whether to use gated layer. Default is `None`.
88
+ skip_padding: bool, optional
89
+ Whether to skip padding in convolutions. Default is `False`.
90
+ conv2d_bias: bool, optional
91
+ Whether to use bias term in convolutions. Default is `True`.
92
+ """
93
+ super().__init__()
94
+
95
+ # Set kernel size & padding
96
+ if kernel is None:
97
+ kernel = self.default_kernel_size
98
+ elif isinstance(kernel, int):
99
+ kernel = (kernel, kernel)
100
+ elif len(kernel) != 2:
101
+ raise ValueError("kernel has to be None, int, or an iterable of length 2")
102
+ assert all([k % 2 == 1 for k in kernel]), "kernel sizes have to be odd"
103
+ kernel = list(kernel)
104
+ self.skip_padding = skip_padding
105
+ pad = [0] * len(kernel) if self.skip_padding else [k // 2 for k in kernel]
106
+ # print(kernel, pad)
107
+
108
+ modules = []
109
+ if block_type == "cabdcabd":
110
+ for i in range(2):
111
+ conv = nn.Conv2d(
112
+ channels,
113
+ channels,
114
+ kernel[i],
115
+ padding=pad[i],
116
+ groups=groups,
117
+ bias=conv2d_bias,
118
+ )
119
+ modules.append(conv)
120
+ modules.append(nonlin())
121
+ if batchnorm:
122
+ modules.append(nn.BatchNorm2d(channels))
123
+ if dropout is not None:
124
+ modules.append(nn.Dropout2d(dropout))
125
+ elif block_type == "bacdbac":
126
+ for i in range(2):
127
+ if batchnorm:
128
+ modules.append(nn.BatchNorm2d(channels))
129
+ modules.append(nonlin())
130
+ conv = nn.Conv2d(
131
+ channels,
132
+ channels,
133
+ kernel[i],
134
+ padding=pad[i],
135
+ groups=groups,
136
+ bias=conv2d_bias,
137
+ )
138
+ modules.append(conv)
139
+ if dropout is not None and i == 0:
140
+ modules.append(nn.Dropout2d(dropout))
141
+ elif block_type == "bacdbacd":
142
+ for i in range(2):
143
+ if batchnorm:
144
+ modules.append(nn.BatchNorm2d(channels))
145
+ modules.append(nonlin())
146
+ conv = nn.Conv2d(
147
+ channels,
148
+ channels,
149
+ kernel[i],
150
+ padding=pad[i],
151
+ groups=groups,
152
+ bias=conv2d_bias,
153
+ )
154
+ modules.append(conv)
155
+ modules.append(nn.Dropout2d(dropout))
156
+
157
+ else:
158
+ raise ValueError(f"unrecognized block type '{block_type}'")
159
+
160
+ self.gated = gated
161
+ if gated:
162
+ modules.append(GateLayer2d(channels, 1, nonlin))
163
+
164
+ self.block = nn.Sequential(*modules)
165
+
166
+ def forward(self, x):
167
+
168
+ out = self.block(x)
169
+ if out.shape != x.shape:
170
+ return out + F.center_crop(x, out.shape[-2:])
171
+ else:
172
+ return out + x
173
+
174
+
175
+ class ResidualGatedBlock(ResidualBlock):
176
+
177
+ def __init__(self, *args, **kwargs):
178
+ super().__init__(*args, **kwargs, gated=True)
179
+
180
+
181
+ class GateLayer2d(nn.Module):
182
+ """
183
+ Double the number of channels through a convolutional layer, then use
184
+ half the channels as gate for the other half.
185
+ """
186
+
187
+ def __init__(self, channels, kernel_size, nonlin=nn.LeakyReLU):
188
+ super().__init__()
189
+ assert kernel_size % 2 == 1
190
+ pad = kernel_size // 2
191
+ self.conv = nn.Conv2d(channels, 2 * channels, kernel_size, padding=pad)
192
+ self.nonlin = nonlin()
193
+
194
+ def forward(self, x):
195
+ x = self.conv(x)
196
+ x, gate = torch.chunk(x, 2, dim=1)
197
+ x = self.nonlin(x) # TODO remove this?
198
+ gate = torch.sigmoid(gate)
199
+ return x * gate
200
+
201
+
202
+ class ResBlockWithResampling(nn.Module):
203
+ """
204
+ Residual block that takes care of resampling (i.e. downsampling or upsampling) steps (by a factor 2).
205
+ It is structured as follows:
206
+ 1. `pre_conv`: a downsampling or upsampling strided convolutional layer in case of resampling, or
207
+ a 1x1 convolutional layer that maps the number of channels of the input to `inner_channels`.
208
+ 2. `ResidualBlock`
209
+ 3. `post_conv`: a 1x1 convolutional layer that maps the number of channels to `c_out`.
210
+
211
+ Some implementation notes:
212
+ - Resampling is performed through a strided convolution layer at the beginning of the block.
213
+ - The strided convolution block has fixed kernel size of 3x3 and 1 layer of zero-padding.
214
+ - The number of channels is adjusted at the beginning and end of the block through 1x1 convolutional layers.
215
+ - The number of internal channels is by default the same as the number of output channels, but
216
+ min_inner_channels can override the behaviour.
217
+ """
218
+
219
+ def __init__(
220
+ self,
221
+ mode: Literal["top-down", "bottom-up"],
222
+ c_in: int,
223
+ c_out: int,
224
+ min_inner_channels: int = None,
225
+ nonlin: Callable = nn.LeakyReLU,
226
+ resample: bool = False,
227
+ res_block_kernel: Union[int, Iterable[int]] = None,
228
+ groups: int = 1,
229
+ batchnorm: bool = True,
230
+ res_block_type: str = None,
231
+ dropout: float = None,
232
+ gated: bool = None,
233
+ skip_padding: bool = False,
234
+ conv2d_bias: bool = True,
235
+ # lowres_input: bool = False,
236
+ ):
237
+ """
238
+ Constructor.
239
+
240
+ Parameters
241
+ ----------
242
+ mode: Literal["top-down", "bottom-up"]
243
+ The type of resampling performed in the initial strided convolution of the block.
244
+ If "bottom-up" downsampling of a factor 2 is done.
245
+ If "top-down" upsampling of a factor 2 is done.
246
+ c_in: int
247
+ The number of input channels.
248
+ c_out: int
249
+ The number of output channels.
250
+ min_inner_channels: int, optional
251
+ The number of channels used in the inner layer of this module.
252
+ Default is `None`, meaning that the number of inner channels is set to `c_out`.
253
+ nonlin: Callable, optional
254
+ The non-linearity function used in the block. Default is `nn.LeakyReLU`.
255
+ resample: bool, optional
256
+ Whether to perform resampling in the first convolutional layer.
257
+ If `False`, the first convolutional layer just maps the input to a tensor with
258
+ `inner_channels` channels through 1x1 convolution. Deafult is `False`.
259
+ res_block_kernel: Union[int, Iterable[int]], optional
260
+ The kernel size used in the convolutions of the residual block.
261
+ It can be either a single integer or a pair of integers defining the squared kernel.
262
+ Default is `None`.
263
+ groups: int, optional
264
+ The number of groups to consider in the convolutions. Default is 1.
265
+ batchnorm: bool, optional
266
+ Whether to use batchnorm layers. Default is `True`.
267
+ res_block_type: str, optional
268
+ A string specifying the structure of residual block.
269
+ Check `ResidualBlock` doscstring for more information.
270
+ Default is `None`.
271
+ dropout: float, optional
272
+ The dropout probability in dropout layers. If `None` dropout is not used.
273
+ Default is `None`.
274
+ gated: bool, optional
275
+ Whether to use gated layer. Default is `None`.
276
+ skip_padding: bool, optional
277
+ Whether to skip padding in convolutions. Default is `False`.
278
+ conv2d_bias: bool, optional
279
+ Whether to use bias term in convolutions. Default is `True`.
280
+ """
281
+ super().__init__()
282
+ assert mode in ["top-down", "bottom-up"]
283
+
284
+ if min_inner_channels is None:
285
+ min_inner_channels = 0
286
+ # inner_channels is the number of channels used in the inner layers
287
+ # of ResBlockWithResampling
288
+ inner_channels = max(c_out, min_inner_channels)
289
+
290
+ # Define first conv layer to change num channels and/or up/downsample
291
+ if resample:
292
+ if mode == "bottom-up": # downsample
293
+ self.pre_conv = nn.Conv2d(
294
+ in_channels=c_in,
295
+ out_channels=inner_channels,
296
+ kernel_size=3,
297
+ padding=1,
298
+ stride=2,
299
+ groups=groups,
300
+ bias=conv2d_bias,
301
+ )
302
+ elif mode == "top-down": # upsample
303
+ self.pre_conv = nn.ConvTranspose2d(
304
+ in_channels=c_in,
305
+ kernel_size=3,
306
+ out_channels=inner_channels,
307
+ padding=1,
308
+ stride=2,
309
+ groups=groups,
310
+ output_padding=1,
311
+ bias=conv2d_bias,
312
+ )
313
+ elif c_in != inner_channels:
314
+ self.pre_conv = nn.Conv2d(
315
+ c_in, inner_channels, 1, groups=groups, bias=conv2d_bias
316
+ )
317
+ else:
318
+ self.pre_conv = None
319
+
320
+ # Residual block
321
+ self.res = ResidualBlock(
322
+ channels=inner_channels,
323
+ nonlin=nonlin,
324
+ kernel=res_block_kernel,
325
+ groups=groups,
326
+ batchnorm=batchnorm,
327
+ dropout=dropout,
328
+ gated=gated,
329
+ block_type=res_block_type,
330
+ skip_padding=skip_padding,
331
+ conv2d_bias=conv2d_bias,
332
+ )
333
+
334
+ # Define last conv layer to get correct num output channels
335
+ if inner_channels != c_out:
336
+ self.post_conv = nn.Conv2d(
337
+ inner_channels, c_out, 1, groups=groups, bias=conv2d_bias
338
+ )
339
+ else:
340
+ self.post_conv = None
341
+
342
+ def forward(self, x):
343
+ if self.pre_conv is not None:
344
+ x = self.pre_conv(x)
345
+
346
+ x = self.res(x)
347
+
348
+ if self.post_conv is not None:
349
+ x = self.post_conv(x)
350
+ return x
351
+
352
+
353
+ class TopDownDeterministicResBlock(ResBlockWithResampling):
354
+
355
+ def __init__(self, *args, upsample: bool = False, **kwargs):
356
+ kwargs["resample"] = upsample
357
+ super().__init__("top-down", *args, **kwargs)
358
+
359
+
360
+ class BottomUpDeterministicResBlock(ResBlockWithResampling):
361
+
362
+ def __init__(self, *args, downsample: bool = False, **kwargs):
363
+ kwargs["resample"] = downsample
364
+ super().__init__("bottom-up", *args, **kwargs)
365
+
366
+
367
+ class BottomUpLayer(nn.Module):
368
+ """
369
+ Bottom-up deterministic layer.
370
+ It consists of one or a stack of `BottomUpDeterministicResBlock`'s.
371
+ The outputs are the so-called `bu_values` that are later used in the Decoder to update the
372
+ generative distributions.
373
+
374
+ NOTE: When Lateral Contextualization is Enabled (i.e., `enable_multiscale=True`),
375
+ the low-res lateral input is first fed through a BottomUpDeterministicBlock (BUDB)
376
+ (without downsampling), and then merged to the latent tensor produced by the primary flow
377
+ of the `BottomUpLayer` through the `MergeLowRes` layer. It is meaningful to remark that
378
+ the BUDB that takes care of encoding the low-res input can be either shared with the
379
+ primary flow (and in that case it is the "same_size" BUDB (or stack of BUDBs) -> see `self.net`),
380
+ or can be a deep-copy of the primary flow's BUDB.
381
+ This behaviour is controlled by `lowres_separate_branch` parameter.
382
+ """
383
+
384
+ def __init__(
385
+ self,
386
+ n_res_blocks: int,
387
+ n_filters: int,
388
+ downsampling_steps: int = 0,
389
+ nonlin: Callable = None,
390
+ batchnorm: bool = True,
391
+ dropout: float = None,
392
+ res_block_type: str = None,
393
+ res_block_kernel: int = None,
394
+ res_block_skip_padding: bool = False,
395
+ gated: bool = None,
396
+ enable_multiscale: bool = False,
397
+ multiscale_lowres_size_factor: int = None,
398
+ lowres_separate_branch: bool = False,
399
+ multiscale_retain_spatial_dims: bool = False,
400
+ decoder_retain_spatial_dims: bool = False,
401
+ output_expected_shape: Iterable[int] = None,
402
+ ):
403
+ """
404
+ Constructor.
405
+
406
+ Parameters
407
+ ----------
408
+ n_res_blocks: int
409
+ Number of `BottomUpDeterministicResBlock` modules stacked in this layer.
410
+ n_filters: int
411
+ Number of channels present through out the layers of this block.
412
+ downsampling_steps: int, optional
413
+ Number of downsampling steps that has to be done in this layer (typically 1).
414
+ Default is 0.
415
+ nonlin: Callable, optional
416
+ The non-linearity function used in the block. Default is `None`.
417
+ batchnorm: bool, optional
418
+ Whether to use batchnorm layers. Default is `True`.
419
+ dropout: float, optional
420
+ The dropout probability in dropout layers. If `None` dropout is not used.
421
+ Default is `None`.
422
+ res_block_type: str, optional
423
+ A string specifying the structure of residual block.
424
+ Check `ResidualBlock` doscstring for more information.
425
+ Default is `None`.
426
+ res_block_kernel: Union[int, Iterable[int]], optional
427
+ The kernel size used in the convolutions of the residual block.
428
+ It can be either a single integer or a pair of integers defining the squared kernel.
429
+ Default is `None`.
430
+ res_block_skip_padding: bool, optional
431
+ Whether to skip padding in convolutions in the Residual block. Default is `False`.
432
+ gated: bool, optional
433
+ Whether to use gated layer. Default is `None`.
434
+ enable_multiscale: bool, optional
435
+ Whether to enable multiscale (Lateral Contextualization) or not. Default is `False`.
436
+ multiscale_lowres_size_factor: int, optional
437
+ A factor the expresses the relative size of the primary flow tensor with respect to the
438
+ lower-resolution lateral input tensor. Default in `None`.
439
+ lowres_separate_branch: bool, optional
440
+ Whether the residual block(s) encoding the low-res input should be shared (`False`) or
441
+ not (`True`) with the primary flow "same-size" residual block(s). Default is `False`.
442
+ multiscale_retain_spatial_dims: bool, optional
443
+ Whether to pad the latent tensor resulting from the bottom-up layer's primary flow
444
+ to match the size of the low-res input. Default is `False`.
445
+ decoder_retain_spatial_dims: bool, optional
446
+ Default is `False`.
447
+ output_expected_shape: Iterable[int], optional
448
+ The expected shape of the layer output (only used if `enable_multiscale == True`).
449
+ Default is `None`.
450
+ """
451
+ super().__init__()
452
+
453
+ # Define attributes for Lateral Contextualization
454
+ self.enable_multiscale = enable_multiscale
455
+ self.lowres_separate_branch = lowres_separate_branch
456
+ self.multiscale_retain_spatial_dims = multiscale_retain_spatial_dims
457
+ self.multiscale_lowres_size_factor = multiscale_lowres_size_factor
458
+ self.decoder_retain_spatial_dims = decoder_retain_spatial_dims
459
+ self.output_expected_shape = output_expected_shape
460
+ assert self.output_expected_shape is None or self.enable_multiscale is True
461
+
462
+ bu_blocks_downsized = []
463
+ bu_blocks_samesize = []
464
+ for _ in range(n_res_blocks):
465
+ do_resample = False
466
+ if downsampling_steps > 0:
467
+ do_resample = True
468
+ downsampling_steps -= 1
469
+ block = BottomUpDeterministicResBlock(
470
+ c_in=n_filters,
471
+ c_out=n_filters,
472
+ nonlin=nonlin,
473
+ downsample=do_resample,
474
+ batchnorm=batchnorm,
475
+ dropout=dropout,
476
+ res_block_type=res_block_type,
477
+ res_block_kernel=res_block_kernel,
478
+ skip_padding=res_block_skip_padding,
479
+ gated=gated,
480
+ )
481
+ if do_resample:
482
+ bu_blocks_downsized.append(block)
483
+ else:
484
+ bu_blocks_samesize.append(block)
485
+
486
+ self.net_downsized = nn.Sequential(*bu_blocks_downsized)
487
+ self.net = nn.Sequential(*bu_blocks_samesize)
488
+
489
+ # Using the same net for the low resolution (and larger sized image)
490
+ self.lowres_net = self.lowres_merge = None
491
+ if self.enable_multiscale:
492
+ self._init_multiscale(
493
+ n_filters=n_filters,
494
+ nonlin=nonlin,
495
+ batchnorm=batchnorm,
496
+ dropout=dropout,
497
+ res_block_type=res_block_type,
498
+ )
499
+
500
+ # msg = f'[{self.__class__.__name__}] McEnabled:{int(enable_multiscale)} '
501
+ # if enable_multiscale:
502
+ # msg += f'McParallelBeam:{int(multiscale_retain_spatial_dims)} McFactor{multiscale_lowres_size_factor}'
503
+ # print(msg)
504
+
505
+ def _init_multiscale(
506
+ self,
507
+ nonlin: Callable = None,
508
+ n_filters: int = None,
509
+ batchnorm: bool = None,
510
+ dropout: float = None,
511
+ res_block_type: str = None,
512
+ ) -> None:
513
+ """
514
+ This method defines the modules responsible of merging compressed lateral inputs to the outputs
515
+ of the primary flow at different hierarchical levels in the multiresolution approach (LC).
516
+
517
+ Specifically, the method initializes `lowres_net`, which is a stack of `BottomUpDeterministicBlock`'s
518
+ (w/out downsampling) that takes care of additionally processing the low-res input, and `lowres_merge`,
519
+ which is the module responsible of merging the compressed lateral input to the main flow.
520
+
521
+ NOTE: The merge modality is set by default to "residual", meaning that the merge layer
522
+ performs concatenation on dim=1, followed by 1x1 convolution and a Residual Gated block.
523
+
524
+ Parameters
525
+ ----------
526
+ nonlin: Callable, optional
527
+ The non-linearity function used in the block. Default is `None`.
528
+ n_filters: int
529
+ Number of channels present through out the layers of this block.
530
+ batchnorm: bool, optional
531
+ Whether to use batchnorm layers. Default is `True`.
532
+ dropout: float, optional
533
+ The dropout probability in dropout layers. If `None` dropout is not used.
534
+ Default is `None`.
535
+ res_block_type: str, optional
536
+ A string specifying the structure of residual block.
537
+ Check `ResidualBlock` doscstring for more information.
538
+ Default is `None`.
539
+ """
540
+ self.lowres_net = self.net
541
+ if self.lowres_separate_branch:
542
+ self.lowres_net = deepcopy(self.net)
543
+
544
+ self.lowres_merge = MergeLowRes(
545
+ channels=n_filters,
546
+ merge_type="residual",
547
+ nonlin=nonlin,
548
+ batchnorm=batchnorm,
549
+ dropout=dropout,
550
+ res_block_type=res_block_type,
551
+ multiscale_retain_spatial_dims=self.multiscale_retain_spatial_dims,
552
+ multiscale_lowres_size_factor=self.multiscale_lowres_size_factor,
553
+ )
554
+
555
+ def forward(
556
+ self, x: torch.Tensor, lowres_x: torch.Tensor = None
557
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
558
+ """
559
+ Parameters
560
+ ----------
561
+ x: torch.Tensor
562
+ The input of the `BottomUpLayer`, i.e., the input image or the output of the
563
+ previous layer.
564
+ lowres_x: torch.Tensor, optional
565
+ The low-res input used for Lateral Contextualization (LC). Default is `None`.
566
+ """
567
+ # The input is fed through the residual downsampling block(s)
568
+ primary_flow = self.net_downsized(x)
569
+ # The downsampling output is fed through additional residual block(s)
570
+ primary_flow = self.net(primary_flow)
571
+
572
+ # If LC is not used, simply return output of primary-flow
573
+ if self.enable_multiscale is False:
574
+ assert lowres_x is None
575
+ return primary_flow, primary_flow
576
+
577
+ if lowres_x is not None:
578
+ # First encode the low-res lateral input
579
+ lowres_flow = self.lowres_net(lowres_x)
580
+ # Then pass the result through the MergeLowRes layer
581
+ merged = self.lowres_merge(primary_flow, lowres_flow)
582
+ else:
583
+ merged = primary_flow
584
+
585
+ if (
586
+ self.multiscale_retain_spatial_dims is False
587
+ or self.decoder_retain_spatial_dims is True
588
+ ):
589
+ return merged, merged
590
+
591
+ if self.output_expected_shape is not None:
592
+ expected_shape = self.output_expected_shape
593
+ else:
594
+ fac = self.multiscale_lowres_size_factor
595
+ expected_shape = (merged.shape[-2] // fac, merged.shape[-1] // fac)
596
+ assert merged.shape[-2:] != expected_shape
597
+
598
+ # Crop the resulting tensor so that it matches with the Decoder
599
+ value_to_use_in_topdown = crop_img_tensor(merged, expected_shape)
600
+ return merged, value_to_use_in_topdown
601
+
602
+
603
+ class MergeLayer(nn.Module):
604
+ """
605
+ This layer merges two or more 4D input tensors by concatenating along dim=1 and passes the result through:
606
+ a) a convolutional 1x1 layer (`merge_type == "linear"`), or
607
+ b) a convolutional 1x1 layer and then a gated residual block (`merge_type == "residual"`), or
608
+ c) a convolutional 1x1 layer and then an ungated residual block (`merge_type == "residual_ungated"`).
609
+ """
610
+
611
+ def __init__(
612
+ self,
613
+ merge_type: Literal["linear", "residual", "residual_ungated"],
614
+ channels: Union[int, Iterable[int]],
615
+ nonlin: Callable = nn.LeakyReLU,
616
+ batchnorm: bool = True,
617
+ dropout: float = None,
618
+ res_block_type: str = None,
619
+ res_block_kernel: int = None,
620
+ res_block_skip_padding: bool = False,
621
+ conv2d_bias: bool = True,
622
+ ):
623
+ """
624
+ Constructor.
625
+
626
+ Parameters
627
+ ----------
628
+ merge_type: Literal["linear", "residual", "residual_ungated"]
629
+ The type of merge done in the layer. It can be chosen between "linear", "residual", and "residual_ungated".
630
+ Check the class docstring for more information about the behaviour of different merge modalities.
631
+ channels: Union[int, Iterable[int]]
632
+ The number of channels used in the convolutional blocks of this layer.
633
+ If it is an `int`:
634
+ - 1st 1x1 Conv2d: in_channels=2*channels, out_channels=channels
635
+ - (Optional) ResBlock: in_channels=channels, out_channels=channels
636
+ If it is an Iterable (must have `len(channels)==3`):
637
+ - 1st 1x1 Conv2d: in_channels=sum(channels[:-1]), out_channels=channels[-1]
638
+ - (Optional) ResBlock: in_channels=channels[-1], out_channels=channels[-1]
639
+ nonlin: Callable, optional
640
+ The non-linearity function used in the block. Default is `nn.LeakyReLU`.
641
+ batchnorm: bool, optional
642
+ Whether to use batchnorm layers. Default is `True`.
643
+ dropout: float, optional
644
+ The dropout probability in dropout layers. If `None` dropout is not used.
645
+ Default is `None`.
646
+ res_block_type: str, optional
647
+ A string specifying the structure of residual block.
648
+ Check `ResidualBlock` doscstring for more information.
649
+ Default is `None`.
650
+ res_block_kernel: Union[int, Iterable[int]], optional
651
+ The kernel size used in the convolutions of the residual block.
652
+ It can be either a single integer or a pair of integers defining the squared kernel.
653
+ Default is `None`.
654
+ res_block_skip_padding: bool, optional
655
+ Whether to skip padding in convolutions in the Residual block. Default is `False`.
656
+ conv2d_bias: bool, optional
657
+ Whether to use bias term in convolutions. Default is `True`.
658
+ """
659
+ super().__init__()
660
+ try:
661
+ iter(channels)
662
+ except TypeError: # it is not iterable
663
+ channels = [channels] * 3
664
+ else: # it is iterable
665
+ if len(channels) == 1:
666
+ channels = [channels[0]] * 3
667
+
668
+ # assert len(channels) == 3
669
+
670
+ if merge_type == "linear":
671
+ self.layer = nn.Conv2d(
672
+ sum(channels[:-1]), channels[-1], 1, bias=conv2d_bias
673
+ )
674
+ elif merge_type == "residual":
675
+ self.layer = nn.Sequential(
676
+ nn.Conv2d(
677
+ sum(channels[:-1]), channels[-1], 1, padding=0, bias=conv2d_bias
678
+ ),
679
+ ResidualGatedBlock(
680
+ channels[-1],
681
+ nonlin,
682
+ batchnorm=batchnorm,
683
+ dropout=dropout,
684
+ block_type=res_block_type,
685
+ kernel=res_block_kernel,
686
+ conv2d_bias=conv2d_bias,
687
+ skip_padding=res_block_skip_padding,
688
+ ),
689
+ )
690
+ elif merge_type == "residual_ungated":
691
+ self.layer = nn.Sequential(
692
+ nn.Conv2d(
693
+ sum(channels[:-1]), channels[-1], 1, padding=0, bias=conv2d_bias
694
+ ),
695
+ ResidualBlock(
696
+ channels[-1],
697
+ nonlin,
698
+ batchnorm=batchnorm,
699
+ dropout=dropout,
700
+ block_type=res_block_type,
701
+ kernel=res_block_kernel,
702
+ conv2d_bias=conv2d_bias,
703
+ skip_padding=res_block_skip_padding,
704
+ ),
705
+ )
706
+
707
+ def forward(self, *args) -> torch.Tensor:
708
+
709
+ # Concatenate the input tensors along dim=1
710
+ x = torch.cat(args, dim=1)
711
+
712
+ # Pass the concatenated tensor through the conv layer
713
+ x = self.layer(x)
714
+
715
+ return x
716
+
717
+
718
+ class MergeLowRes(MergeLayer):
719
+ """
720
+ Child class of `MergeLayer`, specifically designed to merge the low-resolution patches
721
+ that are used in Lateral Contextualization approach.
722
+ """
723
+
724
+ def __init__(self, *args, **kwargs):
725
+ self.retain_spatial_dims = kwargs.pop("multiscale_retain_spatial_dims")
726
+ self.multiscale_lowres_size_factor = kwargs.pop("multiscale_lowres_size_factor")
727
+ super().__init__(*args, **kwargs)
728
+
729
+ def forward(self, latent: torch.Tensor, lowres: torch.Tensor) -> torch.Tensor:
730
+ """
731
+ Parameters
732
+ ----------
733
+ latent: torch.Tensor
734
+ The output latent tensor from previous layer in the LVAE hierarchy.
735
+ lowres: torch.Tensor
736
+ The low-res patch image to be merged to increase the context.
737
+ """
738
+ if self.retain_spatial_dims:
739
+ # Pad latent tensor to match lowres tensor's shape
740
+ latent = pad_img_tensor(latent, lowres.shape[2:])
741
+ else:
742
+ # Crop lowres tensor to match latent tensor's shape
743
+ lh, lw = lowres.shape[-2:]
744
+ h = lh // self.multiscale_lowres_size_factor
745
+ w = lw // self.multiscale_lowres_size_factor
746
+ h_pad = (lh - h) // 2
747
+ w_pad = (lw - w) // 2
748
+ lowres = lowres[:, :, h_pad:-h_pad, w_pad:-w_pad]
749
+
750
+ return super().forward(latent, lowres)
751
+
752
+
753
+ class SkipConnectionMerger(MergeLayer):
754
+ """
755
+ A specialized `MergeLayer` module, designed to handle skip connections in the model.
756
+ """
757
+
758
+ def __init__(
759
+ self,
760
+ nonlin: Callable,
761
+ channels: Union[int, Iterable[int]],
762
+ batchnorm: bool,
763
+ dropout: float,
764
+ res_block_type: str,
765
+ merge_type: Literal["linear", "residual", "residual_ungated"] = "residual",
766
+ conv2d_bias: bool = True,
767
+ res_block_kernel: int = None,
768
+ res_block_skip_padding: bool = False,
769
+ ):
770
+ """
771
+ Constructor.
772
+
773
+ nonlin: Callable, optional
774
+ The non-linearity function used in the block. Default is `nn.LeakyReLU`.
775
+ channels: Union[int, Iterable[int]]
776
+ The number of channels used in the convolutional blocks of this layer.
777
+ If it is an `int`:
778
+ - 1st 1x1 Conv2d: in_channels=2*channels, out_channels=channels
779
+ - (Optional) ResBlock: in_channels=channels, out_channels=channels
780
+ If it is an Iterable (must have `len(channels)==3`):
781
+ - 1st 1x1 Conv2d: in_channels=sum(channels[:-1]), out_channels=channels[-1]
782
+ - (Optional) ResBlock: in_channels=channels[-1], out_channels=channels[-1]
783
+ batchnorm: bool, optional
784
+ Whether to use batchnorm layers. Default is `True`.
785
+ dropout: float, optional
786
+ The dropout probability in dropout layers. If `None` dropout is not used.
787
+ Default is `None`.
788
+ res_block_type: str, optional
789
+ A string specifying the structure of residual block.
790
+ Check `ResidualBlock` doscstring for more information.
791
+ Default is `None`.
792
+ merge_type: Literal["linear", "residual", "residual_ungated"]
793
+ The type of merge done in the layer. It can be chosen between "linear", "residual", and "residual_ungated".
794
+ Check the class docstring for more information about the behaviour of different merge modalities.
795
+ conv2d_bias: bool, optional
796
+ Whether to use bias term in convolutions. Default is `True`.
797
+ res_block_kernel: Union[int, Iterable[int]], optional
798
+ The kernel size used in the convolutions of the residual block.
799
+ It can be either a single integer or a pair of integers defining the squared kernel.
800
+ Default is `None`.
801
+ res_block_skip_padding: bool, optional
802
+ Whether to skip padding in convolutions in the Residual block. Default is `False`.
803
+ """
804
+ super().__init__(
805
+ channels=channels,
806
+ nonlin=nonlin,
807
+ merge_type=merge_type,
808
+ batchnorm=batchnorm,
809
+ dropout=dropout,
810
+ res_block_type=res_block_type,
811
+ res_block_kernel=res_block_kernel,
812
+ conv2d_bias=conv2d_bias,
813
+ res_block_skip_padding=res_block_skip_padding,
814
+ )
815
+
816
+
817
+ class TopDownLayer(nn.Module):
818
+ """
819
+ Top-down inference layer.
820
+ It includes:
821
+ - Stochastic sampling,
822
+ - Computation of KL divergence,
823
+ - A small deterministic ResNet that performs upsampling.
824
+
825
+ NOTE 1:
826
+ The algorithm for generative inference approximately works as follows:
827
+ - p_params = output of top-down layer above
828
+ - bu = inferred bottom-up value at this layer
829
+ - q_params = merge(bu, p_params)
830
+ - z = stochastic_layer(q_params)
831
+ - (optional) get and merge skip connection from prev top-down layer
832
+ - top-down deterministic ResNet
833
+
834
+ NOTE 2:
835
+ The Top-Down layer can work in two modes: inference and prediction/generative.
836
+ Depending on the particular mode, it follows distinct behaviours:
837
+ - In inference mode, parameters of q(z_i|z_i+1) are obtained from the inference path,
838
+ by merging outcomes of bottom-up and top-down passes. The exception is the top layer,
839
+ in which the parameters of q(z_L|x) are set as the output of the topmost bottom-up layer.
840
+ - On the contrary in prediciton/generative mode, parameters of q(z_i|z_i+1) can be obtained
841
+ once again by merging bottom-up and top-down outputs (CONDITIONAL GENERATION), or it is
842
+ possible to directly sample from the prior p(z_i|z_i+1) (UNCONDITIONAL GENERATION).
843
+
844
+ NOTE 3:
845
+ When doing unconditional generation, bu_value is not available. Hence the
846
+ merge layer is not used, and z is sampled directly from p_params.
847
+
848
+ NOTE 4:
849
+ If this is the top layer, at inference time, the uppermost bottom-up value
850
+ is used directly as q_params, and p_params are defined in this layer
851
+ (while they are usually taken from the previous layer), and can be learned.
852
+ """
853
+
854
+ def __init__(
855
+ self,
856
+ z_dim: int,
857
+ n_res_blocks: int,
858
+ n_filters: int,
859
+ is_top_layer: bool = False,
860
+ downsampling_steps: int = None,
861
+ nonlin: Callable = None,
862
+ merge_type: Literal["linear", "residual", "residual_ungated"] = None,
863
+ batchnorm: bool = True,
864
+ dropout: float = None,
865
+ stochastic_skip: bool = False,
866
+ res_block_type: str = None,
867
+ res_block_kernel: int = None,
868
+ res_block_skip_padding: bool = None,
869
+ groups: int = 1,
870
+ gated: bool = None,
871
+ learn_top_prior: bool = False,
872
+ top_prior_param_shape: Iterable[int] = None,
873
+ analytical_kl: bool = False,
874
+ bottomup_no_padding_mode: bool = False,
875
+ topdown_no_padding_mode: bool = False,
876
+ retain_spatial_dims: bool = False,
877
+ restricted_kl: bool = False,
878
+ vanilla_latent_hw: Iterable[int] = None,
879
+ non_stochastic_version: bool = False,
880
+ input_image_shape: Union[None, Tuple[int, int]] = None,
881
+ normalize_latent_factor: float = 1.0,
882
+ conv2d_bias: bool = True,
883
+ stochastic_use_naive_exponential: bool = False,
884
+ ):
885
+ """
886
+ Constructor.
887
+
888
+ Parameters
889
+ ----------
890
+ z_dim: int
891
+ The size of the latent space.
892
+ n_res_blocks: int
893
+ The number of TopDownDeterministicResBlock blocks
894
+ n_filters: int
895
+ The number of channels present through out the layers of this block.
896
+ is_top_layer: bool, optional
897
+ Whether the current layer is at the top of the Decoder hierarchy. Default is `False`.
898
+ downsampling_steps: int, optional
899
+ The number of downsampling steps that has to be done in this layer (typically 1).
900
+ Default is `False`.
901
+ nonlin: Callable, optional
902
+ The non-linearity function used in the block (e.g., `nn.ReLU`). Deafault is `None`.
903
+ merge_type: Literal["linear", "residual", "residual_ungated"], optional
904
+ The type of merge done in the layer. It can be chosen between "linear", "residual",
905
+ and "residual_ungated". Check the `MergeLayer` class docstring for more information
906
+ about the behaviour of different merging modalities. Default is `None`.
907
+ batchnorm: bool, optional
908
+ Whether to use batchnorm layers. Default is `True`.
909
+ dropout: float, optional
910
+ The dropout probability in dropout layers. If `None` dropout is not used.
911
+ Default is `None`.
912
+ stochastic_skip: bool, optional
913
+ Whether to use skip connections between previous top-down layer's output and this layer's stochastic output.
914
+ Stochastic skip connection allows the previous layer's output has a way to directly reach this hierarchical
915
+ level, hence facilitating the gradient flow during backpropagation. Default is `False`.
916
+ res_block_type: str, optional
917
+ A string specifying the structure of residual block.
918
+ Check `ResidualBlock` documentation for more information.
919
+ Default is `None`.
920
+ res_block_kernel: Union[int, Iterable[int]], optional
921
+ The kernel size used in the convolutions of the residual block.
922
+ It can be either a single integer or a pair of integers defining the squared kernel.
923
+ Default is `None`.
924
+ res_block_skip_padding: bool, optional
925
+ Whether to skip padding in convolutions in the Residual block. Default is `None`.
926
+ groups: int, optional
927
+ The number of groups to consider in the convolutions. Default is 1.
928
+ gated: bool, optional
929
+ Whether to use gated layer in `ResidualBlock`. Default is `None`.
930
+ learn_top_prior:
931
+ Whether to set the top prior as learnable.
932
+ If this is set to `False`, in the top-most layer the prior will be N(0,1).
933
+ Otherwise, we will still have a normal distribution whose parameters will be learnt.
934
+ Deafult is `False`.
935
+ top_prior_param_shape: Iterable[int], optional
936
+ The size of the tensor which expresses the mean and the variance
937
+ of the prior for the top most layer. Default is `None`.
938
+ analytical_kl: bool, optional
939
+ If True, KL divergence is calculated according to the analytical formula.
940
+ Otherwise, an MC approximation using sampled latents is calculated.
941
+ Default is `False`.
942
+ bottomup_no_padding_mode: bool, optional
943
+ Whether padding is used in the different layers of the bottom-up pass.
944
+ It is meaningful to know this in advance in order to assess whether before
945
+ merging `bu_values` and `p_params` tensors any alignment is needed.
946
+ Default is `False`.
947
+ topdown_no_padding_mode: bool, optional
948
+ Whether padding is used in the different layers of the top-down pass.
949
+ It is meaningful to know this in advance in order to assess whether before
950
+ merging `bu_values` and `p_params` tensors any alignment is needed.
951
+ The same information is also needed in handling the skip connections between
952
+ top-down layers. Default is `False`.
953
+ retain_spatial_dims: bool, optional
954
+ If `True`, the size of Encoder's latent space is kept to `input_image_shape` within the topdown layer.
955
+ This implies that the oput spatial size equals the input spatial size.
956
+ To achieve this, we centercrop the intermediate representation.
957
+ Default is `False`.
958
+ restricted_kl: bool, optional
959
+ Whether to compute the restricted version of KL Divergence.
960
+ See `NormalStochasticBlock2d` module for more information about its computation.
961
+ Default is `False`.
962
+ vanilla_latent_hw: Iterable[int], optional
963
+ The shape of the latent tensor used for prediction (i.e., it influences the computation of restricted KL).
964
+ Default is `None`.
965
+ non_stochastic_version: bool, optional
966
+ Whether to replace the stochastic layer that samples a latent variable from the latent distribiution with
967
+ a non-stochastic layer that simply drwas a sample as the mode of the latent distribution.
968
+ Default is `False`.
969
+ input_image_shape: Tuple[int, int], optionalut
970
+ The shape of the input image tensor.
971
+ When `retain_spatial_dims` is set to `True`, this is used to ensure that the shape of this layer
972
+ output has the same shape as the input. Default is `None`.
973
+ normalize_latent_factor: float, optional
974
+ A factor used to normalize the latent tensors `q_params`.
975
+ Specifically, normalization is done by dividing the latent tensor by this factor.
976
+ Default is 1.0.
977
+ conv2d_bias: bool, optional
978
+ Whether to use bias term is the convolutional blocks of this layer.
979
+ Default is `True`.
980
+ stochastic_use_naive_exponential: bool, optional
981
+ If `False`, in the NormalStochasticBlock2d exponentials are computed according
982
+ to the alternative definition provided by `StableExponential` class.
983
+ This should improve numerical stability in the training process.
984
+ Default is `False`.
985
+ """
986
+ super().__init__()
987
+
988
+ self.is_top_layer = is_top_layer
989
+ self.z_dim = z_dim
990
+ self.stochastic_skip = stochastic_skip
991
+ self.learn_top_prior = learn_top_prior
992
+ self.analytical_kl = analytical_kl
993
+ self.bottomup_no_padding_mode = bottomup_no_padding_mode
994
+ self.topdown_no_padding_mode = topdown_no_padding_mode
995
+ self.retain_spatial_dims = retain_spatial_dims
996
+ self.latent_shape = input_image_shape if self.retain_spatial_dims else None
997
+ self.non_stochastic_version = non_stochastic_version
998
+ self.normalize_latent_factor = normalize_latent_factor
999
+ self._vanilla_latent_hw = vanilla_latent_hw
1000
+
1001
+ # Define top layer prior parameters, possibly learnable
1002
+ if is_top_layer:
1003
+ self.top_prior_params = nn.Parameter(
1004
+ torch.zeros(top_prior_param_shape), requires_grad=learn_top_prior
1005
+ )
1006
+
1007
+ # Downsampling steps left to do in this layer
1008
+ dws_left = downsampling_steps
1009
+
1010
+ # Define deterministic top-down block, which is a sequence of deterministic
1011
+ # residual blocks with (optional) downsampling.
1012
+ block_list = []
1013
+ for _ in range(n_res_blocks):
1014
+ do_resample = False
1015
+ if dws_left > 0:
1016
+ do_resample = True
1017
+ dws_left -= 1
1018
+ block_list.append(
1019
+ TopDownDeterministicResBlock(
1020
+ c_in=n_filters,
1021
+ c_out=n_filters,
1022
+ nonlin=nonlin,
1023
+ upsample=do_resample,
1024
+ batchnorm=batchnorm,
1025
+ dropout=dropout,
1026
+ res_block_type=res_block_type,
1027
+ res_block_kernel=res_block_kernel,
1028
+ skip_padding=res_block_skip_padding,
1029
+ gated=gated,
1030
+ conv2d_bias=conv2d_bias,
1031
+ groups=groups,
1032
+ )
1033
+ )
1034
+ self.deterministic_block = nn.Sequential(*block_list)
1035
+
1036
+ # Define stochastic block with 2D convolutions
1037
+ if self.non_stochastic_version:
1038
+ self.stochastic = NonStochasticBlock2d(
1039
+ c_in=n_filters,
1040
+ c_vars=z_dim,
1041
+ c_out=n_filters,
1042
+ transform_p_params=(not is_top_layer),
1043
+ groups=groups,
1044
+ conv2d_bias=conv2d_bias,
1045
+ )
1046
+ else:
1047
+ self.stochastic = NormalStochasticBlock2d(
1048
+ c_in=n_filters,
1049
+ c_vars=z_dim,
1050
+ c_out=n_filters,
1051
+ transform_p_params=(not is_top_layer),
1052
+ vanilla_latent_hw=vanilla_latent_hw,
1053
+ restricted_kl=restricted_kl,
1054
+ use_naive_exponential=stochastic_use_naive_exponential,
1055
+ )
1056
+
1057
+ if not is_top_layer:
1058
+ # Merge layer: it combines bottom-up inference and top-down
1059
+ # generative outcomes to give posterior parameters
1060
+ self.merge = MergeLayer(
1061
+ channels=n_filters,
1062
+ merge_type=merge_type,
1063
+ nonlin=nonlin,
1064
+ batchnorm=batchnorm,
1065
+ dropout=dropout,
1066
+ res_block_type=res_block_type,
1067
+ res_block_kernel=res_block_kernel,
1068
+ conv2d_bias=conv2d_bias,
1069
+ )
1070
+
1071
+ # Skip connection that goes around the stochastic top-down layer
1072
+ if stochastic_skip:
1073
+ self.skip_connection_merger = SkipConnectionMerger(
1074
+ channels=n_filters,
1075
+ nonlin=nonlin,
1076
+ batchnorm=batchnorm,
1077
+ dropout=dropout,
1078
+ res_block_type=res_block_type,
1079
+ merge_type=merge_type,
1080
+ conv2d_bias=conv2d_bias,
1081
+ res_block_kernel=res_block_kernel,
1082
+ res_block_skip_padding=res_block_skip_padding,
1083
+ )
1084
+
1085
+ # print(f'[{self.__class__.__name__}] normalize_latent_factor:{self.normalize_latent_factor}')
1086
+
1087
+ def sample_from_q(
1088
+ self,
1089
+ input_: torch.Tensor,
1090
+ bu_value: torch.Tensor,
1091
+ var_clip_max: float = None,
1092
+ mask: torch.Tensor = None,
1093
+ ) -> torch.Tensor:
1094
+ """
1095
+ This method computes the latent inference distribution q(z_i|z_{i+1}) amd samples a latent tensor from it.
1096
+
1097
+ Parameters
1098
+ ----------
1099
+ input_: torch.Tensor
1100
+ The input tensor to the layer, which is the output of the top-down layer above.
1101
+ bu_value: torch.Tensor
1102
+ The tensor defining the parameters /mu_q and /sigma_q computed during the bottom-up deterministic pass
1103
+ at the correspondent hierarchical layer.
1104
+ var_clip_max: float, optional
1105
+ The maximum value reachable by the log-variance of the latent distribtion.
1106
+ Values exceeding this threshold are clipped. Default is `None`.
1107
+ mask: Union[None, torch.Tensor], optional
1108
+ A tensor that is used to mask the sampled latent tensor. Default is `None`.
1109
+ """
1110
+ if self.is_top_layer: # In top layer, we don't merge bu_value with p_params
1111
+ q_params = bu_value
1112
+ else:
1113
+ # NOTE: Here the assumption is that the vampprior is only applied on the top layer.
1114
+ n_img_prior = None
1115
+ p_params = self.get_p_params(input_, n_img_prior)
1116
+ q_params = self.merge(bu_value, p_params)
1117
+
1118
+ sample = self.stochastic.sample_from_q(q_params, var_clip_max)
1119
+
1120
+ if mask:
1121
+ return sample[mask]
1122
+
1123
+ return sample
1124
+
1125
+ def get_p_params(
1126
+ self,
1127
+ input_: torch.Tensor,
1128
+ n_img_prior: int,
1129
+ ) -> torch.Tensor:
1130
+ """
1131
+ This method returns the parameters of the prior distribution p(z_i|z_{i+1}) for the latent tensor
1132
+ depending on the hierarchical level of the layer and other specific conditions.
1133
+
1134
+ Parameters
1135
+ ----------
1136
+ input_: torch.Tensor
1137
+ The input tensor to the layer, which is the output of the top-down layer above.
1138
+ n_img_prior: int
1139
+ The number of images to be generated from the unconditional prior distribution p(z_L).
1140
+ """
1141
+ p_params = None
1142
+
1143
+ # If top layer, define p_params as the ones of the prior p(z_L)
1144
+ if self.is_top_layer:
1145
+ p_params = self.top_prior_params
1146
+
1147
+ # Sample specific number of images by expanding the prior
1148
+ if n_img_prior is not None:
1149
+ p_params = p_params.expand(n_img_prior, -1, -1, -1)
1150
+
1151
+ # Else the input from the layer above is p_params itself
1152
+ else:
1153
+ p_params = input_
1154
+
1155
+ return p_params
1156
+
1157
+ def align_pparams_buvalue(
1158
+ self, p_params: torch.Tensor, bu_value: torch.Tensor
1159
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1160
+ """
1161
+ In case the padding is not used either (or both) in encoder and decoder, we could have a shape mismatch
1162
+ in the spatial dimensions (usually, dim=2 & dim=3).
1163
+ This method performs a centercrop to ensure that both remain aligned.
1164
+
1165
+ Parameters
1166
+ ----------
1167
+ p_params: torch.Tensor
1168
+ The tensor defining the parameters /mu_p and /sigma_p for the latent distribution p(z_i|z_{i+1}).
1169
+ bu_value: torch.Tensor
1170
+ The tensor defining the parameters /mu_q and /sigma_q computed during the bottom-up deterministic pass
1171
+ at the correspondent hierarchical layer.
1172
+ """
1173
+ if bu_value.shape[-2:] != p_params.shape[-2:]:
1174
+ assert self.bottomup_no_padding_mode is True
1175
+ if self.topdown_no_padding_mode is False:
1176
+ assert bu_value.shape[-1] > p_params.shape[-1]
1177
+ bu_value = F.center_crop(bu_value, p_params.shape[-2:])
1178
+ else:
1179
+ if bu_value.shape[-1] > p_params.shape[-1]:
1180
+ bu_value = F.center_crop(bu_value, p_params.shape[-2:])
1181
+ else:
1182
+ p_params = F.center_crop(p_params, bu_value.shape[-2:])
1183
+ return p_params, bu_value
1184
+
1185
+ def forward(
1186
+ self,
1187
+ input_: torch.Tensor = None,
1188
+ skip_connection_input: torch.Tensor = None,
1189
+ inference_mode: bool = False,
1190
+ bu_value: torch.Tensor = None,
1191
+ n_img_prior: int = None,
1192
+ forced_latent: torch.Tensor = None,
1193
+ use_mode: bool = False,
1194
+ force_constant_output: bool = False,
1195
+ mode_pred: bool = False,
1196
+ use_uncond_mode: bool = False,
1197
+ var_clip_max: float = None,
1198
+ ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
1199
+ """
1200
+ Parameters
1201
+ ----------
1202
+ input_: torch.Tensor, optional
1203
+ The input tensor to the layer, which is the output of the top-down layer above.
1204
+ Default is `None`.
1205
+ skip_connection_input: torch.Tensor, optional
1206
+ The tensor brought by the skip connection between the current and the previous top-down layer.
1207
+ Default is `None`.
1208
+ inference_mode: bool, optional
1209
+ Whether the layer is in inference mode. See NOTE 2 in class description for more info.
1210
+ Default is `False`.
1211
+ bu_value: torch.Tensor, optional
1212
+ The tensor defining the parameters /mu_q and /sigma_q computed during the bottom-up deterministic pass
1213
+ at the correspondent hierarchical layer. Default is `None`.
1214
+ n_img_prior: int, optional
1215
+ The number of images to be generated from the unconditional prior distribution p(z_L).
1216
+ Default is `None`.
1217
+ forced_latent: torch.Tensor, optional
1218
+ A pre-defined latent tensor. If it is not `None`, than it is used as the actual latent tensor and,
1219
+ hence, sampling does not happen. Default is `None`.
1220
+ use_mode: bool, optional
1221
+ Wheteher the latent tensor should be set as the latent distribution mode.
1222
+ In the case of Gaussian, the mode coincides with the mean of the distribution.
1223
+ Default is `False`.
1224
+ force_constant_output: bool, optional
1225
+ Whether to copy the first sample (and rel. distrib parameters) over the whole batch.
1226
+ This is used when doing experiment from the prior - q is not used.
1227
+ Default is `False`.
1228
+ mode_pred: bool, optional
1229
+ Whether the model is in prediction mode. Default is `False`.
1230
+ use_uncond_mode: bool, optional
1231
+ Whether to use the uncoditional distribution p(z) to sample latents in prediction mode.
1232
+ var_clip_max: float
1233
+ The maximum value reachable by the log-variance of the latent distribtion.
1234
+ Values exceeding this threshold are clipped.
1235
+ """
1236
+ # Check consistency of arguments
1237
+ inputs_none = input_ is None and skip_connection_input is None
1238
+ if self.is_top_layer and not inputs_none:
1239
+ raise ValueError("In top layer, inputs should be None")
1240
+
1241
+ p_params = self.get_p_params(input_, n_img_prior)
1242
+
1243
+ # Get the parameters for the latent distribution to sample from
1244
+ if inference_mode:
1245
+ if self.is_top_layer:
1246
+ q_params = bu_value
1247
+ if mode_pred is False:
1248
+ p_params, bu_value = self.align_pparams_buvalue(p_params, bu_value)
1249
+ else:
1250
+ if use_uncond_mode:
1251
+ q_params = p_params
1252
+ else:
1253
+ p_params, bu_value = self.align_pparams_buvalue(p_params, bu_value)
1254
+ q_params = self.merge(bu_value, p_params)
1255
+ # In generative mode, q is not used
1256
+ else:
1257
+ q_params = None
1258
+
1259
+ # NOTE: Sampling is done either from q(z_i | z_{i+1}, x) or p(z_i | z_{i+1})
1260
+ # depending on the mode (hence, in practice, by checking whether q_params is None).
1261
+
1262
+ # Normalization of latent space parameters:
1263
+ # it is done, purely for stablity. See Very deep VAEs generalize autoregressive models.
1264
+ if self.normalize_latent_factor:
1265
+ q_params = q_params / self.normalize_latent_factor
1266
+
1267
+ # Sample (and process) a latent tensor in the stochastic layer
1268
+ x, data_stoch = self.stochastic(
1269
+ p_params=p_params,
1270
+ q_params=q_params,
1271
+ forced_latent=forced_latent,
1272
+ use_mode=use_mode,
1273
+ force_constant_output=force_constant_output,
1274
+ analytical_kl=self.analytical_kl,
1275
+ mode_pred=mode_pred,
1276
+ use_uncond_mode=use_uncond_mode,
1277
+ var_clip_max=var_clip_max,
1278
+ )
1279
+
1280
+ # Merge skip connection from previous layer
1281
+ if self.stochastic_skip and not self.is_top_layer:
1282
+ if self.topdown_no_padding_mode is True:
1283
+ # If no padding is done in the current top-down pass, there may be a shape mismatch between current tensor and skip connection input.
1284
+ # As an example, if the output of last TopDownLayer was of size 64*64, due to lack of padding in the current layer, the current tensor
1285
+ # might become different in shape, say 60*60.
1286
+ # In order to avoid shape mismatch, we do central crop of the skip connection input.
1287
+ skip_connection_input = F.center_crop(
1288
+ skip_connection_input, x.shape[-2:]
1289
+ )
1290
+
1291
+ x = self.skip_connection_merger(x, skip_connection_input)
1292
+
1293
+ # Save activation before residual block as it can be the skip connection input in the next layer
1294
+ x_pre_residual = x
1295
+
1296
+ if self.retain_spatial_dims:
1297
+ # when we don't want to do padding in topdown as well, we need to spare some boundary pixels which would be used up.
1298
+ extra_len = (self.topdown_no_padding_mode is True) * 3
1299
+
1300
+ # this means that x should be of the same size as config.data.image_size. So, we have to centercrop by a factor of 2 at this point.
1301
+ # assert x.shape[-1] >= self.latent_shape[-1] // 2 + extra_len
1302
+ # we assume that one topdown layer will have exactly one upscaling layer.
1303
+ new_latent_shape = (
1304
+ self.latent_shape[0] // 2 + extra_len,
1305
+ self.latent_shape[1] // 2 + extra_len,
1306
+ )
1307
+
1308
+ # If the LC is not applied on all layers, then this can happen.
1309
+ if x.shape[-1] > new_latent_shape[-1]:
1310
+ x = F.center_crop(x, new_latent_shape)
1311
+
1312
+ # Last top-down block (sequence of residual blocks)
1313
+ x = self.deterministic_block(x)
1314
+
1315
+ if self.topdown_no_padding_mode:
1316
+ x = F.center_crop(x, self.latent_shape)
1317
+
1318
+ # Save some metrics that will be used in the loss computation
1319
+ keys = [
1320
+ "z",
1321
+ "kl_samplewise",
1322
+ "kl_samplewise_restricted",
1323
+ "kl_spatial",
1324
+ "kl_channelwise",
1325
+ # 'logprob_p',
1326
+ "logprob_q",
1327
+ "qvar_max",
1328
+ ]
1329
+ data = {k: data_stoch.get(k, None) for k in keys}
1330
+ data["q_mu"] = None
1331
+ data["q_lv"] = None
1332
+ if data_stoch["q_params"] is not None:
1333
+ q_mu, q_lv = data_stoch["q_params"]
1334
+ data["q_mu"] = q_mu
1335
+ data["q_lv"] = q_lv
1336
+
1337
+ return x, x_pre_residual, data
1338
+
1339
+
1340
+ class NormalStochasticBlock2d(nn.Module):
1341
+ """
1342
+ Stochastic block used in the Top-Down inference pass.
1343
+
1344
+ Algorithm:
1345
+ - map input parameters to q(z) and (optionally) p(z) via convolution
1346
+ - sample a latent tensor z ~ q(z)
1347
+ - feed z to convolution and return.
1348
+
1349
+ NOTE 1:
1350
+ If parameters for q are not given, sampling is done from p(z).
1351
+
1352
+ NOTE 2:
1353
+ The restricted KL divergence is obtained by first computing the element-wise KL divergence
1354
+ (i.e., the KL computed for each element of the latent tensors). Then, the restricted version
1355
+ is computed by summing over the channels and the spatial dimensions associated only to the
1356
+ portion of the latent tensor that is used for prediction.
1357
+ """
1358
+
1359
+ def __init__(
1360
+ self,
1361
+ c_in: int,
1362
+ c_vars: int,
1363
+ c_out: int,
1364
+ kernel: int = 3,
1365
+ transform_p_params: bool = True,
1366
+ vanilla_latent_hw: int = None,
1367
+ restricted_kl: bool = False,
1368
+ use_naive_exponential: bool = False,
1369
+ ):
1370
+ """
1371
+ Parameters
1372
+ ----------
1373
+ c_in: int
1374
+ The number of channels of the input tensor.
1375
+ c_vars: int
1376
+ The number of channels of the latent space tensor.
1377
+ c_out: int
1378
+ The output of the stochastic layer.
1379
+ Note that this is different from the sampled latent z.
1380
+ kernel: int, optional
1381
+ The size of the kernel used in convolutional layers.
1382
+ Default is 3.
1383
+ transform_p_params: bool, optional
1384
+ Whether a transformation should be applied to the `p_params` tensor.
1385
+ The transformation consists in a 2D convolution ()`conv_in_p()`) that
1386
+ maps the input to a larger number of channels.
1387
+ Default is `True`.
1388
+ vanilla_latent_hw: int, optional
1389
+ The shape of the latent tensor used for prediction (i.e., it influences the computation of restricted KL).
1390
+ Default is `None`.
1391
+ restricted_kl: bool, optional
1392
+ Whether to compute the restricted version of KL Divergence.
1393
+ See NOTE 2 for more information about its computation.
1394
+ Default is `False`.
1395
+ use_naive_exponential: bool, optional
1396
+ If `False`, exponentials are computed according to the alternative definition
1397
+ provided by `StableExponential` class. This should improve numerical stability
1398
+ in the training process. Default is `False`.
1399
+ """
1400
+ super().__init__()
1401
+ assert kernel % 2 == 1
1402
+ pad = kernel // 2
1403
+ self.transform_p_params = transform_p_params
1404
+ self.c_in = c_in
1405
+ self.c_out = c_out
1406
+ self.c_vars = c_vars
1407
+ self._use_naive_exponential = use_naive_exponential
1408
+ self._vanilla_latent_hw = vanilla_latent_hw
1409
+ self._restricted_kl = restricted_kl
1410
+
1411
+ if transform_p_params:
1412
+ self.conv_in_p = nn.Conv2d(c_in, 2 * c_vars, kernel, padding=pad)
1413
+ self.conv_in_q = nn.Conv2d(c_in, 2 * c_vars, kernel, padding=pad)
1414
+ self.conv_out = nn.Conv2d(c_vars, c_out, kernel, padding=pad)
1415
+
1416
+ # def forward_swapped(self, p_params, q_mu, q_lv):
1417
+ #
1418
+ # if self.transform_p_params:
1419
+ # p_params = self.conv_in_p(p_params)
1420
+ # else:
1421
+ # assert p_params.size(1) == 2 * self.c_vars
1422
+ #
1423
+ # # Define p(z)
1424
+ # p_mu, p_lv = p_params.chunk(2, dim=1)
1425
+ # p = Normal(p_mu, (p_lv / 2).exp())
1426
+ #
1427
+ # # Define q(z)
1428
+ # q = Normal(q_mu, (q_lv / 2).exp())
1429
+ # # Sample from q(z)
1430
+ # sampling_distrib = q
1431
+ #
1432
+ # # Generate latent variable (typically by sampling)
1433
+ # z = sampling_distrib.rsample()
1434
+ #
1435
+ # # Output of stochastic layer
1436
+ # out = self.conv_out(z)
1437
+ #
1438
+ # data = {
1439
+ # 'z': z, # sampled variable at this layer (batch, ch, h, w)
1440
+ # 'p_params': p_params, # (b, ch, h, w) where b is 1 or batch size
1441
+ # }
1442
+ # return out, data
1443
+
1444
+ def get_z(
1445
+ self,
1446
+ sampling_distrib: torch.distributions.normal.Normal,
1447
+ forced_latent: torch.Tensor,
1448
+ use_mode: bool,
1449
+ mode_pred: bool,
1450
+ use_uncond_mode: bool,
1451
+ ) -> torch.Tensor:
1452
+ """
1453
+ This method enables to sample a latent tensor given the distribution to sample from.
1454
+
1455
+ Latent variable can be obtained is several ways:
1456
+ - Sampled from the (Gaussian) latent distribution.
1457
+ - Taken as a pre-defined forced latent.
1458
+ - Taken as the mode (mean) of the latent distribution.
1459
+ - In prediction mode (`mode_pred==True`), can be either sample or taken as the distribution mode.
1460
+
1461
+ Parameters
1462
+ ----------
1463
+ sampling_distrib: torch.distributions.normal.Normal
1464
+ The Gaussian distribution from which latent tensor is sampled.
1465
+ forced_latent: torch.Tensor
1466
+ A pre-defined latent tensor. If it is not `None`, than it is used as the actual latent tensor and,
1467
+ hence, sampling does not happen.
1468
+ use_mode: bool
1469
+ Wheteher the latent tensor should be set as the latent distribution mode.
1470
+ In the case of Gaussian, the mode coincides with the mean of the distribution.
1471
+ mode_pred: bool
1472
+ Whether the model is prediction mode.
1473
+ use_uncond_mode: bool
1474
+ Whether to use the uncoditional distribution p(z) to sample latents in prediction mode.
1475
+ """
1476
+ if forced_latent is None:
1477
+ if use_mode:
1478
+ z = sampling_distrib.mean
1479
+ else:
1480
+ if mode_pred:
1481
+ if use_uncond_mode:
1482
+ z = sampling_distrib.mean
1483
+ else:
1484
+ z = sampling_distrib.rsample()
1485
+ else:
1486
+ z = sampling_distrib.rsample()
1487
+ else:
1488
+ z = forced_latent
1489
+ return z
1490
+
1491
+ def sample_from_q(
1492
+ self, q_params: torch.Tensor, var_clip_max: float
1493
+ ) -> torch.Tensor:
1494
+ """
1495
+ Given an input parameter tensor defining q(z),
1496
+ it processes it by calling `process_q_params()` method and
1497
+ sample a latent tensor from the resulting distribution.
1498
+
1499
+ Parameters
1500
+ ----------
1501
+ q_params: torch.Tensor
1502
+ The input tensor to be processed.
1503
+ var_clip_max: float
1504
+ The maximum value reachable by the log-variance of the latent distribtion.
1505
+ Values exceeding this threshold are clipped.
1506
+ """
1507
+ _, _, q = self.process_q_params(q_params, var_clip_max)
1508
+ return q.rsample()
1509
+
1510
+ def compute_kl_metrics(
1511
+ self,
1512
+ p: torch.distributions.normal.Normal,
1513
+ p_params: torch.Tensor,
1514
+ q: torch.distributions.normal.Normal,
1515
+ q_params: torch.Tensor,
1516
+ mode_pred: bool,
1517
+ analytical_kl: bool,
1518
+ z: torch.Tensor,
1519
+ ) -> Dict[str, torch.Tensor]:
1520
+ """
1521
+ Compute KL (analytical or MC estimate) and then process it, extracting composed versions of the metric.
1522
+ Specifically, the different versions of the KL loss terms are:
1523
+ - `kl_elementwise`: KL term for each single element of the latent tensor [Shape: (batch, ch, h, w)].
1524
+ - `kl_samplewise`: KL term associated to each sample in the batch [Shape: (batch, )].
1525
+ - `kl_samplewise_restricted`: KL term only associated to the portion of the latent tensor that is
1526
+ used for prediction and summed over channel and spatial dimensions [Shape: (batch, )].
1527
+ - `kl_channelwise`: KL term associated to each sample and each channel [Shape: (batch, ch, )].
1528
+ - `kl_spatial`: KL term summed over the channels, i.e., retaining the spatial dimensions [Shape: (batch, h, w)]
1529
+
1530
+ Parameters
1531
+ ----------
1532
+ p: torch.distributions.normal.Normal
1533
+ The prior generative distribution p(z_i|z_{i+1}) (or p(z_L)).
1534
+ p_params: torch.Tensor
1535
+ The parameters of the prior generative distribution.
1536
+ q: torch.distributions.normal.Normal
1537
+ The inference distribution q(z_i|z_{i+1}) (or q(z_L|x)).
1538
+ q_params: torch.Tensor
1539
+ The parameters of the inference distribution.
1540
+ mode_pred: bool
1541
+ Whether the model is in prediction mode.
1542
+ analytical_kl: bool
1543
+ Whether to compute the KL divergence analytically or using Monte Carlo estimation.
1544
+ z: torch.Tensor
1545
+ The sampled latent tensor.
1546
+ """
1547
+ kl_samplewise_restricted = None
1548
+
1549
+ if mode_pred is False: # if not in prediction mode
1550
+ # KL term for each single element of the latent tensor [Shape: (batch, ch, h, w)]
1551
+ if analytical_kl:
1552
+ kl_elementwise = kl_divergence(q, p)
1553
+ else:
1554
+ kl_elementwise = kl_normal_mc(z, p_params, q_params)
1555
+
1556
+ # KL term only associated to the portion of the latent tensor that is used for prediction and
1557
+ # summed over channel and spatial dimensions. [Shape: (batch, )]
1558
+ # NOTE: vanilla_latent_hw is the shape of the latent tensor used for prediction, hence
1559
+ # the restriction has shape [Shape: (batch, ch, vanilla_latent_hw[0], vanilla_latent_hw[1])]
1560
+ if self._restricted_kl:
1561
+ pad = (kl_elementwise.shape[-1] - self._vanilla_latent_hw) // 2
1562
+ assert pad > 0, "Disable restricted kl since there is no restriction."
1563
+ tmp = kl_elementwise[..., pad:-pad, pad:-pad]
1564
+ kl_samplewise_restricted = tmp.sum((1, 2, 3))
1565
+
1566
+ # KL term associated to each sample in the batch [Shape: (batch, )]
1567
+ kl_samplewise = kl_elementwise.sum((1, 2, 3))
1568
+
1569
+ # KL term associated to each sample and each channel [Shape: (batch, ch, )]
1570
+ kl_channelwise = kl_elementwise.sum((2, 3))
1571
+
1572
+ # KL term summed over the channels, i.e., retaining the spatial dimensions [Shape: (batch, h, w)]
1573
+ kl_spatial = kl_elementwise.sum(1)
1574
+ else: # if predicting, no need to compute KL
1575
+ kl_elementwise = kl_samplewise = kl_spatial = kl_channelwise = None
1576
+
1577
+ kl_dict = {
1578
+ "kl_elementwise": kl_elementwise, # (batch, ch, h, w)
1579
+ "kl_samplewise": kl_samplewise, # (batch, )
1580
+ "kl_samplewise_restricted": kl_samplewise_restricted, # (batch, )
1581
+ "kl_channelwise": kl_channelwise, # (batch, ch)
1582
+ "kl_spatial": kl_spatial, # (batch, h, w)
1583
+ }
1584
+ return kl_dict
1585
+
1586
+ def process_p_params(
1587
+ self, p_params: torch.Tensor, var_clip_max: float
1588
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.distributions.normal.Normal]:
1589
+ """
1590
+ Process the input parameters to get the prior distribution p(z_i|z_{i+1}) (or p(z_L)).
1591
+
1592
+ Processing consists in:
1593
+ - (optionally) 2D convolution on the input tensor to increase number of channels.
1594
+ - split the resulting tensor into two chunks, the mean and the log-variance.
1595
+ - (optionally) clip the log-variance to an upper threshold.
1596
+ - define the normal distribution p(z) given the parameter tensors above.
1597
+
1598
+ Parameters
1599
+ ----------
1600
+ p_params: torch.Tensor
1601
+ The input tensor to be processed.
1602
+ var_clip_max: float
1603
+ The maximum value reachable by the log-variance of the latent distribtion.
1604
+ Values exceeding this threshold are clipped.
1605
+ """
1606
+ if self.transform_p_params:
1607
+ p_params = self.conv_in_p(p_params)
1608
+ else:
1609
+ assert p_params.size(1) == 2 * self.c_vars
1610
+
1611
+ # Define p(z)
1612
+ p_mu, p_lv = p_params.chunk(2, dim=1)
1613
+ if var_clip_max is not None:
1614
+ p_lv = torch.clip(p_lv, max=var_clip_max)
1615
+
1616
+ p_mu = StableMean(p_mu)
1617
+ p_lv = StableLogVar(p_lv, enable_stable=not self._use_naive_exponential)
1618
+ p = Normal(p_mu.get(), p_lv.get_std())
1619
+ return p_mu, p_lv, p
1620
+
1621
+ def process_q_params(
1622
+ self, q_params: torch.Tensor, var_clip_max: float, allow_oddsizes: bool = False
1623
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.distributions.normal.Normal]:
1624
+ """
1625
+ Process the input parameters to get the inference distribution q(z_i|z_{i+1}) (or q(z|x)).
1626
+
1627
+ Processing consists in:
1628
+ - 2D convolution on the input tensor to increase number of channels.
1629
+ - split the resulting tensor into two chunks, the mean and the log-variance.
1630
+ - (optionally) clip the log-variance to an upper threshold.
1631
+ - (optionally) crop the resulting tensors to ensure that the last spatial dimension is even.
1632
+ - define the normal distribution q(z) given the parameter tensors above.
1633
+
1634
+ Parameters
1635
+ ----------
1636
+ p_params: torch.Tensor
1637
+ The input tensor to be processed.
1638
+ var_clip_max: float
1639
+ The maximum value reachable by the log-variance of the latent distribtion.
1640
+ Values exceeding this threshold are clipped.
1641
+ """
1642
+ q_params = self.conv_in_q(q_params)
1643
+
1644
+ q_mu, q_lv = q_params.chunk(2, dim=1)
1645
+ if var_clip_max is not None:
1646
+ q_lv = torch.clip(q_lv, max=var_clip_max)
1647
+
1648
+ if q_mu.shape[-1] % 2 == 1 and allow_oddsizes is False:
1649
+ q_mu = F.center_crop(q_mu, q_mu.shape[-1] - 1)
1650
+ q_lv = F.center_crop(q_lv, q_lv.shape[-1] - 1)
1651
+ # clip_start = np.random.rand() > 0.5
1652
+ # q_mu = q_mu[:, :, 1:, 1:] if clip_start else q_mu[:, :, :-1, :-1]
1653
+ # q_lv = q_lv[:, :, 1:, 1:] if clip_start else q_lv[:, :, :-1, :-1]
1654
+
1655
+ q_mu = StableMean(q_mu)
1656
+ q_lv = StableLogVar(q_lv, enable_stable=not self._use_naive_exponential)
1657
+ q = Normal(q_mu.get(), q_lv.get_std())
1658
+ return q_mu, q_lv, q
1659
+
1660
+ def forward(
1661
+ self,
1662
+ p_params: torch.Tensor,
1663
+ q_params: torch.Tensor = None,
1664
+ forced_latent: torch.Tensor = None,
1665
+ use_mode: bool = False,
1666
+ force_constant_output: bool = False,
1667
+ analytical_kl: bool = False,
1668
+ mode_pred: bool = False,
1669
+ use_uncond_mode: bool = False,
1670
+ var_clip_max: float = None,
1671
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
1672
+ """
1673
+ Parameters
1674
+ ----------
1675
+ p_params: torch.Tensor
1676
+ The output tensor of the top-down layer above (i.e., mu_{p,i+1}, sigma_{p,i+1}).
1677
+ q_params: torch.Tensor, optional
1678
+ The tensor resulting from merging the bu_value tensor at the same hierarchical level
1679
+ from the bottom-up pass and the `p_params` tensor. Default is `None`.
1680
+ forced_latent: torch.Tensor, optional
1681
+ A pre-defined latent tensor. If it is not `None`, than it is used as the actual latent
1682
+ tensor and, hence, sampling does not happen. Default is `None`.
1683
+ use_mode: bool, optional
1684
+ Wheteher the latent tensor should be set as the latent distribution mode.
1685
+ In the case of Gaussian, the mode coincides with the mean of the distribution.
1686
+ Default is `False`.
1687
+ force_constant_output: bool, optional
1688
+ Whether to copy the first sample (and rel. distrib parameters) over the whole batch.
1689
+ This is used when doing experiment from the prior - q is not used.
1690
+ Default is `False`.
1691
+ analytical_kl: bool, optional
1692
+ Whether to compute the KL divergence analytically or using Monte Carlo estimation.
1693
+ Default is `False`.
1694
+ mode_pred: bool, optional
1695
+ Whether the model is in prediction mode. Default is `False`.
1696
+ use_uncond_mode: bool, optional
1697
+ Whether to use the uncoditional distribution p(z) to sample latents in prediction mode.
1698
+ Default is `False`.
1699
+ var_clip_max: float, optional
1700
+ The maximum value reachable by the log-variance of the latent distribtion.
1701
+ Values exceeding this threshold are clipped. Default is `None`.
1702
+ """
1703
+ debug_qvar_max = 0
1704
+
1705
+ # Check sampling options consistency
1706
+ assert (forced_latent is None) or (not use_mode)
1707
+
1708
+ # Get generative distribution p(z_i|z_{i+1})
1709
+ p_mu, p_lv, p = self.process_p_params(p_params, var_clip_max)
1710
+ p_params = (p_mu, p_lv)
1711
+
1712
+ if q_params is not None:
1713
+ # Get inference distribution q(z_i|z_{i+1})
1714
+ # NOTE: At inference time, don't centercrop the q_params even if they are odd in size.
1715
+ q_mu, q_lv, q = self.process_q_params(
1716
+ q_params, var_clip_max, allow_oddsizes=mode_pred is True
1717
+ )
1718
+ q_params = (q_mu, q_lv)
1719
+ sampling_distrib = q
1720
+ debug_qvar_max = torch.max(q_lv.get())
1721
+
1722
+ # Centercrop p_params so that their size matches the one of q_params
1723
+ q_size = q_mu.get().shape[-1]
1724
+ if p_mu.get().shape[-1] != q_size and mode_pred is False:
1725
+ p_mu.centercrop_to_size(q_size)
1726
+ p_lv.centercrop_to_size(q_size)
1727
+ else:
1728
+ sampling_distrib = p
1729
+
1730
+ # Sample latent variable
1731
+ z = self.get_z(
1732
+ sampling_distrib, forced_latent, use_mode, mode_pred, use_uncond_mode
1733
+ )
1734
+
1735
+ # Copy one sample (and distrib parameters) over the whole batch.
1736
+ # This is used when doing experiment from the prior - q is not used.
1737
+ if force_constant_output:
1738
+ z = z[0:1].expand_as(z).clone()
1739
+ p_params = (
1740
+ p_params[0][0:1].expand_as(p_params[0]).clone(),
1741
+ p_params[1][0:1].expand_as(p_params[1]).clone(),
1742
+ )
1743
+
1744
+ # Pass the sampled latent througn the output convolutional layer of stochastic block
1745
+ out = self.conv_out(z)
1746
+
1747
+ # Compute log p(z)# NOTE: disabling its computation.
1748
+ # if mode_pred is False:
1749
+ # logprob_p = p.log_prob(z).sum((1, 2, 3))
1750
+ # else:
1751
+ # logprob_p = None
1752
+
1753
+ if q_params is not None:
1754
+ # Compute log q(z)
1755
+ logprob_q = q.log_prob(z).sum((1, 2, 3))
1756
+ # Compute KL divergence metrics
1757
+ kl_dict = self.compute_kl_metrics(
1758
+ p, p_params, q, q_params, mode_pred, analytical_kl, z
1759
+ )
1760
+ else:
1761
+ kl_dict = {}
1762
+ logprob_q = None
1763
+
1764
+ # Store meaningful quantities to use them in following layers
1765
+ data = kl_dict
1766
+ data["z"] = z # sampled variable at this layer (batch, ch, h, w)
1767
+ data["p_params"] = p_params # (b, ch, h, w) where b is 1 or batch size
1768
+ data["q_params"] = q_params # (batch, ch, h, w)
1769
+ # data['logprob_p'] = logprob_p # (batch, )
1770
+ data["logprob_q"] = logprob_q # (batch, )
1771
+ data["qvar_max"] = debug_qvar_max
1772
+
1773
+ return out, data
1774
+
1775
+
1776
+ class NonStochasticBlock2d(nn.Module):
1777
+ """
1778
+ Non-stochastic version of the NormalStochasticBlock2d.
1779
+ """
1780
+
1781
+ def __init__(
1782
+ self,
1783
+ c_vars: int,
1784
+ c_in: int,
1785
+ c_out: int,
1786
+ kernel: int = 3,
1787
+ groups: int = 1,
1788
+ conv2d_bias: bool = True,
1789
+ transform_p_params: bool = True,
1790
+ ):
1791
+ """
1792
+ Constructor.
1793
+
1794
+ Parameters
1795
+ ----------
1796
+ c_vars: int
1797
+ The number of channels of the latent space tensor.
1798
+ c_in: int
1799
+ The number of channels of the input tensor.
1800
+ c_out: int
1801
+ The output of the stochastic layer.
1802
+ Note that this is different from the sampled latent z.
1803
+ kernel: int, optional
1804
+ The size of the kernel used in convolutional layers.
1805
+ Default is 3.
1806
+ groups: int, optional
1807
+ The number of groups to consider in the convolutions of this layer.
1808
+ Default is 1.
1809
+ conv2d_bias: bool, optional
1810
+ Whether to use bias term is the convolutional blocks of this layer.
1811
+ Default is `True`.
1812
+ transform_p_params: bool, optional
1813
+ Whether a transformation should be applied to the `p_params` tensor.
1814
+ The transformation consists in a 2D convolution ()`conv_in_p()`) that
1815
+ maps the input to a larger number of channels.
1816
+ Default is `True`.
1817
+ """
1818
+ super().__init__()
1819
+ assert kernel % 2 == 1
1820
+ pad = kernel // 2
1821
+ self.transform_p_params = transform_p_params
1822
+ self.c_in = c_in
1823
+ self.c_out = c_out
1824
+ self.c_vars = c_vars
1825
+
1826
+ if transform_p_params:
1827
+ self.conv_in_p = nn.Conv2d(
1828
+ c_in, 2 * c_vars, kernel, padding=pad, bias=conv2d_bias, groups=groups
1829
+ )
1830
+ self.conv_in_q = nn.Conv2d(
1831
+ c_in, 2 * c_vars, kernel, padding=pad, bias=conv2d_bias, groups=groups
1832
+ )
1833
+ self.conv_out = nn.Conv2d(
1834
+ c_vars, c_out, kernel, padding=pad, bias=conv2d_bias, groups=groups
1835
+ )
1836
+
1837
+ def compute_kl_metrics(
1838
+ self,
1839
+ p: torch.distributions.normal.Normal,
1840
+ p_params: torch.Tensor,
1841
+ q: torch.distributions.normal.Normal,
1842
+ q_params: torch.Tensor,
1843
+ mode_pred: bool,
1844
+ analytical_kl: bool,
1845
+ z: torch.Tensor,
1846
+ ) -> Dict[str, None]:
1847
+ """
1848
+ Compute KL (analytical or MC estimate) and then process it, extracting composed versions of the metric.
1849
+ Specifically, the different versions of the KL loss terms are:
1850
+ - `kl_elementwise`: KL term for each single element of the latent tensor [Shape: (batch, ch, h, w)].
1851
+ - `kl_samplewise`: KL term associated to each sample in the batch [Shape: (batch, )].
1852
+ - `kl_samplewise_restricted`: KL term only associated to the portion of the latent tensor that is
1853
+ used for prediction and summed over channel and spatial dimensions [Shape: (batch, )].
1854
+ - `kl_channelwise`: KL term associated to each sample and each channel [Shape: (batch, ch, )].
1855
+ - `kl_spatial`: # KL term summed over the channels, i.e., retaining the spatial dimensions [Shape: (batch, h, w)]
1856
+
1857
+ NOTE: in this class all the KL metrics are set to `None`.
1858
+
1859
+ Parameters
1860
+ ----------
1861
+ p: torch.distributions.normal.Normal
1862
+ The prior generative distribution p(z_i|z_{i+1}) (or p(z_L)).
1863
+ p_params: torch.Tensor
1864
+ The parameters of the prior generative distribution.
1865
+ q: torch.distributions.normal.Normal
1866
+ The inference distribution q(z_i|z_{i+1}) (or q(z_L|x)).
1867
+ q_params: torch.Tensor
1868
+ The parameters of the inference distribution.
1869
+ mode_pred: bool
1870
+ Whether the model is in prediction mode.
1871
+ analytical_kl: bool
1872
+ Whether to compute the KL divergence analytically or using Monte Carlo estimation.
1873
+ z: torch.Tensor
1874
+ The sampled latent tensor.
1875
+ """
1876
+ kl_dict = {
1877
+ "kl_elementwise": None, # (batch, ch, h, w)
1878
+ "kl_samplewise": None, # (batch, )
1879
+ "kl_spatial": None, # (batch, h, w)
1880
+ "kl_channelwise": None, # (batch, ch)
1881
+ }
1882
+ return kl_dict
1883
+
1884
+ def process_p_params(self, p_params, var_clip_max):
1885
+ if self.transform_p_params:
1886
+ p_params = self.conv_in_p(p_params)
1887
+ else:
1888
+
1889
+ assert (
1890
+ p_params.size(1) == 2 * self.c_vars
1891
+ ), f"{p_params.shape} {self.c_vars}"
1892
+
1893
+ # Define p(z)
1894
+ p_mu, p_lv = p_params.chunk(2, dim=1)
1895
+ return p_mu, None
1896
+
1897
+ def process_q_params(self, q_params, var_clip_max, allow_oddsizes=False):
1898
+ # Define q(z)
1899
+ q_params = self.conv_in_q(q_params)
1900
+ q_mu, q_lv = q_params.chunk(2, dim=1)
1901
+
1902
+ if q_mu.shape[-1] % 2 == 1 and allow_oddsizes is False:
1903
+ q_mu = F.center_crop(q_mu, q_mu.shape[-1] - 1)
1904
+
1905
+ return q_mu, None
1906
+
1907
+ def forward(
1908
+ self,
1909
+ p_params: torch.Tensor,
1910
+ q_params: torch.Tensor = None,
1911
+ forced_latent: Union[None, torch.Tensor] = None,
1912
+ use_mode: bool = False,
1913
+ force_constant_output: bool = False,
1914
+ analytical_kl: bool = False,
1915
+ mode_pred: bool = False,
1916
+ use_uncond_mode: bool = False,
1917
+ var_clip_max: float = None,
1918
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
1919
+ """
1920
+ Parameters
1921
+ ----------
1922
+ p_params: torch.Tensor
1923
+ The output tensor of the top-down layer above (i.e., mu_{p,i+1}, sigma_{p,i+1}).
1924
+ q_params: torch.Tensor, optional
1925
+ The tensor resulting from merging the bu_value tensor at the same hierarchical level
1926
+ from the bottom-up pass and the `p_params` tensor. Default is `None`.
1927
+ forced_latent: torch.Tensor, optional
1928
+ A pre-defined latent tensor. If it is not `None`, than it is used as the actual latent
1929
+ tensor and, hence, sampling does not happen. Default is `None`.
1930
+ use_mode: bool, optional
1931
+ Wheteher the latent tensor should be set as the latent distribution mode.
1932
+ In the case of Gaussian, the mode coincides with the mean of the distribution.
1933
+ Default is `False`.
1934
+ force_constant_output: bool, optional
1935
+ Whether to copy the first sample (and rel. distrib parameters) over the whole batch.
1936
+ This is used when doing experiment from the prior - q is not used.
1937
+ Default is `False`.
1938
+ analytical_kl: bool, optional
1939
+ Whether to compute the KL divergence analytically or using Monte Carlo estimation.
1940
+ Default is `False`.
1941
+ mode_pred: bool, optional
1942
+ Whether the model is in prediction mode. Default is `False`.
1943
+ use_uncond_mode: bool, optional
1944
+ Whether to use the uncoditional distribution p(z) to sample latents in prediction mode.
1945
+ Default is `False`.
1946
+ var_clip_max: float, optional
1947
+ The maximum value reachable by the log-variance of the latent distribtion.
1948
+ Values exceeding this threshold are clipped. Default is `None`.
1949
+ """
1950
+ debug_qvar_max = 0
1951
+ assert (forced_latent is None) or (not use_mode)
1952
+
1953
+ p_mu, _ = self.process_p_params(p_params, var_clip_max)
1954
+
1955
+ p_params = (p_mu, None)
1956
+
1957
+ if q_params is not None:
1958
+ # At inference time, just don't centercrop the q_params even if they are odd in size.
1959
+ q_mu, _ = self.process_q_params(
1960
+ q_params, var_clip_max, allow_oddsizes=mode_pred is True
1961
+ )
1962
+ q_params = (q_mu, None)
1963
+ debug_qvar_max = torch.Tensor([1]).to(q_mu.device)
1964
+ # Sample from q(z)
1965
+ sampling_distrib = q_mu
1966
+ q_size = q_mu.shape[-1]
1967
+ if p_mu.shape[-1] != q_size and mode_pred is False:
1968
+ p_mu.centercrop_to_size(q_size)
1969
+ else:
1970
+ # Sample from p(z)
1971
+ sampling_distrib = p_mu
1972
+
1973
+ # Generate latent variable (typically by sampling)
1974
+ z = sampling_distrib
1975
+
1976
+ # Copy one sample (and distrib parameters) over the whole batch.
1977
+ # This is used when doing experiment from the prior - q is not used.
1978
+ if force_constant_output:
1979
+ z = z[0:1].expand_as(z).clone()
1980
+ p_params = (
1981
+ p_params[0][0:1].expand_as(p_params[0]).clone(),
1982
+ p_params[1][0:1].expand_as(p_params[1]).clone(),
1983
+ )
1984
+
1985
+ # Output of stochastic layer
1986
+ out = self.conv_out(z)
1987
+
1988
+ kl_dict = {}
1989
+ logprob_q = None
1990
+
1991
+ data = kl_dict
1992
+ data["z"] = z # sampled variable at this layer (batch, ch, h, w)
1993
+ data["p_params"] = p_params # (b, ch, h, w) where b is 1 or batch size
1994
+ data["q_params"] = q_params # (batch, ch, h, w)
1995
+ data["logprob_q"] = logprob_q # (batch, )
1996
+ data["qvar_max"] = debug_qvar_max
1997
+
1998
+ return out, data