careamics 0.0.4.2__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (118) hide show
  1. careamics/__init__.py +17 -2
  2. careamics/careamist.py +239 -28
  3. careamics/cli/conf.py +19 -31
  4. careamics/cli/main.py +112 -12
  5. careamics/cli/utils.py +29 -0
  6. careamics/config/__init__.py +48 -24
  7. careamics/config/algorithms/__init__.py +15 -0
  8. careamics/config/algorithms/care_algorithm_model.py +50 -0
  9. careamics/config/algorithms/n2n_algorithm_model.py +42 -0
  10. careamics/config/algorithms/n2v_algorithm_model.py +35 -0
  11. careamics/config/algorithms/unet_algorithm_model.py +88 -0
  12. careamics/config/{vae_algorithm_model.py → algorithms/vae_algorithm_model.py} +26 -23
  13. careamics/config/architectures/__init__.py +1 -11
  14. careamics/config/architectures/architecture_model.py +3 -3
  15. careamics/config/architectures/lvae_model.py +109 -21
  16. careamics/config/architectures/unet_model.py +1 -0
  17. careamics/config/care_configuration.py +100 -0
  18. careamics/config/configuration.py +354 -0
  19. careamics/config/{configuration_factory.py → configuration_factories.py} +152 -81
  20. careamics/config/configuration_io.py +85 -0
  21. careamics/config/data/__init__.py +10 -0
  22. careamics/config/{data_model.py → data/data_model.py} +58 -198
  23. careamics/config/data/n2v_data_model.py +193 -0
  24. careamics/config/likelihood_model.py +8 -8
  25. careamics/config/loss_model.py +56 -0
  26. careamics/config/n2n_configuration.py +101 -0
  27. careamics/config/n2v_configuration.py +266 -0
  28. careamics/config/nm_model.py +24 -25
  29. careamics/config/support/__init__.py +7 -7
  30. careamics/config/support/supported_algorithms.py +0 -3
  31. careamics/config/support/supported_architectures.py +0 -4
  32. careamics/config/transformations/__init__.py +10 -4
  33. careamics/config/transformations/transform_model.py +3 -3
  34. careamics/config/transformations/transform_unions.py +42 -0
  35. careamics/config/validators/validator_utils.py +3 -3
  36. careamics/dataset/__init__.py +2 -2
  37. careamics/dataset/dataset_utils/__init__.py +3 -3
  38. careamics/dataset/dataset_utils/dataset_utils.py +4 -6
  39. careamics/dataset/dataset_utils/file_utils.py +9 -9
  40. careamics/dataset/dataset_utils/iterate_over_files.py +4 -3
  41. careamics/dataset/dataset_utils/running_stats.py +22 -23
  42. careamics/dataset/in_memory_dataset.py +11 -12
  43. careamics/dataset/iterable_dataset.py +4 -4
  44. careamics/dataset/iterable_pred_dataset.py +2 -1
  45. careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
  46. careamics/dataset/patching/random_patching.py +11 -10
  47. careamics/dataset/patching/sequential_patching.py +26 -26
  48. careamics/dataset/patching/validate_patch_dimension.py +3 -3
  49. careamics/dataset/tiling/__init__.py +2 -2
  50. careamics/dataset/tiling/collate_tiles.py +3 -3
  51. careamics/dataset/tiling/lvae_tiled_patching.py +2 -1
  52. careamics/dataset/tiling/tiled_patching.py +11 -10
  53. careamics/file_io/__init__.py +5 -5
  54. careamics/file_io/read/__init__.py +1 -1
  55. careamics/file_io/read/get_func.py +2 -2
  56. careamics/file_io/write/__init__.py +2 -2
  57. careamics/lightning/__init__.py +5 -5
  58. careamics/lightning/callbacks/__init__.py +1 -1
  59. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +3 -3
  60. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +2 -1
  61. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +2 -1
  62. careamics/lightning/callbacks/progress_bar_callback.py +2 -2
  63. careamics/lightning/lightning_module.py +69 -34
  64. careamics/lightning/train_data_module.py +41 -27
  65. careamics/losses/__init__.py +3 -3
  66. careamics/losses/loss_factory.py +1 -85
  67. careamics/losses/lvae/losses.py +223 -164
  68. careamics/lvae_training/calibration.py +184 -0
  69. careamics/lvae_training/dataset/config.py +2 -2
  70. careamics/lvae_training/dataset/multich_dataset.py +11 -19
  71. careamics/lvae_training/dataset/multifile_dataset.py +3 -2
  72. careamics/lvae_training/dataset/types.py +15 -26
  73. careamics/lvae_training/dataset/utils/index_manager.py +4 -4
  74. careamics/lvae_training/eval_utils.py +125 -213
  75. careamics/model_io/__init__.py +1 -1
  76. careamics/model_io/bioimage/__init__.py +1 -1
  77. careamics/model_io/bioimage/_readme_factory.py +26 -34
  78. careamics/model_io/bioimage/cover_factory.py +171 -0
  79. careamics/model_io/bioimage/model_description.py +56 -34
  80. careamics/model_io/bmz_io.py +42 -42
  81. careamics/model_io/model_io_utils.py +9 -9
  82. careamics/models/layers.py +22 -20
  83. careamics/models/lvae/layers.py +348 -975
  84. careamics/models/lvae/likelihoods.py +10 -8
  85. careamics/models/lvae/lvae.py +214 -275
  86. careamics/models/lvae/noise_models.py +179 -112
  87. careamics/models/lvae/stochastic.py +393 -0
  88. careamics/models/lvae/utils.py +82 -73
  89. careamics/models/model_factory.py +2 -15
  90. careamics/models/unet.py +8 -8
  91. careamics/prediction_utils/__init__.py +1 -1
  92. careamics/prediction_utils/prediction_outputs.py +15 -15
  93. careamics/prediction_utils/stitch_prediction.py +6 -6
  94. careamics/transforms/__init__.py +5 -5
  95. careamics/transforms/compose.py +13 -13
  96. careamics/transforms/n2v_manipulate.py +3 -3
  97. careamics/transforms/pixel_manipulation.py +9 -9
  98. careamics/transforms/xy_random_rotate90.py +4 -4
  99. careamics/utils/__init__.py +5 -5
  100. careamics/utils/context.py +2 -1
  101. careamics/utils/lightning_utils.py +57 -0
  102. careamics/utils/logging.py +11 -10
  103. careamics/utils/serializers.py +2 -0
  104. careamics/utils/torch_utils.py +8 -8
  105. {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/METADATA +16 -13
  106. careamics-0.0.6.dist-info/RECORD +176 -0
  107. {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/WHEEL +1 -1
  108. careamics/config/architectures/custom_model.py +0 -162
  109. careamics/config/architectures/register_model.py +0 -103
  110. careamics/config/configuration_model.py +0 -603
  111. careamics/config/fcn_algorithm_model.py +0 -152
  112. careamics/config/references/__init__.py +0 -45
  113. careamics/config/references/algorithm_descriptions.py +0 -132
  114. careamics/config/references/references.py +0 -39
  115. careamics/config/transformations/transform_union.py +0 -20
  116. careamics-0.0.4.2.dist-info/RECORD +0 -165
  117. {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/entry_points.txt +0 -0
  118. {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/licenses/LICENSE +0 -0
@@ -1,27 +1,23 @@
1
- """
2
- Script containing the common basic blocks (nn.Module) reused by the LadderVAE architecture.
3
-
4
- Hierarchy in the model blocks:
5
-
6
- """
1
+ """Script containing the common basic blocks (nn.Module) reused by the LadderVAE."""
7
2
 
3
+ from collections.abc import Iterable
8
4
  from copy import deepcopy
9
- from typing import Callable, Dict, Iterable, Literal, Tuple, Union
5
+ from typing import Callable, Literal, Optional, Union
10
6
 
7
+ import numpy as np
11
8
  import torch
12
9
  import torch.nn as nn
13
- import torchvision.transforms.functional as F
14
- from torch.distributions import kl_divergence
15
- from torch.distributions.normal import Normal
16
10
 
11
+ from .stochastic import NormalStochasticBlock
17
12
  from .utils import (
18
- StableLogVar,
19
- StableMean,
20
13
  crop_img_tensor,
21
- kl_normal_mc,
22
14
  pad_img_tensor,
23
15
  )
24
16
 
17
+ ConvType = Union[nn.Conv2d, nn.Conv3d]
18
+ NormType = Union[nn.BatchNorm2d, nn.BatchNorm3d]
19
+ DropoutType = Union[nn.Dropout2d, nn.Dropout3d]
20
+
25
21
 
26
22
  class ResidualBlock(nn.Module):
27
23
  """
@@ -51,13 +47,13 @@ class ResidualBlock(nn.Module):
51
47
  self,
52
48
  channels: int,
53
49
  nonlin: Callable,
54
- kernel: Union[int, Iterable[int]] = None,
50
+ conv_strides: tuple[int] = (2, 2),
51
+ kernel: Union[int, Iterable[int], None] = None,
55
52
  groups: int = 1,
56
53
  batchnorm: bool = True,
57
54
  block_type: str = None,
58
55
  dropout: float = None,
59
56
  gated: bool = None,
60
- skip_padding: bool = False,
61
57
  conv2d_bias: bool = True,
62
58
  ):
63
59
  """
@@ -85,8 +81,6 @@ class ResidualBlock(nn.Module):
85
81
  Default is `None`.
86
82
  gated: bool, optional
87
83
  Whether to use gated layer. Default is `None`.
88
- skip_padding: bool, optional
89
- Whether to skip padding in convolutions. Default is `False`.
90
84
  conv2d_bias: bool, optional
91
85
  Whether to use bias term in convolutions. Default is `True`.
92
86
  """
@@ -99,99 +93,142 @@ class ResidualBlock(nn.Module):
99
93
  kernel = (kernel, kernel)
100
94
  elif len(kernel) != 2:
101
95
  raise ValueError("kernel has to be None, int, or an iterable of length 2")
102
- assert all([k % 2 == 1 for k in kernel]), "kernel sizes have to be odd"
96
+ assert all(k % 2 == 1 for k in kernel), "kernel sizes have to be odd"
103
97
  kernel = list(kernel)
104
- self.skip_padding = skip_padding
105
- pad = [0] * len(kernel) if self.skip_padding else [k // 2 for k in kernel]
106
- # print(kernel, pad)
98
+
99
+ # Define modules
100
+ conv_layer: ConvType = getattr(nn, f"Conv{len(conv_strides)}d")
101
+ norm_layer: NormType = getattr(nn, f"BatchNorm{len(conv_strides)}d")
102
+ dropout_layer: DropoutType = getattr(nn, f"Dropout{len(conv_strides)}d")
103
+ # TODO: same comment as in lvae.py, would be more readable to have `conv_dims`
107
104
 
108
105
  modules = []
109
106
  if block_type == "cabdcabd":
110
107
  for i in range(2):
111
- conv = nn.Conv2d(
108
+ conv = conv_layer(
112
109
  channels,
113
110
  channels,
114
111
  kernel[i],
115
- padding=pad[i],
112
+ padding="same",
116
113
  groups=groups,
117
114
  bias=conv2d_bias,
118
115
  )
119
116
  modules.append(conv)
120
117
  modules.append(nonlin)
121
118
  if batchnorm:
122
- modules.append(nn.BatchNorm2d(channels))
119
+ modules.append(norm_layer(channels))
123
120
  if dropout is not None:
124
- modules.append(nn.Dropout2d(dropout))
121
+ modules.append(dropout_layer(dropout))
125
122
  elif block_type == "bacdbac":
126
123
  for i in range(2):
127
124
  if batchnorm:
128
- modules.append(nn.BatchNorm2d(channels))
125
+ modules.append(norm_layer(channels))
129
126
  modules.append(nonlin)
130
- conv = nn.Conv2d(
127
+ conv = conv_layer(
131
128
  channels,
132
129
  channels,
133
130
  kernel[i],
134
- padding=pad[i],
131
+ padding="same",
135
132
  groups=groups,
136
133
  bias=conv2d_bias,
137
134
  )
138
135
  modules.append(conv)
139
136
  if dropout is not None and i == 0:
140
- modules.append(nn.Dropout2d(dropout))
137
+ modules.append(dropout_layer(dropout))
141
138
  elif block_type == "bacdbacd":
142
139
  for i in range(2):
143
140
  if batchnorm:
144
- modules.append(nn.BatchNorm2d(channels))
141
+ modules.append(norm_layer(channels))
145
142
  modules.append(nonlin)
146
- conv = nn.Conv2d(
143
+ conv = conv_layer(
147
144
  channels,
148
145
  channels,
149
146
  kernel[i],
150
- padding=pad[i],
147
+ padding="same",
151
148
  groups=groups,
152
149
  bias=conv2d_bias,
153
150
  )
154
151
  modules.append(conv)
155
- modules.append(nn.Dropout2d(dropout))
152
+ modules.append(dropout_layer(dropout))
156
153
 
157
154
  else:
158
155
  raise ValueError(f"unrecognized block type '{block_type}'")
159
156
 
160
157
  self.gated = gated
161
158
  if gated:
162
- modules.append(GateLayer2d(channels, 1, nonlin))
159
+ modules.append(
160
+ GateLayer(
161
+ channels=channels,
162
+ conv_strides=conv_strides,
163
+ kernel_size=1,
164
+ nonlin=nonlin,
165
+ )
166
+ )
163
167
 
164
168
  self.block = nn.Sequential(*modules)
165
169
 
166
- def forward(self, x):
170
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
171
+ """Forward pass.
172
+
173
+ Parameters
174
+ ----------
175
+ x : torch.Tensor
176
+ input tensor # TODO add shape
167
177
 
178
+ Returns
179
+ -------
180
+ torch.Tensor
181
+ output tensor # TODO add shape
182
+ """
168
183
  out = self.block(x)
169
- if out.shape != x.shape:
170
- return out + F.center_crop(x, out.shape[-2:])
171
- else:
172
- return out + x
184
+ assert (
185
+ out.shape == x.shape
186
+ ), f"output shape: {out.shape} != input shape: {x.shape}"
187
+ return out + x
173
188
 
174
189
 
175
190
  class ResidualGatedBlock(ResidualBlock):
191
+ """Layer class that implements a residual block with a gating mechanism."""
176
192
 
177
193
  def __init__(self, *args, **kwargs):
178
194
  super().__init__(*args, **kwargs, gated=True)
179
195
 
180
196
 
181
- class GateLayer2d(nn.Module):
197
+ class GateLayer(nn.Module):
182
198
  """
199
+ Layer class that implements a gating mechanism.
200
+
183
201
  Double the number of channels through a convolutional layer, then use
184
202
  half the channels as gate for the other half.
185
203
  """
186
204
 
187
- def __init__(self, channels, kernel_size, nonlin=nn.LeakyReLU):
205
+ def __init__(
206
+ self,
207
+ channels: int,
208
+ conv_strides: tuple[int] = (2, 2),
209
+ kernel_size: int = 3,
210
+ nonlin: Callable = nn.LeakyReLU(),
211
+ ):
188
212
  super().__init__()
189
213
  assert kernel_size % 2 == 1
190
214
  pad = kernel_size // 2
191
- self.conv = nn.Conv2d(channels, 2 * channels, kernel_size, padding=pad)
215
+ conv_layer: ConvType = getattr(nn, f"Conv{len(conv_strides)}d")
216
+ self.conv = conv_layer(channels, 2 * channels, kernel_size, padding=pad)
192
217
  self.nonlin = nonlin
193
218
 
194
- def forward(self, x):
219
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
220
+ """Forward pass.
221
+
222
+ Parameters
223
+ ----------
224
+ x : torch.Tensor
225
+ input # TODO add shape
226
+
227
+ Returns
228
+ -------
229
+ torch.Tensor
230
+ output # TODO add shape
231
+ """
195
232
  x = self.conv(x)
196
233
  x, gate = torch.chunk(x, 2, dim=1)
197
234
  x = self.nonlin(x) # TODO remove this?
@@ -201,6 +238,8 @@ class GateLayer2d(nn.Module):
201
238
 
202
239
  class ResBlockWithResampling(nn.Module):
203
240
  """
241
+ Residual block with resampling.
242
+
204
243
  Residual block that takes care of resampling (i.e. downsampling or upsampling) steps (by a factor 2).
205
244
  It is structured as follows:
206
245
  1. `pre_conv`: a downsampling or upsampling strided convolutional layer in case of resampling, or
@@ -210,7 +249,7 @@ class ResBlockWithResampling(nn.Module):
210
249
 
211
250
  Some implementation notes:
212
251
  - Resampling is performed through a strided convolution layer at the beginning of the block.
213
- - The strided convolution block has fixed kernel size of 3x3 and 1 layer of zero-padding.
252
+ - The strided convolution block has fixed kernel size of 3x3 and 1 layer of padding with zeros.
214
253
  - The number of channels is adjusted at the beginning and end of the block through 1x1 convolutional layers.
215
254
  - The number of internal channels is by default the same as the number of output channels, but
216
255
  min_inner_channels can override the behaviour.
@@ -221,16 +260,16 @@ class ResBlockWithResampling(nn.Module):
221
260
  mode: Literal["top-down", "bottom-up"],
222
261
  c_in: int,
223
262
  c_out: int,
224
- min_inner_channels: int = None,
225
- nonlin: Callable = nn.LeakyReLU,
263
+ conv_strides: tuple[int],
264
+ min_inner_channels: Union[int, None] = None,
265
+ nonlin: Callable = nn.LeakyReLU(),
226
266
  resample: bool = False,
227
- res_block_kernel: Union[int, Iterable[int]] = None,
267
+ res_block_kernel: Optional[Union[int, Iterable[int]]] = None,
228
268
  groups: int = 1,
229
269
  batchnorm: bool = True,
230
- res_block_type: str = None,
231
- dropout: float = None,
232
- gated: bool = None,
233
- skip_padding: bool = False,
270
+ res_block_type: Union[str, None] = None,
271
+ dropout: Union[float, None] = None,
272
+ gated: Union[bool, None] = None,
234
273
  conv2d_bias: bool = True,
235
274
  # lowres_input: bool = False,
236
275
  ):
@@ -273,14 +312,15 @@ class ResBlockWithResampling(nn.Module):
273
312
  Default is `None`.
274
313
  gated: bool, optional
275
314
  Whether to use gated layer. Default is `None`.
276
- skip_padding: bool, optional
277
- Whether to skip padding in convolutions. Default is `False`.
278
315
  conv2d_bias: bool, optional
279
316
  Whether to use bias term in convolutions. Default is `True`.
280
317
  """
281
318
  super().__init__()
282
319
  assert mode in ["top-down", "bottom-up"]
283
320
 
321
+ conv_layer: ConvType = getattr(nn, f"Conv{len(conv_strides)}d")
322
+ transp_conv_layer: ConvType = getattr(nn, f"ConvTranspose{len(conv_strides)}d")
323
+
284
324
  if min_inner_channels is None:
285
325
  min_inner_channels = 0
286
326
  # inner_channels is the number of channels used in the inner layers
@@ -290,28 +330,28 @@ class ResBlockWithResampling(nn.Module):
290
330
  # Define first conv layer to change num channels and/or up/downsample
291
331
  if resample:
292
332
  if mode == "bottom-up": # downsample
293
- self.pre_conv = nn.Conv2d(
333
+ self.pre_conv = conv_layer(
294
334
  in_channels=c_in,
295
335
  out_channels=inner_channels,
296
336
  kernel_size=3,
297
337
  padding=1,
298
- stride=2,
338
+ stride=conv_strides,
299
339
  groups=groups,
300
340
  bias=conv2d_bias,
301
341
  )
302
342
  elif mode == "top-down": # upsample
303
- self.pre_conv = nn.ConvTranspose2d(
343
+ self.pre_conv = transp_conv_layer(
304
344
  in_channels=c_in,
305
345
  kernel_size=3,
306
346
  out_channels=inner_channels,
307
- padding=1,
308
- stride=2,
347
+ padding=1, # TODO maybe don't hardcode this?
348
+ stride=conv_strides,
309
349
  groups=groups,
310
- output_padding=1,
350
+ output_padding=1 if len(conv_strides) == 2 else (0, 1, 1),
311
351
  bias=conv2d_bias,
312
352
  )
313
353
  elif c_in != inner_channels:
314
- self.pre_conv = nn.Conv2d(
354
+ self.pre_conv = conv_layer(
315
355
  c_in, inner_channels, 1, groups=groups, bias=conv2d_bias
316
356
  )
317
357
  else:
@@ -320,6 +360,7 @@ class ResBlockWithResampling(nn.Module):
320
360
  # Residual block
321
361
  self.res = ResidualBlock(
322
362
  channels=inner_channels,
363
+ conv_strides=conv_strides,
323
364
  nonlin=nonlin,
324
365
  kernel=res_block_kernel,
325
366
  groups=groups,
@@ -327,19 +368,30 @@ class ResBlockWithResampling(nn.Module):
327
368
  dropout=dropout,
328
369
  gated=gated,
329
370
  block_type=res_block_type,
330
- skip_padding=skip_padding,
331
371
  conv2d_bias=conv2d_bias,
332
372
  )
333
373
 
334
374
  # Define last conv layer to get correct num output channels
335
375
  if inner_channels != c_out:
336
- self.post_conv = nn.Conv2d(
376
+ self.post_conv = conv_layer(
337
377
  inner_channels, c_out, 1, groups=groups, bias=conv2d_bias
338
378
  )
339
379
  else:
340
380
  self.post_conv = None
341
381
 
342
- def forward(self, x):
382
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
383
+ """Forward pass.
384
+
385
+ Parameters
386
+ ----------
387
+ x : torch.Tensor
388
+ input # TODO add shape
389
+
390
+ Returns
391
+ -------
392
+ torch.Tensor
393
+ output # TODO add shape
394
+ """
343
395
  if self.pre_conv is not None:
344
396
  x = self.pre_conv(x)
345
397
 
@@ -351,6 +403,7 @@ class ResBlockWithResampling(nn.Module):
351
403
 
352
404
 
353
405
  class TopDownDeterministicResBlock(ResBlockWithResampling):
406
+ """Resnet block for top-down deterministic layers."""
354
407
 
355
408
  def __init__(self, *args, upsample: bool = False, **kwargs):
356
409
  kwargs["resample"] = upsample
@@ -358,6 +411,7 @@ class TopDownDeterministicResBlock(ResBlockWithResampling):
358
411
 
359
412
 
360
413
  class BottomUpDeterministicResBlock(ResBlockWithResampling):
414
+ """Resnet block for bottom-up deterministic layers."""
361
415
 
362
416
  def __init__(self, *args, downsample: bool = False, **kwargs):
363
417
  kwargs["resample"] = downsample
@@ -367,6 +421,7 @@ class BottomUpDeterministicResBlock(ResBlockWithResampling):
367
421
  class BottomUpLayer(nn.Module):
368
422
  """
369
423
  Bottom-up deterministic layer.
424
+
370
425
  It consists of one or a stack of `BottomUpDeterministicResBlock`'s.
371
426
  The outputs are the so-called `bu_values` that are later used in the Decoder to update the
372
427
  generative distributions.
@@ -385,20 +440,20 @@ class BottomUpLayer(nn.Module):
385
440
  self,
386
441
  n_res_blocks: int,
387
442
  n_filters: int,
443
+ conv_strides: tuple[int] = (2, 2),
388
444
  downsampling_steps: int = 0,
389
- nonlin: Callable = None,
445
+ nonlin: Optional[Callable] = None,
390
446
  batchnorm: bool = True,
391
- dropout: float = None,
392
- res_block_type: str = None,
393
- res_block_kernel: int = None,
394
- res_block_skip_padding: bool = False,
395
- gated: bool = None,
447
+ dropout: Optional[float] = None,
448
+ res_block_type: Optional[str] = None,
449
+ res_block_kernel: Optional[int] = None,
450
+ gated: Optional[bool] = None,
396
451
  enable_multiscale: bool = False,
397
- multiscale_lowres_size_factor: int = None,
452
+ multiscale_lowres_size_factor: Optional[int] = None,
398
453
  lowres_separate_branch: bool = False,
399
454
  multiscale_retain_spatial_dims: bool = False,
400
455
  decoder_retain_spatial_dims: bool = False,
401
- output_expected_shape: Iterable[int] = None,
456
+ output_expected_shape: Optional[Iterable[int]] = None,
402
457
  ):
403
458
  """
404
459
  Constructor.
@@ -427,8 +482,6 @@ class BottomUpLayer(nn.Module):
427
482
  The kernel size used in the convolutions of the residual block.
428
483
  It can be either a single integer or a pair of integers defining the squared kernel.
429
484
  Default is `None`.
430
- res_block_skip_padding: bool, optional
431
- Whether to skip padding in convolutions in the Residual block. Default is `False`.
432
485
  gated: bool, optional
433
486
  Whether to use gated layer. Default is `None`.
434
487
  enable_multiscale: bool, optional
@@ -443,7 +496,8 @@ class BottomUpLayer(nn.Module):
443
496
  Whether to pad the latent tensor resulting from the bottom-up layer's primary flow
444
497
  to match the size of the low-res input. Default is `False`.
445
498
  decoder_retain_spatial_dims: bool, optional
446
- Default is `False`.
499
+ Whether in the corresponding top-down layer the shape of tensor is retained between
500
+ input and output. Default is `False`.
447
501
  output_expected_shape: Iterable[int], optional
448
502
  The expected shape of the layer output (only used if `enable_multiscale == True`).
449
503
  Default is `None`.
@@ -467,6 +521,7 @@ class BottomUpLayer(nn.Module):
467
521
  do_resample = True
468
522
  downsampling_steps -= 1
469
523
  block = BottomUpDeterministicResBlock(
524
+ conv_strides=conv_strides,
470
525
  c_in=n_filters,
471
526
  c_out=n_filters,
472
527
  nonlin=nonlin,
@@ -475,7 +530,6 @@ class BottomUpLayer(nn.Module):
475
530
  dropout=dropout,
476
531
  res_block_type=res_block_type,
477
532
  res_block_kernel=res_block_kernel,
478
- skip_padding=res_block_skip_padding,
479
533
  gated=gated,
480
534
  )
481
535
  if do_resample:
@@ -491,6 +545,7 @@ class BottomUpLayer(nn.Module):
491
545
  if self.enable_multiscale:
492
546
  self._init_multiscale(
493
547
  n_filters=n_filters,
548
+ conv_strides=conv_strides,
494
549
  nonlin=nonlin,
495
550
  batchnorm=batchnorm,
496
551
  dropout=dropout,
@@ -506,20 +561,25 @@ class BottomUpLayer(nn.Module):
506
561
  self,
507
562
  nonlin: Callable = None,
508
563
  n_filters: int = None,
564
+ conv_strides: tuple[int] = (2, 2),
509
565
  batchnorm: bool = None,
510
566
  dropout: float = None,
511
567
  res_block_type: str = None,
512
568
  ) -> None:
513
569
  """
514
- This method defines the modules responsible of merging compressed lateral inputs to the outputs
515
- of the primary flow at different hierarchical levels in the multiresolution approach (LC).
570
+ Bottom-up layer's method that initializes the LC modules.
516
571
 
517
- Specifically, the method initializes `lowres_net`, which is a stack of `BottomUpDeterministicBlock`'s
518
- (w/out downsampling) that takes care of additionally processing the low-res input, and `lowres_merge`,
519
- which is the module responsible of merging the compressed lateral input to the main flow.
572
+ Defines the modules responsible of merging compressed lateral inputs to the
573
+ outputs of the primary flow at different hierarchical levels in the
574
+ multiresolution approach (LC). Specifically, the method initializes `lowres_net`
575
+ , which is a stack of `BottomUpDeterministicBlock`'s (w/out downsampling) that
576
+ takes care of additionally processing the low-res input, and `lowres_merge`,
577
+ which is the module responsible of merging the compressed lateral input to the
578
+ main flow.
520
579
 
521
- NOTE: The merge modality is set by default to "residual", meaning that the merge layer
522
- performs concatenation on dim=1, followed by 1x1 convolution and a Residual Gated block.
580
+ NOTE: The merge modality is set by default to "residual", meaning that the
581
+ merge layer performs concatenation on dim=1, followed by 1x1 convolution and
582
+ a Residual Gated block.
523
583
 
524
584
  Parameters
525
585
  ----------
@@ -543,6 +603,7 @@ class BottomUpLayer(nn.Module):
543
603
 
544
604
  self.lowres_merge = MergeLowRes(
545
605
  channels=n_filters,
606
+ conv_strides=conv_strides,
546
607
  merge_type="residual",
547
608
  nonlin=nonlin,
548
609
  batchnorm=batchnorm,
@@ -553,9 +614,10 @@ class BottomUpLayer(nn.Module):
553
614
  )
554
615
 
555
616
  def forward(
556
- self, x: torch.Tensor, lowres_x: torch.Tensor = None
557
- ) -> Tuple[torch.Tensor, torch.Tensor]:
558
- """
617
+ self, x: torch.Tensor, lowres_x: Union[torch.Tensor, None] = None
618
+ ) -> tuple[torch.Tensor, torch.Tensor]:
619
+ """Forward pass.
620
+
559
621
  Parameters
560
622
  ----------
561
623
  x: torch.Tensor
@@ -563,6 +625,9 @@ class BottomUpLayer(nn.Module):
563
625
  previous layer.
564
626
  lowres_x: torch.Tensor, optional
565
627
  The low-res input used for Lateral Contextualization (LC). Default is `None`.
628
+
629
+ NOTE: first returned tensor is used as input for the next BU layer, while the second
630
+ tensor is the bu_value passed to the top-down layer.
566
631
  """
567
632
  # The input is fed through the residual downsampling block(s)
568
633
  primary_flow = self.net_downsized(x)
@@ -582,12 +647,25 @@ class BottomUpLayer(nn.Module):
582
647
  else:
583
648
  merged = primary_flow
584
649
 
650
+ # NOTE: Explanation of possible cases for the conditionals:
651
+ # - if both are `True` -> `merged` has the same spatial dims as the input (`x`) since
652
+ # spatial dims are retained by padding `primary_flow` in `MergeLowRes`. This is
653
+ # OK for the corresp TopDown layer, as it also retains spatial dims.
654
+ # - if both are `False` -> `merged`'s spatial dims are equal to `self.net_downsized(x)`,
655
+ # since no padding is done in `MergeLowRes` and, instead, the lowres input is cropped.
656
+ # This is OK for the corresp TopDown layer, as it also halves the spatial dims.
657
+ # - if 1st is `False` and 2nd is `True` -> not a concern, it cannot happen
658
+ # (see lvae.py, line 111, intialization of `multiscale_decoder_retain_spatial_dims`).
585
659
  if (
586
660
  self.multiscale_retain_spatial_dims is False
587
661
  or self.decoder_retain_spatial_dims is True
588
662
  ):
589
663
  return merged, merged
590
664
 
665
+ # NOTE: if we reach here, it means that `multiscale_retain_spatial_dims` is `True`,
666
+ # but `decoder_retain_spatial_dims` is `False`, meaning that merging LC preserves
667
+ # the spatial dimensions, but at the same time we don't want to retain the spatial
668
+ # dims in the corresponding top-down layer. Therefore, we need to crop the tensor.
591
669
  if self.output_expected_shape is not None:
592
670
  expected_shape = self.output_expected_shape
593
671
  else:
@@ -602,7 +680,10 @@ class BottomUpLayer(nn.Module):
602
680
 
603
681
  class MergeLayer(nn.Module):
604
682
  """
605
- This layer merges two or more 4D input tensors by concatenating along dim=1 and passes the result through:
683
+ Layer class that merges two or more input tensors.
684
+
685
+ Merges two or more (B, C, [Z], Y, X) input tensors by concatenating
686
+ them along dim=1 and passes the result through:
606
687
  a) a convolutional 1x1 layer (`merge_type == "linear"`), or
607
688
  b) a convolutional 1x1 layer and then a gated residual block (`merge_type == "residual"`), or
608
689
  c) a convolutional 1x1 layer and then an ungated residual block (`merge_type == "residual_ungated"`).
@@ -612,13 +693,13 @@ class MergeLayer(nn.Module):
612
693
  self,
613
694
  merge_type: Literal["linear", "residual", "residual_ungated"],
614
695
  channels: Union[int, Iterable[int]],
615
- nonlin: Callable = nn.LeakyReLU,
696
+ conv_strides: tuple[int] = (2, 2),
697
+ nonlin: Callable = nn.LeakyReLU(),
616
698
  batchnorm: bool = True,
617
- dropout: float = None,
618
- res_block_type: str = None,
619
- res_block_kernel: int = None,
620
- res_block_skip_padding: bool = False,
621
- conv2d_bias: bool = True,
699
+ dropout: Optional[float] = None,
700
+ res_block_type: Optional[str] = None,
701
+ res_block_kernel: Optional[int] = None,
702
+ conv2d_bias: Optional[bool] = True,
622
703
  ):
623
704
  """
624
705
  Constructor.
@@ -626,16 +707,21 @@ class MergeLayer(nn.Module):
626
707
  Parameters
627
708
  ----------
628
709
  merge_type: Literal["linear", "residual", "residual_ungated"]
629
- The type of merge done in the layer. It can be chosen between "linear", "residual", and "residual_ungated".
630
- Check the class docstring for more information about the behaviour of different merge modalities.
710
+ The type of merge done in the layer. It can be chosen between "linear",
711
+ "residual", and "residual_ungated". Check the class docstring for more
712
+ information about the behaviour of different merge modalities.
631
713
  channels: Union[int, Iterable[int]]
632
714
  The number of channels used in the convolutional blocks of this layer.
633
715
  If it is an `int`:
634
716
  - 1st 1x1 Conv2d: in_channels=2*channels, out_channels=channels
635
717
  - (Optional) ResBlock: in_channels=channels, out_channels=channels
636
718
  If it is an Iterable (must have `len(channels)==3`):
637
- - 1st 1x1 Conv2d: in_channels=sum(channels[:-1]), out_channels=channels[-1]
638
- - (Optional) ResBlock: in_channels=channels[-1], out_channels=channels[-1]
719
+ - 1st 1x1 Conv2d: in_channels=sum(channels[:-1]),
720
+ out_channels=channels[-1]
721
+ - (Optional) ResBlock: in_channels=channels[-1],
722
+ out_channels=channels[-1]
723
+ conv_strides: tuple, optional
724
+ The strides used in the convolutions. Default is `(2, 2)`.
639
725
  nonlin: Callable, optional
640
726
  The non-linearity function used in the block. Default is `nn.LeakyReLU`.
641
727
  batchnorm: bool, optional
@@ -649,10 +735,9 @@ class MergeLayer(nn.Module):
649
735
  Default is `None`.
650
736
  res_block_kernel: Union[int, Iterable[int]], optional
651
737
  The kernel size used in the convolutions of the residual block.
652
- It can be either a single integer or a pair of integers defining the squared kernel.
738
+ It can be either a single integer or a pair of integers defining the squared
739
+ kernel.
653
740
  Default is `None`.
654
- res_block_skip_padding: bool, optional
655
- Whether to skip padding in convolutions in the Residual block. Default is `False`.
656
741
  conv2d_bias: bool, optional
657
742
  Whether to use bias term in convolutions. Default is `True`.
658
743
  """
@@ -665,42 +750,42 @@ class MergeLayer(nn.Module):
665
750
  if len(channels) == 1:
666
751
  channels = [channels[0]] * 3
667
752
 
668
- # assert len(channels) == 3
753
+ self.conv_layer: ConvType = getattr(nn, f"Conv{len(conv_strides)}d")
669
754
 
670
755
  if merge_type == "linear":
671
- self.layer = nn.Conv2d(
756
+ self.layer = self.conv_layer(
672
757
  sum(channels[:-1]), channels[-1], 1, bias=conv2d_bias
673
758
  )
674
759
  elif merge_type == "residual":
675
760
  self.layer = nn.Sequential(
676
- nn.Conv2d(
761
+ self.conv_layer(
677
762
  sum(channels[:-1]), channels[-1], 1, padding=0, bias=conv2d_bias
678
763
  ),
679
764
  ResidualGatedBlock(
680
- channels[-1],
681
- nonlin,
765
+ conv_strides=conv_strides,
766
+ channels=channels[-1],
767
+ nonlin=nonlin,
682
768
  batchnorm=batchnorm,
683
769
  dropout=dropout,
684
770
  block_type=res_block_type,
685
771
  kernel=res_block_kernel,
686
772
  conv2d_bias=conv2d_bias,
687
- skip_padding=res_block_skip_padding,
688
773
  ),
689
774
  )
690
775
  elif merge_type == "residual_ungated":
691
776
  self.layer = nn.Sequential(
692
- nn.Conv2d(
777
+ self.conv_layer(
693
778
  sum(channels[:-1]), channels[-1], 1, padding=0, bias=conv2d_bias
694
779
  ),
695
780
  ResidualBlock(
696
- channels[-1],
697
- nonlin,
781
+ conv_strides=conv_strides,
782
+ channels=channels[-1],
783
+ nonlin=nonlin,
698
784
  batchnorm=batchnorm,
699
785
  dropout=dropout,
700
786
  block_type=res_block_type,
701
787
  kernel=res_block_kernel,
702
788
  conv2d_bias=conv2d_bias,
703
- skip_padding=res_block_skip_padding,
704
789
  ),
705
790
  )
706
791
 
@@ -717,7 +802,9 @@ class MergeLayer(nn.Module):
717
802
 
718
803
  class MergeLowRes(MergeLayer):
719
804
  """
720
- Child class of `MergeLayer`, specifically designed to merge the low-resolution patches
805
+ Child class of `MergeLayer`.
806
+
807
+ Specifically designed to merge the low-resolution patches
721
808
  that are used in Lateral Contextualization approach.
722
809
  """
723
810
 
@@ -727,7 +814,8 @@ class MergeLowRes(MergeLayer):
727
814
  super().__init__(*args, **kwargs)
728
815
 
729
816
  def forward(self, latent: torch.Tensor, lowres: torch.Tensor) -> torch.Tensor:
730
- """
817
+ """Forward pass.
818
+
731
819
  Parameters
732
820
  ----------
733
821
  latent: torch.Tensor
@@ -735,25 +823,28 @@ class MergeLowRes(MergeLayer):
735
823
  lowres: torch.Tensor
736
824
  The low-res patch image to be merged to increase the context.
737
825
  """
826
+ # TODO: treat (X, Y) and Z differently (e.g., line 762)
738
827
  if self.retain_spatial_dims:
739
828
  # Pad latent tensor to match lowres tensor's shape
829
+ # Output.shape == Lowres.shape (== Input.shape),
830
+ # where Input is the input to the BU layer
740
831
  latent = pad_img_tensor(latent, lowres.shape[2:])
741
832
  else:
742
833
  # Crop lowres tensor to match latent tensor's shape
743
- lh, lw = lowres.shape[-2:]
744
- h = lh // self.multiscale_lowres_size_factor
745
- w = lw // self.multiscale_lowres_size_factor
746
- h_pad = (lh - h) // 2
747
- w_pad = (lw - w) // 2
748
- lowres = lowres[:, :, h_pad:-h_pad, w_pad:-w_pad]
834
+ lz, ly, lx = lowres.shape[2:]
835
+ z = lz // self.multiscale_lowres_size_factor
836
+ y = ly // self.multiscale_lowres_size_factor
837
+ x = lx // self.multiscale_lowres_size_factor
838
+ z_pad = (lz - z) // 2
839
+ y_pad = (ly - y) // 2
840
+ x_pad = (lx - x) // 2
841
+ lowres = lowres[:, :, z_pad:-z_pad, y_pad:-y_pad, x_pad:-x_pad]
749
842
 
750
843
  return super().forward(latent, lowres)
751
844
 
752
845
 
753
846
  class SkipConnectionMerger(MergeLayer):
754
- """
755
- A specialized `MergeLayer` module, designed to handle skip connections in the model.
756
- """
847
+ """Specialized `MergeLayer` module, handles skip connections in the model."""
757
848
 
758
849
  def __init__(
759
850
  self,
@@ -762,10 +853,10 @@ class SkipConnectionMerger(MergeLayer):
762
853
  batchnorm: bool,
763
854
  dropout: float,
764
855
  res_block_type: str,
856
+ conv_strides: tuple[int] = (2, 2),
765
857
  merge_type: Literal["linear", "residual", "residual_ungated"] = "residual",
766
858
  conv2d_bias: bool = True,
767
- res_block_kernel: int = None,
768
- res_block_skip_padding: bool = False,
859
+ res_block_kernel: Optional[int] = None,
769
860
  ):
770
861
  """
771
862
  Constructor.
@@ -780,15 +871,15 @@ class SkipConnectionMerger(MergeLayer):
780
871
  If it is an Iterable (must have `len(channels)==3`):
781
872
  - 1st 1x1 Conv2d: in_channels=sum(channels[:-1]), out_channels=channels[-1]
782
873
  - (Optional) ResBlock: in_channels=channels[-1], out_channels=channels[-1]
783
- batchnorm: bool, optional
784
- Whether to use batchnorm layers. Default is `True`.
785
- dropout: float, optional
874
+ batchnorm: bool
875
+ Whether to use batchnorm layers.
876
+ dropout: float
786
877
  The dropout probability in dropout layers. If `None` dropout is not used.
787
- Default is `None`.
788
- res_block_type: str, optional
878
+ res_block_type: str
789
879
  A string specifying the structure of residual block.
790
880
  Check `ResidualBlock` doscstring for more information.
791
- Default is `None`.
881
+ conv_strides: tuple, optional
882
+ The strides used in the convolutions. Default is `(2, 2)`.
792
883
  merge_type: Literal["linear", "residual", "residual_ungated"]
793
884
  The type of merge done in the layer. It can be chosen between "linear", "residual", and "residual_ungated".
794
885
  Check the class docstring for more information about the behaviour of different merge modalities.
@@ -798,10 +889,9 @@ class SkipConnectionMerger(MergeLayer):
798
889
  The kernel size used in the convolutions of the residual block.
799
890
  It can be either a single integer or a pair of integers defining the squared kernel.
800
891
  Default is `None`.
801
- res_block_skip_padding: bool, optional
802
- Whether to skip padding in convolutions in the Residual block. Default is `False`.
803
892
  """
804
893
  super().__init__(
894
+ conv_strides=conv_strides,
805
895
  channels=channels,
806
896
  nonlin=nonlin,
807
897
  merge_type=merge_type,
@@ -810,26 +900,25 @@ class SkipConnectionMerger(MergeLayer):
810
900
  res_block_type=res_block_type,
811
901
  res_block_kernel=res_block_kernel,
812
902
  conv2d_bias=conv2d_bias,
813
- res_block_skip_padding=res_block_skip_padding,
814
903
  )
815
904
 
816
905
 
817
906
  class TopDownLayer(nn.Module):
818
- """
819
- Top-down inference layer.
907
+ """Top-down inference layer.
908
+
820
909
  It includes:
821
910
  - Stochastic sampling,
822
911
  - Computation of KL divergence,
823
912
  - A small deterministic ResNet that performs upsampling.
824
913
 
825
914
  NOTE 1:
826
- The algorithm for generative inference approximately works as follows:
827
- - p_params = output of top-down layer above
828
- - bu = inferred bottom-up value at this layer
829
- - q_params = merge(bu, p_params)
830
- - z = stochastic_layer(q_params)
831
- - (optional) get and merge skip connection from prev top-down layer
832
- - top-down deterministic ResNet
915
+ The algorithm for generative inference approximately works as follows:
916
+ - p_params = output of top-down layer above
917
+ - bu = inferred bottom-up value at this layer
918
+ - q_params = merge(bu, p_params)
919
+ - z = stochastic_layer(q_params)
920
+ - (optional) get and merge skip connection from prev top-down layer
921
+ - top-down deterministic ResNet
833
922
 
834
923
  NOTE 2:
835
924
  The Top-Down layer can work in two modes: inference and prediction/generative.
@@ -856,28 +945,26 @@ class TopDownLayer(nn.Module):
856
945
  z_dim: int,
857
946
  n_res_blocks: int,
858
947
  n_filters: int,
948
+ conv_strides: tuple[int],
859
949
  is_top_layer: bool = False,
860
- downsampling_steps: int = None,
861
- nonlin: Callable = None,
862
- merge_type: Literal["linear", "residual", "residual_ungated"] = None,
950
+ upsampling_steps: Union[int, None] = None,
951
+ nonlin: Union[Callable, None] = None,
952
+ merge_type: Union[
953
+ Literal["linear", "residual", "residual_ungated"], None
954
+ ] = None,
863
955
  batchnorm: bool = True,
864
- dropout: float = None,
956
+ dropout: Union[float, None] = None,
865
957
  stochastic_skip: bool = False,
866
- res_block_type: str = None,
867
- res_block_kernel: int = None,
868
- res_block_skip_padding: bool = None,
958
+ res_block_type: Union[str, None] = None,
959
+ res_block_kernel: Union[int, None] = None,
869
960
  groups: int = 1,
870
- gated: bool = None,
961
+ gated: Union[bool, None] = None,
871
962
  learn_top_prior: bool = False,
872
- top_prior_param_shape: Iterable[int] = None,
963
+ top_prior_param_shape: Union[Iterable[int], None] = None,
873
964
  analytical_kl: bool = False,
874
- bottomup_no_padding_mode: bool = False,
875
- topdown_no_padding_mode: bool = False,
876
965
  retain_spatial_dims: bool = False,
877
- restricted_kl: bool = False,
878
- vanilla_latent_hw: Iterable[int] = None,
879
- non_stochastic_version: bool = False,
880
- input_image_shape: Union[None, Tuple[int, int]] = None,
966
+ vanilla_latent_hw: Union[Iterable[int], None] = None,
967
+ input_image_shape: Union[tuple[int, int], None] = None,
881
968
  normalize_latent_factor: float = 1.0,
882
969
  conv2d_bias: bool = True,
883
970
  stochastic_use_naive_exponential: bool = False,
@@ -893,11 +980,13 @@ class TopDownLayer(nn.Module):
893
980
  The number of TopDownDeterministicResBlock blocks
894
981
  n_filters: int
895
982
  The number of channels present through out the layers of this block.
983
+ conv_strides: tuple, optional
984
+ The strides used in the convolutions. Default is `(2, 2)`.
896
985
  is_top_layer: bool, optional
897
986
  Whether the current layer is at the top of the Decoder hierarchy. Default is `False`.
898
- downsampling_steps: int, optional
899
- The number of downsampling steps that has to be done in this layer (typically 1).
900
- Default is `False`.
987
+ upsampling_steps: int, optional
988
+ The number of upsampling steps that has to be done in this layer (typically 1).
989
+ Default is `None`.
901
990
  nonlin: Callable, optional
902
991
  The non-linearity function used in the block (e.g., `nn.ReLU`). Default is `None`.
903
992
  merge_type: Literal["linear", "residual", "residual_ungated"], optional
@@ -921,8 +1010,6 @@ class TopDownLayer(nn.Module):
921
1010
  The kernel size used in the convolutions of the residual block.
922
1011
  It can be either a single integer or a pair of integers defining the squared kernel.
923
1012
  Default is `None`.
924
- res_block_skip_padding: bool, optional
925
- Whether to skip padding in convolutions in the Residual block. Default is `None`.
926
1013
  groups: int, optional
927
1014
  The number of groups to consider in the convolutions. Default is 1.
928
1015
  gated: bool, optional
@@ -939,33 +1026,14 @@ class TopDownLayer(nn.Module):
939
1026
  If True, KL divergence is calculated according to the analytical formula.
940
1027
  Otherwise, an MC approximation using sampled latents is calculated.
941
1028
  Default is `False`.
942
- bottomup_no_padding_mode: bool, optional
943
- Whether padding is used in the different layers of the bottom-up pass.
944
- It is meaningful to know this in advance in order to assess whether before
945
- merging `bu_values` and `p_params` tensors any alignment is needed.
946
- Default is `False`.
947
- topdown_no_padding_mode: bool, optional
948
- Whether padding is used in the different layers of the top-down pass.
949
- It is meaningful to know this in advance in order to assess whether before
950
- merging `bu_values` and `p_params` tensors any alignment is needed.
951
- The same information is also needed in handling the skip connections between
952
- top-down layers. Default is `False`.
953
1029
  retain_spatial_dims: bool, optional
954
1030
  If `True`, the size of Encoder's latent space is kept to `input_image_shape` within the topdown layer.
955
1031
  This implies that the oput spatial size equals the input spatial size.
956
1032
  To achieve this, we centercrop the intermediate representation.
957
1033
  Default is `False`.
958
- restricted_kl: bool, optional
959
- Whether to compute the restricted version of KL Divergence.
960
- See `NormalStochasticBlock2d` module for more information about its computation.
961
- Default is `False`.
962
1034
  vanilla_latent_hw: Iterable[int], optional
963
1035
  The shape of the latent tensor used for prediction (i.e., it influences the computation of restricted KL).
964
1036
  Default is `None`.
965
- non_stochastic_version: bool, optional
966
- Whether to replace the stochastic layer that samples a latent variable from the latent distribiution with
967
- a non-stochastic layer that simply drwas a sample as the mode of the latent distribution.
968
- Default is `False`.
969
1037
  input_image_shape: Tuple[int, int], optionalut
970
1038
  The shape of the input image tensor.
971
1039
  When `retain_spatial_dims` is set to `True`, this is used to ensure that the shape of this layer
@@ -990,13 +1058,13 @@ class TopDownLayer(nn.Module):
990
1058
  self.stochastic_skip = stochastic_skip
991
1059
  self.learn_top_prior = learn_top_prior
992
1060
  self.analytical_kl = analytical_kl
993
- self.bottomup_no_padding_mode = bottomup_no_padding_mode
994
- self.topdown_no_padding_mode = topdown_no_padding_mode
995
1061
  self.retain_spatial_dims = retain_spatial_dims
996
- self.latent_shape = input_image_shape if self.retain_spatial_dims else None
997
- self.non_stochastic_version = non_stochastic_version
1062
+ self.input_image_shape = (
1063
+ input_image_shape if len(conv_strides) == 3 else input_image_shape[1:]
1064
+ )
1065
+ self.latent_shape = self.input_image_shape if self.retain_spatial_dims else None
998
1066
  self.normalize_latent_factor = normalize_latent_factor
999
- self._vanilla_latent_hw = vanilla_latent_hw
1067
+ self._vanilla_latent_hw = vanilla_latent_hw # TODO: check this, it is not used
1000
1068
 
1001
1069
  # Define top layer prior parameters, possibly learnable
1002
1070
  if is_top_layer:
@@ -1004,28 +1072,28 @@ class TopDownLayer(nn.Module):
1004
1072
  torch.zeros(top_prior_param_shape), requires_grad=learn_top_prior
1005
1073
  )
1006
1074
 
1007
- # Downsampling steps left to do in this layer
1008
- dws_left = downsampling_steps
1075
+ # Upsampling steps left to do in this layer
1076
+ ups_left = upsampling_steps
1009
1077
 
1010
1078
  # Define deterministic top-down block, which is a sequence of deterministic
1011
- # residual blocks with (optional) downsampling.
1079
+ # residual blocks with (optional) upsampling.
1012
1080
  block_list = []
1013
1081
  for _ in range(n_res_blocks):
1014
1082
  do_resample = False
1015
- if dws_left > 0:
1083
+ if ups_left > 0:
1016
1084
  do_resample = True
1017
- dws_left -= 1
1085
+ ups_left -= 1
1018
1086
  block_list.append(
1019
1087
  TopDownDeterministicResBlock(
1020
1088
  c_in=n_filters,
1021
1089
  c_out=n_filters,
1090
+ conv_strides=conv_strides,
1022
1091
  nonlin=nonlin,
1023
1092
  upsample=do_resample,
1024
1093
  batchnorm=batchnorm,
1025
1094
  dropout=dropout,
1026
1095
  res_block_type=res_block_type,
1027
1096
  res_block_kernel=res_block_kernel,
1028
- skip_padding=res_block_skip_padding,
1029
1097
  gated=gated,
1030
1098
  conv2d_bias=conv2d_bias,
1031
1099
  groups=groups,
@@ -1033,32 +1101,24 @@ class TopDownLayer(nn.Module):
1033
1101
  )
1034
1102
  self.deterministic_block = nn.Sequential(*block_list)
1035
1103
 
1036
- # Define stochastic block with 2D convolutions
1037
- if self.non_stochastic_version:
1038
- self.stochastic = NonStochasticBlock2d(
1039
- c_in=n_filters,
1040
- c_vars=z_dim,
1041
- c_out=n_filters,
1042
- transform_p_params=(not is_top_layer),
1043
- groups=groups,
1044
- conv2d_bias=conv2d_bias,
1045
- )
1046
- else:
1047
- self.stochastic = NormalStochasticBlock2d(
1048
- c_in=n_filters,
1049
- c_vars=z_dim,
1050
- c_out=n_filters,
1051
- transform_p_params=(not is_top_layer),
1052
- vanilla_latent_hw=vanilla_latent_hw,
1053
- restricted_kl=restricted_kl,
1054
- use_naive_exponential=stochastic_use_naive_exponential,
1055
- )
1104
+ # Define stochastic block with convolutions
1105
+
1106
+ self.stochastic = NormalStochasticBlock(
1107
+ c_in=n_filters,
1108
+ c_vars=z_dim,
1109
+ c_out=n_filters,
1110
+ conv_dims=len(conv_strides),
1111
+ transform_p_params=(not is_top_layer),
1112
+ vanilla_latent_hw=vanilla_latent_hw,
1113
+ use_naive_exponential=stochastic_use_naive_exponential,
1114
+ )
1056
1115
 
1057
1116
  if not is_top_layer:
1058
1117
  # Merge layer: it combines bottom-up inference and top-down
1059
1118
  # generative outcomes to give posterior parameters
1060
1119
  self.merge = MergeLayer(
1061
1120
  channels=n_filters,
1121
+ conv_strides=conv_strides,
1062
1122
  merge_type=merge_type,
1063
1123
  nonlin=nonlin,
1064
1124
  batchnorm=batchnorm,
@@ -1072,6 +1132,7 @@ class TopDownLayer(nn.Module):
1072
1132
  if stochastic_skip:
1073
1133
  self.skip_connection_merger = SkipConnectionMerger(
1074
1134
  channels=n_filters,
1135
+ conv_strides=conv_strides,
1075
1136
  nonlin=nonlin,
1076
1137
  batchnorm=batchnorm,
1077
1138
  dropout=dropout,
@@ -1079,28 +1140,27 @@ class TopDownLayer(nn.Module):
1079
1140
  merge_type=merge_type,
1080
1141
  conv2d_bias=conv2d_bias,
1081
1142
  res_block_kernel=res_block_kernel,
1082
- res_block_skip_padding=res_block_skip_padding,
1083
1143
  )
1084
1144
 
1085
- # print(f'[{self.__class__.__name__}] normalize_latent_factor:{self.normalize_latent_factor}')
1086
-
1087
1145
  def sample_from_q(
1088
1146
  self,
1089
1147
  input_: torch.Tensor,
1090
1148
  bu_value: torch.Tensor,
1091
- var_clip_max: float = None,
1149
+ var_clip_max: Optional[float] = None,
1092
1150
  mask: torch.Tensor = None,
1093
1151
  ) -> torch.Tensor:
1094
1152
  """
1095
- This method computes the latent inference distribution q(z_i|z_{i+1}) amd samples a latent tensor from it.
1153
+ Method computes the latent inference distribution q(z_i|z_{i+1}).
1154
+
1155
+ Used for sampling a latent tensor from it.
1096
1156
 
1097
1157
  Parameters
1098
1158
  ----------
1099
1159
  input_: torch.Tensor
1100
- The input tensor to the layer, which is the output of the top-down layer above.
1160
+ The input tensor to the layer, which is the output of the top-down layer.
1101
1161
  bu_value: torch.Tensor
1102
- The tensor defining the parameters /mu_q and /sigma_q computed during the bottom-up deterministic pass
1103
- at the correspondent hierarchical layer.
1162
+ The tensor defining the parameters /mu_q and /sigma_q computed during the
1163
+ bottom-up deterministic pass at the correspondent hierarchical layer.
1104
1164
  var_clip_max: float, optional
1105
1165
  The maximum value reachable by the log-variance of the latent distribution.
1106
1166
  Values exceeding this threshold are clipped. Default is `None`.
@@ -1127,9 +1187,11 @@ class TopDownLayer(nn.Module):
1127
1187
  input_: torch.Tensor,
1128
1188
  n_img_prior: int,
1129
1189
  ) -> torch.Tensor:
1130
- """
1131
- This method returns the parameters of the prior distribution p(z_i|z_{i+1}) for the latent tensor
1132
- depending on the hierarchical level of the layer and other specific conditions.
1190
+ """Return the parameters of the prior distribution p(z_i|z_{i+1}).
1191
+
1192
+ The parameters depend on the hierarchical level of the layer:
1193
+ - if it is the topmost level, parameters are the ones of the prior.
1194
+ - else, the input from the layer above is the parameters itself.
1133
1195
 
1134
1196
  Parameters
1135
1197
  ----------
@@ -1154,81 +1216,56 @@ class TopDownLayer(nn.Module):
1154
1216
 
1155
1217
  return p_params
1156
1218
 
1157
- def align_pparams_buvalue(
1158
- self, p_params: torch.Tensor, bu_value: torch.Tensor
1159
- ) -> Tuple[torch.Tensor, torch.Tensor]:
1160
- """
1161
- In case the padding is not used either (or both) in encoder and decoder, we could have a shape mismatch
1162
- in the spatial dimensions (usually, dim=2 & dim=3).
1163
- This method performs a centercrop to ensure that both remain aligned.
1164
-
1165
- Parameters
1166
- ----------
1167
- p_params: torch.Tensor
1168
- The tensor defining the parameters /mu_p and /sigma_p for the latent distribution p(z_i|z_{i+1}).
1169
- bu_value: torch.Tensor
1170
- The tensor defining the parameters /mu_q and /sigma_q computed during the bottom-up deterministic pass
1171
- at the correspondent hierarchical layer.
1172
- """
1173
- if bu_value.shape[-2:] != p_params.shape[-2:]:
1174
- assert self.bottomup_no_padding_mode is True # TODO WTF ?
1175
- if self.topdown_no_padding_mode is False:
1176
- assert bu_value.shape[-1] > p_params.shape[-1]
1177
- bu_value = F.center_crop(bu_value, p_params.shape[-2:])
1178
- else:
1179
- if bu_value.shape[-1] > p_params.shape[-1]:
1180
- bu_value = F.center_crop(bu_value, p_params.shape[-2:])
1181
- else:
1182
- p_params = F.center_crop(p_params, bu_value.shape[-2:])
1183
- return p_params, bu_value
1184
-
1185
1219
  def forward(
1186
1220
  self,
1187
- input_: torch.Tensor = None,
1188
- skip_connection_input: torch.Tensor = None,
1221
+ input_: Union[torch.Tensor, None] = None,
1222
+ skip_connection_input: Union[torch.Tensor, None] = None,
1189
1223
  inference_mode: bool = False,
1190
- bu_value: torch.Tensor = None,
1191
- n_img_prior: int = None,
1192
- forced_latent: torch.Tensor = None,
1193
- use_mode: bool = False,
1224
+ bu_value: Union[torch.Tensor, None] = None,
1225
+ n_img_prior: Union[int, None] = None,
1226
+ forced_latent: Union[torch.Tensor, None] = None,
1194
1227
  force_constant_output: bool = False,
1195
1228
  mode_pred: bool = False,
1196
1229
  use_uncond_mode: bool = False,
1197
- var_clip_max: float = None,
1198
- ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
1199
- """
1230
+ var_clip_max: Union[float, None] = None,
1231
+ ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
1232
+ """Forward pass.
1233
+
1200
1234
  Parameters
1201
1235
  ----------
1202
1236
  input_: torch.Tensor, optional
1203
- The input tensor to the layer, which is the output of the top-down layer above.
1237
+ The input tensor to the layer, which is the output of the top-down layer.
1204
1238
  Default is `None`.
1205
1239
  skip_connection_input: torch.Tensor, optional
1206
- The tensor brought by the skip connection between the current and the previous top-down layer.
1240
+ The tensor brought by the skip connection between the current and the
1241
+ previous top-down layer.
1207
1242
  Default is `None`.
1208
1243
  inference_mode: bool, optional
1209
- Whether the layer is in inference mode. See NOTE 2 in class description for more info.
1244
+ Whether the layer is in inference mode. See NOTE 2 in class description
1245
+ for more info.
1210
1246
  Default is `False`.
1211
1247
  bu_value: torch.Tensor, optional
1212
- The tensor defining the parameters /mu_q and /sigma_q computed during the bottom-up deterministic pass
1248
+ The tensor defining the parameters /mu_q and /sigma_q computed during the
1249
+ bottom-up deterministic pass
1213
1250
  at the correspondent hierarchical layer. Default is `None`.
1214
1251
  n_img_prior: int, optional
1215
- The number of images to be generated from the unconditional prior distribution p(z_L).
1252
+ The number of images to be generated from the unconditional prior
1253
+ distribution p(z_L).
1216
1254
  Default is `None`.
1217
1255
  forced_latent: torch.Tensor, optional
1218
- A pre-defined latent tensor. If it is not `None`, than it is used as the actual latent tensor and,
1256
+ A pre-defined latent tensor. If it is not `None`, than it is used as the
1257
+ actual latent tensor and,
1219
1258
  hence, sampling does not happen. Default is `None`.
1220
- use_mode: bool, optional
1221
- Whether the latent tensor should be set as the latent distribution mode.
1222
- In the case of Gaussian, the mode coincides with the mean of the distribution.
1223
- Default is `False`.
1224
1259
  force_constant_output: bool, optional
1225
- Whether to copy the first sample (and rel. distrib parameters) over the whole batch.
1260
+ Whether to copy the first sample (and rel. distrib parameters) over the
1261
+ whole batch.
1226
1262
  This is used when doing experiment from the prior - q is not used.
1227
1263
  Default is `False`.
1228
1264
  mode_pred: bool, optional
1229
1265
  Whether the model is in prediction mode. Default is `False`.
1230
1266
  use_uncond_mode: bool, optional
1231
- Whether to use the uncoditional distribution p(z) to sample latents in prediction mode.
1267
+ Whether to use the uncoditional distribution p(z) to sample latents in
1268
+ prediction mode.
1232
1269
  var_clip_max: float
1233
1270
  The maximum value reachable by the log-variance of the latent distribution.
1234
1271
  Values exceeding this threshold are clipped.
@@ -1241,26 +1278,33 @@ class TopDownLayer(nn.Module):
1241
1278
  p_params = self.get_p_params(input_, n_img_prior)
1242
1279
 
1243
1280
  # Get the parameters for the latent distribution to sample from
1244
- if inference_mode: # TODO What's this ?
1281
+ if inference_mode: # TODO What's this ? reuse Fede's code?
1245
1282
  if self.is_top_layer:
1246
1283
  q_params = bu_value
1247
1284
  if mode_pred is False:
1248
- p_params, bu_value = self.align_pparams_buvalue(p_params, bu_value)
1285
+ assert p_params.shape[2:] == bu_value.shape[2:], (
1286
+ "Spatial dimensions of p_params and bu_value should match. "
1287
+ f"Instead, we got p_params={p_params.shape[2:]} and "
1288
+ f"bu_value={bu_value.shape[2:]}."
1289
+ )
1249
1290
  else:
1250
1291
  if use_uncond_mode:
1251
1292
  q_params = p_params
1252
1293
  else:
1253
- p_params, bu_value = self.align_pparams_buvalue(p_params, bu_value)
1294
+ assert p_params.shape[2:] == bu_value.shape[2:], (
1295
+ "Spatial dimensions of p_params and bu_value should match. "
1296
+ f"Instead, we got p_params={p_params.shape[2:]} and "
1297
+ f"bu_value={bu_value.shape[2:]}."
1298
+ )
1254
1299
  q_params = self.merge(bu_value, p_params)
1255
- # In generative mode, q is not used
1256
- else:
1300
+ else: # generative mode, q is not used, we sample from p(z_i | z_{i+1})
1257
1301
  q_params = None
1258
1302
 
1259
1303
  # NOTE: Sampling is done either from q(z_i | z_{i+1}, x) or p(z_i | z_{i+1})
1260
1304
  # depending on the mode (hence, in practice, by checking whether q_params is None).
1261
1305
 
1262
- # Normalization of latent space parameters:
1263
- # it is done, purely for stablity. See Very deep VAEs generalize autoregressive models.
1306
+ # Normalization of latent space parameters for stablity.
1307
+ # See Very deep VAEs generalize autoregressive models.
1264
1308
  if self.normalize_latent_factor:
1265
1309
  q_params = q_params / self.normalize_latent_factor
1266
1310
 
@@ -1269,52 +1313,44 @@ class TopDownLayer(nn.Module):
1269
1313
  p_params=p_params,
1270
1314
  q_params=q_params,
1271
1315
  forced_latent=forced_latent,
1272
- use_mode=use_mode,
1273
1316
  force_constant_output=force_constant_output,
1274
1317
  analytical_kl=self.analytical_kl,
1275
1318
  mode_pred=mode_pred,
1276
1319
  use_uncond_mode=use_uncond_mode,
1277
1320
  var_clip_max=var_clip_max,
1278
1321
  )
1279
-
1280
1322
  # Merge skip connection from previous layer
1281
1323
  if self.stochastic_skip and not self.is_top_layer:
1282
- if self.topdown_no_padding_mode is True:
1283
- # If no padding is done in the current top-down pass, there may be a shape mismatch between current tensor and skip connection input.
1284
- # As an example, if the output of last TopDownLayer was of size 64*64, due to lack of padding in the current layer, the current tensor
1285
- # might become different in shape, say 60*60.
1286
- # In order to avoid shape mismatch, we do central crop of the skip connection input.
1287
- skip_connection_input = F.center_crop(
1288
- skip_connection_input, x.shape[-2:]
1289
- )
1290
-
1291
1324
  x = self.skip_connection_merger(x, skip_connection_input)
1292
-
1293
- # Save activation before residual block as it can be the skip connection input in the next layer
1294
- x_pre_residual = x
1295
-
1296
1325
  if self.retain_spatial_dims:
1297
- # when we don't want to do padding in topdown as well, we need to spare some boundary pixels which would be used up.
1298
- extra_len = (self.topdown_no_padding_mode is True) * 3
1299
-
1300
- # this means that x should be of the same size as config.data.image_size. So, we have to centercrop by a factor of 2 at this point.
1301
- # assert x.shape[-1] >= self.latent_shape[-1] // 2 + extra_len
1302
- # we assume that one topdown layer will have exactly one upscaling layer.
1303
- new_latent_shape = (
1304
- self.latent_shape[0] // 2 + extra_len,
1305
- self.latent_shape[1] // 2 + extra_len,
1306
- )
1307
-
1308
- # If the LC is not applied on all layers, then this can happen.
1326
+ # NOTE: we assume that one topdown layer will have exactly one upscaling layer.
1327
+
1328
+ # NOTE: in case, in the Bottom-Up layer, LC retains spatial dimensions,
1329
+ # we have the following (see `MergeLowRes`):
1330
+ # - the "primary-flow" tensor is padded to match the low-res patch size
1331
+ # (e.g., from 32x32 to 64x64)
1332
+ # - padded tensor is then merged with the low-res patch (concatenation
1333
+ # along dim=1 + convolution)
1334
+ # Therefore, we need to do the symmetric operation here, that is to
1335
+ # crop `x` for the same amount we padded it in the correspondent BU layer.
1336
+
1337
+ # NOTE: cropping is done to retain the shape of the input in the output.
1338
+ # Therefore we need it only in the case `x` is the same shape of the input,
1339
+ # because that's the only case in which we need to retain the shape.
1340
+ # Here, it must be strictly greater than half the input shape, which is
1341
+ # the case if and only if `x.shape == self.latent_shape`.
1342
+ rescale = (
1343
+ np.array((1, 2, 2)) if len(self.latent_shape) == 3 else np.array((2, 2))
1344
+ ) # TODO better way?
1345
+ new_latent_shape = tuple(np.array(self.latent_shape) // rescale)
1309
1346
  if x.shape[-1] > new_latent_shape[-1]:
1310
- x = F.center_crop(x, new_latent_shape)
1311
-
1312
- # Last top-down block (sequence of residual blocks)
1347
+ x = crop_img_tensor(x, new_latent_shape)
1348
+ # TODO: `retain_spatial_dims` is the same for all the TD layers.
1349
+ # How to handle the case in which we do not have LC for all layers?
1350
+ # The answer is in `self.latent_shape`, which is equal to `input_image_shape`
1351
+ # (e.g., (64, 64)) if `retain_spatial_dims` is `True`, else it is `None`.
1352
+ # Last top-down block (sequence of residual blocks w\ upsampling)
1313
1353
  x = self.deterministic_block(x)
1314
-
1315
- if self.topdown_no_padding_mode:
1316
- x = F.center_crop(x, self.latent_shape)
1317
-
1318
1354
  # Save some metrics that will be used in the loss computation
1319
1355
  keys = [
1320
1356
  "z",
@@ -1322,7 +1358,6 @@ class TopDownLayer(nn.Module):
1322
1358
  "kl_samplewise_restricted",
1323
1359
  "kl_spatial",
1324
1360
  "kl_channelwise",
1325
- # 'logprob_p',
1326
1361
  "logprob_q",
1327
1362
  "qvar_max",
1328
1363
  ]
@@ -1333,666 +1368,4 @@ class TopDownLayer(nn.Module):
1333
1368
  q_mu, q_lv = data_stoch["q_params"]
1334
1369
  data["q_mu"] = q_mu
1335
1370
  data["q_lv"] = q_lv
1336
-
1337
- return x, x_pre_residual, data
1338
-
1339
-
1340
- class NormalStochasticBlock2d(nn.Module):
1341
- """
1342
- Stochastic block used in the Top-Down inference pass.
1343
-
1344
- Algorithm:
1345
- - map input parameters to q(z) and (optionally) p(z) via convolution
1346
- - sample a latent tensor z ~ q(z)
1347
- - feed z to convolution and return.
1348
-
1349
- NOTE 1:
1350
- If parameters for q are not given, sampling is done from p(z).
1351
-
1352
- NOTE 2:
1353
- The restricted KL divergence is obtained by first computing the element-wise KL divergence
1354
- (i.e., the KL computed for each element of the latent tensors). Then, the restricted version
1355
- is computed by summing over the channels and the spatial dimensions associated only to the
1356
- portion of the latent tensor that is used for prediction.
1357
- """
1358
-
1359
- def __init__(
1360
- self,
1361
- c_in: int,
1362
- c_vars: int,
1363
- c_out: int,
1364
- kernel: int = 3,
1365
- transform_p_params: bool = True,
1366
- vanilla_latent_hw: int = None,
1367
- restricted_kl: bool = False,
1368
- use_naive_exponential: bool = False,
1369
- ):
1370
- """
1371
- Parameters
1372
- ----------
1373
- c_in: int
1374
- The number of channels of the input tensor.
1375
- c_vars: int
1376
- The number of channels of the latent space tensor.
1377
- c_out: int
1378
- The output of the stochastic layer.
1379
- Note that this is different from the sampled latent z.
1380
- kernel: int, optional
1381
- The size of the kernel used in convolutional layers.
1382
- Default is 3.
1383
- transform_p_params: bool, optional
1384
- Whether a transformation should be applied to the `p_params` tensor.
1385
- The transformation consists in a 2D convolution ()`conv_in_p()`) that
1386
- maps the input to a larger number of channels.
1387
- Default is `True`.
1388
- vanilla_latent_hw: int, optional
1389
- The shape of the latent tensor used for prediction (i.e., it influences the computation of restricted KL).
1390
- Default is `None`.
1391
- restricted_kl: bool, optional
1392
- Whether to compute the restricted version of KL Divergence.
1393
- See NOTE 2 for more information about its computation.
1394
- Default is `False`.
1395
- use_naive_exponential: bool, optional
1396
- If `False`, exponentials are computed according to the alternative definition
1397
- provided by `StableExponential` class. This should improve numerical stability
1398
- in the training process. Default is `False`.
1399
- """
1400
- super().__init__()
1401
- assert kernel % 2 == 1
1402
- pad = kernel // 2
1403
- self.transform_p_params = transform_p_params
1404
- self.c_in = c_in
1405
- self.c_out = c_out
1406
- self.c_vars = c_vars
1407
- self._use_naive_exponential = use_naive_exponential
1408
- self._vanilla_latent_hw = vanilla_latent_hw
1409
- self._restricted_kl = restricted_kl
1410
-
1411
- if transform_p_params:
1412
- self.conv_in_p = nn.Conv2d(c_in, 2 * c_vars, kernel, padding=pad)
1413
- self.conv_in_q = nn.Conv2d(c_in, 2 * c_vars, kernel, padding=pad)
1414
- self.conv_out = nn.Conv2d(c_vars, c_out, kernel, padding=pad)
1415
-
1416
- # def forward_swapped(self, p_params, q_mu, q_lv):
1417
- #
1418
- # if self.transform_p_params:
1419
- # p_params = self.conv_in_p(p_params)
1420
- # else:
1421
- # assert p_params.size(1) == 2 * self.c_vars
1422
- #
1423
- # # Define p(z)
1424
- # p_mu, p_lv = p_params.chunk(2, dim=1)
1425
- # p = Normal(p_mu, (p_lv / 2).exp())
1426
- #
1427
- # # Define q(z)
1428
- # q = Normal(q_mu, (q_lv / 2).exp())
1429
- # # Sample from q(z)
1430
- # sampling_distrib = q
1431
- #
1432
- # # Generate latent variable (typically by sampling)
1433
- # z = sampling_distrib.rsample()
1434
- #
1435
- # # Output of stochastic layer
1436
- # out = self.conv_out(z)
1437
- #
1438
- # data = {
1439
- # 'z': z, # sampled variable at this layer (batch, ch, h, w)
1440
- # 'p_params': p_params, # (b, ch, h, w) where b is 1 or batch size
1441
- # }
1442
- # return out, data
1443
-
1444
- def get_z(
1445
- self,
1446
- sampling_distrib: torch.distributions.normal.Normal,
1447
- forced_latent: torch.Tensor,
1448
- use_mode: bool,
1449
- mode_pred: bool,
1450
- use_uncond_mode: bool,
1451
- ) -> torch.Tensor:
1452
- """
1453
- This method enables to sample a latent tensor given the distribution to sample from.
1454
-
1455
- Latent variable can be obtained is several ways:
1456
- - Sampled from the (Gaussian) latent distribution.
1457
- - Taken as a pre-defined forced latent.
1458
- - Taken as the mode (mean) of the latent distribution.
1459
- - In prediction mode (`mode_pred==True`), can be either sample or taken as the distribution mode.
1460
-
1461
- Parameters
1462
- ----------
1463
- sampling_distrib: torch.distributions.normal.Normal
1464
- The Gaussian distribution from which latent tensor is sampled.
1465
- forced_latent: torch.Tensor
1466
- A pre-defined latent tensor. If it is not `None`, than it is used as the actual latent tensor and,
1467
- hence, sampling does not happen.
1468
- use_mode: bool
1469
- Whether the latent tensor should be set as the latent distribution mode.
1470
- In the case of Gaussian, the mode coincides with the mean of the distribution.
1471
- mode_pred: bool
1472
- Whether the model is prediction mode.
1473
- use_uncond_mode: bool
1474
- Whether to use the uncoditional distribution p(z) to sample latents in prediction mode.
1475
- """
1476
- if forced_latent is None:
1477
- if use_mode:
1478
- z = sampling_distrib.mean
1479
- else:
1480
- if mode_pred:
1481
- if use_uncond_mode:
1482
- z = sampling_distrib.mean
1483
- else:
1484
- z = sampling_distrib.rsample()
1485
- else:
1486
- z = sampling_distrib.rsample()
1487
- else:
1488
- z = forced_latent
1489
- return z
1490
-
1491
- def sample_from_q(
1492
- self, q_params: torch.Tensor, var_clip_max: float
1493
- ) -> torch.Tensor:
1494
- """
1495
- Given an input parameter tensor defining q(z),
1496
- it processes it by calling `process_q_params()` method and
1497
- sample a latent tensor from the resulting distribution.
1498
-
1499
- Parameters
1500
- ----------
1501
- q_params: torch.Tensor
1502
- The input tensor to be processed.
1503
- var_clip_max: float
1504
- The maximum value reachable by the log-variance of the latent distribution.
1505
- Values exceeding this threshold are clipped.
1506
- """
1507
- _, _, q = self.process_q_params(q_params, var_clip_max)
1508
- return q.rsample()
1509
-
1510
- def compute_kl_metrics(
1511
- self,
1512
- p: torch.distributions.normal.Normal,
1513
- p_params: torch.Tensor,
1514
- q: torch.distributions.normal.Normal,
1515
- q_params: torch.Tensor,
1516
- mode_pred: bool,
1517
- analytical_kl: bool,
1518
- z: torch.Tensor,
1519
- ) -> Dict[str, torch.Tensor]:
1520
- """
1521
- Compute KL (analytical or MC estimate) and then process it, extracting composed versions of the metric.
1522
- Specifically, the different versions of the KL loss terms are:
1523
- - `kl_elementwise`: KL term for each single element of the latent tensor [Shape: (batch, ch, h, w)].
1524
- - `kl_samplewise`: KL term associated to each sample in the batch [Shape: (batch, )].
1525
- - `kl_samplewise_restricted`: KL term only associated to the portion of the latent tensor that is
1526
- used for prediction and summed over channel and spatial dimensions [Shape: (batch, )].
1527
- - `kl_channelwise`: KL term associated to each sample and each channel [Shape: (batch, ch, )].
1528
- - `kl_spatial`: KL term summed over the channels, i.e., retaining the spatial dimensions [Shape: (batch, h, w)]
1529
-
1530
- Parameters
1531
- ----------
1532
- p: torch.distributions.normal.Normal
1533
- The prior generative distribution p(z_i|z_{i+1}) (or p(z_L)).
1534
- p_params: torch.Tensor
1535
- The parameters of the prior generative distribution.
1536
- q: torch.distributions.normal.Normal
1537
- The inference distribution q(z_i|z_{i+1}) (or q(z_L|x)).
1538
- q_params: torch.Tensor
1539
- The parameters of the inference distribution.
1540
- mode_pred: bool
1541
- Whether the model is in prediction mode.
1542
- analytical_kl: bool
1543
- Whether to compute the KL divergence analytically or using Monte Carlo estimation.
1544
- z: torch.Tensor
1545
- The sampled latent tensor.
1546
- """
1547
- kl_samplewise_restricted = None
1548
-
1549
- if mode_pred is False: # if not in prediction mode
1550
- # KL term for each single element of the latent tensor [Shape: (batch, ch, h, w)]
1551
- if analytical_kl:
1552
- kl_elementwise = kl_divergence(q, p)
1553
- else:
1554
- kl_elementwise = kl_normal_mc(z, p_params, q_params)
1555
-
1556
- # KL term only associated to the portion of the latent tensor that is used for prediction and
1557
- # summed over channel and spatial dimensions. [Shape: (batch, )]
1558
- # NOTE: vanilla_latent_hw is the shape of the latent tensor used for prediction, hence
1559
- # the restriction has shape [Shape: (batch, ch, vanilla_latent_hw[0], vanilla_latent_hw[1])]
1560
- if self._restricted_kl:
1561
- pad = (kl_elementwise.shape[-1] - self._vanilla_latent_hw) // 2
1562
- assert pad > 0, "Disable restricted kl since there is no restriction."
1563
- tmp = kl_elementwise[..., pad:-pad, pad:-pad]
1564
- kl_samplewise_restricted = tmp.sum((1, 2, 3))
1565
-
1566
- # KL term associated to each sample in the batch [Shape: (batch, )]
1567
- kl_samplewise = kl_elementwise.sum((1, 2, 3))
1568
-
1569
- # KL term associated to each sample and each channel [Shape: (batch, ch, )]
1570
- kl_channelwise = kl_elementwise.sum((2, 3))
1571
-
1572
- # KL term summed over the channels, i.e., retaining the spatial dimensions [Shape: (batch, h, w)]
1573
- kl_spatial = kl_elementwise.sum(1)
1574
- else: # if predicting, no need to compute KL
1575
- kl_elementwise = kl_samplewise = kl_spatial = kl_channelwise = None
1576
-
1577
- kl_dict = {
1578
- "kl_elementwise": kl_elementwise, # (batch, ch, h, w)
1579
- "kl_samplewise": kl_samplewise, # (batch, )
1580
- "kl_samplewise_restricted": kl_samplewise_restricted, # (batch, )
1581
- "kl_channelwise": kl_channelwise, # (batch, ch)
1582
- "kl_spatial": kl_spatial, # (batch, h, w)
1583
- }
1584
- return kl_dict
1585
-
1586
- def process_p_params(
1587
- self, p_params: torch.Tensor, var_clip_max: float
1588
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.distributions.normal.Normal]:
1589
- """
1590
- Process the input parameters to get the prior distribution p(z_i|z_{i+1}) (or p(z_L)).
1591
-
1592
- Processing consists in:
1593
- - (optionally) 2D convolution on the input tensor to increase number of channels.
1594
- - split the resulting tensor into two chunks, the mean and the log-variance.
1595
- - (optionally) clip the log-variance to an upper threshold.
1596
- - define the normal distribution p(z) given the parameter tensors above.
1597
-
1598
- Parameters
1599
- ----------
1600
- p_params: torch.Tensor
1601
- The input tensor to be processed.
1602
- var_clip_max: float
1603
- The maximum value reachable by the log-variance of the latent distribution.
1604
- Values exceeding this threshold are clipped.
1605
- """
1606
- if self.transform_p_params:
1607
- p_params = self.conv_in_p(p_params)
1608
- else:
1609
- assert p_params.size(1) == 2 * self.c_vars
1610
-
1611
- # Define p(z)
1612
- p_mu, p_lv = p_params.chunk(2, dim=1)
1613
- if var_clip_max is not None:
1614
- p_lv = torch.clip(p_lv, max=var_clip_max)
1615
-
1616
- p_mu = StableMean(p_mu)
1617
- p_lv = StableLogVar(p_lv, enable_stable=not self._use_naive_exponential)
1618
- p = Normal(p_mu.get(), p_lv.get_std())
1619
- return p_mu, p_lv, p
1620
-
1621
- def process_q_params(
1622
- self, q_params: torch.Tensor, var_clip_max: float, allow_oddsizes: bool = False
1623
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.distributions.normal.Normal]:
1624
- """
1625
- Process the input parameters to get the inference distribution q(z_i|z_{i+1}) (or q(z|x)).
1626
-
1627
- Processing consists in:
1628
- - 2D convolution on the input tensor to increase number of channels.
1629
- - split the resulting tensor into two chunks, the mean and the log-variance.
1630
- - (optionally) clip the log-variance to an upper threshold.
1631
- - (optionally) crop the resulting tensors to ensure that the last spatial dimension is even.
1632
- - define the normal distribution q(z) given the parameter tensors above.
1633
-
1634
- Parameters
1635
- ----------
1636
- p_params: torch.Tensor
1637
- The input tensor to be processed.
1638
- var_clip_max: float
1639
- The maximum value reachable by the log-variance of the latent distribution.
1640
- Values exceeding this threshold are clipped.
1641
- """
1642
- q_params = self.conv_in_q(q_params)
1643
-
1644
- q_mu, q_lv = q_params.chunk(2, dim=1)
1645
- if var_clip_max is not None:
1646
- q_lv = torch.clip(q_lv, max=var_clip_max)
1647
-
1648
- if q_mu.shape[-1] % 2 == 1 and allow_oddsizes is False:
1649
- q_mu = F.center_crop(q_mu, q_mu.shape[-1] - 1)
1650
- q_lv = F.center_crop(q_lv, q_lv.shape[-1] - 1)
1651
- # clip_start = np.random.rand() > 0.5
1652
- # q_mu = q_mu[:, :, 1:, 1:] if clip_start else q_mu[:, :, :-1, :-1]
1653
- # q_lv = q_lv[:, :, 1:, 1:] if clip_start else q_lv[:, :, :-1, :-1]
1654
-
1655
- q_mu = StableMean(q_mu)
1656
- q_lv = StableLogVar(q_lv, enable_stable=not self._use_naive_exponential)
1657
- q = Normal(q_mu.get(), q_lv.get_std())
1658
- return q_mu, q_lv, q
1659
-
1660
- def forward(
1661
- self,
1662
- p_params: torch.Tensor,
1663
- q_params: torch.Tensor = None,
1664
- forced_latent: torch.Tensor = None,
1665
- use_mode: bool = False,
1666
- force_constant_output: bool = False,
1667
- analytical_kl: bool = False,
1668
- mode_pred: bool = False,
1669
- use_uncond_mode: bool = False,
1670
- var_clip_max: float = None,
1671
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
1672
- """
1673
- Parameters
1674
- ----------
1675
- p_params: torch.Tensor
1676
- The output tensor of the top-down layer above (i.e., mu_{p,i+1}, sigma_{p,i+1}).
1677
- q_params: torch.Tensor, optional
1678
- The tensor resulting from merging the bu_value tensor at the same hierarchical level
1679
- from the bottom-up pass and the `p_params` tensor. Default is `None`.
1680
- forced_latent: torch.Tensor, optional
1681
- A pre-defined latent tensor. If it is not `None`, than it is used as the actual latent
1682
- tensor and, hence, sampling does not happen. Default is `None`.
1683
- use_mode: bool, optional
1684
- Whether the latent tensor should be set as the latent distribution mode.
1685
- In the case of Gaussian, the mode coincides with the mean of the distribution.
1686
- Default is `False`.
1687
- force_constant_output: bool, optional
1688
- Whether to copy the first sample (and rel. distrib parameters) over the whole batch.
1689
- This is used when doing experiment from the prior - q is not used.
1690
- Default is `False`.
1691
- analytical_kl: bool, optional
1692
- Whether to compute the KL divergence analytically or using Monte Carlo estimation.
1693
- Default is `False`.
1694
- mode_pred: bool, optional
1695
- Whether the model is in prediction mode. Default is `False`.
1696
- use_uncond_mode: bool, optional
1697
- Whether to use the uncoditional distribution p(z) to sample latents in prediction mode.
1698
- Default is `False`.
1699
- var_clip_max: float, optional
1700
- The maximum value reachable by the log-variance of the latent distribution.
1701
- Values exceeding this threshold are clipped. Default is `None`.
1702
- """
1703
- debug_qvar_max = 0
1704
-
1705
- # Check sampling options consistency
1706
- assert (forced_latent is None) or (not use_mode)
1707
-
1708
- # Get generative distribution p(z_i|z_{i+1})
1709
- p_mu, p_lv, p = self.process_p_params(p_params, var_clip_max)
1710
- p_params = (p_mu, p_lv)
1711
-
1712
- if q_params is not None:
1713
- # Get inference distribution q(z_i|z_{i+1})
1714
- # NOTE: At inference time, don't centercrop the q_params even if they are odd in size.
1715
- q_mu, q_lv, q = self.process_q_params(
1716
- q_params, var_clip_max, allow_oddsizes=mode_pred is True
1717
- )
1718
- q_params = (q_mu, q_lv)
1719
- sampling_distrib = q
1720
- debug_qvar_max = torch.max(q_lv.get())
1721
-
1722
- # Centercrop p_params so that their size matches the one of q_params
1723
- q_size = q_mu.get().shape[-1]
1724
- if p_mu.get().shape[-1] != q_size and mode_pred is False:
1725
- p_mu.centercrop_to_size(q_size)
1726
- p_lv.centercrop_to_size(q_size)
1727
- else:
1728
- sampling_distrib = p
1729
-
1730
- # Sample latent variable
1731
- z = self.get_z(
1732
- sampling_distrib, forced_latent, use_mode, mode_pred, use_uncond_mode
1733
- )
1734
-
1735
- # Copy one sample (and distrib parameters) over the whole batch.
1736
- # This is used when doing experiment from the prior - q is not used.
1737
- if force_constant_output:
1738
- z = z[0:1].expand_as(z).clone()
1739
- p_params = (
1740
- p_params[0][0:1].expand_as(p_params[0]).clone(),
1741
- p_params[1][0:1].expand_as(p_params[1]).clone(),
1742
- )
1743
-
1744
- # Pass the sampled latent througn the output convolutional layer of stochastic block
1745
- out = self.conv_out(z)
1746
-
1747
- # Compute log p(z)# NOTE: disabling its computation.
1748
- # if mode_pred is False:
1749
- # logprob_p = p.log_prob(z).sum((1, 2, 3))
1750
- # else:
1751
- # logprob_p = None
1752
-
1753
- if q_params is not None:
1754
- # Compute log q(z)
1755
- logprob_q = q.log_prob(z).sum((1, 2, 3))
1756
- # Compute KL divergence metrics
1757
- kl_dict = self.compute_kl_metrics(
1758
- p, p_params, q, q_params, mode_pred, analytical_kl, z
1759
- )
1760
- else:
1761
- kl_dict = {}
1762
- logprob_q = None
1763
-
1764
- # Store meaningful quantities to use them in following layers
1765
- data = kl_dict
1766
- data["z"] = z # sampled variable at this layer (batch, ch, h, w)
1767
- data["p_params"] = p_params # (b, ch, h, w) where b is 1 or batch size
1768
- data["q_params"] = q_params # (batch, ch, h, w)
1769
- # data['logprob_p'] = logprob_p # (batch, )
1770
- data["logprob_q"] = logprob_q # (batch, )
1771
- data["qvar_max"] = debug_qvar_max
1772
-
1773
- return out, data
1774
-
1775
-
1776
- class NonStochasticBlock2d(nn.Module):
1777
- """
1778
- Non-stochastic version of the NormalStochasticBlock2d.
1779
- """
1780
-
1781
- def __init__(
1782
- self,
1783
- c_vars: int,
1784
- c_in: int,
1785
- c_out: int,
1786
- kernel: int = 3,
1787
- groups: int = 1,
1788
- conv2d_bias: bool = True,
1789
- transform_p_params: bool = True,
1790
- ):
1791
- """
1792
- Constructor.
1793
-
1794
- Parameters
1795
- ----------
1796
- c_vars: int
1797
- The number of channels of the latent space tensor.
1798
- c_in: int
1799
- The number of channels of the input tensor.
1800
- c_out: int
1801
- The output of the stochastic layer.
1802
- Note that this is different from the sampled latent z.
1803
- kernel: int, optional
1804
- The size of the kernel used in convolutional layers.
1805
- Default is 3.
1806
- groups: int, optional
1807
- The number of groups to consider in the convolutions of this layer.
1808
- Default is 1.
1809
- conv2d_bias: bool, optional
1810
- Whether to use bias term is the convolutional blocks of this layer.
1811
- Default is `True`.
1812
- transform_p_params: bool, optional
1813
- Whether a transformation should be applied to the `p_params` tensor.
1814
- The transformation consists in a 2D convolution ()`conv_in_p()`) that
1815
- maps the input to a larger number of channels.
1816
- Default is `True`.
1817
- """
1818
- super().__init__()
1819
- assert kernel % 2 == 1
1820
- pad = kernel // 2
1821
- self.transform_p_params = transform_p_params
1822
- self.c_in = c_in
1823
- self.c_out = c_out
1824
- self.c_vars = c_vars
1825
-
1826
- if transform_p_params:
1827
- self.conv_in_p = nn.Conv2d(
1828
- c_in, 2 * c_vars, kernel, padding=pad, bias=conv2d_bias, groups=groups
1829
- )
1830
- self.conv_in_q = nn.Conv2d(
1831
- c_in, 2 * c_vars, kernel, padding=pad, bias=conv2d_bias, groups=groups
1832
- )
1833
- self.conv_out = nn.Conv2d(
1834
- c_vars, c_out, kernel, padding=pad, bias=conv2d_bias, groups=groups
1835
- )
1836
-
1837
- def compute_kl_metrics(
1838
- self,
1839
- p: torch.distributions.normal.Normal,
1840
- p_params: torch.Tensor,
1841
- q: torch.distributions.normal.Normal,
1842
- q_params: torch.Tensor,
1843
- mode_pred: bool,
1844
- analytical_kl: bool,
1845
- z: torch.Tensor,
1846
- ) -> Dict[str, None]:
1847
- """
1848
- Compute KL (analytical or MC estimate) and then process it, extracting composed versions of the metric.
1849
- Specifically, the different versions of the KL loss terms are:
1850
- - `kl_elementwise`: KL term for each single element of the latent tensor [Shape: (batch, ch, h, w)].
1851
- - `kl_samplewise`: KL term associated to each sample in the batch [Shape: (batch, )].
1852
- - `kl_samplewise_restricted`: KL term only associated to the portion of the latent tensor that is
1853
- used for prediction and summed over channel and spatial dimensions [Shape: (batch, )].
1854
- - `kl_channelwise`: KL term associated to each sample and each channel [Shape: (batch, ch, )].
1855
- - `kl_spatial`: # KL term summed over the channels, i.e., retaining the spatial dimensions [Shape: (batch, h, w)]
1856
-
1857
- NOTE: in this class all the KL metrics are set to `None`.
1858
-
1859
- Parameters
1860
- ----------
1861
- p: torch.distributions.normal.Normal
1862
- The prior generative distribution p(z_i|z_{i+1}) (or p(z_L)).
1863
- p_params: torch.Tensor
1864
- The parameters of the prior generative distribution.
1865
- q: torch.distributions.normal.Normal
1866
- The inference distribution q(z_i|z_{i+1}) (or q(z_L|x)).
1867
- q_params: torch.Tensor
1868
- The parameters of the inference distribution.
1869
- mode_pred: bool
1870
- Whether the model is in prediction mode.
1871
- analytical_kl: bool
1872
- Whether to compute the KL divergence analytically or using Monte Carlo estimation.
1873
- z: torch.Tensor
1874
- The sampled latent tensor.
1875
- """
1876
- kl_dict = {
1877
- "kl_elementwise": None, # (batch, ch, h, w)
1878
- "kl_samplewise": None, # (batch, )
1879
- "kl_spatial": None, # (batch, h, w)
1880
- "kl_channelwise": None, # (batch, ch)
1881
- }
1882
- return kl_dict
1883
-
1884
- def process_p_params(self, p_params, var_clip_max):
1885
- if self.transform_p_params:
1886
- p_params = self.conv_in_p(p_params)
1887
- else:
1888
-
1889
- assert (
1890
- p_params.size(1) == 2 * self.c_vars
1891
- ), f"{p_params.shape} {self.c_vars}"
1892
-
1893
- # Define p(z)
1894
- p_mu, p_lv = p_params.chunk(2, dim=1)
1895
- return p_mu, None
1896
-
1897
- def process_q_params(self, q_params, var_clip_max, allow_oddsizes=False):
1898
- # Define q(z)
1899
- q_params = self.conv_in_q(q_params)
1900
- q_mu, q_lv = q_params.chunk(2, dim=1)
1901
-
1902
- if q_mu.shape[-1] % 2 == 1 and allow_oddsizes is False:
1903
- q_mu = F.center_crop(q_mu, q_mu.shape[-1] - 1)
1904
-
1905
- return q_mu, None
1906
-
1907
- def forward(
1908
- self,
1909
- p_params: torch.Tensor,
1910
- q_params: torch.Tensor = None,
1911
- forced_latent: Union[None, torch.Tensor] = None,
1912
- use_mode: bool = False,
1913
- force_constant_output: bool = False,
1914
- analytical_kl: bool = False,
1915
- mode_pred: bool = False,
1916
- use_uncond_mode: bool = False,
1917
- var_clip_max: float = None,
1918
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
1919
- """
1920
- Parameters
1921
- ----------
1922
- p_params: torch.Tensor
1923
- The output tensor of the top-down layer above (i.e., mu_{p,i+1}, sigma_{p,i+1}).
1924
- q_params: torch.Tensor, optional
1925
- The tensor resulting from merging the bu_value tensor at the same hierarchical level
1926
- from the bottom-up pass and the `p_params` tensor. Default is `None`.
1927
- forced_latent: torch.Tensor, optional
1928
- A pre-defined latent tensor. If it is not `None`, than it is used as the actual latent
1929
- tensor and, hence, sampling does not happen. Default is `None`.
1930
- use_mode: bool, optional
1931
- Whether the latent tensor should be set as the latent distribution mode.
1932
- In the case of Gaussian, the mode coincides with the mean of the distribution.
1933
- Default is `False`.
1934
- force_constant_output: bool, optional
1935
- Whether to copy the first sample (and rel. distrib parameters) over the whole batch.
1936
- This is used when doing experiment from the prior - q is not used.
1937
- Default is `False`.
1938
- analytical_kl: bool, optional
1939
- Whether to compute the KL divergence analytically or using Monte Carlo estimation.
1940
- Default is `False`.
1941
- mode_pred: bool, optional
1942
- Whether the model is in prediction mode. Default is `False`.
1943
- use_uncond_mode: bool, optional
1944
- Whether to use the uncoditional distribution p(z) to sample latents in prediction mode.
1945
- Default is `False`.
1946
- var_clip_max: float, optional
1947
- The maximum value reachable by the log-variance of the latent distribution.
1948
- Values exceeding this threshold are clipped. Default is `None`.
1949
- """
1950
- debug_qvar_max = 0
1951
- assert (forced_latent is None) or (not use_mode)
1952
-
1953
- p_mu, _ = self.process_p_params(p_params, var_clip_max)
1954
-
1955
- p_params = (p_mu, None)
1956
-
1957
- if q_params is not None:
1958
- # At inference time, just don't centercrop the q_params even if they are odd in size.
1959
- q_mu, _ = self.process_q_params(
1960
- q_params, var_clip_max, allow_oddsizes=mode_pred is True
1961
- )
1962
- q_params = (q_mu, None)
1963
- debug_qvar_max = torch.Tensor([1]).to(q_mu.device)
1964
- # Sample from q(z)
1965
- sampling_distrib = q_mu
1966
- q_size = q_mu.shape[-1]
1967
- if p_mu.shape[-1] != q_size and mode_pred is False:
1968
- p_mu.centercrop_to_size(q_size)
1969
- else:
1970
- # Sample from p(z)
1971
- sampling_distrib = p_mu
1972
-
1973
- # Generate latent variable (typically by sampling)
1974
- z = sampling_distrib
1975
-
1976
- # Copy one sample (and distrib parameters) over the whole batch.
1977
- # This is used when doing experiment from the prior - q is not used.
1978
- if force_constant_output:
1979
- z = z[0:1].expand_as(z).clone()
1980
- p_params = (
1981
- p_params[0][0:1].expand_as(p_params[0]).clone(),
1982
- p_params[1][0:1].expand_as(p_params[1]).clone(),
1983
- )
1984
-
1985
- # Output of stochastic layer
1986
- out = self.conv_out(z)
1987
-
1988
- kl_dict = {}
1989
- logprob_q = None
1990
-
1991
- data = kl_dict
1992
- data["z"] = z # sampled variable at this layer (batch, ch, h, w)
1993
- data["p_params"] = p_params # (b, ch, h, w) where b is 1 or batch size
1994
- data["q_params"] = q_params # (batch, ch, h, w)
1995
- data["logprob_q"] = logprob_q # (batch, )
1996
- data["qvar_max"] = debug_qvar_max
1997
-
1998
- return out, data
1371
+ return x, data