careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (69) hide show
  1. careamics/careamist.py +163 -266
  2. careamics/config/algorithm_model.py +0 -15
  3. careamics/config/architectures/custom_model.py +3 -3
  4. careamics/config/configuration_example.py +0 -3
  5. careamics/config/configuration_factory.py +23 -25
  6. careamics/config/configuration_model.py +11 -11
  7. careamics/config/data_model.py +80 -50
  8. careamics/config/inference_model.py +29 -17
  9. careamics/config/optimizer_models.py +7 -7
  10. careamics/config/support/supported_transforms.py +0 -1
  11. careamics/config/tile_information.py +26 -58
  12. careamics/config/transformations/normalize_model.py +32 -4
  13. careamics/config/validators/validator_utils.py +1 -1
  14. careamics/dataset/__init__.py +12 -1
  15. careamics/dataset/dataset_utils/__init__.py +8 -1
  16. careamics/dataset/dataset_utils/file_utils.py +1 -1
  17. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  18. careamics/dataset/dataset_utils/read_tiff.py +0 -9
  19. careamics/dataset/dataset_utils/running_stats.py +186 -0
  20. careamics/dataset/in_memory_dataset.py +66 -171
  21. careamics/dataset/in_memory_pred_dataset.py +88 -0
  22. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  23. careamics/dataset/iterable_dataset.py +92 -249
  24. careamics/dataset/iterable_pred_dataset.py +121 -0
  25. careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
  26. careamics/dataset/patching/patching.py +54 -25
  27. careamics/dataset/patching/random_patching.py +9 -4
  28. careamics/dataset/patching/validate_patch_dimension.py +5 -3
  29. careamics/dataset/tiling/__init__.py +10 -0
  30. careamics/dataset/tiling/collate_tiles.py +33 -0
  31. careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
  32. careamics/lightning_datamodule.py +1 -6
  33. careamics/lightning_module.py +11 -7
  34. careamics/lightning_prediction_datamodule.py +52 -72
  35. careamics/lvae_training/__init__.py +0 -0
  36. careamics/lvae_training/data_modules.py +1220 -0
  37. careamics/lvae_training/data_utils.py +618 -0
  38. careamics/lvae_training/eval_utils.py +905 -0
  39. careamics/lvae_training/get_config.py +84 -0
  40. careamics/lvae_training/lightning_module.py +701 -0
  41. careamics/lvae_training/metrics.py +214 -0
  42. careamics/lvae_training/train_lvae.py +339 -0
  43. careamics/lvae_training/train_utils.py +121 -0
  44. careamics/model_io/bioimage/model_description.py +40 -32
  45. careamics/model_io/bmz_io.py +1 -1
  46. careamics/model_io/model_io_utils.py +5 -2
  47. careamics/models/lvae/__init__.py +0 -0
  48. careamics/models/lvae/layers.py +1998 -0
  49. careamics/models/lvae/likelihoods.py +312 -0
  50. careamics/models/lvae/lvae.py +985 -0
  51. careamics/models/lvae/noise_models.py +409 -0
  52. careamics/models/lvae/utils.py +395 -0
  53. careamics/prediction_utils/__init__.py +12 -0
  54. careamics/prediction_utils/create_pred_datamodule.py +185 -0
  55. careamics/prediction_utils/prediction_outputs.py +165 -0
  56. careamics/prediction_utils/stitch_prediction.py +100 -0
  57. careamics/transforms/n2v_manipulate.py +3 -1
  58. careamics/transforms/normalize.py +139 -68
  59. careamics/transforms/pixel_manipulation.py +33 -9
  60. careamics/transforms/tta.py +43 -29
  61. careamics/utils/ram.py +2 -2
  62. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +7 -6
  63. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/RECORD +65 -42
  64. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
  65. careamics/lightning_prediction_loop.py +0 -118
  66. careamics/prediction/__init__.py +0 -7
  67. careamics/prediction/stitch_prediction.py +0 -70
  68. careamics/utils/running_stats.py +0 -43
  69. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,985 @@
1
+ """
2
+ Ladder VAE (LVAE) Model
3
+
4
+ The current implementation is based on "Interpretable Unsupervised Diversity Denoising and Artefact Removal, Prakash et al."
5
+ """
6
+
7
+ from typing import Dict, Iterable, List, Tuple, Union
8
+
9
+ import ml_collections
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+ from .layers import (
15
+ BottomUpDeterministicResBlock,
16
+ BottomUpLayer,
17
+ TopDownDeterministicResBlock,
18
+ TopDownLayer,
19
+ )
20
+ from .likelihoods import GaussianLikelihood, NoiseModelLikelihood
21
+ from .noise_models import get_noise_model
22
+ from .utils import Interpolate, LossType, ModelType, crop_img_tensor, pad_img_tensor
23
+
24
+
25
+ class LadderVAE(nn.Module):
26
+
27
+ def __init__(
28
+ self,
29
+ data_mean: Union[np.ndarray, Dict[str, torch.Tensor]],
30
+ data_std: Union[np.ndarray, Dict[str, torch.Tensor]],
31
+ config: ml_collections.ConfigDict,
32
+ use_uncond_mode_at: Iterable[int] = [],
33
+ target_ch: int = 2,
34
+ ):
35
+ """
36
+ Constructor.
37
+
38
+ Parameters
39
+ ----------
40
+ data_mean: Union[np.ndarray, Dict[str, torch.Tensor]]
41
+ The mean of the data used for normalization.
42
+ data_std: Union[np.ndarray, Dict[str, torch.Tensor]]
43
+ The standard deviation of the data used for normalization.
44
+ config: ml_collections.ConfigDict
45
+ The configuration object of the model.
46
+ use_uncond_mode_at: Iterable[int], optional
47
+ A sequence of indexes associated to the layers in which sampling is disabled
48
+ and the mode (mean value) is used instead. Default is `[]`.
49
+ target_ch: int, optional
50
+ The number of target channels (e.g., 1 for super-resolution or 2 for splitting).
51
+ Default is `2`.
52
+ """
53
+ super().__init__()
54
+
55
+ # -------------------------------------------------------
56
+ # Customizable attributes
57
+ self.image_size = config.data.image_size
58
+ self._multiscale_count = config.data.multiscale_lowres_count
59
+ self.z_dims = config.model.z_dims
60
+ self.encoder_n_filters = config.model.n_filters
61
+ self.decoder_n_filters = config.model.n_filters
62
+ self.encoder_dropout = config.model.dropout
63
+ self.decoder_dropout = config.model.dropout
64
+ self.nonlin = config.model.nonlin
65
+ self.predict_logvar = config.model.predict_logvar
66
+ self.enable_noise_model = config.model.enable_noise_model
67
+ self.noise_model_ch1_fpath = config.model.noise_model_ch1_fpath
68
+ self.noise_model_ch2_fpath = config.model.noise_model_ch2_fpath
69
+ self.analytical_kl = config.model.analytical_kl
70
+ # -------------------------------------------------------
71
+
72
+ # -------------------------------------------------------
73
+ # Model attributes -> Hardcoded
74
+ self.model_type = ModelType.LadderVae
75
+ self.encoder_blocks_per_layer = 1
76
+ self.decoder_blocks_per_layer = 1
77
+ self.bottomup_batchnorm = True
78
+ self.topdown_batchnorm = True
79
+ self.topdown_conv2d_bias = True
80
+ self.gated = True
81
+ self.encoder_res_block_kernel = 3
82
+ self.decoder_res_block_kernel = 3
83
+ self.encoder_res_block_skip_padding = False
84
+ self.decoder_res_block_skip_padding = False
85
+ self.merge_type = "residual"
86
+ self.no_initial_downscaling = True
87
+ self.skip_bottomk_buvalues = 0
88
+ self.non_stochastic_version = False
89
+ self.stochastic_skip = True
90
+ self.learn_top_prior = True
91
+ self.res_block_type = "bacdbacd"
92
+ self.mode_pred = False
93
+ self.logvar_lowerbound = -5
94
+ self._var_clip_max = 20
95
+ self._stochastic_use_naive_exponential = False
96
+ self._enable_topdown_normalize_factor = True
97
+
98
+ # Noise model attributes -> Hardcoded
99
+ self.noise_model_type = "gmm"
100
+ self.denoise_channel = (
101
+ "input" # 4 values for denoise_channel {'Ch1', 'Ch2', 'input','all'}
102
+ )
103
+ self.noise_model_learnable = False
104
+
105
+ # Attributes that handle LC -> Hardcoded
106
+ self.enable_multiscale = (
107
+ self._multiscale_count is not None and self._multiscale_count > 1
108
+ )
109
+ self.multiscale_retain_spatial_dims = True
110
+ self.multiscale_lowres_separate_branch = False
111
+ self.multiscale_decoder_retain_spatial_dims = (
112
+ self.multiscale_retain_spatial_dims and self.enable_multiscale
113
+ )
114
+
115
+ # Derived attributes
116
+ self.n_layers = len(self.z_dims)
117
+ self.encoder_no_padding_mode = (
118
+ self.encoder_res_block_skip_padding is True
119
+ and self.encoder_res_block_kernel > 1
120
+ )
121
+ self.decoder_no_padding_mode = (
122
+ self.decoder_res_block_skip_padding is True
123
+ and self.decoder_res_block_kernel > 1
124
+ )
125
+
126
+ # Others...
127
+ self._tethered_to_input = False
128
+ self._tethered_ch1_scalar = self._tethered_ch2_scalar = None
129
+ if self._tethered_to_input:
130
+ target_ch = 1
131
+ requires_grad = False
132
+ self._tethered_ch1_scalar = nn.Parameter(
133
+ torch.ones(1) * 0.5, requires_grad=requires_grad
134
+ )
135
+ self._tethered_ch2_scalar = nn.Parameter(
136
+ torch.ones(1) * 2.0, requires_grad=requires_grad
137
+ )
138
+ # -------------------------------------------------------
139
+
140
+ # -------------------------------------------------------
141
+ # Data attributes
142
+ self.color_ch = 1
143
+ self.img_shape = (self.image_size, self.image_size)
144
+ self.normalized_input = True
145
+ # -------------------------------------------------------
146
+
147
+ # -------------------------------------------------------
148
+ # Loss attributes
149
+ self._restricted_kl = False # HC
150
+ # enabling reconstruction loss on mixed input
151
+ self.mixed_rec_w = 0
152
+ self.nbr_consistency_w = 0
153
+
154
+ # Setting the loss_type
155
+ self.loss_type = config.loss.get("loss_type", LossType.DenoiSplitMuSplit)
156
+ # -------------------------------------------------------
157
+
158
+ # -------------------------------------------------------
159
+ # # Training attributes
160
+ # # can be used to tile the validation predictions
161
+ # self._val_idx_manager = val_idx_manager
162
+ # self._val_frame_creator = None
163
+ # # initialize the learning rate scheduler params.
164
+ # self.lr_scheduler_monitor = self.lr_scheduler_mode = None
165
+ # self._init_lr_scheduler_params(config)
166
+ # self._global_step = 0
167
+ # -------------------------------------------------------
168
+
169
+ # -------------------------------------------------------
170
+ # Attributes from constructor arguments
171
+ self.target_ch = target_ch
172
+ self.use_uncond_mode_at = use_uncond_mode_at
173
+
174
+ # Data mean and std used for normalization
175
+ if isinstance(data_mean, np.ndarray):
176
+ self.data_mean = torch.Tensor(data_mean)
177
+ self.data_std = torch.Tensor(data_std)
178
+ elif isinstance(data_mean, dict):
179
+ for k in data_mean.keys():
180
+ data_mean[k] = (
181
+ torch.Tensor(data_mean[k])
182
+ if not isinstance(data_mean[k], dict)
183
+ else data_mean[k]
184
+ )
185
+ data_std[k] = (
186
+ torch.Tensor(data_std[k])
187
+ if not isinstance(data_std[k], dict)
188
+ else data_std[k]
189
+ )
190
+ self.data_mean = data_mean
191
+ self.data_std = data_std
192
+ else:
193
+ raise NotImplementedError(
194
+ "data_mean and data_std must be either a numpy array or a dictionary"
195
+ )
196
+
197
+ assert self.data_std is not None
198
+ assert self.data_mean is not None
199
+
200
+ # Initialize the Noise Model
201
+ self.likelihood_gm = self.likelihood_NM = None
202
+ self.noiseModel = get_noise_model(
203
+ enable_noise_model=self.enable_noise_model,
204
+ model_type=self.model_type,
205
+ noise_model_type=self.noise_model_type,
206
+ noise_model_ch1_fpath=self.noise_model_ch1_fpath,
207
+ noise_model_ch2_fpath=self.noise_model_ch2_fpath,
208
+ noise_model_learnable=self.noise_model_learnable,
209
+ )
210
+
211
+ if self.noiseModel is None:
212
+ self.likelihood_form = "gaussian"
213
+ else:
214
+ self.likelihood_form = "noise_model"
215
+
216
+ # Calculate the downsampling happening in the network
217
+ self.downsample = [1] * self.n_layers
218
+ self.overall_downscale_factor = np.power(2, sum(self.downsample))
219
+ if not self.no_initial_downscaling: # by default do another downscaling
220
+ self.overall_downscale_factor *= 2
221
+
222
+ assert max(self.downsample) <= self.encoder_blocks_per_layer
223
+ assert len(self.downsample) == self.n_layers
224
+ # -------------------------------------------------------
225
+
226
+ # -------------------------------------------------------
227
+ ### CREATE MODEL BLOCKS
228
+ # First bottom-up layer: change num channels + downsample by factor 2
229
+ # unless we want to prevent this
230
+ stride = 1 if self.no_initial_downscaling else 2
231
+ self.first_bottom_up = self.create_first_bottom_up(stride)
232
+
233
+ # Input Branches for Lateral Contextualization
234
+ self.lowres_first_bottom_ups = None
235
+ self._init_multires()
236
+
237
+ # Other bottom-up layers
238
+ self.bottom_up_layers = self.create_bottom_up_layers(
239
+ self.multiscale_lowres_separate_branch
240
+ )
241
+
242
+ # Top-down layers
243
+ self.top_down_layers = self.create_top_down_layers()
244
+ self.final_top_down = self.create_final_topdown_layer(
245
+ not self.no_initial_downscaling
246
+ )
247
+
248
+ # Likelihood module
249
+ self.likelihood = self.create_likelihood_module()
250
+
251
+ # Output layer --> Project to target_ch many channels
252
+ logvar_ch_needed = self.predict_logvar is not None
253
+ self.output_layer = self.parameter_net = nn.Conv2d(
254
+ self.decoder_n_filters,
255
+ self.target_ch * (1 + logvar_ch_needed),
256
+ kernel_size=3,
257
+ padding=1,
258
+ bias=self.topdown_conv2d_bias,
259
+ )
260
+
261
+ # # gradient norms. updated while training. this is also logged.
262
+ # self.grad_norm_bottom_up = 0.0
263
+ # self.grad_norm_top_down = 0.0
264
+ # PSNR computation on validation.
265
+ # self.label1_psnr = RunningPSNR()
266
+ # self.label2_psnr = RunningPSNR()
267
+
268
+ # msg =f'[{self.__class__.__name__}] Stoc:{not self.non_stochastic_version} RecMode:{self.reconstruction_mode} TethInput:{self._tethered_to_input}'
269
+ # msg += f' TargetCh: {self.target_ch}'
270
+ # print(msg)
271
+
272
+ ### SET OF METHODS TO CREATE MODEL BLOCKS
273
+ def create_first_bottom_up(
274
+ self,
275
+ init_stride: int,
276
+ num_res_blocks: int = 1,
277
+ ) -> nn.Sequential:
278
+ """
279
+ This method creates the first bottom-up block of the Encoder.
280
+ Its role is to perform a first image compression step.
281
+ It is composed by a sequence of nn.Conv2d + non-linearity +
282
+ BottomUpDeterministicResBlock (1 or more, default is 1).
283
+
284
+ Parameters
285
+ ----------
286
+ init_stride: int
287
+ The stride used by the intial Conv2d block.
288
+ num_res_blocks: int, optional
289
+ The number of BottomUpDeterministicResBlocks to include in the layer, default is 1.
290
+ """
291
+ nonlin = self.get_nonlin()
292
+ modules = [
293
+ nn.Conv2d(
294
+ in_channels=self.color_ch,
295
+ out_channels=self.encoder_n_filters,
296
+ kernel_size=self.encoder_res_block_kernel,
297
+ padding=(
298
+ 0
299
+ if self.encoder_res_block_skip_padding
300
+ else self.encoder_res_block_kernel // 2
301
+ ),
302
+ stride=init_stride,
303
+ ),
304
+ nonlin(),
305
+ ]
306
+
307
+ for _ in range(num_res_blocks):
308
+ modules.append(
309
+ BottomUpDeterministicResBlock(
310
+ c_in=self.encoder_n_filters,
311
+ c_out=self.encoder_n_filters,
312
+ nonlin=nonlin,
313
+ downsample=False,
314
+ batchnorm=self.bottomup_batchnorm,
315
+ dropout=self.encoder_dropout,
316
+ res_block_type=self.res_block_type,
317
+ skip_padding=self.encoder_res_block_skip_padding,
318
+ res_block_kernel=self.encoder_res_block_kernel,
319
+ )
320
+ )
321
+
322
+ return nn.Sequential(*modules)
323
+
324
+ def create_bottom_up_layers(self, lowres_separate_branch: bool) -> nn.ModuleList:
325
+ """
326
+ This method creates the stack of bottom-up layers of the Encoder
327
+ that are used to generate the so-called `bu_values`.
328
+
329
+ NOTE:
330
+ If `self._multiscale_count < self.n_layers`, then LC is done only in the first
331
+ `self._multiscale_count` bottom-up layers (starting from the bottom).
332
+
333
+ Parameters
334
+ ----------
335
+ lowres_separate_branch: bool
336
+ Whether the residual block(s) used for encoding the low-res input are shared (`False`) or
337
+ not (`True`) with the "same-size" residual block(s) in the `BottomUpLayer`'s primary flow.
338
+ """
339
+ multiscale_lowres_size_factor = 1
340
+ nonlin = self.get_nonlin()
341
+
342
+ bottom_up_layers = nn.ModuleList([])
343
+ for i in range(self.n_layers):
344
+ # Whether this is the top layer
345
+ is_top = i == self.n_layers - 1
346
+
347
+ # LC is applied only to the first (_multiscale_count - 1) bottom-up layers
348
+ layer_enable_multiscale = (
349
+ self.enable_multiscale and self._multiscale_count > i + 1
350
+ )
351
+
352
+ # This factor determines the factor by which the low-resolution tensor is larger
353
+ # N.B. Only used if layer_enable_multiscale == True, so we updated it only in that case
354
+ multiscale_lowres_size_factor *= 1 + int(layer_enable_multiscale)
355
+
356
+ output_expected_shape = (
357
+ (self.img_shape[0] // 2 ** (i + 1), self.img_shape[1] // 2 ** (i + 1))
358
+ if self._multiscale_count > 1
359
+ else None
360
+ )
361
+
362
+ # Add bottom-up deterministic layer at level i.
363
+ # It's a sequence of residual blocks (BottomUpDeterministicResBlock), possibly with downsampling between them.
364
+ bottom_up_layers.append(
365
+ BottomUpLayer(
366
+ n_res_blocks=self.encoder_blocks_per_layer,
367
+ n_filters=self.encoder_n_filters,
368
+ downsampling_steps=self.downsample[i],
369
+ nonlin=nonlin,
370
+ batchnorm=self.bottomup_batchnorm,
371
+ dropout=self.encoder_dropout,
372
+ res_block_type=self.res_block_type,
373
+ res_block_kernel=self.encoder_res_block_kernel,
374
+ res_block_skip_padding=self.encoder_res_block_skip_padding,
375
+ gated=self.gated,
376
+ lowres_separate_branch=lowres_separate_branch,
377
+ enable_multiscale=self.enable_multiscale, # shouldn't the arg be `layer_enable_multiscale` here?
378
+ multiscale_retain_spatial_dims=self.multiscale_retain_spatial_dims,
379
+ multiscale_lowres_size_factor=multiscale_lowres_size_factor,
380
+ decoder_retain_spatial_dims=self.multiscale_decoder_retain_spatial_dims,
381
+ output_expected_shape=output_expected_shape,
382
+ )
383
+ )
384
+
385
+ return bottom_up_layers
386
+
387
+ def create_top_down_layers(self) -> nn.ModuleList:
388
+ """
389
+ This method creates the stack of top-down layers of the Decoder.
390
+ In these layer the `bu`_values` from the Encoder are merged with the `p_params` from the previous layer
391
+ of the Decoder to get `q_params`. Then, a stochastic layer generates a sample from the latent distribution
392
+ with parameters `q_params`. Finally, this sample is fed through a TopDownDeterministicResBlock to
393
+ compute the `p_params` for the layer below.
394
+
395
+ NOTE 1:
396
+ The algorithm for generative inference approximately works as follows:
397
+ - p_params = output of top-down layer above
398
+ - bu = inferred bottom-up value at this layer
399
+ - q_params = merge(bu, p_params)
400
+ - z = stochastic_layer(q_params)
401
+ - (optional) get and merge skip connection from prev top-down layer
402
+ - top-down deterministic ResNet
403
+
404
+ NOTE 2:
405
+ When doing unconditional generation, bu_value is not available. Hence the
406
+ merge layer is not used, and z is sampled directly from p_params.
407
+
408
+ Parameters
409
+ ----------
410
+ """
411
+ top_down_layers = nn.ModuleList([])
412
+ nonlin = self.get_nonlin()
413
+ # NOTE: top-down layers are created starting from the bottom-most
414
+ for i in range(self.n_layers):
415
+ # Check if this is the top layer
416
+ is_top = i == self.n_layers - 1
417
+
418
+ if self._enable_topdown_normalize_factor:
419
+ normalize_latent_factor = (
420
+ 1 / np.sqrt(2 * (1 + i)) if len(self.z_dims) > 4 else 1.0
421
+ )
422
+ else:
423
+ normalize_latent_factor = 1.0
424
+
425
+ top_down_layers.append(
426
+ TopDownLayer(
427
+ z_dim=self.z_dims[i],
428
+ n_res_blocks=self.decoder_blocks_per_layer,
429
+ n_filters=self.decoder_n_filters,
430
+ is_top_layer=is_top,
431
+ downsampling_steps=self.downsample[i],
432
+ nonlin=nonlin,
433
+ merge_type=self.merge_type,
434
+ batchnorm=self.topdown_batchnorm,
435
+ dropout=self.decoder_dropout,
436
+ stochastic_skip=self.stochastic_skip,
437
+ learn_top_prior=self.learn_top_prior,
438
+ top_prior_param_shape=self.get_top_prior_param_shape(),
439
+ res_block_type=self.res_block_type,
440
+ res_block_kernel=self.decoder_res_block_kernel,
441
+ res_block_skip_padding=self.decoder_res_block_skip_padding,
442
+ gated=self.gated,
443
+ analytical_kl=self.analytical_kl,
444
+ restricted_kl=self._restricted_kl,
445
+ vanilla_latent_hw=self.get_latent_spatial_size(i),
446
+ # in no_padding_mode, what gets passed from the encoder are not multiples of 2 and so merging operation does not work natively.
447
+ bottomup_no_padding_mode=self.encoder_no_padding_mode,
448
+ topdown_no_padding_mode=self.decoder_no_padding_mode,
449
+ retain_spatial_dims=self.multiscale_decoder_retain_spatial_dims,
450
+ non_stochastic_version=self.non_stochastic_version,
451
+ input_image_shape=self.img_shape,
452
+ normalize_latent_factor=normalize_latent_factor,
453
+ conv2d_bias=self.topdown_conv2d_bias,
454
+ stochastic_use_naive_exponential=self._stochastic_use_naive_exponential,
455
+ )
456
+ )
457
+ return top_down_layers
458
+
459
+ def create_final_topdown_layer(self, upsample: bool) -> nn.Sequential:
460
+ """
461
+ This method creates the final top-down layer of the Decoder.
462
+
463
+ Parameters
464
+ ----------
465
+ upsample: bool
466
+ Whether to upsample the input of the final top-down layer
467
+ by bilinear interpolation with `scale_factor=2`.
468
+ """
469
+ # Final top-down layer
470
+ modules = list()
471
+
472
+ if upsample:
473
+ modules.append(Interpolate(scale=2))
474
+
475
+ for i in range(self.decoder_blocks_per_layer):
476
+ modules.append(
477
+ TopDownDeterministicResBlock(
478
+ c_in=self.decoder_n_filters,
479
+ c_out=self.decoder_n_filters,
480
+ nonlin=self.get_nonlin(),
481
+ batchnorm=self.topdown_batchnorm,
482
+ dropout=self.decoder_dropout,
483
+ res_block_type=self.res_block_type,
484
+ res_block_kernel=self.decoder_res_block_kernel,
485
+ skip_padding=self.decoder_res_block_skip_padding,
486
+ gated=self.gated,
487
+ conv2d_bias=self.topdown_conv2d_bias,
488
+ )
489
+ )
490
+ return nn.Sequential(*modules)
491
+
492
+ def create_likelihood_module(self):
493
+ """
494
+ This method defines the likelihood module for the current LVAE model.
495
+ The existing likelihood modules are `GaussianLikelihood` and `NoiseModelLikelihood`.
496
+ """
497
+ self.likelihood_gm = GaussianLikelihood(
498
+ self.decoder_n_filters,
499
+ self.target_ch,
500
+ predict_logvar=self.predict_logvar,
501
+ logvar_lowerbound=self.logvar_lowerbound,
502
+ conv2d_bias=self.topdown_conv2d_bias,
503
+ )
504
+
505
+ self.likelihood_NM = None
506
+ if self.enable_noise_model:
507
+ self.likelihood_NM = NoiseModelLikelihood(
508
+ self.decoder_n_filters,
509
+ self.target_ch,
510
+ self.data_mean,
511
+ self.data_std,
512
+ self.noiseModel,
513
+ )
514
+ if self.loss_type == LossType.DenoiSplitMuSplit or self.likelihood_NM is None:
515
+ return self.likelihood_gm
516
+
517
+ return self.likelihood_NM
518
+
519
+ def _init_multires(self, config: ml_collections.ConfigDict = None) -> nn.ModuleList:
520
+ """
521
+ This method defines the input block/branch to encode/compress low-res lateral inputs at different hierarchical levels
522
+ in the multiresolution approach (LC). The role of the input branches is similar to the one of the first bottom-up layer
523
+ in the primary flow of the Encoder, namely to compress the lateral input image to a degree that is compatible with the
524
+ one of the primary flow.
525
+
526
+ NOTE 1: Each input branch consists of a sequence of Conv2d + non-linearity + BottomUpDeterministicResBlock.
527
+ It is meaningful to observe that the `BottomUpDeterministicResBlock` shares the same model attributes with the blocks
528
+ in the primary flow of the Encoder (e.g., c_in, c_out, dropout, etc. etc.). Moreover, it does not perform downsampling.
529
+
530
+ NOTE 2: `_multiscale_count` attribute defines the total number of inputs to the bottom-up pass.
531
+ In other terms if we have the input patch and n_LC additional lateral inputs, we will have a total of (n_LC + 1) inputs.
532
+ """
533
+ stride = 1 if self.no_initial_downscaling else 2
534
+ nonlin = self.get_nonlin()
535
+ if self._multiscale_count is None:
536
+ self._multiscale_count = 1
537
+
538
+ msg = "Multiscale count({}) should not exceed the number of bottom up layers ({}) by more than 1"
539
+ msg = msg.format(self._multiscale_count, self.n_layers)
540
+ assert (
541
+ self._multiscale_count <= 1 or self._multiscale_count <= 1 + self.n_layers
542
+ ), msg
543
+
544
+ msg = (
545
+ "if multiscale is enabled, then we are just working with monocrome images."
546
+ )
547
+ assert self._multiscale_count == 1 or self.color_ch == 1, msg
548
+
549
+ lowres_first_bottom_ups = []
550
+ for _ in range(1, self._multiscale_count):
551
+ first_bottom_up = nn.Sequential(
552
+ nn.Conv2d(
553
+ in_channels=self.color_ch,
554
+ out_channels=self.encoder_n_filters,
555
+ kernel_size=5,
556
+ padding=2,
557
+ stride=stride,
558
+ ),
559
+ nonlin(),
560
+ BottomUpDeterministicResBlock(
561
+ c_in=self.encoder_n_filters,
562
+ c_out=self.encoder_n_filters,
563
+ nonlin=nonlin,
564
+ downsample=False,
565
+ batchnorm=self.bottomup_batchnorm,
566
+ dropout=self.encoder_dropout,
567
+ res_block_type=self.res_block_type,
568
+ skip_padding=self.encoder_res_block_skip_padding,
569
+ ),
570
+ )
571
+ lowres_first_bottom_ups.append(first_bottom_up)
572
+
573
+ self.lowres_first_bottom_ups = (
574
+ nn.ModuleList(lowres_first_bottom_ups)
575
+ if len(lowres_first_bottom_ups)
576
+ else None
577
+ )
578
+
579
+ ### SET OF FORWARD-LIKE METHODS
580
+ def bottomup_pass(self, inp: torch.Tensor) -> List[torch.Tensor]:
581
+ """
582
+ Wrapper of _bottomup_pass().
583
+ """
584
+ return self._bottomup_pass(
585
+ inp,
586
+ self.first_bottom_up,
587
+ self.lowres_first_bottom_ups,
588
+ self.bottom_up_layers,
589
+ )
590
+
591
+ def _bottomup_pass(
592
+ self,
593
+ inp: torch.Tensor,
594
+ first_bottom_up: nn.Sequential,
595
+ lowres_first_bottom_ups: nn.ModuleList,
596
+ bottom_up_layers: nn.ModuleList,
597
+ ) -> List[torch.Tensor]:
598
+ """
599
+ This method defines the forward pass throught the LVAE Encoder, the so-called
600
+ Bottom-Up pass.
601
+
602
+ Parameters
603
+ ----------
604
+ inp: torch.Tensor
605
+ The input tensor to the bottom-up pass of shape (B, 1+n_LC, H, W), where n_LC
606
+ is the number of lateral low-res inputs used in the LC approach.
607
+ In particular, the first channel corresponds to the input patch, while the
608
+ remaining ones are associated to the lateral low-res inputs.
609
+ first_bottom_up: nn.Sequential
610
+ The module defining the first bottom-up layer of the Encoder.
611
+ lowres_first_bottom_ups: nn.ModuleList
612
+ The list of modules defining Lateral Contextualization.
613
+ bottom_up_layers: nn.ModuleList
614
+ The list of modules defining the stack of bottom-up layers of the Encoder.
615
+ """
616
+ if self._multiscale_count > 1:
617
+ x = first_bottom_up(inp[:, :1])
618
+ else:
619
+ x = first_bottom_up(inp)
620
+
621
+ # Loop from bottom to top layer, store all deterministic nodes we
622
+ # need for the top-down pass in bu_values list
623
+ bu_values = []
624
+ for i in range(self.n_layers):
625
+ lowres_x = None
626
+ if self._multiscale_count > 1 and i + 1 < inp.shape[1]:
627
+ lowres_x = lowres_first_bottom_ups[i](inp[:, i + 1 : i + 2])
628
+
629
+ x, bu_value = bottom_up_layers[i](x, lowres_x=lowres_x)
630
+ bu_values.append(bu_value)
631
+
632
+ return bu_values
633
+
634
+ def topdown_pass(
635
+ self,
636
+ bu_values: torch.Tensor = None,
637
+ n_img_prior: torch.Tensor = None,
638
+ mode_layers: Iterable[int] = None,
639
+ constant_layers: Iterable[int] = None,
640
+ forced_latent: List[torch.Tensor] = None,
641
+ top_down_layers: nn.ModuleList = None,
642
+ final_top_down_layer: nn.Sequential = None,
643
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
644
+ """
645
+ This method defines the forward pass throught the LVAE Decoder, the so-called
646
+ Top-Down pass.
647
+
648
+ Parameters
649
+ ----------
650
+ bu_values: torch.Tensor, optional
651
+ Output of the bottom-up pass. It will have values from multiple layers of the ladder.
652
+ n_img_prior: optional
653
+ When `bu_values` is `None`, `n_img_prior` indicates the number of images to generate
654
+ from the prior (so bottom-up pass is not used at all here).
655
+ mode_layers: Iterable[int], optional
656
+ A sequence of indexes associated to the layers in which sampling is disabled and
657
+ the mode (mean value) is used instead. Set to `None` to avoid this behaviour.
658
+ constant_layers: Iterable[int], optional
659
+ A sequence of indexes associated to the layers in which a single instance's z is
660
+ copied over the entire batch (bottom-up path is not used, so only prior is used here).
661
+ Set to `None` to avoid this behaviour.
662
+ forced_latent: List[torch.Tensor], optional
663
+ A list of tensors that are used as fixed latent variables (hence, sampling doesn't take
664
+ place in this case).
665
+ top_down_layers: nn.ModuleList, optional
666
+ A list of top-down layers to use in the top-down pass. If `None`, the method uses the
667
+ default layers defined in the contructor.
668
+ final_top_down_layer: nn.Sequential, optional
669
+ The last top-down layer of the top-down pass. If `None`, the method uses the default
670
+ layers defined in the contructor.
671
+ """
672
+ if top_down_layers is None:
673
+ top_down_layers = self.top_down_layers
674
+ if final_top_down_layer is None:
675
+ final_top_down_layer = self.final_top_down
676
+
677
+ # Default: no layer is sampled from the distribution's mode
678
+ if mode_layers is None:
679
+ mode_layers = []
680
+ if constant_layers is None:
681
+ constant_layers = []
682
+ prior_experiment = len(mode_layers) > 0 or len(constant_layers) > 0
683
+
684
+ # If the bottom-up inference values are not given, don't do
685
+ # inference, sample from prior instead
686
+ inference_mode = bu_values is not None
687
+
688
+ # Check consistency of arguments
689
+ if inference_mode != (n_img_prior is None):
690
+ msg = (
691
+ "Number of images for top-down generation has to be given "
692
+ "if and only if we're not doing inference"
693
+ )
694
+ raise RuntimeError(msg)
695
+ if (
696
+ inference_mode
697
+ and prior_experiment
698
+ and (self.non_stochastic_version is False)
699
+ ):
700
+ msg = (
701
+ "Prior experiments (e.g. sampling from mode) are not"
702
+ " compatible with inference mode"
703
+ )
704
+ raise RuntimeError(msg)
705
+
706
+ # Sampled latent variables at each layer
707
+ z = [None] * self.n_layers
708
+
709
+ # KL divergence of each layer
710
+ kl = [None] * self.n_layers
711
+ # Kl divergence restricted, only for the LC enabled setup denoiSplit.
712
+ kl_restricted = [None] * self.n_layers
713
+
714
+ # mean from which z is sampled.
715
+ q_mu = [None] * self.n_layers
716
+ # log(var) from which z is sampled.
717
+ q_lv = [None] * self.n_layers
718
+
719
+ # Spatial map of KL divergence for each layer
720
+ kl_spatial = [None] * self.n_layers
721
+
722
+ debug_qvar_max = [None] * self.n_layers
723
+
724
+ kl_channelwise = [None] * self.n_layers
725
+
726
+ if forced_latent is None:
727
+ forced_latent = [None] * self.n_layers
728
+
729
+ # log p(z) where z is the sample in the topdown pass
730
+ # logprob_p = 0.
731
+
732
+ # Top-down inference/generation loop
733
+ out = out_pre_residual = None
734
+ for i in reversed(range(self.n_layers)):
735
+
736
+ # If available, get deterministic node from bottom-up inference
737
+ try:
738
+ bu_value = bu_values[i]
739
+ except TypeError:
740
+ bu_value = None
741
+
742
+ # Whether the current layer should be sampled from the mode
743
+ use_mode = i in mode_layers
744
+ constant_out = i in constant_layers
745
+ use_uncond_mode = i in self.use_uncond_mode_at
746
+
747
+ # Input for skip connection
748
+ skip_input = out # TODO or n? or both?
749
+
750
+ # Full top-down layer, including sampling and deterministic part
751
+ out, out_pre_residual, aux = top_down_layers[i](
752
+ input_=out,
753
+ skip_connection_input=skip_input,
754
+ inference_mode=inference_mode,
755
+ bu_value=bu_value,
756
+ n_img_prior=n_img_prior,
757
+ use_mode=use_mode,
758
+ force_constant_output=constant_out,
759
+ forced_latent=forced_latent[i],
760
+ mode_pred=self.mode_pred,
761
+ use_uncond_mode=use_uncond_mode,
762
+ var_clip_max=self._var_clip_max,
763
+ )
764
+
765
+ # Save useful variables
766
+ z[i] = aux["z"] # sampled variable at this layer (batch, ch, h, w)
767
+ kl[i] = aux["kl_samplewise"] # (batch, )
768
+ kl_restricted[i] = aux["kl_samplewise_restricted"]
769
+ kl_spatial[i] = aux["kl_spatial"] # (batch, h, w)
770
+ q_mu[i] = aux["q_mu"]
771
+ q_lv[i] = aux["q_lv"]
772
+
773
+ kl_channelwise[i] = aux["kl_channelwise"]
774
+ debug_qvar_max[i] = aux["qvar_max"]
775
+ # if self.mode_pred is False:
776
+ # logprob_p += aux['logprob_p'].mean() # mean over batch
777
+ # else:
778
+ # logprob_p = None
779
+
780
+ # Final top-down layer
781
+ out = final_top_down_layer(out)
782
+
783
+ # Store useful variables in a dict to return them
784
+ data = {
785
+ "z": z, # list of tensors with shape (batch, ch[i], h[i], w[i])
786
+ "kl": kl, # list of tensors with shape (batch, )
787
+ "kl_restricted": kl_restricted, # list of tensors with shape (batch, )
788
+ "kl_spatial": kl_spatial, # list of tensors w shape (batch, h[i], w[i])
789
+ "kl_channelwise": kl_channelwise, # list of tensors with shape (batch, ch[i])
790
+ # 'logprob_p': logprob_p, # scalar, mean over batch
791
+ "q_mu": q_mu,
792
+ "q_lv": q_lv,
793
+ "debug_qvar_max": debug_qvar_max,
794
+ }
795
+ return out, data
796
+
797
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
798
+ """
799
+ Parameters
800
+ ----------
801
+ x: torch.Tensor
802
+ The input tensor of shape (B, C, H, W).
803
+ """
804
+ img_size = x.size()[2:]
805
+
806
+ # Pad input to size equal to the closest power of 2
807
+ x_pad = self.pad_input(x)
808
+
809
+ # Bottom-up inference: return list of length n_layers (bottom to top)
810
+ bu_values = self.bottomup_pass(x_pad)
811
+ for i in range(0, self.skip_bottomk_buvalues):
812
+ bu_values[i] = None
813
+
814
+ mode_layers = range(self.n_layers) if self.non_stochastic_version else None
815
+
816
+ # Top-down inference/generation
817
+ out, td_data = self.topdown_pass(bu_values, mode_layers=mode_layers)
818
+
819
+ if out.shape[-1] > img_size[-1]:
820
+ # Restore original image size
821
+ out = crop_img_tensor(out, img_size)
822
+
823
+ out = self.output_layer(out)
824
+ if self._tethered_to_input:
825
+ assert out.shape[1] == 1
826
+ ch2 = self.get_other_channel(out, x_pad)
827
+ out = torch.cat([out, ch2], dim=1)
828
+
829
+ return out, td_data
830
+
831
+ ### SET OF UTILS METHODS
832
+ # def sample_prior(
833
+ # self,
834
+ # n_imgs,
835
+ # mode_layers=None,
836
+ # constant_layers=None
837
+ # ):
838
+
839
+ # # Generate from prior
840
+ # out, _ = self.topdown_pass(n_img_prior=n_imgs, mode_layers=mode_layers, constant_layers=constant_layers)
841
+ # out = crop_img_tensor(out, self.img_shape)
842
+
843
+ # # Log likelihood and other info (per data point)
844
+ # _, likelihood_data = self.likelihood(out, None)
845
+
846
+ # return likelihood_data['sample']
847
+
848
+ # ### ???
849
+ # def sample_from_q(self, x, masks=None):
850
+ # """
851
+ # This method performs the bottomup_pass() and samples from the
852
+ # obtained distribution.
853
+ # """
854
+ # img_size = x.size()[2:]
855
+
856
+ # # Pad input to make everything easier with conv strides
857
+ # x_pad = self.pad_input(x)
858
+
859
+ # # Bottom-up inference: return list of length n_layers (bottom to top)
860
+ # bu_values = self.bottomup_pass(x_pad)
861
+ # return self._sample_from_q(bu_values, masks=masks)
862
+ # ### ???
863
+
864
+ # def _sample_from_q(self, bu_values, top_down_layers=None, final_top_down_layer=None, masks=None):
865
+ # if top_down_layers is None:
866
+ # top_down_layers = self.top_down_layers
867
+ # if final_top_down_layer is None:
868
+ # final_top_down_layer = self.final_top_down
869
+ # if masks is None:
870
+ # masks = [None] * len(bu_values)
871
+
872
+ # msg = "Multiscale is not supported as of now. You need the output from the previous layers to do this."
873
+ # assert self.n_layers == 1, msg
874
+ # samples = []
875
+ # for i in reversed(range(self.n_layers)):
876
+ # bu_value = bu_values[i]
877
+
878
+ # # Note that the first argument can be set to None since we are just dealing with one level
879
+ # sample = top_down_layers[i].sample_from_q(None, bu_value, var_clip_max=self._var_clip_max, mask=masks[i])
880
+ # samples.append(sample)
881
+
882
+ # return samples
883
+
884
+ # def reset_for_different_output_size(self, output_size):
885
+ # for i in range(self.n_layers):
886
+ # sz = output_size // 2**(1 + i)
887
+ # self.bottom_up_layers[i].output_expected_shape = (sz, sz)
888
+ # self.top_down_layers[i].latent_shape = (output_size, output_size)
889
+
890
+ def pad_input(self, x):
891
+ """
892
+ Pads input x so that its sizes are powers of 2
893
+ :param x:
894
+ :return: Padded tensor
895
+ """
896
+ size = self.get_padded_size(x.size())
897
+ x = pad_img_tensor(x, size)
898
+ return x
899
+
900
+ ### SET OF GETTERS
901
+ def get_nonlin(self):
902
+ nonlin = {
903
+ "relu": nn.ReLU,
904
+ "leakyrelu": nn.LeakyReLU,
905
+ "elu": nn.ELU,
906
+ "selu": nn.SELU,
907
+ }
908
+ return nonlin[self.nonlin]
909
+
910
+ def get_padded_size(self, size):
911
+ """
912
+ Returns the smallest size (H, W) of the image with actual size given
913
+ as input, such that H and W are powers of 2.
914
+ :param size: input size, tuple either (N, C, H, w) or (H, W)
915
+ :return: 2-tuple (H, W)
916
+ """
917
+ # Make size argument into (heigth, width)
918
+ if len(size) == 4:
919
+ size = size[2:]
920
+ if len(size) != 2:
921
+ msg = (
922
+ "input size must be either (N, C, H, W) or (H, W), but it "
923
+ f"has length {len(size)} (size={size})"
924
+ )
925
+ raise RuntimeError(msg)
926
+
927
+ if self.multiscale_decoder_retain_spatial_dims is True:
928
+ # In this case, we can go much more deeper and so this is not required
929
+ # (in the way it is. ;). More work would be needed if this was to be correctly implemented )
930
+ return list(size)
931
+
932
+ # Overall downscale factor from input to top layer (power of 2)
933
+ dwnsc = self.overall_downscale_factor
934
+
935
+ # Output smallest powers of 2 that are larger than current sizes
936
+ padded_size = list(((s - 1) // dwnsc + 1) * dwnsc for s in size)
937
+
938
+ return padded_size
939
+
940
+ def get_latent_spatial_size(self, level_idx: int):
941
+ """
942
+ level_idx: 0 is the bottommost layer, the highest resolution one.
943
+ """
944
+ actual_downsampling = level_idx + 1
945
+ dwnsc = 2**actual_downsampling
946
+ sz = self.get_padded_size(self.img_shape)
947
+ h = sz[0] // dwnsc
948
+ w = sz[1] // dwnsc
949
+ assert h == w
950
+ return h
951
+
952
+ def get_top_prior_param_shape(self, n_imgs: int = 1):
953
+ # TODO num channels depends on random variable we're using
954
+
955
+ # Compute the total downscaling performed in the Encoder
956
+ if self.multiscale_decoder_retain_spatial_dims is False:
957
+ dwnsc = self.overall_downscale_factor
958
+ else:
959
+ # LC allow the encoder latents to keep the same (H, W) size at different levels
960
+ actual_downsampling = self.n_layers + 1 - self._multiscale_count
961
+ dwnsc = 2**actual_downsampling
962
+
963
+ sz = self.get_padded_size(self.img_shape)
964
+ h = sz[0] // dwnsc
965
+ w = sz[1] // dwnsc
966
+ c = self.z_dims[-1] * 2 # mu and logvar
967
+ top_layer_shape = (n_imgs, c, h, w)
968
+ return top_layer_shape
969
+
970
+ def get_other_channel(self, ch1, input):
971
+ assert self.data_std["target"].squeeze().shape == (2,)
972
+ assert self.data_mean["target"].squeeze().shape == (2,)
973
+ assert self.target_ch == 2
974
+ ch1_un = (
975
+ ch1[:, :1] * self.data_std["target"][:, :1]
976
+ + self.data_mean["target"][:, :1]
977
+ )
978
+ input_un = input * self.data_std["input"] + self.data_mean["input"]
979
+ ch2_un = self._tethered_ch2_scalar * (
980
+ input_un - ch1_un * self._tethered_ch1_scalar
981
+ )
982
+ ch2 = (ch2_un - self.data_mean["target"][:, -1:]) / self.data_std["target"][
983
+ :, -1:
984
+ ]
985
+ return ch2