careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,848 @@
1
+ """
2
+ Ladder VAE (LVAE) Model.
3
+
4
+ The current implementation is based on "Interpretable Unsupervised Diversity Denoising
5
+ and Artefact Removal, Prakash et al."
6
+ """
7
+
8
+ from collections.abc import Iterable
9
+ from typing import Union
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+
15
+ from ..activation import get_activation
16
+ from .layers import (
17
+ BottomUpDeterministicResBlock,
18
+ BottomUpLayer,
19
+ GateLayer,
20
+ TopDownDeterministicResBlock,
21
+ TopDownLayer,
22
+ )
23
+ from .utils import Interpolate, ModelType, crop_img_tensor
24
+
25
+
26
+ class LadderVAE(nn.Module):
27
+ """
28
+ Constructor.
29
+
30
+ Parameters
31
+ ----------
32
+ input_shape : int
33
+ The size of the input image.
34
+ output_channels : int
35
+ The number of output channels.
36
+ multiscale_count : int
37
+ The number of scales for multiscale processing.
38
+ z_dims : list[int]
39
+ The dimensions of the latent space for each layer.
40
+ encoder_n_filters : int
41
+ The number of filters in the encoder.
42
+ decoder_n_filters : int
43
+ The number of filters in the decoder.
44
+ encoder_conv_strides : list[int]
45
+ The strides for the conv layers encoder.
46
+ decoder_conv_strides : list[int]
47
+ The strides for the conv layers decoder.
48
+ encoder_dropout : float
49
+ The dropout rate for the encoder.
50
+ decoder_dropout : float
51
+ The dropout rate for the decoder.
52
+ nonlinearity : str
53
+ The nonlinearity function to use.
54
+ predict_logvar : bool
55
+ Whether to predict the log variance.
56
+ analytical_kl : bool
57
+ Whether to use analytical KL divergence.
58
+
59
+ Raises
60
+ ------
61
+ NotImplementedError
62
+ If only 2D convolutions are supported.
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ input_shape: int,
68
+ output_channels: int,
69
+ multiscale_count: int,
70
+ z_dims: list[int],
71
+ encoder_n_filters: int,
72
+ decoder_n_filters: int,
73
+ encoder_conv_strides: list[int],
74
+ decoder_conv_strides: list[int],
75
+ encoder_dropout: float,
76
+ decoder_dropout: float,
77
+ nonlinearity: str,
78
+ predict_logvar: bool,
79
+ analytical_kl: bool,
80
+ ):
81
+ super().__init__()
82
+
83
+ # -------------------------------------------------------
84
+ # Customizable attributes
85
+ self.image_size = input_shape
86
+ """Input image size. (Z, Y, X) or (Y, X) if the data is 2D."""
87
+ # TODO: we need to be careful with this since used to be an int.
88
+ # the tuple of shapes used to be `self.input_shape`.
89
+ self.target_ch = output_channels
90
+ self.encoder_conv_strides = encoder_conv_strides
91
+ self.decoder_conv_strides = decoder_conv_strides
92
+ self._multiscale_count = multiscale_count
93
+ self.z_dims = z_dims
94
+ self.encoder_n_filters = encoder_n_filters
95
+ self.decoder_n_filters = decoder_n_filters
96
+ self.encoder_dropout = encoder_dropout
97
+ self.decoder_dropout = decoder_dropout
98
+ self.nonlin = nonlinearity
99
+ self.predict_logvar = predict_logvar
100
+ self.analytical_kl = analytical_kl
101
+ # -------------------------------------------------------
102
+
103
+ # -------------------------------------------------------
104
+ # Model attributes -> Hardcoded
105
+ self.model_type = ModelType.LadderVae # TODO remove !
106
+ self.encoder_blocks_per_layer = 1
107
+ self.decoder_blocks_per_layer = 1
108
+ self.bottomup_batchnorm = True
109
+ self.topdown_batchnorm = True
110
+ self.topdown_conv2d_bias = True
111
+ self.gated = True
112
+ self.encoder_res_block_kernel = 3
113
+ self.decoder_res_block_kernel = 3
114
+ self.encoder_res_block_skip_padding = False
115
+ self.decoder_res_block_skip_padding = False
116
+ self.merge_type = "residual"
117
+ self.no_initial_downscaling = True
118
+ self.skip_bottomk_buvalues = 0
119
+ self.stochastic_skip = True
120
+ self.learn_top_prior = True
121
+ self.res_block_type = "bacdbacd" # TODO remove !
122
+ self.mode_pred = False
123
+ self.logvar_lowerbound = -5
124
+ self._var_clip_max = 20
125
+ self._stochastic_use_naive_exponential = False
126
+ self._enable_topdown_normalize_factor = True
127
+
128
+ # Attributes that handle LC -> Hardcoded
129
+ self.enable_multiscale = self._multiscale_count > 1
130
+ self.multiscale_retain_spatial_dims = True
131
+ self.multiscale_lowres_separate_branch = False
132
+ self.multiscale_decoder_retain_spatial_dims = (
133
+ self.multiscale_retain_spatial_dims and self.enable_multiscale
134
+ )
135
+
136
+ # Derived attributes
137
+ self.n_layers = len(self.z_dims)
138
+
139
+ # Others...
140
+ self._tethered_to_input = False
141
+ self._tethered_ch1_scalar = self._tethered_ch2_scalar = None
142
+ if self._tethered_to_input:
143
+ target_ch = 1
144
+ requires_grad = False
145
+ self._tethered_ch1_scalar = nn.Parameter(
146
+ torch.ones(1) * 0.5, requires_grad=requires_grad
147
+ )
148
+ self._tethered_ch2_scalar = nn.Parameter(
149
+ torch.ones(1) * 2.0, requires_grad=requires_grad
150
+ )
151
+ # -------------------------------------------------------
152
+
153
+ # -------------------------------------------------------
154
+ # Data attributes
155
+ self.color_ch = 1 # TODO for now we only support 1 channel
156
+ self.normalized_input = True
157
+ # -------------------------------------------------------
158
+
159
+ # -------------------------------------------------------
160
+ # Loss attributes
161
+ # enabling reconstruction loss on mixed input
162
+ self.mixed_rec_w = 0
163
+ self.nbr_consistency_w = 0
164
+
165
+ # -------------------------------------------------------
166
+ # 3D related stuff
167
+ self._mode_3D = len(self.image_size) == 3 # TODO refac
168
+ self._model_3D_depth = self.image_size[0] if self._mode_3D else 1
169
+ self._decoder_mode_3D = len(self.decoder_conv_strides) == 3
170
+ if self._mode_3D and not self._decoder_mode_3D:
171
+ assert self._model_3D_depth % 2 == 1, "3D model depth should be odd"
172
+ assert (
173
+ self._mode_3D is True or self._decoder_mode_3D is False
174
+ ), "Decoder cannot be 3D when encoder is 2D"
175
+ self._squish3d = self._mode_3D and not self._decoder_mode_3D
176
+ self._3D_squisher = (
177
+ None
178
+ if not self._squish3d
179
+ else nn.ModuleList(
180
+ [
181
+ GateLayer(
182
+ channels=self.encoder_n_filters,
183
+ conv_strides=self.encoder_conv_strides,
184
+ )
185
+ for k in range(len(self.z_dims))
186
+ ]
187
+ )
188
+ )
189
+ # TODO: this bit is in the Ashesh's confusing-hacky style... Can we do better?
190
+
191
+ # -------------------------------------------------------
192
+ # # Training attributes
193
+ # # can be used to tile the validation predictions
194
+ # self._val_idx_manager = val_idx_manager
195
+ # self._val_frame_creator = None
196
+ # # initialize the learning rate scheduler params.
197
+ # self.lr_scheduler_monitor = self.lr_scheduler_mode = None
198
+ # self._init_lr_scheduler_params(config)
199
+ # self._global_step = 0
200
+ # -------------------------------------------------------
201
+
202
+ # -------------------------------------------------------
203
+
204
+ # Calculate the downsampling happening in the network
205
+ self.downsample = [1] * self.n_layers
206
+ self.overall_downscale_factor = np.power(2, sum(self.downsample))
207
+ if not self.no_initial_downscaling: # by default do another downscaling
208
+ self.overall_downscale_factor *= 2
209
+
210
+ assert max(self.downsample) <= self.encoder_blocks_per_layer
211
+ assert len(self.downsample) == self.n_layers
212
+ # -------------------------------------------------------
213
+
214
+ # -------------------------------------------------------
215
+ ### CREATE MODEL BLOCKS
216
+ # First bottom-up layer: change num channels + downsample by factor 2
217
+ # unless we want to prevent this
218
+ self.encoder_conv_op = getattr(nn, f"Conv{len(self.encoder_conv_strides)}d")
219
+ # TODO these should be defined for all layers here ?
220
+ self.decoder_conv_op = getattr(nn, f"Conv{len(self.decoder_conv_strides)}d")
221
+ # TODO: would be more readable to have a derived parameters to use like
222
+ # `conv_dims = len(self.encoder_conv_strides)` and then use `Conv{conv_dims}d`
223
+ stride = 1 if self.no_initial_downscaling else 2
224
+ self.first_bottom_up = self.create_first_bottom_up(stride)
225
+
226
+ # Input Branches for Lateral Contextualization
227
+ self.lowres_first_bottom_ups = None
228
+ self._init_multires()
229
+
230
+ # Other bottom-up layers
231
+ self.bottom_up_layers = self.create_bottom_up_layers(
232
+ self.multiscale_lowres_separate_branch
233
+ )
234
+
235
+ # Top-down layers
236
+ self.top_down_layers = self.create_top_down_layers()
237
+ self.final_top_down = self.create_final_topdown_layer(
238
+ not self.no_initial_downscaling
239
+ )
240
+
241
+ # Likelihood module
242
+ # self.likelihood = self.create_likelihood_module()
243
+
244
+ # Output layer --> Project to target_ch many channels
245
+ logvar_ch_needed = self.predict_logvar is not None
246
+ self.output_layer = self.parameter_net = self.decoder_conv_op(
247
+ self.decoder_n_filters,
248
+ self.target_ch * (1 + logvar_ch_needed),
249
+ kernel_size=3,
250
+ padding=1,
251
+ bias=self.topdown_conv2d_bias,
252
+ )
253
+
254
+ # # gradient norms. updated while training. this is also logged.
255
+ # self.grad_norm_bottom_up = 0.0
256
+ # self.grad_norm_top_down = 0.0
257
+ # PSNR computation on validation.
258
+ # self.label1_psnr = RunningPSNR()
259
+ # self.label2_psnr = RunningPSNR()
260
+ # TODO: did you add this?
261
+
262
+ # msg =f'[{self.__class__.__name__}] Stoc:{not self.non_stochastic_version} RecMode:{self.reconstruction_mode} TethInput:{self._tethered_to_input}'
263
+ # msg += f' TargetCh: {self.target_ch}'
264
+ # print(msg)
265
+
266
+ ### SET OF METHODS TO CREATE MODEL BLOCKS
267
+ def create_first_bottom_up(
268
+ self,
269
+ init_stride: int,
270
+ num_res_blocks: int = 1,
271
+ ) -> nn.Sequential:
272
+ """
273
+ Method creates the first bottom-up block of the Encoder.
274
+
275
+ Its role is to perform a first image compression step.
276
+ It is composed by a sequence of nn.Conv2d + non-linearity +
277
+ BottomUpDeterministicResBlock (1 or more, default is 1).
278
+
279
+ Parameters
280
+ ----------
281
+ init_stride: int
282
+ The stride used by the intial Conv2d block.
283
+ num_res_blocks: int, optional
284
+ The number of BottomUpDeterministicResBlocks, default is 1.
285
+ """
286
+ # From what I got from Ashesh, Z should not be touched in any case.
287
+ nonlin = get_activation(self.nonlin)
288
+ conv_block = self.encoder_conv_op(
289
+ in_channels=self.color_ch,
290
+ out_channels=self.encoder_n_filters,
291
+ kernel_size=self.encoder_res_block_kernel,
292
+ padding=(
293
+ 0
294
+ if self.encoder_res_block_skip_padding
295
+ else self.encoder_res_block_kernel // 2
296
+ ),
297
+ stride=init_stride,
298
+ )
299
+
300
+ modules = [conv_block, nonlin]
301
+
302
+ for _ in range(num_res_blocks):
303
+ modules.append(
304
+ BottomUpDeterministicResBlock(
305
+ conv_strides=self.encoder_conv_strides,
306
+ c_in=self.encoder_n_filters,
307
+ c_out=self.encoder_n_filters,
308
+ nonlin=nonlin,
309
+ downsample=False,
310
+ batchnorm=self.bottomup_batchnorm,
311
+ dropout=self.encoder_dropout,
312
+ res_block_type=self.res_block_type,
313
+ res_block_kernel=self.encoder_res_block_kernel,
314
+ )
315
+ )
316
+
317
+ return nn.Sequential(*modules)
318
+
319
+ def create_bottom_up_layers(self, lowres_separate_branch: bool) -> nn.ModuleList:
320
+ """
321
+ Method creates the stack of bottom-up layers of the Encoder.
322
+
323
+ that are used to generate the so-called `bu_values`.
324
+
325
+ NOTE:
326
+ If `self._multiscale_count < self.n_layers`, then LC is done only in the first
327
+ `self._multiscale_count` bottom-up layers (starting from the bottom).
328
+
329
+ Parameters
330
+ ----------
331
+ lowres_separate_branch: bool
332
+ Whether the residual block(s) used for encoding the low-res input are shared
333
+ (`False`) or not (`True`) with the "same-size" residual block(s) in the
334
+ `BottomUpLayer`'s primary flow.
335
+ """
336
+ multiscale_lowres_size_factor = 1
337
+ nonlin = get_activation(self.nonlin)
338
+
339
+ bottom_up_layers = nn.ModuleList([])
340
+ for i in range(self.n_layers):
341
+ # Whether this is the top layer
342
+ is_top = i == self.n_layers - 1
343
+
344
+ # LC is applied only to the first (_multiscale_count - 1) bottom-up layers
345
+ layer_enable_multiscale = (
346
+ self.enable_multiscale and self._multiscale_count > i + 1
347
+ )
348
+
349
+ # This factor determines the factor by which the low-resolution tensor is larger
350
+ # N.B. Only used if layer_enable_multiscale == True, so we updated it only in that case
351
+ multiscale_lowres_size_factor *= 1 + int(layer_enable_multiscale)
352
+
353
+ # TODO: check correctness of this
354
+ if self._multiscale_count > 1:
355
+ output_expected_shape = (dim // 2 ** (i + 1) for dim in self.image_size)
356
+ else:
357
+ output_expected_shape = None
358
+
359
+ # Add bottom-up deterministic layer at level i.
360
+ # It's a sequence of residual blocks (BottomUpDeterministicResBlock), possibly with downsampling between them.
361
+ bottom_up_layers.append(
362
+ BottomUpLayer(
363
+ n_res_blocks=self.encoder_blocks_per_layer,
364
+ n_filters=self.encoder_n_filters,
365
+ downsampling_steps=self.downsample[i],
366
+ nonlin=nonlin,
367
+ conv_strides=self.encoder_conv_strides,
368
+ batchnorm=self.bottomup_batchnorm,
369
+ dropout=self.encoder_dropout,
370
+ res_block_type=self.res_block_type,
371
+ res_block_kernel=self.encoder_res_block_kernel,
372
+ gated=self.gated,
373
+ lowres_separate_branch=lowres_separate_branch,
374
+ enable_multiscale=self.enable_multiscale, # TODO: shouldn't the arg be `layer_enable_multiscale` here?
375
+ multiscale_retain_spatial_dims=self.multiscale_retain_spatial_dims,
376
+ multiscale_lowres_size_factor=multiscale_lowres_size_factor,
377
+ decoder_retain_spatial_dims=self.multiscale_decoder_retain_spatial_dims,
378
+ output_expected_shape=output_expected_shape,
379
+ )
380
+ )
381
+
382
+ return bottom_up_layers
383
+
384
+ def create_top_down_layers(self) -> nn.ModuleList:
385
+ """
386
+ Method creates the stack of top-down layers of the Decoder.
387
+
388
+ In these layer the `bu`_values` from the Encoder are merged with the `p_params` from the previous layer
389
+ of the Decoder to get `q_params`. Then, a stochastic layer generates a sample from the latent distribution
390
+ with parameters `q_params`. Finally, this sample is fed through a TopDownDeterministicResBlock to
391
+ compute the `p_params` for the layer below.
392
+
393
+ NOTE 1:
394
+ The algorithm for generative inference approximately works as follows:
395
+ - p_params = output of top-down layer above
396
+ - bu = inferred bottom-up value at this layer
397
+ - q_params = merge(bu, p_params)
398
+ - z = stochastic_layer(q_params)
399
+ - (optional) get and merge skip connection from prev top-down layer
400
+ - top-down deterministic ResNet
401
+
402
+ NOTE 2:
403
+ When doing unconditional generation, bu_value is not available. Hence the
404
+ merge layer is not used, and z is sampled directly from p_params.
405
+
406
+ """
407
+ top_down_layers = nn.ModuleList([])
408
+ nonlin = get_activation(self.nonlin)
409
+ # NOTE: top-down layers are created starting from the bottom-most
410
+ for i in range(self.n_layers):
411
+ # Check if this is the top layer
412
+ is_top = i == self.n_layers - 1
413
+
414
+ if self._enable_topdown_normalize_factor: # TODO: What is this?
415
+ normalize_latent_factor = (
416
+ 1 / np.sqrt(2 * (1 + i)) if len(self.z_dims) > 4 else 1.0
417
+ )
418
+ else:
419
+ normalize_latent_factor = 1.0
420
+
421
+ top_down_layers.append(
422
+ TopDownLayer(
423
+ z_dim=self.z_dims[i],
424
+ n_res_blocks=self.decoder_blocks_per_layer,
425
+ n_filters=self.decoder_n_filters,
426
+ is_top_layer=is_top,
427
+ conv_strides=self.decoder_conv_strides,
428
+ upsampling_steps=self.downsample[i],
429
+ nonlin=nonlin,
430
+ merge_type=self.merge_type,
431
+ batchnorm=self.topdown_batchnorm,
432
+ dropout=self.decoder_dropout,
433
+ stochastic_skip=self.stochastic_skip,
434
+ learn_top_prior=self.learn_top_prior,
435
+ top_prior_param_shape=self.get_top_prior_param_shape(),
436
+ res_block_type=self.res_block_type,
437
+ res_block_kernel=self.decoder_res_block_kernel,
438
+ gated=self.gated,
439
+ analytical_kl=self.analytical_kl,
440
+ vanilla_latent_hw=self.get_latent_spatial_size(i),
441
+ retain_spatial_dims=self.multiscale_decoder_retain_spatial_dims,
442
+ input_image_shape=self.image_size,
443
+ normalize_latent_factor=normalize_latent_factor,
444
+ conv2d_bias=self.topdown_conv2d_bias,
445
+ stochastic_use_naive_exponential=self._stochastic_use_naive_exponential,
446
+ )
447
+ )
448
+ return top_down_layers
449
+
450
+ def create_final_topdown_layer(self, upsample: bool) -> nn.Sequential:
451
+ """Create the final top-down layer of the Decoder.
452
+
453
+ NOTE: In this layer, (optional) upsampling is performed by bilinear interpolation
454
+ instead of transposed convolution (like in other TD layers).
455
+
456
+ Parameters
457
+ ----------
458
+ upsample: bool
459
+ Whether to upsample the input of the final top-down layer
460
+ by bilinear interpolation with `scale_factor=2`.
461
+ """
462
+ # Final top-down layer
463
+ modules = list()
464
+
465
+ if upsample:
466
+ modules.append(Interpolate(scale=2))
467
+
468
+ for i in range(self.decoder_blocks_per_layer):
469
+ modules.append(
470
+ TopDownDeterministicResBlock(
471
+ c_in=self.decoder_n_filters,
472
+ c_out=self.decoder_n_filters,
473
+ nonlin=get_activation(self.nonlin),
474
+ conv_strides=self.decoder_conv_strides,
475
+ batchnorm=self.topdown_batchnorm,
476
+ dropout=self.decoder_dropout,
477
+ res_block_type=self.res_block_type,
478
+ res_block_kernel=self.decoder_res_block_kernel,
479
+ gated=self.gated,
480
+ conv2d_bias=self.topdown_conv2d_bias,
481
+ )
482
+ )
483
+ return nn.Sequential(*modules)
484
+
485
+ def _init_multires(self, config=None) -> nn.ModuleList:
486
+ """
487
+ Method defines the input block/branch to encode/compress low-res lateral inputs.
488
+
489
+ at different hierarchical levels
490
+ in the multiresolution approach (LC). The role of the input branches is similar
491
+ to the one of the first bottom-up layer in the primary flow of the Encoder,
492
+ namely to compress the lateral input image to a degree that is compatible with
493
+ the one of the primary flow.
494
+
495
+ NOTE 1: Each input branch consists of a sequence of Conv2d + non-linearity
496
+ + BottomUpDeterministicResBlock. It is meaningful to observe that the
497
+ `BottomUpDeterministicResBlock` shares the same model attributes with the blocks
498
+ in the primary flow of the Encoder (e.g., c_in, c_out, dropout, etc. etc.).
499
+ Moreover, it does not perform downsampling.
500
+
501
+ NOTE 2: `_multiscale_count` attribute defines the total number of inputs to the
502
+ bottom-up pass. In other terms if we have the input patch and n_LC additional
503
+ lateral inputs, we will have a total of (n_LC + 1) inputs.
504
+ """
505
+ stride = 1 if self.no_initial_downscaling else 2
506
+ nonlin = get_activation(self.nonlin)
507
+ if self._multiscale_count is None:
508
+ self._multiscale_count = 1
509
+
510
+ msg = (
511
+ f"Multiscale count ({self._multiscale_count}) should not exceed the number"
512
+ f"of bottom up layers ({self.n_layers}) by more than 1.\n"
513
+ )
514
+ assert (
515
+ self._multiscale_count <= 1 or self._multiscale_count <= 1 + self.n_layers
516
+ ), msg # TODO how ?
517
+
518
+ msg = (
519
+ "Multiscale approach only supports monocrome images. "
520
+ f"Found instead color_ch={self.color_ch}."
521
+ )
522
+ # assert self._multiscale_count == 1 or self.color_ch == 1, msg
523
+
524
+ lowres_first_bottom_ups = []
525
+ for _ in range(1, self._multiscale_count):
526
+ first_bottom_up = nn.Sequential(
527
+ self.encoder_conv_op(
528
+ in_channels=self.color_ch,
529
+ out_channels=self.encoder_n_filters,
530
+ kernel_size=5,
531
+ padding="same",
532
+ stride=stride,
533
+ ),
534
+ nonlin,
535
+ BottomUpDeterministicResBlock(
536
+ c_in=self.encoder_n_filters,
537
+ c_out=self.encoder_n_filters,
538
+ conv_strides=self.encoder_conv_strides,
539
+ nonlin=nonlin,
540
+ downsample=False,
541
+ batchnorm=self.bottomup_batchnorm,
542
+ dropout=self.encoder_dropout,
543
+ res_block_type=self.res_block_type,
544
+ ),
545
+ )
546
+ lowres_first_bottom_ups.append(first_bottom_up)
547
+
548
+ self.lowres_first_bottom_ups = (
549
+ nn.ModuleList(lowres_first_bottom_ups)
550
+ if len(lowres_first_bottom_ups)
551
+ else None
552
+ )
553
+
554
+ ### SET OF FORWARD-LIKE METHODS
555
+ def bottomup_pass(self, inp: torch.Tensor) -> list[torch.Tensor]:
556
+ """Wrapper of _bottomup_pass()."""
557
+ # TODO Remove wrapper
558
+ return self._bottomup_pass(
559
+ inp,
560
+ self.first_bottom_up,
561
+ self.lowres_first_bottom_ups,
562
+ self.bottom_up_layers,
563
+ )
564
+
565
+ def _bottomup_pass(
566
+ self,
567
+ inp: torch.Tensor,
568
+ first_bottom_up: nn.Sequential,
569
+ lowres_first_bottom_ups: nn.ModuleList,
570
+ bottom_up_layers: nn.ModuleList,
571
+ ) -> list[torch.Tensor]:
572
+ """
573
+ Method defines the forward pass through the LVAE Encoder, the so-called.
574
+
575
+ Bottom-Up pass.
576
+
577
+ Parameters
578
+ ----------
579
+ inp: torch.Tensor
580
+ The input tensor to the bottom-up pass of shape (B, 1+n_LC, H, W), where n_LC
581
+ is the number of lateral low-res inputs used in the LC approach.
582
+ In particular, the first channel corresponds to the input patch, while the
583
+ remaining ones are associated to the lateral low-res inputs.
584
+ first_bottom_up: nn.Sequential
585
+ The module defining the first bottom-up layer of the Encoder.
586
+ lowres_first_bottom_ups: nn.ModuleList
587
+ The list of modules defining Lateral Contextualization.
588
+ bottom_up_layers: nn.ModuleList
589
+ The list of modules defining the stack of bottom-up layers of the Encoder.
590
+ """
591
+ if self._multiscale_count > 1:
592
+ x = first_bottom_up(inp[:, :1])
593
+ else:
594
+ x = first_bottom_up(inp)
595
+
596
+ # Loop from bottom to top layer, store all deterministic nodes we
597
+ # need for the top-down pass in bu_values list
598
+ bu_values = []
599
+ for i in range(self.n_layers):
600
+ lowres_x = None
601
+ if self._multiscale_count > 1 and i + 1 < inp.shape[1]:
602
+ lowres_x = lowres_first_bottom_ups[i](inp[:, i + 1 : i + 2])
603
+ x, bu_value = bottom_up_layers[i](x, lowres_x=lowres_x)
604
+ bu_values.append(bu_value)
605
+
606
+ return bu_values
607
+
608
+ def topdown_pass(
609
+ self,
610
+ bu_values: Union[torch.Tensor, None] = None,
611
+ n_img_prior: Union[torch.Tensor, None] = None,
612
+ constant_layers: Union[Iterable[int], None] = None,
613
+ forced_latent: Union[list[torch.Tensor], None] = None,
614
+ top_down_layers: Union[nn.ModuleList, None] = None,
615
+ final_top_down_layer: Union[nn.Sequential, None] = None,
616
+ ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
617
+ """
618
+ Method defines the forward pass through the LVAE Decoder, the so-called.
619
+
620
+ Top-Down pass.
621
+
622
+ Parameters
623
+ ----------
624
+ bu_values: torch.Tensor, optional
625
+ Output of the bottom-up pass. It will have values from multiple layers of
626
+ the ladder.
627
+ n_img_prior: optional
628
+ When `bu_values` is `None`, `n_img_prior` indicates the number of images to
629
+ generate
630
+ from the prior (so bottom-up pass is not used at all here).
631
+ constant_layers: Iterable[int], optional
632
+ A sequence of indexes associated to the layers in which a single instance's
633
+ z is copied over the entire batch (bottom-up path is not used, so only prior
634
+ is used here). Set to `None` to avoid this behaviour.
635
+ forced_latent: list[torch.Tensor], optional
636
+ A list of tensors that are used as fixed latent variables (hence, sampling
637
+ doesn't take place in this case).
638
+ top_down_layers: nn.ModuleList, optional
639
+ A list of top-down layers to use in the top-down pass. If `None`, the method
640
+ uses the default layers defined in the constructor.
641
+ final_top_down_layer: nn.Sequential, optional
642
+ The last top-down layer of the top-down pass. If `None`, the method uses the
643
+ default layers defined in the constructor.
644
+ """
645
+ if top_down_layers is None:
646
+ top_down_layers = self.top_down_layers
647
+ if final_top_down_layer is None:
648
+ final_top_down_layer = self.final_top_down
649
+
650
+ # Default: no layer is sampled from the distribution's mode
651
+ if constant_layers is None:
652
+ constant_layers = []
653
+ prior_experiment = len(constant_layers) > 0
654
+
655
+ # If the bottom-up inference values are not given, don't do
656
+ # inference, sample from prior instead
657
+ inference_mode = bu_values is not None
658
+
659
+ # Check consistency of arguments
660
+ if inference_mode != (n_img_prior is None):
661
+ msg = (
662
+ "Number of images for top-down generation has to be given "
663
+ "if and only if we're not doing inference"
664
+ )
665
+ raise RuntimeError(msg)
666
+ if inference_mode and prior_experiment:
667
+ msg = (
668
+ "Prior experiments (e.g. sampling from mode) are not"
669
+ " compatible with inference mode"
670
+ )
671
+ raise RuntimeError(msg)
672
+
673
+ # Sampled latent variables at each layer
674
+ z = [None] * self.n_layers
675
+ # KL divergence of each layer
676
+ kl = [None] * self.n_layers
677
+ # Kl divergence restricted, only for the LC enabled setup denoiSplit.
678
+ kl_restricted = [None] * self.n_layers
679
+ # mean from which z is sampled.
680
+ q_mu = [None] * self.n_layers
681
+ # log(var) from which z is sampled.
682
+ q_lv = [None] * self.n_layers
683
+ # Spatial map of KL divergence for each layer
684
+ kl_spatial = [None] * self.n_layers
685
+ debug_qvar_max = [None] * self.n_layers
686
+ kl_channelwise = [None] * self.n_layers
687
+ if forced_latent is None:
688
+ forced_latent = [None] * self.n_layers
689
+
690
+ # Top-down inference/generation loop
691
+ out = None
692
+ for i in reversed(range(self.n_layers)):
693
+ # If available, get deterministic node from bottom-up inference
694
+ try:
695
+ bu_value = bu_values[i]
696
+ except TypeError:
697
+ bu_value = None
698
+
699
+ # Whether the current layer should be sampled from the mode
700
+ constant_out = i in constant_layers
701
+
702
+ # Input for skip connection
703
+ skip_input = out
704
+
705
+ # Full top-down layer, including sampling and deterministic part
706
+ out, aux = top_down_layers[i](
707
+ input_=out,
708
+ skip_connection_input=skip_input,
709
+ inference_mode=inference_mode,
710
+ bu_value=bu_value,
711
+ n_img_prior=n_img_prior,
712
+ force_constant_output=constant_out,
713
+ forced_latent=forced_latent[i],
714
+ mode_pred=self.mode_pred,
715
+ var_clip_max=self._var_clip_max,
716
+ )
717
+ # Save useful variables
718
+ z[i] = aux["z"] # sampled variable at this layer (batch, ch, h, w)
719
+ kl[i] = aux["kl_samplewise"] # (batch, )
720
+ kl_restricted[i] = aux["kl_samplewise_restricted"]
721
+ kl_spatial[i] = aux["kl_spatial"] # (batch, h, w)
722
+ q_mu[i] = aux["q_mu"]
723
+ q_lv[i] = aux["q_lv"]
724
+
725
+ kl_channelwise[i] = aux["kl_channelwise"]
726
+ debug_qvar_max[i] = aux["qvar_max"]
727
+ # if self.mode_pred is False:
728
+ # logprob_p += aux['logprob_p'].mean() # mean over batch
729
+ # else:
730
+ # logprob_p = None
731
+
732
+ # Final top-down layer
733
+ out = final_top_down_layer(out)
734
+
735
+ # Store useful variables in a dict to return them
736
+ data = {
737
+ "z": z, # list of tensors with shape (batch, ch[i], h[i], w[i])
738
+ "kl": kl, # list of tensors with shape (batch, )
739
+ "kl_restricted": kl_restricted, # list of tensors with shape (batch, )
740
+ "kl_spatial": kl_spatial, # list of tensors w shape (batch, h[i], w[i])
741
+ "kl_channelwise": kl_channelwise, # list of tensors with shape (batch, ch[i])
742
+ # 'logprob_p': logprob_p, # scalar, mean over batch
743
+ "q_mu": q_mu,
744
+ "q_lv": q_lv,
745
+ "debug_qvar_max": debug_qvar_max,
746
+ }
747
+ return out, data
748
+
749
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
750
+ """
751
+ Forward pass through the LVAE model.
752
+
753
+ Parameters
754
+ ----------
755
+ x: torch.Tensor
756
+ The input tensor of shape (B, C, H, W).
757
+ """
758
+ img_size = x.size()[2:]
759
+
760
+ # Bottom-up inference: return list of length n_layers (bottom to top)
761
+ bu_values = self.bottomup_pass(x)
762
+ for i in range(0, self.skip_bottomk_buvalues):
763
+ bu_values[i] = None
764
+
765
+ if self._squish3d:
766
+ bu_values = [
767
+ torch.mean(self._3D_squisher[k](bu_value), dim=2)
768
+ for k, bu_value in enumerate(bu_values)
769
+ ]
770
+
771
+ # Top-down inference/generation
772
+ out, td_data = self.topdown_pass(bu_values)
773
+
774
+ if out.shape[-1] > img_size[-1]:
775
+ # Restore original image size
776
+ out = crop_img_tensor(out, img_size)
777
+
778
+ out = self.output_layer(out)
779
+
780
+ return out, td_data
781
+
782
+ ### SET OF GETTERS
783
+ def get_padded_size(self, size):
784
+ """
785
+ Returns the smallest size (H, W) of the image with actual size given
786
+ as input, such that H and W are powers of 2.
787
+ :param size: input size, tuple either (N, C, H, W) or (H, W)
788
+ :return: 2-tuple (H, W)
789
+ """
790
+ # Make size argument into (heigth, width)
791
+ # assert len(size) in [2, 4, 5] # TODO commented out cuz it's weird
792
+ # We're only interested in the Y,X dimensions
793
+ size = size[-2:]
794
+
795
+ if self.multiscale_decoder_retain_spatial_dims is True:
796
+ # In this case, we can go much more deeper and so this is not required
797
+ # (in the way it is. ;). More work would be needed if this was to be correctly implemented )
798
+ return list(size)
799
+
800
+ # Overall downscale factor from input to top layer (power of 2)
801
+ dwnsc = self.overall_downscale_factor
802
+
803
+ # Output smallest powers of 2 that are larger than current sizes
804
+ padded_size = [((s - 1) // dwnsc + 1) * dwnsc for s in size]
805
+ # TODO Needed for pad/crop odd sizes. Move to dataset?
806
+ return padded_size
807
+
808
+ def get_latent_spatial_size(self, level_idx: int):
809
+ """Level_idx: 0 is the bottommost layer, the highest resolution one."""
810
+ actual_downsampling = level_idx + 1
811
+ dwnsc = 2**actual_downsampling
812
+ sz = self.get_padded_size(self.image_size)
813
+ h = sz[0] // dwnsc
814
+ w = sz[1] // dwnsc
815
+ assert h == w
816
+ return h
817
+
818
+ def get_top_prior_param_shape(self, n_imgs: int = 1):
819
+
820
+ # Compute the total downscaling performed in the Encoder
821
+ if self.multiscale_decoder_retain_spatial_dims is False:
822
+ dwnsc = self.overall_downscale_factor
823
+ else:
824
+ # LC allow the encoder latents to keep the same (H, W) size at different levels
825
+ actual_downsampling = self.n_layers + 1 - self._multiscale_count
826
+ dwnsc = 2**actual_downsampling
827
+
828
+ h = self.image_size[-2] // dwnsc
829
+ w = self.image_size[-1] // dwnsc
830
+ mu_logvar = self.z_dims[-1] * 2 # mu and logvar
831
+ top_layer_shape = (n_imgs, mu_logvar, h, w)
832
+ # TODO refactor!
833
+ if self._model_3D_depth > 1 and self._decoder_mode_3D is True:
834
+ # TODO check if model_3D_depth is needed ?
835
+ top_layer_shape = (n_imgs, mu_logvar, self._model_3D_depth, h, w)
836
+ return top_layer_shape
837
+
838
+ def reset_for_inference(self, tile_size: tuple[int, int] | None = None):
839
+ """Should be called if we want to predict for a different input/output size."""
840
+ self.mode_pred = True
841
+ if tile_size is None:
842
+ tile_size = self.image_size
843
+ self.image_size = tile_size
844
+ for i in range(self.n_layers):
845
+ self.bottom_up_layers[i].output_expected_shape = (
846
+ ts // 2 ** (i + 1) for ts in tile_size
847
+ )
848
+ self.top_down_layers[i].latent_shape = tile_size