careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (155) hide show
  1. careamics/__init__.py +6 -1
  2. careamics/careamist.py +729 -0
  3. careamics/config/__init__.py +39 -0
  4. careamics/config/architectures/__init__.py +17 -0
  5. careamics/config/architectures/architecture_model.py +37 -0
  6. careamics/config/architectures/custom_model.py +162 -0
  7. careamics/config/architectures/lvae_model.py +174 -0
  8. careamics/config/architectures/register_model.py +103 -0
  9. careamics/config/architectures/unet_model.py +118 -0
  10. careamics/config/callback_model.py +123 -0
  11. careamics/config/configuration_factory.py +583 -0
  12. careamics/config/configuration_model.py +604 -0
  13. careamics/config/data_model.py +527 -0
  14. careamics/config/fcn_algorithm_model.py +147 -0
  15. careamics/config/inference_model.py +239 -0
  16. careamics/config/likelihood_model.py +43 -0
  17. careamics/config/nm_model.py +101 -0
  18. careamics/config/optimizer_models.py +187 -0
  19. careamics/config/references/__init__.py +45 -0
  20. careamics/config/references/algorithm_descriptions.py +132 -0
  21. careamics/config/references/references.py +39 -0
  22. careamics/config/support/__init__.py +31 -0
  23. careamics/config/support/supported_activations.py +27 -0
  24. careamics/config/support/supported_algorithms.py +33 -0
  25. careamics/config/support/supported_architectures.py +17 -0
  26. careamics/config/support/supported_data.py +109 -0
  27. careamics/config/support/supported_loggers.py +10 -0
  28. careamics/config/support/supported_losses.py +29 -0
  29. careamics/config/support/supported_optimizers.py +57 -0
  30. careamics/config/support/supported_pixel_manipulations.py +15 -0
  31. careamics/config/support/supported_struct_axis.py +21 -0
  32. careamics/config/support/supported_transforms.py +11 -0
  33. careamics/config/tile_information.py +65 -0
  34. careamics/config/training_model.py +72 -0
  35. careamics/config/transformations/__init__.py +15 -0
  36. careamics/config/transformations/n2v_manipulate_model.py +64 -0
  37. careamics/config/transformations/normalize_model.py +60 -0
  38. careamics/config/transformations/transform_model.py +45 -0
  39. careamics/config/transformations/xy_flip_model.py +43 -0
  40. careamics/config/transformations/xy_random_rotate90_model.py +35 -0
  41. careamics/config/vae_algorithm_model.py +171 -0
  42. careamics/config/validators/__init__.py +5 -0
  43. careamics/config/validators/validator_utils.py +101 -0
  44. careamics/conftest.py +39 -0
  45. careamics/dataset/__init__.py +17 -0
  46. careamics/dataset/dataset_utils/__init__.py +19 -0
  47. careamics/dataset/dataset_utils/dataset_utils.py +101 -0
  48. careamics/dataset/dataset_utils/file_utils.py +141 -0
  49. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  50. careamics/dataset/dataset_utils/running_stats.py +186 -0
  51. careamics/dataset/in_memory_dataset.py +310 -0
  52. careamics/dataset/in_memory_pred_dataset.py +88 -0
  53. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  54. careamics/dataset/iterable_dataset.py +295 -0
  55. careamics/dataset/iterable_pred_dataset.py +122 -0
  56. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  57. careamics/dataset/patching/__init__.py +1 -0
  58. careamics/dataset/patching/patching.py +299 -0
  59. careamics/dataset/patching/random_patching.py +201 -0
  60. careamics/dataset/patching/sequential_patching.py +212 -0
  61. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  62. careamics/dataset/tiling/__init__.py +10 -0
  63. careamics/dataset/tiling/collate_tiles.py +33 -0
  64. careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
  65. careamics/dataset/tiling/tiled_patching.py +164 -0
  66. careamics/dataset/zarr_dataset.py +151 -0
  67. careamics/file_io/__init__.py +15 -0
  68. careamics/file_io/read/__init__.py +12 -0
  69. careamics/file_io/read/get_func.py +56 -0
  70. careamics/file_io/read/tiff.py +58 -0
  71. careamics/file_io/read/zarr.py +60 -0
  72. careamics/file_io/write/__init__.py +15 -0
  73. careamics/file_io/write/get_func.py +63 -0
  74. careamics/file_io/write/tiff.py +40 -0
  75. careamics/lightning/__init__.py +18 -0
  76. careamics/lightning/callbacks/__init__.py +11 -0
  77. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  78. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  79. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  80. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
  81. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
  82. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  83. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  84. careamics/lightning/lightning_module.py +632 -0
  85. careamics/lightning/predict_data_module.py +333 -0
  86. careamics/lightning/train_data_module.py +680 -0
  87. careamics/losses/__init__.py +15 -0
  88. careamics/losses/fcn/__init__.py +1 -0
  89. careamics/losses/fcn/losses.py +98 -0
  90. careamics/losses/loss_factory.py +155 -0
  91. careamics/losses/lvae/__init__.py +1 -0
  92. careamics/losses/lvae/loss_utils.py +83 -0
  93. careamics/losses/lvae/losses.py +445 -0
  94. careamics/lvae_training/__init__.py +0 -0
  95. careamics/lvae_training/dataset/__init__.py +0 -0
  96. careamics/lvae_training/dataset/data_utils.py +701 -0
  97. careamics/lvae_training/dataset/lc_dataset.py +259 -0
  98. careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
  99. careamics/lvae_training/dataset/vae_data_config.py +179 -0
  100. careamics/lvae_training/dataset/vae_dataset.py +1054 -0
  101. careamics/lvae_training/eval_utils.py +905 -0
  102. careamics/lvae_training/get_config.py +84 -0
  103. careamics/lvae_training/lightning_module.py +701 -0
  104. careamics/lvae_training/metrics.py +214 -0
  105. careamics/lvae_training/train_lvae.py +342 -0
  106. careamics/lvae_training/train_utils.py +121 -0
  107. careamics/model_io/__init__.py +7 -0
  108. careamics/model_io/bioimage/__init__.py +11 -0
  109. careamics/model_io/bioimage/_readme_factory.py +121 -0
  110. careamics/model_io/bioimage/bioimage_utils.py +52 -0
  111. careamics/model_io/bioimage/model_description.py +327 -0
  112. careamics/model_io/bmz_io.py +246 -0
  113. careamics/model_io/model_io_utils.py +95 -0
  114. careamics/models/__init__.py +5 -0
  115. careamics/models/activation.py +39 -0
  116. careamics/models/layers.py +493 -0
  117. careamics/models/lvae/__init__.py +3 -0
  118. careamics/models/lvae/layers.py +1998 -0
  119. careamics/models/lvae/likelihoods.py +364 -0
  120. careamics/models/lvae/lvae.py +901 -0
  121. careamics/models/lvae/noise_models.py +541 -0
  122. careamics/models/lvae/utils.py +395 -0
  123. careamics/models/model_factory.py +67 -0
  124. careamics/models/unet.py +443 -0
  125. careamics/prediction_utils/__init__.py +10 -0
  126. careamics/prediction_utils/lvae_prediction.py +158 -0
  127. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  128. careamics/prediction_utils/prediction_outputs.py +135 -0
  129. careamics/prediction_utils/stitch_prediction.py +112 -0
  130. careamics/transforms/__init__.py +20 -0
  131. careamics/transforms/compose.py +107 -0
  132. careamics/transforms/n2v_manipulate.py +146 -0
  133. careamics/transforms/normalize.py +243 -0
  134. careamics/transforms/pixel_manipulation.py +407 -0
  135. careamics/transforms/struct_mask_parameters.py +20 -0
  136. careamics/transforms/transform.py +24 -0
  137. careamics/transforms/tta.py +88 -0
  138. careamics/transforms/xy_flip.py +123 -0
  139. careamics/transforms/xy_random_rotate90.py +101 -0
  140. careamics/utils/__init__.py +19 -0
  141. careamics/utils/autocorrelation.py +40 -0
  142. careamics/utils/base_enum.py +60 -0
  143. careamics/utils/context.py +66 -0
  144. careamics/utils/logging.py +322 -0
  145. careamics/utils/metrics.py +188 -0
  146. careamics/utils/path_utils.py +26 -0
  147. careamics/utils/ram.py +15 -0
  148. careamics/utils/receptive_field.py +108 -0
  149. careamics/utils/torch_utils.py +127 -0
  150. careamics-0.0.3.dist-info/METADATA +78 -0
  151. careamics-0.0.3.dist-info/RECORD +154 -0
  152. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
  153. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
  154. careamics-0.0.1.dist-info/METADATA +0 -46
  155. careamics-0.0.1.dist-info/RECORD +0 -6
@@ -0,0 +1,901 @@
1
+ """
2
+ Ladder VAE (LVAE) Model
3
+
4
+ The current implementation is based on "Interpretable Unsupervised Diversity Denoising and Artefact Removal, Prakash et al."
5
+ """
6
+
7
+ from collections.abc import Iterable
8
+ from typing import Dict, List, Tuple
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+ from careamics.config.architectures import register_model
15
+
16
+ from ..activation import get_activation
17
+ from .layers import (
18
+ BottomUpDeterministicResBlock,
19
+ BottomUpLayer,
20
+ TopDownDeterministicResBlock,
21
+ TopDownLayer,
22
+ )
23
+ from .utils import Interpolate, ModelType, crop_img_tensor, pad_img_tensor
24
+
25
+
26
+ @register_model("LVAE")
27
+ class LadderVAE(nn.Module):
28
+
29
+ def __init__(
30
+ self,
31
+ input_shape: int,
32
+ output_channels: int,
33
+ multiscale_count: int,
34
+ z_dims: List[int],
35
+ encoder_n_filters: int,
36
+ decoder_n_filters: int,
37
+ encoder_dropout: float,
38
+ decoder_dropout: float,
39
+ nonlinearity: str,
40
+ predict_logvar: bool,
41
+ enable_noise_model: bool,
42
+ analytical_kl: bool,
43
+ ):
44
+ """
45
+ Constructor.
46
+
47
+ Parameters
48
+ ----------
49
+
50
+ """
51
+ super().__init__()
52
+
53
+ # -------------------------------------------------------
54
+ # Customizable attributes
55
+ self.image_size = input_shape
56
+ self.target_ch = output_channels
57
+ self._multiscale_count = multiscale_count
58
+ self.z_dims = z_dims
59
+ self.encoder_n_filters = encoder_n_filters
60
+ self.decoder_n_filters = decoder_n_filters
61
+ self.encoder_dropout = encoder_dropout
62
+ self.decoder_dropout = decoder_dropout
63
+ self.nonlin = nonlinearity
64
+ self.predict_logvar = predict_logvar
65
+ self.enable_noise_model = enable_noise_model
66
+
67
+ self.analytical_kl = analytical_kl
68
+ # -------------------------------------------------------
69
+
70
+ # -------------------------------------------------------
71
+ # Model attributes -> Hardcoded
72
+ self.model_type = ModelType.LadderVae # TODO remove !
73
+ self.model_type = ModelType.LadderVae # TODO remove !
74
+ self.encoder_blocks_per_layer = 1
75
+ self.decoder_blocks_per_layer = 1
76
+ self.bottomup_batchnorm = True
77
+ self.topdown_batchnorm = True
78
+ self.topdown_conv2d_bias = True
79
+ self.gated = True
80
+ self.encoder_res_block_kernel = 3
81
+ self.decoder_res_block_kernel = 3
82
+ self.encoder_res_block_skip_padding = False
83
+ self.decoder_res_block_skip_padding = False
84
+ self.merge_type = "residual"
85
+ self.no_initial_downscaling = True
86
+ self.skip_bottomk_buvalues = 0
87
+ self.non_stochastic_version = False
88
+ self.stochastic_skip = True
89
+ self.learn_top_prior = True
90
+ self.res_block_type = "bacdbacd" # TODO remove !
91
+ self.mode_pred = False
92
+ self.logvar_lowerbound = -5
93
+ self._var_clip_max = 20
94
+ self._stochastic_use_naive_exponential = False
95
+ self._enable_topdown_normalize_factor = True
96
+
97
+ # Noise model attributes -> Hardcoded
98
+ self.noise_model_type = "gmm"
99
+ self.denoise_channel = (
100
+ "input" # 4 values for denoise_channel {'Ch1', 'Ch2', 'input','all'}
101
+ )
102
+ self.noise_model_learnable = False
103
+
104
+ # Attributes that handle LC -> Hardcoded
105
+ self.enable_multiscale = (
106
+ self._multiscale_count is not None and self._multiscale_count > 1
107
+ )
108
+ self.multiscale_retain_spatial_dims = True
109
+ self.multiscale_lowres_separate_branch = False
110
+ self.multiscale_decoder_retain_spatial_dims = (
111
+ self.multiscale_retain_spatial_dims and self.enable_multiscale
112
+ )
113
+
114
+ # Derived attributes
115
+ self.n_layers = len(self.z_dims)
116
+ self.encoder_no_padding_mode = (
117
+ self.encoder_res_block_skip_padding is True
118
+ and self.encoder_res_block_kernel > 1
119
+ )
120
+ self.decoder_no_padding_mode = (
121
+ self.decoder_res_block_skip_padding is True
122
+ and self.decoder_res_block_kernel > 1
123
+ )
124
+
125
+ # Others...
126
+ self._tethered_to_input = False
127
+ self._tethered_ch1_scalar = self._tethered_ch2_scalar = None
128
+ if self._tethered_to_input:
129
+ target_ch = 1
130
+ requires_grad = False
131
+ self._tethered_ch1_scalar = nn.Parameter(
132
+ torch.ones(1) * 0.5, requires_grad=requires_grad
133
+ )
134
+ self._tethered_ch2_scalar = nn.Parameter(
135
+ torch.ones(1) * 2.0, requires_grad=requires_grad
136
+ )
137
+ # -------------------------------------------------------
138
+
139
+ # -------------------------------------------------------
140
+ # Data attributes
141
+ self.color_ch = 1
142
+ self.img_shape = (self.image_size, self.image_size)
143
+ self.normalized_input = True
144
+ # -------------------------------------------------------
145
+
146
+ # -------------------------------------------------------
147
+ # Loss attributes
148
+ self._restricted_kl = False # HC
149
+ # enabling reconstruction loss on mixed input
150
+ self.mixed_rec_w = 0
151
+ self.nbr_consistency_w = 0
152
+
153
+ # -------------------------------------------------------
154
+
155
+ # -------------------------------------------------------
156
+ # # Training attributes
157
+ # # can be used to tile the validation predictions
158
+ # self._val_idx_manager = val_idx_manager
159
+ # self._val_frame_creator = None
160
+ # # initialize the learning rate scheduler params.
161
+ # self.lr_scheduler_monitor = self.lr_scheduler_mode = None
162
+ # self._init_lr_scheduler_params(config)
163
+ # self._global_step = 0
164
+ # -------------------------------------------------------
165
+
166
+ # -------------------------------------------------------
167
+
168
+ # Calculate the downsampling happening in the network
169
+ self.downsample = [1] * self.n_layers
170
+ self.overall_downscale_factor = np.power(2, sum(self.downsample))
171
+ if not self.no_initial_downscaling: # by default do another downscaling
172
+ self.overall_downscale_factor *= 2
173
+
174
+ assert max(self.downsample) <= self.encoder_blocks_per_layer
175
+ assert len(self.downsample) == self.n_layers
176
+ # -------------------------------------------------------
177
+
178
+ # -------------------------------------------------------
179
+ ### CREATE MODEL BLOCKS
180
+ # First bottom-up layer: change num channels + downsample by factor 2
181
+ # unless we want to prevent this
182
+ stride = 1 if self.no_initial_downscaling else 2
183
+ self.first_bottom_up = self.create_first_bottom_up(stride)
184
+
185
+ # Input Branches for Lateral Contextualization
186
+ self.lowres_first_bottom_ups = None
187
+ self._init_multires()
188
+
189
+ # Other bottom-up layers
190
+ self.bottom_up_layers = self.create_bottom_up_layers(
191
+ self.multiscale_lowres_separate_branch
192
+ )
193
+
194
+ # Top-down layers
195
+ self.top_down_layers = self.create_top_down_layers()
196
+ self.final_top_down = self.create_final_topdown_layer(
197
+ not self.no_initial_downscaling
198
+ )
199
+
200
+ # Likelihood module
201
+ # self.likelihood = self.create_likelihood_module()
202
+
203
+ # Output layer --> Project to target_ch many channels
204
+ logvar_ch_needed = self.predict_logvar is not None
205
+ self.output_layer = self.parameter_net = nn.Conv2d(
206
+ self.decoder_n_filters,
207
+ self.target_ch * (1 + logvar_ch_needed),
208
+ kernel_size=3,
209
+ padding=1,
210
+ bias=self.topdown_conv2d_bias,
211
+ )
212
+
213
+ # # gradient norms. updated while training. this is also logged.
214
+ # self.grad_norm_bottom_up = 0.0
215
+ # self.grad_norm_top_down = 0.0
216
+ # PSNR computation on validation.
217
+ # self.label1_psnr = RunningPSNR()
218
+ # self.label2_psnr = RunningPSNR()
219
+
220
+ # msg =f'[{self.__class__.__name__}] Stoc:{not self.non_stochastic_version} RecMode:{self.reconstruction_mode} TethInput:{self._tethered_to_input}'
221
+ # msg += f' TargetCh: {self.target_ch}'
222
+ # print(msg)
223
+
224
+ ### SET OF METHODS TO CREATE MODEL BLOCKS
225
+ def create_first_bottom_up(
226
+ self,
227
+ init_stride: int,
228
+ num_res_blocks: int = 1,
229
+ ) -> nn.Sequential:
230
+ """
231
+ This method creates the first bottom-up block of the Encoder.
232
+ Its role is to perform a first image compression step.
233
+ It is composed by a sequence of nn.Conv2d + non-linearity +
234
+ BottomUpDeterministicResBlock (1 or more, default is 1).
235
+
236
+ Parameters
237
+ ----------
238
+ init_stride: int
239
+ The stride used by the initial Conv2d block.
240
+ num_res_blocks: int, optional
241
+ The number of BottomUpDeterministicResBlocks to include in the layer, default is 1.
242
+ """
243
+ nonlin = get_activation(self.nonlin)
244
+ modules = [
245
+ nn.Conv2d(
246
+ in_channels=self.color_ch,
247
+ out_channels=self.encoder_n_filters,
248
+ kernel_size=self.encoder_res_block_kernel,
249
+ padding=(
250
+ 0
251
+ if self.encoder_res_block_skip_padding
252
+ else self.encoder_res_block_kernel // 2
253
+ ),
254
+ stride=init_stride,
255
+ ),
256
+ nonlin,
257
+ ]
258
+
259
+ for _ in range(num_res_blocks):
260
+ modules.append(
261
+ BottomUpDeterministicResBlock(
262
+ c_in=self.encoder_n_filters,
263
+ c_out=self.encoder_n_filters,
264
+ nonlin=nonlin,
265
+ downsample=False,
266
+ batchnorm=self.bottomup_batchnorm,
267
+ dropout=self.encoder_dropout,
268
+ res_block_type=self.res_block_type,
269
+ skip_padding=self.encoder_res_block_skip_padding,
270
+ res_block_kernel=self.encoder_res_block_kernel,
271
+ )
272
+ )
273
+
274
+ return nn.Sequential(*modules)
275
+
276
+ def create_bottom_up_layers(self, lowres_separate_branch: bool) -> nn.ModuleList:
277
+ """
278
+ This method creates the stack of bottom-up layers of the Encoder
279
+ that are used to generate the so-called `bu_values`.
280
+
281
+ NOTE:
282
+ If `self._multiscale_count < self.n_layers`, then LC is done only in the first
283
+ `self._multiscale_count` bottom-up layers (starting from the bottom).
284
+
285
+ Parameters
286
+ ----------
287
+ lowres_separate_branch: bool
288
+ Whether the residual block(s) used for encoding the low-res input are shared (`False`) or
289
+ not (`True`) with the "same-size" residual block(s) in the `BottomUpLayer`'s primary flow.
290
+ """
291
+ multiscale_lowres_size_factor = 1
292
+ nonlin = get_activation(self.nonlin)
293
+
294
+ bottom_up_layers = nn.ModuleList([])
295
+ for i in range(self.n_layers):
296
+ # Whether this is the top layer
297
+ is_top = i == self.n_layers - 1
298
+
299
+ # LC is applied only to the first (_multiscale_count - 1) bottom-up layers
300
+ layer_enable_multiscale = (
301
+ self.enable_multiscale and self._multiscale_count > i + 1
302
+ )
303
+
304
+ # This factor determines the factor by which the low-resolution tensor is larger
305
+ # N.B. Only used if layer_enable_multiscale == True, so we updated it only in that case
306
+ multiscale_lowres_size_factor *= 1 + int(layer_enable_multiscale)
307
+
308
+ output_expected_shape = (
309
+ (self.img_shape[0] // 2 ** (i + 1), self.img_shape[1] // 2 ** (i + 1))
310
+ if self._multiscale_count > 1
311
+ else None
312
+ )
313
+
314
+ # Add bottom-up deterministic layer at level i.
315
+ # It's a sequence of residual blocks (BottomUpDeterministicResBlock), possibly with downsampling between them.
316
+ bottom_up_layers.append(
317
+ BottomUpLayer(
318
+ n_res_blocks=self.encoder_blocks_per_layer,
319
+ n_filters=self.encoder_n_filters,
320
+ downsampling_steps=self.downsample[i],
321
+ nonlin=nonlin,
322
+ batchnorm=self.bottomup_batchnorm,
323
+ dropout=self.encoder_dropout,
324
+ res_block_type=self.res_block_type,
325
+ res_block_kernel=self.encoder_res_block_kernel,
326
+ res_block_skip_padding=self.encoder_res_block_skip_padding,
327
+ gated=self.gated,
328
+ lowres_separate_branch=lowres_separate_branch,
329
+ enable_multiscale=self.enable_multiscale, # shouldn't the arg be `layer_enable_multiscale` here?
330
+ multiscale_retain_spatial_dims=self.multiscale_retain_spatial_dims,
331
+ multiscale_lowres_size_factor=multiscale_lowres_size_factor,
332
+ decoder_retain_spatial_dims=self.multiscale_decoder_retain_spatial_dims,
333
+ output_expected_shape=output_expected_shape,
334
+ )
335
+ )
336
+
337
+ return bottom_up_layers
338
+
339
+ def create_top_down_layers(self) -> nn.ModuleList:
340
+ """
341
+ This method creates the stack of top-down layers of the Decoder.
342
+ In these layer the `bu`_values` from the Encoder are merged with the `p_params` from the previous layer
343
+ of the Decoder to get `q_params`. Then, a stochastic layer generates a sample from the latent distribution
344
+ with parameters `q_params`. Finally, this sample is fed through a TopDownDeterministicResBlock to
345
+ compute the `p_params` for the layer below.
346
+
347
+ NOTE 1:
348
+ The algorithm for generative inference approximately works as follows:
349
+ - p_params = output of top-down layer above
350
+ - bu = inferred bottom-up value at this layer
351
+ - q_params = merge(bu, p_params)
352
+ - z = stochastic_layer(q_params)
353
+ - (optional) get and merge skip connection from prev top-down layer
354
+ - top-down deterministic ResNet
355
+
356
+ NOTE 2:
357
+ When doing unconditional generation, bu_value is not available. Hence the
358
+ merge layer is not used, and z is sampled directly from p_params.
359
+
360
+ Parameters
361
+ ----------
362
+ """
363
+ top_down_layers = nn.ModuleList([])
364
+ nonlin = get_activation(self.nonlin)
365
+ # NOTE: top-down layers are created starting from the bottom-most
366
+ for i in range(self.n_layers):
367
+ # Check if this is the top layer
368
+ is_top = i == self.n_layers - 1
369
+
370
+ if self._enable_topdown_normalize_factor:
371
+ normalize_latent_factor = (
372
+ 1 / np.sqrt(2 * (1 + i)) if len(self.z_dims) > 4 else 1.0
373
+ )
374
+ else:
375
+ normalize_latent_factor = 1.0
376
+
377
+ top_down_layers.append(
378
+ TopDownLayer(
379
+ z_dim=self.z_dims[i],
380
+ n_res_blocks=self.decoder_blocks_per_layer,
381
+ n_filters=self.decoder_n_filters,
382
+ is_top_layer=is_top,
383
+ downsampling_steps=self.downsample[i],
384
+ nonlin=nonlin,
385
+ merge_type=self.merge_type,
386
+ batchnorm=self.topdown_batchnorm,
387
+ dropout=self.decoder_dropout,
388
+ stochastic_skip=self.stochastic_skip,
389
+ learn_top_prior=self.learn_top_prior,
390
+ top_prior_param_shape=self.get_top_prior_param_shape(),
391
+ res_block_type=self.res_block_type,
392
+ res_block_kernel=self.decoder_res_block_kernel,
393
+ res_block_skip_padding=self.decoder_res_block_skip_padding,
394
+ gated=self.gated,
395
+ analytical_kl=self.analytical_kl,
396
+ restricted_kl=self._restricted_kl,
397
+ vanilla_latent_hw=self.get_latent_spatial_size(i),
398
+ # in no_padding_mode, what gets passed from the encoder are not multiples of 2 and so merging operation does not work natively.
399
+ bottomup_no_padding_mode=self.encoder_no_padding_mode,
400
+ topdown_no_padding_mode=self.decoder_no_padding_mode,
401
+ retain_spatial_dims=self.multiscale_decoder_retain_spatial_dims,
402
+ non_stochastic_version=self.non_stochastic_version,
403
+ input_image_shape=self.img_shape,
404
+ normalize_latent_factor=normalize_latent_factor,
405
+ conv2d_bias=self.topdown_conv2d_bias,
406
+ stochastic_use_naive_exponential=self._stochastic_use_naive_exponential,
407
+ )
408
+ )
409
+ return top_down_layers
410
+
411
+ def create_final_topdown_layer(self, upsample: bool) -> nn.Sequential:
412
+ """
413
+ This method creates the final top-down layer of the Decoder.
414
+
415
+ Parameters
416
+ ----------
417
+ upsample: bool
418
+ Whether to upsample the input of the final top-down layer
419
+ by bilinear interpolation with `scale_factor=2`.
420
+ """
421
+ # Final top-down layer
422
+ modules = list()
423
+
424
+ if upsample:
425
+ modules.append(Interpolate(scale=2))
426
+
427
+ for i in range(self.decoder_blocks_per_layer):
428
+ modules.append(
429
+ TopDownDeterministicResBlock(
430
+ c_in=self.decoder_n_filters,
431
+ c_out=self.decoder_n_filters,
432
+ nonlin=get_activation(self.nonlin),
433
+ batchnorm=self.topdown_batchnorm,
434
+ dropout=self.decoder_dropout,
435
+ res_block_type=self.res_block_type,
436
+ res_block_kernel=self.decoder_res_block_kernel,
437
+ skip_padding=self.decoder_res_block_skip_padding,
438
+ gated=self.gated,
439
+ conv2d_bias=self.topdown_conv2d_bias,
440
+ )
441
+ )
442
+ return nn.Sequential(*modules)
443
+
444
+ def _init_multires(
445
+ self, config=None
446
+ ) -> nn.ModuleList: # TODO config: ml_collections.ConfigDict refactor
447
+ """
448
+ This method defines the input block/branch to encode/compress low-res lateral inputs at different hierarchical levels
449
+ in the multiresolution approach (LC). The role of the input branches is similar to the one of the first bottom-up layer
450
+ in the primary flow of the Encoder, namely to compress the lateral input image to a degree that is compatible with the
451
+ one of the primary flow.
452
+
453
+ NOTE 1: Each input branch consists of a sequence of Conv2d + non-linearity + BottomUpDeterministicResBlock.
454
+ It is meaningful to observe that the `BottomUpDeterministicResBlock` shares the same model attributes with the blocks
455
+ in the primary flow of the Encoder (e.g., c_in, c_out, dropout, etc. etc.). Moreover, it does not perform downsampling.
456
+
457
+ NOTE 2: `_multiscale_count` attribute defines the total number of inputs to the bottom-up pass.
458
+ In other terms if we have the input patch and n_LC additional lateral inputs, we will have a total of (n_LC + 1) inputs.
459
+ """
460
+ stride = 1 if self.no_initial_downscaling else 2
461
+ nonlin = get_activation(self.nonlin)
462
+ if self._multiscale_count is None:
463
+ self._multiscale_count = 1
464
+
465
+ msg = "Multiscale count({}) should not exceed the number of bottom up layers ({}) by more than 1"
466
+ msg = msg.format(self._multiscale_count, self.n_layers)
467
+ assert (
468
+ self._multiscale_count <= 1 or self._multiscale_count <= 1 + self.n_layers
469
+ ), msg
470
+
471
+ msg = (
472
+ "if multiscale is enabled, then we are just working with monocrome images."
473
+ )
474
+ assert self._multiscale_count == 1 or self.color_ch == 1, msg
475
+
476
+ lowres_first_bottom_ups = []
477
+ for _ in range(1, self._multiscale_count):
478
+ first_bottom_up = nn.Sequential(
479
+ nn.Conv2d(
480
+ in_channels=self.color_ch,
481
+ out_channels=self.encoder_n_filters,
482
+ kernel_size=5,
483
+ padding=2,
484
+ stride=stride,
485
+ ),
486
+ nonlin,
487
+ BottomUpDeterministicResBlock(
488
+ c_in=self.encoder_n_filters,
489
+ c_out=self.encoder_n_filters,
490
+ nonlin=nonlin,
491
+ downsample=False,
492
+ batchnorm=self.bottomup_batchnorm,
493
+ dropout=self.encoder_dropout,
494
+ res_block_type=self.res_block_type,
495
+ skip_padding=self.encoder_res_block_skip_padding,
496
+ ),
497
+ )
498
+ lowres_first_bottom_ups.append(first_bottom_up)
499
+
500
+ self.lowres_first_bottom_ups = (
501
+ nn.ModuleList(lowres_first_bottom_ups)
502
+ if len(lowres_first_bottom_ups)
503
+ else None
504
+ )
505
+
506
+ ### SET OF FORWARD-LIKE METHODS
507
+ def bottomup_pass(self, inp: torch.Tensor) -> List[torch.Tensor]:
508
+ """
509
+ Wrapper of _bottomup_pass().
510
+ """
511
+ return self._bottomup_pass(
512
+ inp,
513
+ self.first_bottom_up,
514
+ self.lowres_first_bottom_ups,
515
+ self.bottom_up_layers,
516
+ )
517
+
518
+ def _bottomup_pass(
519
+ self,
520
+ inp: torch.Tensor,
521
+ first_bottom_up: nn.Sequential,
522
+ lowres_first_bottom_ups: nn.ModuleList,
523
+ bottom_up_layers: nn.ModuleList,
524
+ ) -> List[torch.Tensor]:
525
+ """
526
+ This method defines the forward pass through the LVAE Encoder, the so-called
527
+ Bottom-Up pass.
528
+
529
+ Parameters
530
+ ----------
531
+ inp: torch.Tensor
532
+ The input tensor to the bottom-up pass of shape (B, 1+n_LC, H, W), where n_LC
533
+ is the number of lateral low-res inputs used in the LC approach.
534
+ In particular, the first channel corresponds to the input patch, while the
535
+ remaining ones are associated to the lateral low-res inputs.
536
+ first_bottom_up: nn.Sequential
537
+ The module defining the first bottom-up layer of the Encoder.
538
+ lowres_first_bottom_ups: nn.ModuleList
539
+ The list of modules defining Lateral Contextualization.
540
+ bottom_up_layers: nn.ModuleList
541
+ The list of modules defining the stack of bottom-up layers of the Encoder.
542
+ """
543
+ if self._multiscale_count > 1:
544
+ x = first_bottom_up(inp[:, :1])
545
+ else:
546
+ x = first_bottom_up(inp)
547
+
548
+ # Loop from bottom to top layer, store all deterministic nodes we
549
+ # need for the top-down pass in bu_values list
550
+ bu_values = []
551
+ for i in range(self.n_layers):
552
+ lowres_x = None
553
+ if self._multiscale_count > 1 and i + 1 < inp.shape[1]:
554
+ lowres_x = lowres_first_bottom_ups[i](inp[:, i + 1 : i + 2])
555
+
556
+ x, bu_value = bottom_up_layers[i](x, lowres_x=lowres_x)
557
+ bu_values.append(bu_value)
558
+
559
+ return bu_values
560
+
561
+ def topdown_pass(
562
+ self,
563
+ bu_values: torch.Tensor = None,
564
+ n_img_prior: torch.Tensor = None,
565
+ mode_layers: Iterable[int] = None,
566
+ constant_layers: Iterable[int] = None,
567
+ forced_latent: List[torch.Tensor] = None,
568
+ top_down_layers: nn.ModuleList = None,
569
+ final_top_down_layer: nn.Sequential = None,
570
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
571
+ """
572
+ This method defines the forward pass through the LVAE Decoder, the so-called
573
+ Top-Down pass.
574
+
575
+ Parameters
576
+ ----------
577
+ bu_values: torch.Tensor, optional
578
+ Output of the bottom-up pass. It will have values from multiple layers of the ladder.
579
+ n_img_prior: optional
580
+ When `bu_values` is `None`, `n_img_prior` indicates the number of images to generate
581
+ from the prior (so bottom-up pass is not used at all here).
582
+ mode_layers: Iterable[int], optional
583
+ A sequence of indexes associated to the layers in which sampling is disabled and
584
+ the mode (mean value) is used instead. Set to `None` to avoid this behaviour.
585
+ constant_layers: Iterable[int], optional
586
+ A sequence of indexes associated to the layers in which a single instance's z is
587
+ copied over the entire batch (bottom-up path is not used, so only prior is used here).
588
+ Set to `None` to avoid this behaviour.
589
+ forced_latent: List[torch.Tensor], optional
590
+ A list of tensors that are used as fixed latent variables (hence, sampling doesn't take
591
+ place in this case).
592
+ top_down_layers: nn.ModuleList, optional
593
+ A list of top-down layers to use in the top-down pass. If `None`, the method uses the
594
+ default layers defined in the constructor.
595
+ final_top_down_layer: nn.Sequential, optional
596
+ The last top-down layer of the top-down pass. If `None`, the method uses the default
597
+ layers defined in the constructor.
598
+ """
599
+ if top_down_layers is None:
600
+ top_down_layers = self.top_down_layers
601
+ if final_top_down_layer is None:
602
+ final_top_down_layer = self.final_top_down
603
+
604
+ # Default: no layer is sampled from the distribution's mode
605
+ if mode_layers is None:
606
+ mode_layers = []
607
+ if constant_layers is None:
608
+ constant_layers = []
609
+ prior_experiment = len(mode_layers) > 0 or len(constant_layers) > 0
610
+
611
+ # If the bottom-up inference values are not given, don't do
612
+ # inference, sample from prior instead
613
+ inference_mode = bu_values is not None
614
+
615
+ # Check consistency of arguments
616
+ if inference_mode != (n_img_prior is None):
617
+ msg = (
618
+ "Number of images for top-down generation has to be given "
619
+ "if and only if we're not doing inference"
620
+ )
621
+ raise RuntimeError(msg)
622
+ if (
623
+ inference_mode
624
+ and prior_experiment
625
+ and (self.non_stochastic_version is False)
626
+ ):
627
+ msg = (
628
+ "Prior experiments (e.g. sampling from mode) are not"
629
+ " compatible with inference mode"
630
+ )
631
+ raise RuntimeError(msg)
632
+
633
+ # Sampled latent variables at each layer
634
+ z = [None] * self.n_layers
635
+
636
+ # KL divergence of each layer
637
+ kl = [None] * self.n_layers
638
+ # Kl divergence restricted, only for the LC enabled setup denoiSplit.
639
+ kl_restricted = [None] * self.n_layers
640
+
641
+ # mean from which z is sampled.
642
+ q_mu = [None] * self.n_layers
643
+ # log(var) from which z is sampled.
644
+ q_lv = [None] * self.n_layers
645
+
646
+ # Spatial map of KL divergence for each layer
647
+ kl_spatial = [None] * self.n_layers
648
+
649
+ debug_qvar_max = [None] * self.n_layers
650
+
651
+ kl_channelwise = [None] * self.n_layers
652
+
653
+ if forced_latent is None:
654
+ forced_latent = [None] * self.n_layers
655
+
656
+ # log p(z) where z is the sample in the topdown pass
657
+ # logprob_p = 0.
658
+
659
+ # Top-down inference/generation loop
660
+ out = out_pre_residual = None
661
+ for i in reversed(range(self.n_layers)):
662
+
663
+ # If available, get deterministic node from bottom-up inference
664
+ try:
665
+ bu_value = bu_values[i]
666
+ except TypeError:
667
+ bu_value = None
668
+
669
+ # Whether the current layer should be sampled from the mode
670
+ use_mode = i in mode_layers
671
+ constant_out = i in constant_layers
672
+
673
+ # Input for skip connection
674
+ skip_input = out # TODO or n? or both?
675
+
676
+ # Full top-down layer, including sampling and deterministic part
677
+ out, out_pre_residual, aux = top_down_layers[i](
678
+ input_=out,
679
+ skip_connection_input=skip_input,
680
+ inference_mode=inference_mode,
681
+ bu_value=bu_value,
682
+ n_img_prior=n_img_prior,
683
+ use_mode=use_mode,
684
+ force_constant_output=constant_out,
685
+ forced_latent=forced_latent[i],
686
+ mode_pred=self.mode_pred,
687
+ var_clip_max=self._var_clip_max,
688
+ )
689
+
690
+ # Save useful variables
691
+ z[i] = aux["z"] # sampled variable at this layer (batch, ch, h, w)
692
+ kl[i] = aux["kl_samplewise"] # (batch, )
693
+ kl_restricted[i] = aux["kl_samplewise_restricted"]
694
+ kl_spatial[i] = aux["kl_spatial"] # (batch, h, w)
695
+ q_mu[i] = aux["q_mu"]
696
+ q_lv[i] = aux["q_lv"]
697
+
698
+ kl_channelwise[i] = aux["kl_channelwise"]
699
+ debug_qvar_max[i] = aux["qvar_max"]
700
+ # if self.mode_pred is False:
701
+ # logprob_p += aux['logprob_p'].mean() # mean over batch
702
+ # else:
703
+ # logprob_p = None
704
+
705
+ # Final top-down layer
706
+ out = final_top_down_layer(out)
707
+
708
+ # Store useful variables in a dict to return them
709
+ data = {
710
+ "z": z, # list of tensors with shape (batch, ch[i], h[i], w[i])
711
+ "kl": kl, # list of tensors with shape (batch, )
712
+ "kl_restricted": kl_restricted, # list of tensors with shape (batch, )
713
+ "kl_spatial": kl_spatial, # list of tensors w shape (batch, h[i], w[i])
714
+ "kl_channelwise": kl_channelwise, # list of tensors with shape (batch, ch[i])
715
+ # 'logprob_p': logprob_p, # scalar, mean over batch
716
+ "q_mu": q_mu,
717
+ "q_lv": q_lv,
718
+ "debug_qvar_max": debug_qvar_max,
719
+ }
720
+ return out, data
721
+
722
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
723
+ """
724
+ Parameters
725
+ ----------
726
+ x: torch.Tensor
727
+ The input tensor of shape (B, C, H, W).
728
+ """
729
+ img_size = x.size()[2:]
730
+
731
+ # Pad input to size equal to the closest power of 2
732
+ x_pad = self.pad_input(x)
733
+
734
+ # Bottom-up inference: return list of length n_layers (bottom to top)
735
+ bu_values = self.bottomup_pass(x_pad)
736
+ for i in range(0, self.skip_bottomk_buvalues):
737
+ bu_values[i] = None
738
+
739
+ mode_layers = range(self.n_layers) if self.non_stochastic_version else None
740
+
741
+ # Top-down inference/generation
742
+ out, td_data = self.topdown_pass(bu_values, mode_layers=mode_layers)
743
+
744
+ if out.shape[-1] > img_size[-1]:
745
+ # Restore original image size
746
+ out = crop_img_tensor(out, img_size)
747
+
748
+ out = self.output_layer(out)
749
+ if self._tethered_to_input:
750
+ assert out.shape[1] == 1
751
+ ch2 = self.get_other_channel(out, x_pad)
752
+ out = torch.cat([out, ch2], dim=1)
753
+
754
+ return out, td_data
755
+
756
+ ### SET OF UTILS METHODS
757
+ # def sample_prior(
758
+ # self,
759
+ # n_imgs,
760
+ # mode_layers=None,
761
+ # constant_layers=None
762
+ # ):
763
+
764
+ # # Generate from prior
765
+ # out, _ = self.topdown_pass(n_img_prior=n_imgs, mode_layers=mode_layers, constant_layers=constant_layers)
766
+ # out = crop_img_tensor(out, self.img_shape)
767
+
768
+ # # Log likelihood and other info (per data point)
769
+ # _, likelihood_data = self.likelihood(out, None)
770
+
771
+ # return likelihood_data['sample']
772
+
773
+ # ### ???
774
+ # def sample_from_q(self, x, masks=None):
775
+ # """
776
+ # This method performs the bottomup_pass() and samples from the
777
+ # obtained distribution.
778
+ # """
779
+ # img_size = x.size()[2:]
780
+
781
+ # # Pad input to make everything easier with conv strides
782
+ # x_pad = self.pad_input(x)
783
+
784
+ # # Bottom-up inference: return list of length n_layers (bottom to top)
785
+ # bu_values = self.bottomup_pass(x_pad)
786
+ # return self._sample_from_q(bu_values, masks=masks)
787
+ # ### ???
788
+
789
+ # def _sample_from_q(self, bu_values, top_down_layers=None, final_top_down_layer=None, masks=None):
790
+ # if top_down_layers is None:
791
+ # top_down_layers = self.top_down_layers
792
+ # if final_top_down_layer is None:
793
+ # final_top_down_layer = self.final_top_down
794
+ # if masks is None:
795
+ # masks = [None] * len(bu_values)
796
+
797
+ # msg = "Multiscale is not supported as of now. You need the output from the previous layers to do this."
798
+ # assert self.n_layers == 1, msg
799
+ # samples = []
800
+ # for i in reversed(range(self.n_layers)):
801
+ # bu_value = bu_values[i]
802
+
803
+ # # Note that the first argument can be set to None since we are just dealing with one level
804
+ # sample = top_down_layers[i].sample_from_q(None, bu_value, var_clip_max=self._var_clip_max, mask=masks[i])
805
+ # samples.append(sample)
806
+
807
+ # return samples
808
+
809
+ # def reset_for_different_output_size(self, output_size):
810
+ # for i in range(self.n_layers):
811
+ # sz = output_size // 2**(1 + i)
812
+ # self.bottom_up_layers[i].output_expected_shape = (sz, sz)
813
+ # self.top_down_layers[i].latent_shape = (output_size, output_size)
814
+
815
+ def pad_input(self, x):
816
+ """
817
+ Pads input x so that its sizes are powers of 2
818
+ :param x:
819
+ :return: Padded tensor
820
+ """
821
+ size = self.get_padded_size(x.size())
822
+ x = pad_img_tensor(x, size)
823
+ return x
824
+
825
+ ### SET OF GETTERS
826
+ def get_padded_size(self, size):
827
+ """
828
+ Returns the smallest size (H, W) of the image with actual size given
829
+ as input, such that H and W are powers of 2.
830
+ :param size: input size, tuple either (N, C, H, w) or (H, W)
831
+ :return: 2-tuple (H, W)
832
+ """
833
+ # Make size argument into (heigth, width)
834
+ if len(size) == 4:
835
+ size = size[2:]
836
+ if len(size) != 2:
837
+ msg = (
838
+ "input size must be either (N, C, H, W) or (H, W), but it "
839
+ f"has length {len(size)} (size={size})"
840
+ )
841
+ raise RuntimeError(msg)
842
+
843
+ if self.multiscale_decoder_retain_spatial_dims is True:
844
+ # In this case, we can go much more deeper and so this is not required
845
+ # (in the way it is. ;). More work would be needed if this was to be correctly implemented )
846
+ return list(size)
847
+
848
+ # Overall downscale factor from input to top layer (power of 2)
849
+ dwnsc = self.overall_downscale_factor
850
+
851
+ # Output smallest powers of 2 that are larger than current sizes
852
+ padded_size = list(((s - 1) // dwnsc + 1) * dwnsc for s in size)
853
+
854
+ return padded_size
855
+
856
+ def get_latent_spatial_size(self, level_idx: int):
857
+ """
858
+ level_idx: 0 is the bottommost layer, the highest resolution one.
859
+ """
860
+ actual_downsampling = level_idx + 1
861
+ dwnsc = 2**actual_downsampling
862
+ sz = self.get_padded_size(self.img_shape)
863
+ h = sz[0] // dwnsc
864
+ w = sz[1] // dwnsc
865
+ assert h == w
866
+ return h
867
+
868
+ def get_top_prior_param_shape(self, n_imgs: int = 1):
869
+ # TODO num channels depends on random variable we're using
870
+
871
+ # Compute the total downscaling performed in the Encoder
872
+ if self.multiscale_decoder_retain_spatial_dims is False:
873
+ dwnsc = self.overall_downscale_factor
874
+ else:
875
+ # LC allow the encoder latents to keep the same (H, W) size at different levels
876
+ actual_downsampling = self.n_layers + 1 - self._multiscale_count
877
+ dwnsc = 2**actual_downsampling
878
+
879
+ sz = self.get_padded_size(self.img_shape)
880
+ h = sz[0] // dwnsc
881
+ w = sz[1] // dwnsc
882
+ c = self.z_dims[-1] * 2 # mu and logvar
883
+ top_layer_shape = (n_imgs, c, h, w)
884
+ return top_layer_shape
885
+
886
+ def get_other_channel(self, ch1, input):
887
+ assert self.data_std["target"].squeeze().shape == (2,)
888
+ assert self.data_mean["target"].squeeze().shape == (2,)
889
+ assert self.target_ch == 2
890
+ ch1_un = (
891
+ ch1[:, :1] * self.data_std["target"][:, :1]
892
+ + self.data_mean["target"][:, :1]
893
+ )
894
+ input_un = input * self.data_std["input"] + self.data_mean["input"]
895
+ ch2_un = self._tethered_ch2_scalar * (
896
+ input_un - ch1_un * self._tethered_ch1_scalar
897
+ )
898
+ ch2 = (ch2_un - self.data_mean["target"][:, -1:]) / self.data_std["target"][
899
+ :, -1:
900
+ ]
901
+ return ch2