careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,1371 @@
1
+ """Script containing the common basic blocks (nn.Module) reused by the LadderVAE."""
2
+
3
+ from collections.abc import Iterable
4
+ from copy import deepcopy
5
+ from typing import Callable, Literal, Optional, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from .stochastic import NormalStochasticBlock
12
+ from .utils import (
13
+ crop_img_tensor,
14
+ pad_img_tensor,
15
+ )
16
+
17
+ ConvType = Union[nn.Conv2d, nn.Conv3d]
18
+ NormType = Union[nn.BatchNorm2d, nn.BatchNorm3d]
19
+ DropoutType = Union[nn.Dropout2d, nn.Dropout3d]
20
+
21
+
22
+ class ResidualBlock(nn.Module):
23
+ """
24
+ Residual block with 2 convolutional layers.
25
+
26
+ Some architectural notes:
27
+ - The number of input, intermediate, and output channels is the same,
28
+ - Padding is always 'same',
29
+ - The 2 convolutional layers have the same groups,
30
+ - No stride allowed,
31
+ - Kernel sizes must be odd.
32
+
33
+ The output isgiven by: `out = gate(f(x)) + x`.
34
+ The presence of the gating mechanism is optional, and f(x) has different
35
+ structures depending on the `block_type` argument.
36
+ Specifically, `block_type` is a string specifying the block's structure, with:
37
+ a = activation
38
+ b = batch norm
39
+ c = conv layer
40
+ d = dropout.
41
+ For example, "bacdbacd" defines a block with 2x[batchnorm, activation, conv, dropout].
42
+ """
43
+
44
+ default_kernel_size = (3, 3)
45
+
46
+ def __init__(
47
+ self,
48
+ channels: int,
49
+ nonlin: Callable,
50
+ conv_strides: tuple[int] = (2, 2),
51
+ kernel: Union[int, Iterable[int], None] = None,
52
+ groups: int = 1,
53
+ batchnorm: bool = True,
54
+ block_type: str = None,
55
+ dropout: float = None,
56
+ gated: bool = None,
57
+ conv2d_bias: bool = True,
58
+ ):
59
+ """
60
+ Constructor.
61
+
62
+ Parameters
63
+ ----------
64
+ channels: int
65
+ The number of input and output channels (they are the same).
66
+ nonlin: Callable
67
+ The non-linearity function used in the block (e.g., `nn.ReLU`).
68
+ kernel: Union[int, Iterable[int]], optional
69
+ The kernel size used in the convolutions of the block.
70
+ It can be either a single integer or a pair of integers defining the squared kernel.
71
+ Default is `None`.
72
+ groups: int, optional
73
+ The number of groups to consider in the convolutions. Default is 1.
74
+ batchnorm: bool, optional
75
+ Whether to use batchnorm layers. Default is `True`.
76
+ block_type: str, optional
77
+ A string specifying the block structure, check class docstring for more info.
78
+ Default is `None`.
79
+ dropout: float, optional
80
+ The dropout probability in dropout layers. If `None` dropout is not used.
81
+ Default is `None`.
82
+ gated: bool, optional
83
+ Whether to use gated layer. Default is `None`.
84
+ conv2d_bias: bool, optional
85
+ Whether to use bias term in convolutions. Default is `True`.
86
+ """
87
+ super().__init__()
88
+
89
+ # Set kernel size & padding
90
+ if kernel is None:
91
+ kernel = self.default_kernel_size
92
+ elif isinstance(kernel, int):
93
+ kernel = (kernel, kernel)
94
+ elif len(kernel) != 2:
95
+ raise ValueError("kernel has to be None, int, or an iterable of length 2")
96
+ assert all(k % 2 == 1 for k in kernel), "kernel sizes have to be odd"
97
+ kernel = list(kernel)
98
+
99
+ # Define modules
100
+ conv_layer: ConvType = getattr(nn, f"Conv{len(conv_strides)}d")
101
+ norm_layer: NormType = getattr(nn, f"BatchNorm{len(conv_strides)}d")
102
+ dropout_layer: DropoutType = getattr(nn, f"Dropout{len(conv_strides)}d")
103
+ # TODO: same comment as in lvae.py, would be more readable to have `conv_dims`
104
+
105
+ modules = []
106
+ if block_type == "cabdcabd":
107
+ for i in range(2):
108
+ conv = conv_layer(
109
+ channels,
110
+ channels,
111
+ kernel[i],
112
+ padding="same",
113
+ groups=groups,
114
+ bias=conv2d_bias,
115
+ )
116
+ modules.append(conv)
117
+ modules.append(nonlin)
118
+ if batchnorm:
119
+ modules.append(norm_layer(channels))
120
+ if dropout is not None:
121
+ modules.append(dropout_layer(dropout))
122
+ elif block_type == "bacdbac":
123
+ for i in range(2):
124
+ if batchnorm:
125
+ modules.append(norm_layer(channels))
126
+ modules.append(nonlin)
127
+ conv = conv_layer(
128
+ channels,
129
+ channels,
130
+ kernel[i],
131
+ padding="same",
132
+ groups=groups,
133
+ bias=conv2d_bias,
134
+ )
135
+ modules.append(conv)
136
+ if dropout is not None and i == 0:
137
+ modules.append(dropout_layer(dropout))
138
+ elif block_type == "bacdbacd":
139
+ for i in range(2):
140
+ if batchnorm:
141
+ modules.append(norm_layer(channels))
142
+ modules.append(nonlin)
143
+ conv = conv_layer(
144
+ channels,
145
+ channels,
146
+ kernel[i],
147
+ padding="same",
148
+ groups=groups,
149
+ bias=conv2d_bias,
150
+ )
151
+ modules.append(conv)
152
+ modules.append(dropout_layer(dropout))
153
+
154
+ else:
155
+ raise ValueError(f"unrecognized block type '{block_type}'")
156
+
157
+ self.gated = gated
158
+ if gated:
159
+ modules.append(
160
+ GateLayer(
161
+ channels=channels,
162
+ conv_strides=conv_strides,
163
+ kernel_size=1,
164
+ nonlin=nonlin,
165
+ )
166
+ )
167
+
168
+ self.block = nn.Sequential(*modules)
169
+
170
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
171
+ """Forward pass.
172
+
173
+ Parameters
174
+ ----------
175
+ x : torch.Tensor
176
+ input tensor # TODO add shape
177
+
178
+ Returns
179
+ -------
180
+ torch.Tensor
181
+ output tensor # TODO add shape
182
+ """
183
+ out = self.block(x)
184
+ assert (
185
+ out.shape == x.shape
186
+ ), f"output shape: {out.shape} != input shape: {x.shape}"
187
+ return out + x
188
+
189
+
190
+ class ResidualGatedBlock(ResidualBlock):
191
+ """Layer class that implements a residual block with a gating mechanism."""
192
+
193
+ def __init__(self, *args, **kwargs):
194
+ super().__init__(*args, **kwargs, gated=True)
195
+
196
+
197
+ class GateLayer(nn.Module):
198
+ """
199
+ Layer class that implements a gating mechanism.
200
+
201
+ Double the number of channels through a convolutional layer, then use
202
+ half the channels as gate for the other half.
203
+ """
204
+
205
+ def __init__(
206
+ self,
207
+ channels: int,
208
+ conv_strides: tuple[int] = (2, 2),
209
+ kernel_size: int = 3,
210
+ nonlin: Callable = nn.LeakyReLU(),
211
+ ):
212
+ super().__init__()
213
+ assert kernel_size % 2 == 1
214
+ pad = kernel_size // 2
215
+ conv_layer: ConvType = getattr(nn, f"Conv{len(conv_strides)}d")
216
+ self.conv = conv_layer(channels, 2 * channels, kernel_size, padding=pad)
217
+ self.nonlin = nonlin
218
+
219
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
220
+ """Forward pass.
221
+
222
+ Parameters
223
+ ----------
224
+ x : torch.Tensor
225
+ input # TODO add shape
226
+
227
+ Returns
228
+ -------
229
+ torch.Tensor
230
+ output # TODO add shape
231
+ """
232
+ x = self.conv(x)
233
+ x, gate = torch.chunk(x, 2, dim=1)
234
+ x = self.nonlin(x) # TODO remove this?
235
+ gate = torch.sigmoid(gate)
236
+ return x * gate
237
+
238
+
239
+ class ResBlockWithResampling(nn.Module):
240
+ """
241
+ Residual block with resampling.
242
+
243
+ Residual block that takes care of resampling (i.e. downsampling or upsampling) steps (by a factor 2).
244
+ It is structured as follows:
245
+ 1. `pre_conv`: a downsampling or upsampling strided convolutional layer in case of resampling, or
246
+ a 1x1 convolutional layer that maps the number of channels of the input to `inner_channels`.
247
+ 2. `ResidualBlock`
248
+ 3. `post_conv`: a 1x1 convolutional layer that maps the number of channels to `c_out`.
249
+
250
+ Some implementation notes:
251
+ - Resampling is performed through a strided convolution layer at the beginning of the block.
252
+ - The strided convolution block has fixed kernel size of 3x3 and 1 layer of padding with zeros.
253
+ - The number of channels is adjusted at the beginning and end of the block through 1x1 convolutional layers.
254
+ - The number of internal channels is by default the same as the number of output channels, but
255
+ min_inner_channels can override the behaviour.
256
+ """
257
+
258
+ def __init__(
259
+ self,
260
+ mode: Literal["top-down", "bottom-up"],
261
+ c_in: int,
262
+ c_out: int,
263
+ conv_strides: tuple[int],
264
+ min_inner_channels: Union[int, None] = None,
265
+ nonlin: Callable = nn.LeakyReLU(),
266
+ resample: bool = False,
267
+ res_block_kernel: Optional[Union[int, Iterable[int]]] = None,
268
+ groups: int = 1,
269
+ batchnorm: bool = True,
270
+ res_block_type: Union[str, None] = None,
271
+ dropout: Union[float, None] = None,
272
+ gated: Union[bool, None] = None,
273
+ conv2d_bias: bool = True,
274
+ # lowres_input: bool = False,
275
+ ):
276
+ """
277
+ Constructor.
278
+
279
+ Parameters
280
+ ----------
281
+ mode: Literal["top-down", "bottom-up"]
282
+ The type of resampling performed in the initial strided convolution of the block.
283
+ If "bottom-up" downsampling of a factor 2 is done.
284
+ If "top-down" upsampling of a factor 2 is done.
285
+ c_in: int
286
+ The number of input channels.
287
+ c_out: int
288
+ The number of output channels.
289
+ min_inner_channels: int, optional
290
+ The number of channels used in the inner layer of this module.
291
+ Default is `None`, meaning that the number of inner channels is set to `c_out`.
292
+ nonlin: Callable, optional
293
+ The non-linearity function used in the block. Default is `nn.LeakyReLU`.
294
+ resample: bool, optional
295
+ Whether to perform resampling in the first convolutional layer.
296
+ If `False`, the first convolutional layer just maps the input to a tensor with
297
+ `inner_channels` channels through 1x1 convolution. Default is `False`.
298
+ res_block_kernel: Union[int, Iterable[int]], optional
299
+ The kernel size used in the convolutions of the residual block.
300
+ It can be either a single integer or a pair of integers defining the squared kernel.
301
+ Default is `None`.
302
+ groups: int, optional
303
+ The number of groups to consider in the convolutions. Default is 1.
304
+ batchnorm: bool, optional
305
+ Whether to use batchnorm layers. Default is `True`.
306
+ res_block_type: str, optional
307
+ A string specifying the structure of residual block.
308
+ Check `ResidualBlock` doscstring for more information.
309
+ Default is `None`.
310
+ dropout: float, optional
311
+ The dropout probability in dropout layers. If `None` dropout is not used.
312
+ Default is `None`.
313
+ gated: bool, optional
314
+ Whether to use gated layer. Default is `None`.
315
+ conv2d_bias: bool, optional
316
+ Whether to use bias term in convolutions. Default is `True`.
317
+ """
318
+ super().__init__()
319
+ assert mode in ["top-down", "bottom-up"]
320
+
321
+ conv_layer: ConvType = getattr(nn, f"Conv{len(conv_strides)}d")
322
+ transp_conv_layer: ConvType = getattr(nn, f"ConvTranspose{len(conv_strides)}d")
323
+
324
+ if min_inner_channels is None:
325
+ min_inner_channels = 0
326
+ # inner_channels is the number of channels used in the inner layers
327
+ # of ResBlockWithResampling
328
+ inner_channels = max(c_out, min_inner_channels)
329
+
330
+ # Define first conv layer to change num channels and/or up/downsample
331
+ if resample:
332
+ if mode == "bottom-up": # downsample
333
+ self.pre_conv = conv_layer(
334
+ in_channels=c_in,
335
+ out_channels=inner_channels,
336
+ kernel_size=3,
337
+ padding=1,
338
+ stride=conv_strides,
339
+ groups=groups,
340
+ bias=conv2d_bias,
341
+ )
342
+ elif mode == "top-down": # upsample
343
+ self.pre_conv = transp_conv_layer(
344
+ in_channels=c_in,
345
+ kernel_size=3,
346
+ out_channels=inner_channels,
347
+ padding=1, # TODO maybe don't hardcode this?
348
+ stride=conv_strides,
349
+ groups=groups,
350
+ output_padding=1 if len(conv_strides) == 2 else (0, 1, 1),
351
+ bias=conv2d_bias,
352
+ )
353
+ elif c_in != inner_channels:
354
+ self.pre_conv = conv_layer(
355
+ c_in, inner_channels, 1, groups=groups, bias=conv2d_bias
356
+ )
357
+ else:
358
+ self.pre_conv = None
359
+
360
+ # Residual block
361
+ self.res = ResidualBlock(
362
+ channels=inner_channels,
363
+ conv_strides=conv_strides,
364
+ nonlin=nonlin,
365
+ kernel=res_block_kernel,
366
+ groups=groups,
367
+ batchnorm=batchnorm,
368
+ dropout=dropout,
369
+ gated=gated,
370
+ block_type=res_block_type,
371
+ conv2d_bias=conv2d_bias,
372
+ )
373
+
374
+ # Define last conv layer to get correct num output channels
375
+ if inner_channels != c_out:
376
+ self.post_conv = conv_layer(
377
+ inner_channels, c_out, 1, groups=groups, bias=conv2d_bias
378
+ )
379
+ else:
380
+ self.post_conv = None
381
+
382
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
383
+ """Forward pass.
384
+
385
+ Parameters
386
+ ----------
387
+ x : torch.Tensor
388
+ input # TODO add shape
389
+
390
+ Returns
391
+ -------
392
+ torch.Tensor
393
+ output # TODO add shape
394
+ """
395
+ if self.pre_conv is not None:
396
+ x = self.pre_conv(x)
397
+
398
+ x = self.res(x)
399
+
400
+ if self.post_conv is not None:
401
+ x = self.post_conv(x)
402
+ return x
403
+
404
+
405
+ class TopDownDeterministicResBlock(ResBlockWithResampling):
406
+ """Resnet block for top-down deterministic layers."""
407
+
408
+ def __init__(self, *args, upsample: bool = False, **kwargs):
409
+ kwargs["resample"] = upsample
410
+ super().__init__("top-down", *args, **kwargs)
411
+
412
+
413
+ class BottomUpDeterministicResBlock(ResBlockWithResampling):
414
+ """Resnet block for bottom-up deterministic layers."""
415
+
416
+ def __init__(self, *args, downsample: bool = False, **kwargs):
417
+ kwargs["resample"] = downsample
418
+ super().__init__("bottom-up", *args, **kwargs)
419
+
420
+
421
+ class BottomUpLayer(nn.Module):
422
+ """
423
+ Bottom-up deterministic layer.
424
+
425
+ It consists of one or a stack of `BottomUpDeterministicResBlock`'s.
426
+ The outputs are the so-called `bu_values` that are later used in the Decoder to update the
427
+ generative distributions.
428
+
429
+ NOTE: When Lateral Contextualization is Enabled (i.e., `enable_multiscale=True`),
430
+ the low-res lateral input is first fed through a BottomUpDeterministicBlock (BUDB)
431
+ (without downsampling), and then merged to the latent tensor produced by the primary flow
432
+ of the `BottomUpLayer` through the `MergeLowRes` layer. It is meaningful to remark that
433
+ the BUDB that takes care of encoding the low-res input can be either shared with the
434
+ primary flow (and in that case it is the "same_size" BUDB (or stack of BUDBs) -> see `self.net`),
435
+ or can be a deep-copy of the primary flow's BUDB.
436
+ This behaviour is controlled by `lowres_separate_branch` parameter.
437
+ """
438
+
439
+ def __init__(
440
+ self,
441
+ n_res_blocks: int,
442
+ n_filters: int,
443
+ conv_strides: tuple[int] = (2, 2),
444
+ downsampling_steps: int = 0,
445
+ nonlin: Optional[Callable] = None,
446
+ batchnorm: bool = True,
447
+ dropout: Optional[float] = None,
448
+ res_block_type: Optional[str] = None,
449
+ res_block_kernel: Optional[int] = None,
450
+ gated: Optional[bool] = None,
451
+ enable_multiscale: bool = False,
452
+ multiscale_lowres_size_factor: Optional[int] = None,
453
+ lowres_separate_branch: bool = False,
454
+ multiscale_retain_spatial_dims: bool = False,
455
+ decoder_retain_spatial_dims: bool = False,
456
+ output_expected_shape: Optional[Iterable[int]] = None,
457
+ ):
458
+ """
459
+ Constructor.
460
+
461
+ Parameters
462
+ ----------
463
+ n_res_blocks: int
464
+ Number of `BottomUpDeterministicResBlock` modules stacked in this layer.
465
+ n_filters: int
466
+ Number of channels present through out the layers of this block.
467
+ downsampling_steps: int, optional
468
+ Number of downsampling steps that has to be done in this layer (typically 1).
469
+ Default is 0.
470
+ nonlin: Callable, optional
471
+ The non-linearity function used in the block. Default is `None`.
472
+ batchnorm: bool, optional
473
+ Whether to use batchnorm layers. Default is `True`.
474
+ dropout: float, optional
475
+ The dropout probability in dropout layers. If `None` dropout is not used.
476
+ Default is `None`.
477
+ res_block_type: str, optional
478
+ A string specifying the structure of residual block.
479
+ Check `ResidualBlock` doscstring for more information.
480
+ Default is `None`.
481
+ res_block_kernel: Union[int, Iterable[int]], optional
482
+ The kernel size used in the convolutions of the residual block.
483
+ It can be either a single integer or a pair of integers defining the squared kernel.
484
+ Default is `None`.
485
+ gated: bool, optional
486
+ Whether to use gated layer. Default is `None`.
487
+ enable_multiscale: bool, optional
488
+ Whether to enable multiscale (Lateral Contextualization) or not. Default is `False`.
489
+ multiscale_lowres_size_factor: int, optional
490
+ A factor the expresses the relative size of the primary flow tensor with respect to the
491
+ lower-resolution lateral input tensor. Default in `None`.
492
+ lowres_separate_branch: bool, optional
493
+ Whether the residual block(s) encoding the low-res input should be shared (`False`) or
494
+ not (`True`) with the primary flow "same-size" residual block(s). Default is `False`.
495
+ multiscale_retain_spatial_dims: bool, optional
496
+ Whether to pad the latent tensor resulting from the bottom-up layer's primary flow
497
+ to match the size of the low-res input. Default is `False`.
498
+ decoder_retain_spatial_dims: bool, optional
499
+ Whether in the corresponding top-down layer the shape of tensor is retained between
500
+ input and output. Default is `False`.
501
+ output_expected_shape: Iterable[int], optional
502
+ The expected shape of the layer output (only used if `enable_multiscale == True`).
503
+ Default is `None`.
504
+ """
505
+ super().__init__()
506
+
507
+ # Define attributes for Lateral Contextualization
508
+ self.enable_multiscale = enable_multiscale
509
+ self.lowres_separate_branch = lowres_separate_branch
510
+ self.multiscale_retain_spatial_dims = multiscale_retain_spatial_dims
511
+ self.multiscale_lowres_size_factor = multiscale_lowres_size_factor
512
+ self.decoder_retain_spatial_dims = decoder_retain_spatial_dims
513
+ self.output_expected_shape = output_expected_shape
514
+ assert self.output_expected_shape is None or self.enable_multiscale is True
515
+
516
+ bu_blocks_downsized = []
517
+ bu_blocks_samesize = []
518
+ for _ in range(n_res_blocks):
519
+ do_resample = False
520
+ if downsampling_steps > 0:
521
+ do_resample = True
522
+ downsampling_steps -= 1
523
+ block = BottomUpDeterministicResBlock(
524
+ conv_strides=conv_strides,
525
+ c_in=n_filters,
526
+ c_out=n_filters,
527
+ nonlin=nonlin,
528
+ downsample=do_resample,
529
+ batchnorm=batchnorm,
530
+ dropout=dropout,
531
+ res_block_type=res_block_type,
532
+ res_block_kernel=res_block_kernel,
533
+ gated=gated,
534
+ )
535
+ if do_resample:
536
+ bu_blocks_downsized.append(block)
537
+ else:
538
+ bu_blocks_samesize.append(block)
539
+
540
+ self.net_downsized = nn.Sequential(*bu_blocks_downsized)
541
+ self.net = nn.Sequential(*bu_blocks_samesize)
542
+
543
+ # Using the same net for the low resolution (and larger sized image)
544
+ self.lowres_net = self.lowres_merge = None
545
+ if self.enable_multiscale:
546
+ self._init_multiscale(
547
+ n_filters=n_filters,
548
+ conv_strides=conv_strides,
549
+ nonlin=nonlin,
550
+ batchnorm=batchnorm,
551
+ dropout=dropout,
552
+ res_block_type=res_block_type,
553
+ )
554
+
555
+ # msg = f'[{self.__class__.__name__}] McEnabled:{int(enable_multiscale)} '
556
+ # if enable_multiscale:
557
+ # msg += f'McParallelBeam:{int(multiscale_retain_spatial_dims)} McFactor{multiscale_lowres_size_factor}'
558
+ # print(msg)
559
+
560
+ def _init_multiscale(
561
+ self,
562
+ nonlin: Callable = None,
563
+ n_filters: int = None,
564
+ conv_strides: tuple[int] = (2, 2),
565
+ batchnorm: bool = None,
566
+ dropout: float = None,
567
+ res_block_type: str = None,
568
+ ) -> None:
569
+ """
570
+ Bottom-up layer's method that initializes the LC modules.
571
+
572
+ Defines the modules responsible of merging compressed lateral inputs to the
573
+ outputs of the primary flow at different hierarchical levels in the
574
+ multiresolution approach (LC). Specifically, the method initializes `lowres_net`
575
+ , which is a stack of `BottomUpDeterministicBlock`'s (w/out downsampling) that
576
+ takes care of additionally processing the low-res input, and `lowres_merge`,
577
+ which is the module responsible of merging the compressed lateral input to the
578
+ main flow.
579
+
580
+ NOTE: The merge modality is set by default to "residual", meaning that the
581
+ merge layer performs concatenation on dim=1, followed by 1x1 convolution and
582
+ a Residual Gated block.
583
+
584
+ Parameters
585
+ ----------
586
+ nonlin: Callable, optional
587
+ The non-linearity function used in the block. Default is `None`.
588
+ n_filters: int
589
+ Number of channels present through out the layers of this block.
590
+ batchnorm: bool, optional
591
+ Whether to use batchnorm layers. Default is `True`.
592
+ dropout: float, optional
593
+ The dropout probability in dropout layers. If `None` dropout is not used.
594
+ Default is `None`.
595
+ res_block_type: str, optional
596
+ A string specifying the structure of residual block.
597
+ Check `ResidualBlock` doscstring for more information.
598
+ Default is `None`.
599
+ """
600
+ self.lowres_net = self.net
601
+ if self.lowres_separate_branch:
602
+ self.lowres_net = deepcopy(self.net)
603
+
604
+ self.lowres_merge = MergeLowRes(
605
+ channels=n_filters,
606
+ conv_strides=conv_strides,
607
+ merge_type="residual",
608
+ nonlin=nonlin,
609
+ batchnorm=batchnorm,
610
+ dropout=dropout,
611
+ res_block_type=res_block_type,
612
+ multiscale_retain_spatial_dims=self.multiscale_retain_spatial_dims,
613
+ multiscale_lowres_size_factor=self.multiscale_lowres_size_factor,
614
+ )
615
+
616
+ def forward(
617
+ self, x: torch.Tensor, lowres_x: Union[torch.Tensor, None] = None
618
+ ) -> tuple[torch.Tensor, torch.Tensor]:
619
+ """Forward pass.
620
+
621
+ Parameters
622
+ ----------
623
+ x: torch.Tensor
624
+ The input of the `BottomUpLayer`, i.e., the input image or the output of the
625
+ previous layer.
626
+ lowres_x: torch.Tensor, optional
627
+ The low-res input used for Lateral Contextualization (LC). Default is `None`.
628
+
629
+ NOTE: first returned tensor is used as input for the next BU layer, while the second
630
+ tensor is the bu_value passed to the top-down layer.
631
+ """
632
+ # The input is fed through the residual downsampling block(s)
633
+ primary_flow = self.net_downsized(x)
634
+ # The downsampling output is fed through additional residual block(s)
635
+ primary_flow = self.net(primary_flow)
636
+
637
+ # If LC is not used, simply return output of primary-flow
638
+ if self.enable_multiscale is False:
639
+ assert lowres_x is None
640
+ return primary_flow, primary_flow
641
+
642
+ if lowres_x is not None:
643
+ # First encode the low-res lateral input
644
+ lowres_flow = self.lowres_net(lowres_x)
645
+ # Then pass the result through the MergeLowRes layer
646
+ merged = self.lowres_merge(primary_flow, lowres_flow)
647
+ else:
648
+ merged = primary_flow
649
+
650
+ # NOTE: Explanation of possible cases for the conditionals:
651
+ # - if both are `True` -> `merged` has the same spatial dims as the input (`x`) since
652
+ # spatial dims are retained by padding `primary_flow` in `MergeLowRes`. This is
653
+ # OK for the corresp TopDown layer, as it also retains spatial dims.
654
+ # - if both are `False` -> `merged`'s spatial dims are equal to `self.net_downsized(x)`,
655
+ # since no padding is done in `MergeLowRes` and, instead, the lowres input is cropped.
656
+ # This is OK for the corresp TopDown layer, as it also halves the spatial dims.
657
+ # - if 1st is `False` and 2nd is `True` -> not a concern, it cannot happen
658
+ # (see lvae.py, line 111, intialization of `multiscale_decoder_retain_spatial_dims`).
659
+ if (
660
+ self.multiscale_retain_spatial_dims is False
661
+ or self.decoder_retain_spatial_dims is True
662
+ ):
663
+ return merged, merged
664
+
665
+ # NOTE: if we reach here, it means that `multiscale_retain_spatial_dims` is `True`,
666
+ # but `decoder_retain_spatial_dims` is `False`, meaning that merging LC preserves
667
+ # the spatial dimensions, but at the same time we don't want to retain the spatial
668
+ # dims in the corresponding top-down layer. Therefore, we need to crop the tensor.
669
+ if self.output_expected_shape is not None:
670
+ expected_shape = self.output_expected_shape
671
+ else:
672
+ fac = self.multiscale_lowres_size_factor
673
+ expected_shape = (merged.shape[-2] // fac, merged.shape[-1] // fac)
674
+ assert merged.shape[-2:] != expected_shape
675
+
676
+ # Crop the resulting tensor so that it matches with the Decoder
677
+ value_to_use_in_topdown = crop_img_tensor(merged, expected_shape)
678
+ return merged, value_to_use_in_topdown
679
+
680
+
681
+ class MergeLayer(nn.Module):
682
+ """
683
+ Layer class that merges two or more input tensors.
684
+
685
+ Merges two or more (B, C, [Z], Y, X) input tensors by concatenating
686
+ them along dim=1 and passes the result through:
687
+ a) a convolutional 1x1 layer (`merge_type == "linear"`), or
688
+ b) a convolutional 1x1 layer and then a gated residual block (`merge_type == "residual"`), or
689
+ c) a convolutional 1x1 layer and then an ungated residual block (`merge_type == "residual_ungated"`).
690
+ """
691
+
692
+ def __init__(
693
+ self,
694
+ merge_type: Literal["linear", "residual", "residual_ungated"],
695
+ channels: Union[int, Iterable[int]],
696
+ conv_strides: tuple[int] = (2, 2),
697
+ nonlin: Callable = nn.LeakyReLU(),
698
+ batchnorm: bool = True,
699
+ dropout: Optional[float] = None,
700
+ res_block_type: Optional[str] = None,
701
+ res_block_kernel: Optional[int] = None,
702
+ conv2d_bias: Optional[bool] = True,
703
+ ):
704
+ """
705
+ Constructor.
706
+
707
+ Parameters
708
+ ----------
709
+ merge_type: Literal["linear", "residual", "residual_ungated"]
710
+ The type of merge done in the layer. It can be chosen between "linear",
711
+ "residual", and "residual_ungated". Check the class docstring for more
712
+ information about the behaviour of different merge modalities.
713
+ channels: Union[int, Iterable[int]]
714
+ The number of channels used in the convolutional blocks of this layer.
715
+ If it is an `int`:
716
+ - 1st 1x1 Conv2d: in_channels=2*channels, out_channels=channels
717
+ - (Optional) ResBlock: in_channels=channels, out_channels=channels
718
+ If it is an Iterable (must have `len(channels)==3`):
719
+ - 1st 1x1 Conv2d: in_channels=sum(channels[:-1]),
720
+ out_channels=channels[-1]
721
+ - (Optional) ResBlock: in_channels=channels[-1],
722
+ out_channels=channels[-1]
723
+ conv_strides: tuple, optional
724
+ The strides used in the convolutions. Default is `(2, 2)`.
725
+ nonlin: Callable, optional
726
+ The non-linearity function used in the block. Default is `nn.LeakyReLU`.
727
+ batchnorm: bool, optional
728
+ Whether to use batchnorm layers. Default is `True`.
729
+ dropout: float, optional
730
+ The dropout probability in dropout layers. If `None` dropout is not used.
731
+ Default is `None`.
732
+ res_block_type: str, optional
733
+ A string specifying the structure of residual block.
734
+ Check `ResidualBlock` doscstring for more information.
735
+ Default is `None`.
736
+ res_block_kernel: Union[int, Iterable[int]], optional
737
+ The kernel size used in the convolutions of the residual block.
738
+ It can be either a single integer or a pair of integers defining the squared
739
+ kernel.
740
+ Default is `None`.
741
+ conv2d_bias: bool, optional
742
+ Whether to use bias term in convolutions. Default is `True`.
743
+ """
744
+ super().__init__()
745
+ try:
746
+ iter(channels)
747
+ except TypeError: # it is not iterable
748
+ channels = [channels] * 3
749
+ else: # it is iterable
750
+ if len(channels) == 1:
751
+ channels = [channels[0]] * 3
752
+
753
+ self.conv_layer: ConvType = getattr(nn, f"Conv{len(conv_strides)}d")
754
+
755
+ if merge_type == "linear":
756
+ self.layer = self.conv_layer(
757
+ sum(channels[:-1]), channels[-1], 1, bias=conv2d_bias
758
+ )
759
+ elif merge_type == "residual":
760
+ self.layer = nn.Sequential(
761
+ self.conv_layer(
762
+ sum(channels[:-1]), channels[-1], 1, padding=0, bias=conv2d_bias
763
+ ),
764
+ ResidualGatedBlock(
765
+ conv_strides=conv_strides,
766
+ channels=channels[-1],
767
+ nonlin=nonlin,
768
+ batchnorm=batchnorm,
769
+ dropout=dropout,
770
+ block_type=res_block_type,
771
+ kernel=res_block_kernel,
772
+ conv2d_bias=conv2d_bias,
773
+ ),
774
+ )
775
+ elif merge_type == "residual_ungated":
776
+ self.layer = nn.Sequential(
777
+ self.conv_layer(
778
+ sum(channels[:-1]), channels[-1], 1, padding=0, bias=conv2d_bias
779
+ ),
780
+ ResidualBlock(
781
+ conv_strides=conv_strides,
782
+ channels=channels[-1],
783
+ nonlin=nonlin,
784
+ batchnorm=batchnorm,
785
+ dropout=dropout,
786
+ block_type=res_block_type,
787
+ kernel=res_block_kernel,
788
+ conv2d_bias=conv2d_bias,
789
+ ),
790
+ )
791
+
792
+ def forward(self, *args) -> torch.Tensor:
793
+
794
+ # Concatenate the input tensors along dim=1
795
+ x = torch.cat(args, dim=1)
796
+
797
+ # Pass the concatenated tensor through the conv layer
798
+ x = self.layer(x)
799
+
800
+ return x
801
+
802
+
803
+ class MergeLowRes(MergeLayer):
804
+ """
805
+ Child class of `MergeLayer`.
806
+
807
+ Specifically designed to merge the low-resolution patches
808
+ that are used in Lateral Contextualization approach.
809
+ """
810
+
811
+ def __init__(self, *args, **kwargs):
812
+ self.retain_spatial_dims = kwargs.pop("multiscale_retain_spatial_dims")
813
+ self.multiscale_lowres_size_factor = kwargs.pop("multiscale_lowres_size_factor")
814
+ super().__init__(*args, **kwargs)
815
+
816
+ def forward(self, latent: torch.Tensor, lowres: torch.Tensor) -> torch.Tensor:
817
+ """Forward pass.
818
+
819
+ Parameters
820
+ ----------
821
+ latent: torch.Tensor
822
+ The output latent tensor from previous layer in the LVAE hierarchy.
823
+ lowres: torch.Tensor
824
+ The low-res patch image to be merged to increase the context.
825
+ """
826
+ # TODO: treat (X, Y) and Z differently (e.g., line 762)
827
+ if self.retain_spatial_dims:
828
+ # Pad latent tensor to match lowres tensor's shape
829
+ # Output.shape == Lowres.shape (== Input.shape),
830
+ # where Input is the input to the BU layer
831
+ latent = pad_img_tensor(latent, lowres.shape[2:])
832
+ else:
833
+ # Crop lowres tensor to match latent tensor's shape
834
+ lz, ly, lx = lowres.shape[2:]
835
+ z = lz // self.multiscale_lowres_size_factor
836
+ y = ly // self.multiscale_lowres_size_factor
837
+ x = lx // self.multiscale_lowres_size_factor
838
+ z_pad = (lz - z) // 2
839
+ y_pad = (ly - y) // 2
840
+ x_pad = (lx - x) // 2
841
+ lowres = lowres[:, :, z_pad:-z_pad, y_pad:-y_pad, x_pad:-x_pad]
842
+
843
+ return super().forward(latent, lowres)
844
+
845
+
846
+ class SkipConnectionMerger(MergeLayer):
847
+ """Specialized `MergeLayer` module, handles skip connections in the model."""
848
+
849
+ def __init__(
850
+ self,
851
+ nonlin: Callable,
852
+ channels: Union[int, Iterable[int]],
853
+ batchnorm: bool,
854
+ dropout: float,
855
+ res_block_type: str,
856
+ conv_strides: tuple[int] = (2, 2),
857
+ merge_type: Literal["linear", "residual", "residual_ungated"] = "residual",
858
+ conv2d_bias: bool = True,
859
+ res_block_kernel: Optional[int] = None,
860
+ ):
861
+ """
862
+ Constructor.
863
+
864
+ nonlin: Callable, optional
865
+ The non-linearity function used in the block. Default is `nn.LeakyReLU`.
866
+ channels: Union[int, Iterable[int]]
867
+ The number of channels used in the convolutional blocks of this layer.
868
+ If it is an `int`:
869
+ - 1st 1x1 Conv2d: in_channels=2*channels, out_channels=channels
870
+ - (Optional) ResBlock: in_channels=channels, out_channels=channels
871
+ If it is an Iterable (must have `len(channels)==3`):
872
+ - 1st 1x1 Conv2d: in_channels=sum(channels[:-1]), out_channels=channels[-1]
873
+ - (Optional) ResBlock: in_channels=channels[-1], out_channels=channels[-1]
874
+ batchnorm: bool
875
+ Whether to use batchnorm layers.
876
+ dropout: float
877
+ The dropout probability in dropout layers. If `None` dropout is not used.
878
+ res_block_type: str
879
+ A string specifying the structure of residual block.
880
+ Check `ResidualBlock` doscstring for more information.
881
+ conv_strides: tuple, optional
882
+ The strides used in the convolutions. Default is `(2, 2)`.
883
+ merge_type: Literal["linear", "residual", "residual_ungated"]
884
+ The type of merge done in the layer. It can be chosen between "linear", "residual", and "residual_ungated".
885
+ Check the class docstring for more information about the behaviour of different merge modalities.
886
+ conv2d_bias: bool, optional
887
+ Whether to use bias term in convolutions. Default is `True`.
888
+ res_block_kernel: Union[int, Iterable[int]], optional
889
+ The kernel size used in the convolutions of the residual block.
890
+ It can be either a single integer or a pair of integers defining the squared kernel.
891
+ Default is `None`.
892
+ """
893
+ super().__init__(
894
+ conv_strides=conv_strides,
895
+ channels=channels,
896
+ nonlin=nonlin,
897
+ merge_type=merge_type,
898
+ batchnorm=batchnorm,
899
+ dropout=dropout,
900
+ res_block_type=res_block_type,
901
+ res_block_kernel=res_block_kernel,
902
+ conv2d_bias=conv2d_bias,
903
+ )
904
+
905
+
906
+ class TopDownLayer(nn.Module):
907
+ """Top-down inference layer.
908
+
909
+ It includes:
910
+ - Stochastic sampling,
911
+ - Computation of KL divergence,
912
+ - A small deterministic ResNet that performs upsampling.
913
+
914
+ NOTE 1:
915
+ The algorithm for generative inference approximately works as follows:
916
+ - p_params = output of top-down layer above
917
+ - bu = inferred bottom-up value at this layer
918
+ - q_params = merge(bu, p_params)
919
+ - z = stochastic_layer(q_params)
920
+ - (optional) get and merge skip connection from prev top-down layer
921
+ - top-down deterministic ResNet
922
+
923
+ NOTE 2:
924
+ The Top-Down layer can work in two modes: inference and prediction/generative.
925
+ Depending on the particular mode, it follows distinct behaviours:
926
+ - In inference mode, parameters of q(z_i|z_i+1) are obtained from the inference path,
927
+ by merging outcomes of bottom-up and top-down passes. The exception is the top layer,
928
+ in which the parameters of q(z_L|x) are set as the output of the topmost bottom-up layer.
929
+ - On the contrary in predicition/generative mode, parameters of q(z_i|z_i+1) can be obtained
930
+ once again by merging bottom-up and top-down outputs (CONDITIONAL GENERATION), or it is
931
+ possible to directly sample from the prior p(z_i|z_i+1) (UNCONDITIONAL GENERATION).
932
+
933
+ NOTE 3:
934
+ When doing unconditional generation, bu_value is not available. Hence the
935
+ merge layer is not used, and z is sampled directly from p_params.
936
+
937
+ NOTE 4:
938
+ If this is the top layer, at inference time, the uppermost bottom-up value
939
+ is used directly as q_params, and p_params are defined in this layer
940
+ (while they are usually taken from the previous layer), and can be learned.
941
+ """
942
+
943
+ def __init__(
944
+ self,
945
+ z_dim: int,
946
+ n_res_blocks: int,
947
+ n_filters: int,
948
+ conv_strides: tuple[int],
949
+ is_top_layer: bool = False,
950
+ upsampling_steps: Union[int, None] = None,
951
+ nonlin: Union[Callable, None] = None,
952
+ merge_type: Union[
953
+ Literal["linear", "residual", "residual_ungated"], None
954
+ ] = None,
955
+ batchnorm: bool = True,
956
+ dropout: Union[float, None] = None,
957
+ stochastic_skip: bool = False,
958
+ res_block_type: Union[str, None] = None,
959
+ res_block_kernel: Union[int, None] = None,
960
+ groups: int = 1,
961
+ gated: Union[bool, None] = None,
962
+ learn_top_prior: bool = False,
963
+ top_prior_param_shape: Union[Iterable[int], None] = None,
964
+ analytical_kl: bool = False,
965
+ retain_spatial_dims: bool = False,
966
+ vanilla_latent_hw: Union[Iterable[int], None] = None,
967
+ input_image_shape: Union[tuple[int, int], None] = None,
968
+ normalize_latent_factor: float = 1.0,
969
+ conv2d_bias: bool = True,
970
+ stochastic_use_naive_exponential: bool = False,
971
+ ):
972
+ """
973
+ Constructor.
974
+
975
+ Parameters
976
+ ----------
977
+ z_dim: int
978
+ The size of the latent space.
979
+ n_res_blocks: int
980
+ The number of TopDownDeterministicResBlock blocks
981
+ n_filters: int
982
+ The number of channels present through out the layers of this block.
983
+ conv_strides: tuple, optional
984
+ The strides used in the convolutions. Default is `(2, 2)`.
985
+ is_top_layer: bool, optional
986
+ Whether the current layer is at the top of the Decoder hierarchy. Default is `False`.
987
+ upsampling_steps: int, optional
988
+ The number of upsampling steps that has to be done in this layer (typically 1).
989
+ Default is `None`.
990
+ nonlin: Callable, optional
991
+ The non-linearity function used in the block (e.g., `nn.ReLU`). Default is `None`.
992
+ merge_type: Literal["linear", "residual", "residual_ungated"], optional
993
+ The type of merge done in the layer. It can be chosen between "linear", "residual",
994
+ and "residual_ungated". Check the `MergeLayer` class docstring for more information
995
+ about the behaviour of different merging modalities. Default is `None`.
996
+ batchnorm: bool, optional
997
+ Whether to use batchnorm layers. Default is `True`.
998
+ dropout: float, optional
999
+ The dropout probability in dropout layers. If `None` dropout is not used.
1000
+ Default is `None`.
1001
+ stochastic_skip: bool, optional
1002
+ Whether to use skip connections between previous top-down layer's output and this layer's stochastic output.
1003
+ Stochastic skip connection allows the previous layer's output has a way to directly reach this hierarchical
1004
+ level, hence facilitating the gradient flow during backpropagation. Default is `False`.
1005
+ res_block_type: str, optional
1006
+ A string specifying the structure of residual block.
1007
+ Check `ResidualBlock` documentation for more information.
1008
+ Default is `None`.
1009
+ res_block_kernel: Union[int, Iterable[int]], optional
1010
+ The kernel size used in the convolutions of the residual block.
1011
+ It can be either a single integer or a pair of integers defining the squared kernel.
1012
+ Default is `None`.
1013
+ groups: int, optional
1014
+ The number of groups to consider in the convolutions. Default is 1.
1015
+ gated: bool, optional
1016
+ Whether to use gated layer in `ResidualBlock`. Default is `None`.
1017
+ learn_top_prior:
1018
+ Whether to set the top prior as learnable.
1019
+ If this is set to `False`, in the top-most layer the prior will be N(0,1).
1020
+ Otherwise, we will still have a normal distribution whose parameters will be learnt.
1021
+ Default is `False`.
1022
+ top_prior_param_shape: Iterable[int], optional
1023
+ The size of the tensor which expresses the mean and the variance
1024
+ of the prior for the top most layer. Default is `None`.
1025
+ analytical_kl: bool, optional
1026
+ If True, KL divergence is calculated according to the analytical formula.
1027
+ Otherwise, an MC approximation using sampled latents is calculated.
1028
+ Default is `False`.
1029
+ retain_spatial_dims: bool, optional
1030
+ If `True`, the size of Encoder's latent space is kept to `input_image_shape` within the topdown layer.
1031
+ This implies that the oput spatial size equals the input spatial size.
1032
+ To achieve this, we centercrop the intermediate representation.
1033
+ Default is `False`.
1034
+ vanilla_latent_hw: Iterable[int], optional
1035
+ The shape of the latent tensor used for prediction (i.e., it influences the computation of restricted KL).
1036
+ Default is `None`.
1037
+ input_image_shape: Tuple[int, int], optionalut
1038
+ The shape of the input image tensor.
1039
+ When `retain_spatial_dims` is set to `True`, this is used to ensure that the shape of this layer
1040
+ output has the same shape as the input. Default is `None`.
1041
+ normalize_latent_factor: float, optional
1042
+ A factor used to normalize the latent tensors `q_params`.
1043
+ Specifically, normalization is done by dividing the latent tensor by this factor.
1044
+ Default is 1.0.
1045
+ conv2d_bias: bool, optional
1046
+ Whether to use bias term is the convolutional blocks of this layer.
1047
+ Default is `True`.
1048
+ stochastic_use_naive_exponential: bool, optional
1049
+ If `False`, in the NormalStochasticBlock2d exponentials are computed according
1050
+ to the alternative definition provided by `StableExponential` class.
1051
+ This should improve numerical stability in the training process.
1052
+ Default is `False`.
1053
+ """
1054
+ super().__init__()
1055
+
1056
+ self.is_top_layer = is_top_layer
1057
+ self.z_dim = z_dim
1058
+ self.stochastic_skip = stochastic_skip
1059
+ self.learn_top_prior = learn_top_prior
1060
+ self.analytical_kl = analytical_kl
1061
+ self.retain_spatial_dims = retain_spatial_dims
1062
+ self.input_image_shape = (
1063
+ input_image_shape if len(conv_strides) == 3 else input_image_shape[1:]
1064
+ )
1065
+ self.latent_shape = self.input_image_shape if self.retain_spatial_dims else None
1066
+ self.normalize_latent_factor = normalize_latent_factor
1067
+ self._vanilla_latent_hw = vanilla_latent_hw # TODO: check this, it is not used
1068
+
1069
+ # Define top layer prior parameters, possibly learnable
1070
+ if is_top_layer:
1071
+ self.top_prior_params = nn.Parameter(
1072
+ torch.zeros(top_prior_param_shape), requires_grad=learn_top_prior
1073
+ )
1074
+
1075
+ # Upsampling steps left to do in this layer
1076
+ ups_left = upsampling_steps
1077
+
1078
+ # Define deterministic top-down block, which is a sequence of deterministic
1079
+ # residual blocks with (optional) upsampling.
1080
+ block_list = []
1081
+ for _ in range(n_res_blocks):
1082
+ do_resample = False
1083
+ if ups_left > 0:
1084
+ do_resample = True
1085
+ ups_left -= 1
1086
+ block_list.append(
1087
+ TopDownDeterministicResBlock(
1088
+ c_in=n_filters,
1089
+ c_out=n_filters,
1090
+ conv_strides=conv_strides,
1091
+ nonlin=nonlin,
1092
+ upsample=do_resample,
1093
+ batchnorm=batchnorm,
1094
+ dropout=dropout,
1095
+ res_block_type=res_block_type,
1096
+ res_block_kernel=res_block_kernel,
1097
+ gated=gated,
1098
+ conv2d_bias=conv2d_bias,
1099
+ groups=groups,
1100
+ )
1101
+ )
1102
+ self.deterministic_block = nn.Sequential(*block_list)
1103
+
1104
+ # Define stochastic block with convolutions
1105
+
1106
+ self.stochastic = NormalStochasticBlock(
1107
+ c_in=n_filters,
1108
+ c_vars=z_dim,
1109
+ c_out=n_filters,
1110
+ conv_dims=len(conv_strides),
1111
+ transform_p_params=(not is_top_layer),
1112
+ vanilla_latent_hw=vanilla_latent_hw,
1113
+ use_naive_exponential=stochastic_use_naive_exponential,
1114
+ )
1115
+
1116
+ if not is_top_layer:
1117
+ # Merge layer: it combines bottom-up inference and top-down
1118
+ # generative outcomes to give posterior parameters
1119
+ self.merge = MergeLayer(
1120
+ channels=n_filters,
1121
+ conv_strides=conv_strides,
1122
+ merge_type=merge_type,
1123
+ nonlin=nonlin,
1124
+ batchnorm=batchnorm,
1125
+ dropout=dropout,
1126
+ res_block_type=res_block_type,
1127
+ res_block_kernel=res_block_kernel,
1128
+ conv2d_bias=conv2d_bias,
1129
+ )
1130
+
1131
+ # Skip connection that goes around the stochastic top-down layer
1132
+ if stochastic_skip:
1133
+ self.skip_connection_merger = SkipConnectionMerger(
1134
+ channels=n_filters,
1135
+ conv_strides=conv_strides,
1136
+ nonlin=nonlin,
1137
+ batchnorm=batchnorm,
1138
+ dropout=dropout,
1139
+ res_block_type=res_block_type,
1140
+ merge_type=merge_type,
1141
+ conv2d_bias=conv2d_bias,
1142
+ res_block_kernel=res_block_kernel,
1143
+ )
1144
+
1145
+ def sample_from_q(
1146
+ self,
1147
+ input_: torch.Tensor,
1148
+ bu_value: torch.Tensor,
1149
+ var_clip_max: Optional[float] = None,
1150
+ mask: torch.Tensor = None,
1151
+ ) -> torch.Tensor:
1152
+ """
1153
+ Method computes the latent inference distribution q(z_i|z_{i+1}).
1154
+
1155
+ Used for sampling a latent tensor from it.
1156
+
1157
+ Parameters
1158
+ ----------
1159
+ input_: torch.Tensor
1160
+ The input tensor to the layer, which is the output of the top-down layer.
1161
+ bu_value: torch.Tensor
1162
+ The tensor defining the parameters /mu_q and /sigma_q computed during the
1163
+ bottom-up deterministic pass at the correspondent hierarchical layer.
1164
+ var_clip_max: float, optional
1165
+ The maximum value reachable by the log-variance of the latent distribution.
1166
+ Values exceeding this threshold are clipped. Default is `None`.
1167
+ mask: Union[None, torch.Tensor], optional
1168
+ A tensor that is used to mask the sampled latent tensor. Default is `None`.
1169
+ """
1170
+ if self.is_top_layer: # In top layer, we don't merge bu_value with p_params
1171
+ q_params = bu_value
1172
+ else:
1173
+ # NOTE: Here the assumption is that the vampprior is only applied on the top layer.
1174
+ n_img_prior = None
1175
+ p_params = self.get_p_params(input_, n_img_prior)
1176
+ q_params = self.merge(bu_value, p_params)
1177
+
1178
+ sample = self.stochastic.sample_from_q(q_params, var_clip_max)
1179
+
1180
+ if mask:
1181
+ return sample[mask]
1182
+
1183
+ return sample
1184
+
1185
+ def get_p_params(
1186
+ self,
1187
+ input_: torch.Tensor,
1188
+ n_img_prior: int,
1189
+ ) -> torch.Tensor:
1190
+ """Return the parameters of the prior distribution p(z_i|z_{i+1}).
1191
+
1192
+ The parameters depend on the hierarchical level of the layer:
1193
+ - if it is the topmost level, parameters are the ones of the prior.
1194
+ - else, the input from the layer above is the parameters itself.
1195
+
1196
+ Parameters
1197
+ ----------
1198
+ input_: torch.Tensor
1199
+ The input tensor to the layer, which is the output of the top-down layer above.
1200
+ n_img_prior: int
1201
+ The number of images to be generated from the unconditional prior distribution p(z_L).
1202
+ """
1203
+ p_params = None
1204
+
1205
+ # If top layer, define p_params as the ones of the prior p(z_L)
1206
+ if self.is_top_layer:
1207
+ p_params = self.top_prior_params
1208
+
1209
+ # Sample specific number of images by expanding the prior
1210
+ if n_img_prior is not None:
1211
+ p_params = p_params.expand(n_img_prior, -1, -1, -1)
1212
+
1213
+ # Else the input from the layer above is p_params itself
1214
+ else:
1215
+ p_params = input_
1216
+
1217
+ return p_params
1218
+
1219
+ def forward(
1220
+ self,
1221
+ input_: Union[torch.Tensor, None] = None,
1222
+ skip_connection_input: Union[torch.Tensor, None] = None,
1223
+ inference_mode: bool = False,
1224
+ bu_value: Union[torch.Tensor, None] = None,
1225
+ n_img_prior: Union[int, None] = None,
1226
+ forced_latent: Union[torch.Tensor, None] = None,
1227
+ force_constant_output: bool = False,
1228
+ mode_pred: bool = False,
1229
+ use_uncond_mode: bool = False,
1230
+ var_clip_max: Union[float, None] = None,
1231
+ ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
1232
+ """Forward pass.
1233
+
1234
+ Parameters
1235
+ ----------
1236
+ input_: torch.Tensor, optional
1237
+ The input tensor to the layer, which is the output of the top-down layer.
1238
+ Default is `None`.
1239
+ skip_connection_input: torch.Tensor, optional
1240
+ The tensor brought by the skip connection between the current and the
1241
+ previous top-down layer.
1242
+ Default is `None`.
1243
+ inference_mode: bool, optional
1244
+ Whether the layer is in inference mode. See NOTE 2 in class description
1245
+ for more info.
1246
+ Default is `False`.
1247
+ bu_value: torch.Tensor, optional
1248
+ The tensor defining the parameters /mu_q and /sigma_q computed during the
1249
+ bottom-up deterministic pass
1250
+ at the correspondent hierarchical layer. Default is `None`.
1251
+ n_img_prior: int, optional
1252
+ The number of images to be generated from the unconditional prior
1253
+ distribution p(z_L).
1254
+ Default is `None`.
1255
+ forced_latent: torch.Tensor, optional
1256
+ A pre-defined latent tensor. If it is not `None`, than it is used as the
1257
+ actual latent tensor and,
1258
+ hence, sampling does not happen. Default is `None`.
1259
+ force_constant_output: bool, optional
1260
+ Whether to copy the first sample (and rel. distrib parameters) over the
1261
+ whole batch.
1262
+ This is used when doing experiment from the prior - q is not used.
1263
+ Default is `False`.
1264
+ mode_pred: bool, optional
1265
+ Whether the model is in prediction mode. Default is `False`.
1266
+ use_uncond_mode: bool, optional
1267
+ Whether to use the uncoditional distribution p(z) to sample latents in
1268
+ prediction mode.
1269
+ var_clip_max: float
1270
+ The maximum value reachable by the log-variance of the latent distribution.
1271
+ Values exceeding this threshold are clipped.
1272
+ """
1273
+ # Check consistency of arguments
1274
+ inputs_none = input_ is None and skip_connection_input is None
1275
+ if self.is_top_layer and not inputs_none:
1276
+ raise ValueError("In top layer, inputs should be None")
1277
+
1278
+ p_params = self.get_p_params(input_, n_img_prior)
1279
+
1280
+ # Get the parameters for the latent distribution to sample from
1281
+ if inference_mode: # TODO What's this ? reuse Fede's code?
1282
+ if self.is_top_layer:
1283
+ q_params = bu_value
1284
+ if mode_pred is False:
1285
+ assert p_params.shape[2:] == bu_value.shape[2:], (
1286
+ "Spatial dimensions of p_params and bu_value should match. "
1287
+ f"Instead, we got p_params={p_params.shape[2:]} and "
1288
+ f"bu_value={bu_value.shape[2:]}."
1289
+ )
1290
+ else:
1291
+ if use_uncond_mode:
1292
+ q_params = p_params
1293
+ else:
1294
+ assert p_params.shape[2:] == bu_value.shape[2:], (
1295
+ "Spatial dimensions of p_params and bu_value should match. "
1296
+ f"Instead, we got p_params={p_params.shape[2:]} and "
1297
+ f"bu_value={bu_value.shape[2:]}."
1298
+ )
1299
+ q_params = self.merge(bu_value, p_params)
1300
+ else: # generative mode, q is not used, we sample from p(z_i | z_{i+1})
1301
+ q_params = None
1302
+
1303
+ # NOTE: Sampling is done either from q(z_i | z_{i+1}, x) or p(z_i | z_{i+1})
1304
+ # depending on the mode (hence, in practice, by checking whether q_params is None).
1305
+
1306
+ # Normalization of latent space parameters for stablity.
1307
+ # See Very deep VAEs generalize autoregressive models.
1308
+ if self.normalize_latent_factor:
1309
+ q_params = q_params / self.normalize_latent_factor
1310
+
1311
+ # Sample (and process) a latent tensor in the stochastic layer
1312
+ x, data_stoch = self.stochastic(
1313
+ p_params=p_params,
1314
+ q_params=q_params,
1315
+ forced_latent=forced_latent,
1316
+ force_constant_output=force_constant_output,
1317
+ analytical_kl=self.analytical_kl,
1318
+ mode_pred=mode_pred,
1319
+ use_uncond_mode=use_uncond_mode,
1320
+ var_clip_max=var_clip_max,
1321
+ )
1322
+ # Merge skip connection from previous layer
1323
+ if self.stochastic_skip and not self.is_top_layer:
1324
+ x = self.skip_connection_merger(x, skip_connection_input)
1325
+ if self.retain_spatial_dims:
1326
+ # NOTE: we assume that one topdown layer will have exactly one upscaling layer.
1327
+
1328
+ # NOTE: in case, in the Bottom-Up layer, LC retains spatial dimensions,
1329
+ # we have the following (see `MergeLowRes`):
1330
+ # - the "primary-flow" tensor is padded to match the low-res patch size
1331
+ # (e.g., from 32x32 to 64x64)
1332
+ # - padded tensor is then merged with the low-res patch (concatenation
1333
+ # along dim=1 + convolution)
1334
+ # Therefore, we need to do the symmetric operation here, that is to
1335
+ # crop `x` for the same amount we padded it in the correspondent BU layer.
1336
+
1337
+ # NOTE: cropping is done to retain the shape of the input in the output.
1338
+ # Therefore we need it only in the case `x` is the same shape of the input,
1339
+ # because that's the only case in which we need to retain the shape.
1340
+ # Here, it must be strictly greater than half the input shape, which is
1341
+ # the case if and only if `x.shape == self.latent_shape`.
1342
+ rescale = (
1343
+ np.array((1, 2, 2)) if len(self.latent_shape) == 3 else np.array((2, 2))
1344
+ ) # TODO better way?
1345
+ new_latent_shape = tuple(np.array(self.latent_shape) // rescale)
1346
+ if x.shape[-1] > new_latent_shape[-1]:
1347
+ x = crop_img_tensor(x, new_latent_shape)
1348
+ # TODO: `retain_spatial_dims` is the same for all the TD layers.
1349
+ # How to handle the case in which we do not have LC for all layers?
1350
+ # The answer is in `self.latent_shape`, which is equal to `input_image_shape`
1351
+ # (e.g., (64, 64)) if `retain_spatial_dims` is `True`, else it is `None`.
1352
+ # Last top-down block (sequence of residual blocks w\ upsampling)
1353
+ x = self.deterministic_block(x)
1354
+ # Save some metrics that will be used in the loss computation
1355
+ keys = [
1356
+ "z",
1357
+ "kl_samplewise",
1358
+ "kl_samplewise_restricted",
1359
+ "kl_spatial",
1360
+ "kl_channelwise",
1361
+ "logprob_q",
1362
+ "qvar_max",
1363
+ ]
1364
+ data = {k: data_stoch.get(k, None) for k in keys}
1365
+ data["q_mu"] = None
1366
+ data["q_lv"] = None
1367
+ if data_stoch["q_params"] is not None:
1368
+ q_mu, q_lv = data_stoch["q_params"]
1369
+ data["q_mu"] = q_mu
1370
+ data["q_lv"] = q_lv
1371
+ return x, data