careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,848 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Ladder VAE (LVAE) Model.
|
|
3
|
+
|
|
4
|
+
The current implementation is based on "Interpretable Unsupervised Diversity Denoising
|
|
5
|
+
and Artefact Removal, Prakash et al."
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from collections.abc import Iterable
|
|
9
|
+
from typing import Union
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import torch
|
|
13
|
+
import torch.nn as nn
|
|
14
|
+
|
|
15
|
+
from ..activation import get_activation
|
|
16
|
+
from .layers import (
|
|
17
|
+
BottomUpDeterministicResBlock,
|
|
18
|
+
BottomUpLayer,
|
|
19
|
+
GateLayer,
|
|
20
|
+
TopDownDeterministicResBlock,
|
|
21
|
+
TopDownLayer,
|
|
22
|
+
)
|
|
23
|
+
from .utils import Interpolate, ModelType, crop_img_tensor
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class LadderVAE(nn.Module):
|
|
27
|
+
"""
|
|
28
|
+
Constructor.
|
|
29
|
+
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
input_shape : int
|
|
33
|
+
The size of the input image.
|
|
34
|
+
output_channels : int
|
|
35
|
+
The number of output channels.
|
|
36
|
+
multiscale_count : int
|
|
37
|
+
The number of scales for multiscale processing.
|
|
38
|
+
z_dims : list[int]
|
|
39
|
+
The dimensions of the latent space for each layer.
|
|
40
|
+
encoder_n_filters : int
|
|
41
|
+
The number of filters in the encoder.
|
|
42
|
+
decoder_n_filters : int
|
|
43
|
+
The number of filters in the decoder.
|
|
44
|
+
encoder_conv_strides : list[int]
|
|
45
|
+
The strides for the conv layers encoder.
|
|
46
|
+
decoder_conv_strides : list[int]
|
|
47
|
+
The strides for the conv layers decoder.
|
|
48
|
+
encoder_dropout : float
|
|
49
|
+
The dropout rate for the encoder.
|
|
50
|
+
decoder_dropout : float
|
|
51
|
+
The dropout rate for the decoder.
|
|
52
|
+
nonlinearity : str
|
|
53
|
+
The nonlinearity function to use.
|
|
54
|
+
predict_logvar : bool
|
|
55
|
+
Whether to predict the log variance.
|
|
56
|
+
analytical_kl : bool
|
|
57
|
+
Whether to use analytical KL divergence.
|
|
58
|
+
|
|
59
|
+
Raises
|
|
60
|
+
------
|
|
61
|
+
NotImplementedError
|
|
62
|
+
If only 2D convolutions are supported.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
input_shape: int,
|
|
68
|
+
output_channels: int,
|
|
69
|
+
multiscale_count: int,
|
|
70
|
+
z_dims: list[int],
|
|
71
|
+
encoder_n_filters: int,
|
|
72
|
+
decoder_n_filters: int,
|
|
73
|
+
encoder_conv_strides: list[int],
|
|
74
|
+
decoder_conv_strides: list[int],
|
|
75
|
+
encoder_dropout: float,
|
|
76
|
+
decoder_dropout: float,
|
|
77
|
+
nonlinearity: str,
|
|
78
|
+
predict_logvar: bool,
|
|
79
|
+
analytical_kl: bool,
|
|
80
|
+
):
|
|
81
|
+
super().__init__()
|
|
82
|
+
|
|
83
|
+
# -------------------------------------------------------
|
|
84
|
+
# Customizable attributes
|
|
85
|
+
self.image_size = input_shape
|
|
86
|
+
"""Input image size. (Z, Y, X) or (Y, X) if the data is 2D."""
|
|
87
|
+
# TODO: we need to be careful with this since used to be an int.
|
|
88
|
+
# the tuple of shapes used to be `self.input_shape`.
|
|
89
|
+
self.target_ch = output_channels
|
|
90
|
+
self.encoder_conv_strides = encoder_conv_strides
|
|
91
|
+
self.decoder_conv_strides = decoder_conv_strides
|
|
92
|
+
self._multiscale_count = multiscale_count
|
|
93
|
+
self.z_dims = z_dims
|
|
94
|
+
self.encoder_n_filters = encoder_n_filters
|
|
95
|
+
self.decoder_n_filters = decoder_n_filters
|
|
96
|
+
self.encoder_dropout = encoder_dropout
|
|
97
|
+
self.decoder_dropout = decoder_dropout
|
|
98
|
+
self.nonlin = nonlinearity
|
|
99
|
+
self.predict_logvar = predict_logvar
|
|
100
|
+
self.analytical_kl = analytical_kl
|
|
101
|
+
# -------------------------------------------------------
|
|
102
|
+
|
|
103
|
+
# -------------------------------------------------------
|
|
104
|
+
# Model attributes -> Hardcoded
|
|
105
|
+
self.model_type = ModelType.LadderVae # TODO remove !
|
|
106
|
+
self.encoder_blocks_per_layer = 1
|
|
107
|
+
self.decoder_blocks_per_layer = 1
|
|
108
|
+
self.bottomup_batchnorm = True
|
|
109
|
+
self.topdown_batchnorm = True
|
|
110
|
+
self.topdown_conv2d_bias = True
|
|
111
|
+
self.gated = True
|
|
112
|
+
self.encoder_res_block_kernel = 3
|
|
113
|
+
self.decoder_res_block_kernel = 3
|
|
114
|
+
self.encoder_res_block_skip_padding = False
|
|
115
|
+
self.decoder_res_block_skip_padding = False
|
|
116
|
+
self.merge_type = "residual"
|
|
117
|
+
self.no_initial_downscaling = True
|
|
118
|
+
self.skip_bottomk_buvalues = 0
|
|
119
|
+
self.stochastic_skip = True
|
|
120
|
+
self.learn_top_prior = True
|
|
121
|
+
self.res_block_type = "bacdbacd" # TODO remove !
|
|
122
|
+
self.mode_pred = False
|
|
123
|
+
self.logvar_lowerbound = -5
|
|
124
|
+
self._var_clip_max = 20
|
|
125
|
+
self._stochastic_use_naive_exponential = False
|
|
126
|
+
self._enable_topdown_normalize_factor = True
|
|
127
|
+
|
|
128
|
+
# Attributes that handle LC -> Hardcoded
|
|
129
|
+
self.enable_multiscale = self._multiscale_count > 1
|
|
130
|
+
self.multiscale_retain_spatial_dims = True
|
|
131
|
+
self.multiscale_lowres_separate_branch = False
|
|
132
|
+
self.multiscale_decoder_retain_spatial_dims = (
|
|
133
|
+
self.multiscale_retain_spatial_dims and self.enable_multiscale
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# Derived attributes
|
|
137
|
+
self.n_layers = len(self.z_dims)
|
|
138
|
+
|
|
139
|
+
# Others...
|
|
140
|
+
self._tethered_to_input = False
|
|
141
|
+
self._tethered_ch1_scalar = self._tethered_ch2_scalar = None
|
|
142
|
+
if self._tethered_to_input:
|
|
143
|
+
target_ch = 1
|
|
144
|
+
requires_grad = False
|
|
145
|
+
self._tethered_ch1_scalar = nn.Parameter(
|
|
146
|
+
torch.ones(1) * 0.5, requires_grad=requires_grad
|
|
147
|
+
)
|
|
148
|
+
self._tethered_ch2_scalar = nn.Parameter(
|
|
149
|
+
torch.ones(1) * 2.0, requires_grad=requires_grad
|
|
150
|
+
)
|
|
151
|
+
# -------------------------------------------------------
|
|
152
|
+
|
|
153
|
+
# -------------------------------------------------------
|
|
154
|
+
# Data attributes
|
|
155
|
+
self.color_ch = 1 # TODO for now we only support 1 channel
|
|
156
|
+
self.normalized_input = True
|
|
157
|
+
# -------------------------------------------------------
|
|
158
|
+
|
|
159
|
+
# -------------------------------------------------------
|
|
160
|
+
# Loss attributes
|
|
161
|
+
# enabling reconstruction loss on mixed input
|
|
162
|
+
self.mixed_rec_w = 0
|
|
163
|
+
self.nbr_consistency_w = 0
|
|
164
|
+
|
|
165
|
+
# -------------------------------------------------------
|
|
166
|
+
# 3D related stuff
|
|
167
|
+
self._mode_3D = len(self.image_size) == 3 # TODO refac
|
|
168
|
+
self._model_3D_depth = self.image_size[0] if self._mode_3D else 1
|
|
169
|
+
self._decoder_mode_3D = len(self.decoder_conv_strides) == 3
|
|
170
|
+
if self._mode_3D and not self._decoder_mode_3D:
|
|
171
|
+
assert self._model_3D_depth % 2 == 1, "3D model depth should be odd"
|
|
172
|
+
assert (
|
|
173
|
+
self._mode_3D is True or self._decoder_mode_3D is False
|
|
174
|
+
), "Decoder cannot be 3D when encoder is 2D"
|
|
175
|
+
self._squish3d = self._mode_3D and not self._decoder_mode_3D
|
|
176
|
+
self._3D_squisher = (
|
|
177
|
+
None
|
|
178
|
+
if not self._squish3d
|
|
179
|
+
else nn.ModuleList(
|
|
180
|
+
[
|
|
181
|
+
GateLayer(
|
|
182
|
+
channels=self.encoder_n_filters,
|
|
183
|
+
conv_strides=self.encoder_conv_strides,
|
|
184
|
+
)
|
|
185
|
+
for k in range(len(self.z_dims))
|
|
186
|
+
]
|
|
187
|
+
)
|
|
188
|
+
)
|
|
189
|
+
# TODO: this bit is in the Ashesh's confusing-hacky style... Can we do better?
|
|
190
|
+
|
|
191
|
+
# -------------------------------------------------------
|
|
192
|
+
# # Training attributes
|
|
193
|
+
# # can be used to tile the validation predictions
|
|
194
|
+
# self._val_idx_manager = val_idx_manager
|
|
195
|
+
# self._val_frame_creator = None
|
|
196
|
+
# # initialize the learning rate scheduler params.
|
|
197
|
+
# self.lr_scheduler_monitor = self.lr_scheduler_mode = None
|
|
198
|
+
# self._init_lr_scheduler_params(config)
|
|
199
|
+
# self._global_step = 0
|
|
200
|
+
# -------------------------------------------------------
|
|
201
|
+
|
|
202
|
+
# -------------------------------------------------------
|
|
203
|
+
|
|
204
|
+
# Calculate the downsampling happening in the network
|
|
205
|
+
self.downsample = [1] * self.n_layers
|
|
206
|
+
self.overall_downscale_factor = np.power(2, sum(self.downsample))
|
|
207
|
+
if not self.no_initial_downscaling: # by default do another downscaling
|
|
208
|
+
self.overall_downscale_factor *= 2
|
|
209
|
+
|
|
210
|
+
assert max(self.downsample) <= self.encoder_blocks_per_layer
|
|
211
|
+
assert len(self.downsample) == self.n_layers
|
|
212
|
+
# -------------------------------------------------------
|
|
213
|
+
|
|
214
|
+
# -------------------------------------------------------
|
|
215
|
+
### CREATE MODEL BLOCKS
|
|
216
|
+
# First bottom-up layer: change num channels + downsample by factor 2
|
|
217
|
+
# unless we want to prevent this
|
|
218
|
+
self.encoder_conv_op = getattr(nn, f"Conv{len(self.encoder_conv_strides)}d")
|
|
219
|
+
# TODO these should be defined for all layers here ?
|
|
220
|
+
self.decoder_conv_op = getattr(nn, f"Conv{len(self.decoder_conv_strides)}d")
|
|
221
|
+
# TODO: would be more readable to have a derived parameters to use like
|
|
222
|
+
# `conv_dims = len(self.encoder_conv_strides)` and then use `Conv{conv_dims}d`
|
|
223
|
+
stride = 1 if self.no_initial_downscaling else 2
|
|
224
|
+
self.first_bottom_up = self.create_first_bottom_up(stride)
|
|
225
|
+
|
|
226
|
+
# Input Branches for Lateral Contextualization
|
|
227
|
+
self.lowres_first_bottom_ups = None
|
|
228
|
+
self._init_multires()
|
|
229
|
+
|
|
230
|
+
# Other bottom-up layers
|
|
231
|
+
self.bottom_up_layers = self.create_bottom_up_layers(
|
|
232
|
+
self.multiscale_lowres_separate_branch
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
# Top-down layers
|
|
236
|
+
self.top_down_layers = self.create_top_down_layers()
|
|
237
|
+
self.final_top_down = self.create_final_topdown_layer(
|
|
238
|
+
not self.no_initial_downscaling
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
# Likelihood module
|
|
242
|
+
# self.likelihood = self.create_likelihood_module()
|
|
243
|
+
|
|
244
|
+
# Output layer --> Project to target_ch many channels
|
|
245
|
+
logvar_ch_needed = self.predict_logvar is not None
|
|
246
|
+
self.output_layer = self.parameter_net = self.decoder_conv_op(
|
|
247
|
+
self.decoder_n_filters,
|
|
248
|
+
self.target_ch * (1 + logvar_ch_needed),
|
|
249
|
+
kernel_size=3,
|
|
250
|
+
padding=1,
|
|
251
|
+
bias=self.topdown_conv2d_bias,
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
# # gradient norms. updated while training. this is also logged.
|
|
255
|
+
# self.grad_norm_bottom_up = 0.0
|
|
256
|
+
# self.grad_norm_top_down = 0.0
|
|
257
|
+
# PSNR computation on validation.
|
|
258
|
+
# self.label1_psnr = RunningPSNR()
|
|
259
|
+
# self.label2_psnr = RunningPSNR()
|
|
260
|
+
# TODO: did you add this?
|
|
261
|
+
|
|
262
|
+
# msg =f'[{self.__class__.__name__}] Stoc:{not self.non_stochastic_version} RecMode:{self.reconstruction_mode} TethInput:{self._tethered_to_input}'
|
|
263
|
+
# msg += f' TargetCh: {self.target_ch}'
|
|
264
|
+
# print(msg)
|
|
265
|
+
|
|
266
|
+
### SET OF METHODS TO CREATE MODEL BLOCKS
|
|
267
|
+
def create_first_bottom_up(
|
|
268
|
+
self,
|
|
269
|
+
init_stride: int,
|
|
270
|
+
num_res_blocks: int = 1,
|
|
271
|
+
) -> nn.Sequential:
|
|
272
|
+
"""
|
|
273
|
+
Method creates the first bottom-up block of the Encoder.
|
|
274
|
+
|
|
275
|
+
Its role is to perform a first image compression step.
|
|
276
|
+
It is composed by a sequence of nn.Conv2d + non-linearity +
|
|
277
|
+
BottomUpDeterministicResBlock (1 or more, default is 1).
|
|
278
|
+
|
|
279
|
+
Parameters
|
|
280
|
+
----------
|
|
281
|
+
init_stride: int
|
|
282
|
+
The stride used by the intial Conv2d block.
|
|
283
|
+
num_res_blocks: int, optional
|
|
284
|
+
The number of BottomUpDeterministicResBlocks, default is 1.
|
|
285
|
+
"""
|
|
286
|
+
# From what I got from Ashesh, Z should not be touched in any case.
|
|
287
|
+
nonlin = get_activation(self.nonlin)
|
|
288
|
+
conv_block = self.encoder_conv_op(
|
|
289
|
+
in_channels=self.color_ch,
|
|
290
|
+
out_channels=self.encoder_n_filters,
|
|
291
|
+
kernel_size=self.encoder_res_block_kernel,
|
|
292
|
+
padding=(
|
|
293
|
+
0
|
|
294
|
+
if self.encoder_res_block_skip_padding
|
|
295
|
+
else self.encoder_res_block_kernel // 2
|
|
296
|
+
),
|
|
297
|
+
stride=init_stride,
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
modules = [conv_block, nonlin]
|
|
301
|
+
|
|
302
|
+
for _ in range(num_res_blocks):
|
|
303
|
+
modules.append(
|
|
304
|
+
BottomUpDeterministicResBlock(
|
|
305
|
+
conv_strides=self.encoder_conv_strides,
|
|
306
|
+
c_in=self.encoder_n_filters,
|
|
307
|
+
c_out=self.encoder_n_filters,
|
|
308
|
+
nonlin=nonlin,
|
|
309
|
+
downsample=False,
|
|
310
|
+
batchnorm=self.bottomup_batchnorm,
|
|
311
|
+
dropout=self.encoder_dropout,
|
|
312
|
+
res_block_type=self.res_block_type,
|
|
313
|
+
res_block_kernel=self.encoder_res_block_kernel,
|
|
314
|
+
)
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
return nn.Sequential(*modules)
|
|
318
|
+
|
|
319
|
+
def create_bottom_up_layers(self, lowres_separate_branch: bool) -> nn.ModuleList:
|
|
320
|
+
"""
|
|
321
|
+
Method creates the stack of bottom-up layers of the Encoder.
|
|
322
|
+
|
|
323
|
+
that are used to generate the so-called `bu_values`.
|
|
324
|
+
|
|
325
|
+
NOTE:
|
|
326
|
+
If `self._multiscale_count < self.n_layers`, then LC is done only in the first
|
|
327
|
+
`self._multiscale_count` bottom-up layers (starting from the bottom).
|
|
328
|
+
|
|
329
|
+
Parameters
|
|
330
|
+
----------
|
|
331
|
+
lowres_separate_branch: bool
|
|
332
|
+
Whether the residual block(s) used for encoding the low-res input are shared
|
|
333
|
+
(`False`) or not (`True`) with the "same-size" residual block(s) in the
|
|
334
|
+
`BottomUpLayer`'s primary flow.
|
|
335
|
+
"""
|
|
336
|
+
multiscale_lowres_size_factor = 1
|
|
337
|
+
nonlin = get_activation(self.nonlin)
|
|
338
|
+
|
|
339
|
+
bottom_up_layers = nn.ModuleList([])
|
|
340
|
+
for i in range(self.n_layers):
|
|
341
|
+
# Whether this is the top layer
|
|
342
|
+
is_top = i == self.n_layers - 1
|
|
343
|
+
|
|
344
|
+
# LC is applied only to the first (_multiscale_count - 1) bottom-up layers
|
|
345
|
+
layer_enable_multiscale = (
|
|
346
|
+
self.enable_multiscale and self._multiscale_count > i + 1
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
# This factor determines the factor by which the low-resolution tensor is larger
|
|
350
|
+
# N.B. Only used if layer_enable_multiscale == True, so we updated it only in that case
|
|
351
|
+
multiscale_lowres_size_factor *= 1 + int(layer_enable_multiscale)
|
|
352
|
+
|
|
353
|
+
# TODO: check correctness of this
|
|
354
|
+
if self._multiscale_count > 1:
|
|
355
|
+
output_expected_shape = (dim // 2 ** (i + 1) for dim in self.image_size)
|
|
356
|
+
else:
|
|
357
|
+
output_expected_shape = None
|
|
358
|
+
|
|
359
|
+
# Add bottom-up deterministic layer at level i.
|
|
360
|
+
# It's a sequence of residual blocks (BottomUpDeterministicResBlock), possibly with downsampling between them.
|
|
361
|
+
bottom_up_layers.append(
|
|
362
|
+
BottomUpLayer(
|
|
363
|
+
n_res_blocks=self.encoder_blocks_per_layer,
|
|
364
|
+
n_filters=self.encoder_n_filters,
|
|
365
|
+
downsampling_steps=self.downsample[i],
|
|
366
|
+
nonlin=nonlin,
|
|
367
|
+
conv_strides=self.encoder_conv_strides,
|
|
368
|
+
batchnorm=self.bottomup_batchnorm,
|
|
369
|
+
dropout=self.encoder_dropout,
|
|
370
|
+
res_block_type=self.res_block_type,
|
|
371
|
+
res_block_kernel=self.encoder_res_block_kernel,
|
|
372
|
+
gated=self.gated,
|
|
373
|
+
lowres_separate_branch=lowres_separate_branch,
|
|
374
|
+
enable_multiscale=self.enable_multiscale, # TODO: shouldn't the arg be `layer_enable_multiscale` here?
|
|
375
|
+
multiscale_retain_spatial_dims=self.multiscale_retain_spatial_dims,
|
|
376
|
+
multiscale_lowres_size_factor=multiscale_lowres_size_factor,
|
|
377
|
+
decoder_retain_spatial_dims=self.multiscale_decoder_retain_spatial_dims,
|
|
378
|
+
output_expected_shape=output_expected_shape,
|
|
379
|
+
)
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
return bottom_up_layers
|
|
383
|
+
|
|
384
|
+
def create_top_down_layers(self) -> nn.ModuleList:
|
|
385
|
+
"""
|
|
386
|
+
Method creates the stack of top-down layers of the Decoder.
|
|
387
|
+
|
|
388
|
+
In these layer the `bu`_values` from the Encoder are merged with the `p_params` from the previous layer
|
|
389
|
+
of the Decoder to get `q_params`. Then, a stochastic layer generates a sample from the latent distribution
|
|
390
|
+
with parameters `q_params`. Finally, this sample is fed through a TopDownDeterministicResBlock to
|
|
391
|
+
compute the `p_params` for the layer below.
|
|
392
|
+
|
|
393
|
+
NOTE 1:
|
|
394
|
+
The algorithm for generative inference approximately works as follows:
|
|
395
|
+
- p_params = output of top-down layer above
|
|
396
|
+
- bu = inferred bottom-up value at this layer
|
|
397
|
+
- q_params = merge(bu, p_params)
|
|
398
|
+
- z = stochastic_layer(q_params)
|
|
399
|
+
- (optional) get and merge skip connection from prev top-down layer
|
|
400
|
+
- top-down deterministic ResNet
|
|
401
|
+
|
|
402
|
+
NOTE 2:
|
|
403
|
+
When doing unconditional generation, bu_value is not available. Hence the
|
|
404
|
+
merge layer is not used, and z is sampled directly from p_params.
|
|
405
|
+
|
|
406
|
+
"""
|
|
407
|
+
top_down_layers = nn.ModuleList([])
|
|
408
|
+
nonlin = get_activation(self.nonlin)
|
|
409
|
+
# NOTE: top-down layers are created starting from the bottom-most
|
|
410
|
+
for i in range(self.n_layers):
|
|
411
|
+
# Check if this is the top layer
|
|
412
|
+
is_top = i == self.n_layers - 1
|
|
413
|
+
|
|
414
|
+
if self._enable_topdown_normalize_factor: # TODO: What is this?
|
|
415
|
+
normalize_latent_factor = (
|
|
416
|
+
1 / np.sqrt(2 * (1 + i)) if len(self.z_dims) > 4 else 1.0
|
|
417
|
+
)
|
|
418
|
+
else:
|
|
419
|
+
normalize_latent_factor = 1.0
|
|
420
|
+
|
|
421
|
+
top_down_layers.append(
|
|
422
|
+
TopDownLayer(
|
|
423
|
+
z_dim=self.z_dims[i],
|
|
424
|
+
n_res_blocks=self.decoder_blocks_per_layer,
|
|
425
|
+
n_filters=self.decoder_n_filters,
|
|
426
|
+
is_top_layer=is_top,
|
|
427
|
+
conv_strides=self.decoder_conv_strides,
|
|
428
|
+
upsampling_steps=self.downsample[i],
|
|
429
|
+
nonlin=nonlin,
|
|
430
|
+
merge_type=self.merge_type,
|
|
431
|
+
batchnorm=self.topdown_batchnorm,
|
|
432
|
+
dropout=self.decoder_dropout,
|
|
433
|
+
stochastic_skip=self.stochastic_skip,
|
|
434
|
+
learn_top_prior=self.learn_top_prior,
|
|
435
|
+
top_prior_param_shape=self.get_top_prior_param_shape(),
|
|
436
|
+
res_block_type=self.res_block_type,
|
|
437
|
+
res_block_kernel=self.decoder_res_block_kernel,
|
|
438
|
+
gated=self.gated,
|
|
439
|
+
analytical_kl=self.analytical_kl,
|
|
440
|
+
vanilla_latent_hw=self.get_latent_spatial_size(i),
|
|
441
|
+
retain_spatial_dims=self.multiscale_decoder_retain_spatial_dims,
|
|
442
|
+
input_image_shape=self.image_size,
|
|
443
|
+
normalize_latent_factor=normalize_latent_factor,
|
|
444
|
+
conv2d_bias=self.topdown_conv2d_bias,
|
|
445
|
+
stochastic_use_naive_exponential=self._stochastic_use_naive_exponential,
|
|
446
|
+
)
|
|
447
|
+
)
|
|
448
|
+
return top_down_layers
|
|
449
|
+
|
|
450
|
+
def create_final_topdown_layer(self, upsample: bool) -> nn.Sequential:
|
|
451
|
+
"""Create the final top-down layer of the Decoder.
|
|
452
|
+
|
|
453
|
+
NOTE: In this layer, (optional) upsampling is performed by bilinear interpolation
|
|
454
|
+
instead of transposed convolution (like in other TD layers).
|
|
455
|
+
|
|
456
|
+
Parameters
|
|
457
|
+
----------
|
|
458
|
+
upsample: bool
|
|
459
|
+
Whether to upsample the input of the final top-down layer
|
|
460
|
+
by bilinear interpolation with `scale_factor=2`.
|
|
461
|
+
"""
|
|
462
|
+
# Final top-down layer
|
|
463
|
+
modules = list()
|
|
464
|
+
|
|
465
|
+
if upsample:
|
|
466
|
+
modules.append(Interpolate(scale=2))
|
|
467
|
+
|
|
468
|
+
for i in range(self.decoder_blocks_per_layer):
|
|
469
|
+
modules.append(
|
|
470
|
+
TopDownDeterministicResBlock(
|
|
471
|
+
c_in=self.decoder_n_filters,
|
|
472
|
+
c_out=self.decoder_n_filters,
|
|
473
|
+
nonlin=get_activation(self.nonlin),
|
|
474
|
+
conv_strides=self.decoder_conv_strides,
|
|
475
|
+
batchnorm=self.topdown_batchnorm,
|
|
476
|
+
dropout=self.decoder_dropout,
|
|
477
|
+
res_block_type=self.res_block_type,
|
|
478
|
+
res_block_kernel=self.decoder_res_block_kernel,
|
|
479
|
+
gated=self.gated,
|
|
480
|
+
conv2d_bias=self.topdown_conv2d_bias,
|
|
481
|
+
)
|
|
482
|
+
)
|
|
483
|
+
return nn.Sequential(*modules)
|
|
484
|
+
|
|
485
|
+
def _init_multires(self, config=None) -> nn.ModuleList:
|
|
486
|
+
"""
|
|
487
|
+
Method defines the input block/branch to encode/compress low-res lateral inputs.
|
|
488
|
+
|
|
489
|
+
at different hierarchical levels
|
|
490
|
+
in the multiresolution approach (LC). The role of the input branches is similar
|
|
491
|
+
to the one of the first bottom-up layer in the primary flow of the Encoder,
|
|
492
|
+
namely to compress the lateral input image to a degree that is compatible with
|
|
493
|
+
the one of the primary flow.
|
|
494
|
+
|
|
495
|
+
NOTE 1: Each input branch consists of a sequence of Conv2d + non-linearity
|
|
496
|
+
+ BottomUpDeterministicResBlock. It is meaningful to observe that the
|
|
497
|
+
`BottomUpDeterministicResBlock` shares the same model attributes with the blocks
|
|
498
|
+
in the primary flow of the Encoder (e.g., c_in, c_out, dropout, etc. etc.).
|
|
499
|
+
Moreover, it does not perform downsampling.
|
|
500
|
+
|
|
501
|
+
NOTE 2: `_multiscale_count` attribute defines the total number of inputs to the
|
|
502
|
+
bottom-up pass. In other terms if we have the input patch and n_LC additional
|
|
503
|
+
lateral inputs, we will have a total of (n_LC + 1) inputs.
|
|
504
|
+
"""
|
|
505
|
+
stride = 1 if self.no_initial_downscaling else 2
|
|
506
|
+
nonlin = get_activation(self.nonlin)
|
|
507
|
+
if self._multiscale_count is None:
|
|
508
|
+
self._multiscale_count = 1
|
|
509
|
+
|
|
510
|
+
msg = (
|
|
511
|
+
f"Multiscale count ({self._multiscale_count}) should not exceed the number"
|
|
512
|
+
f"of bottom up layers ({self.n_layers}) by more than 1.\n"
|
|
513
|
+
)
|
|
514
|
+
assert (
|
|
515
|
+
self._multiscale_count <= 1 or self._multiscale_count <= 1 + self.n_layers
|
|
516
|
+
), msg # TODO how ?
|
|
517
|
+
|
|
518
|
+
msg = (
|
|
519
|
+
"Multiscale approach only supports monocrome images. "
|
|
520
|
+
f"Found instead color_ch={self.color_ch}."
|
|
521
|
+
)
|
|
522
|
+
# assert self._multiscale_count == 1 or self.color_ch == 1, msg
|
|
523
|
+
|
|
524
|
+
lowres_first_bottom_ups = []
|
|
525
|
+
for _ in range(1, self._multiscale_count):
|
|
526
|
+
first_bottom_up = nn.Sequential(
|
|
527
|
+
self.encoder_conv_op(
|
|
528
|
+
in_channels=self.color_ch,
|
|
529
|
+
out_channels=self.encoder_n_filters,
|
|
530
|
+
kernel_size=5,
|
|
531
|
+
padding="same",
|
|
532
|
+
stride=stride,
|
|
533
|
+
),
|
|
534
|
+
nonlin,
|
|
535
|
+
BottomUpDeterministicResBlock(
|
|
536
|
+
c_in=self.encoder_n_filters,
|
|
537
|
+
c_out=self.encoder_n_filters,
|
|
538
|
+
conv_strides=self.encoder_conv_strides,
|
|
539
|
+
nonlin=nonlin,
|
|
540
|
+
downsample=False,
|
|
541
|
+
batchnorm=self.bottomup_batchnorm,
|
|
542
|
+
dropout=self.encoder_dropout,
|
|
543
|
+
res_block_type=self.res_block_type,
|
|
544
|
+
),
|
|
545
|
+
)
|
|
546
|
+
lowres_first_bottom_ups.append(first_bottom_up)
|
|
547
|
+
|
|
548
|
+
self.lowres_first_bottom_ups = (
|
|
549
|
+
nn.ModuleList(lowres_first_bottom_ups)
|
|
550
|
+
if len(lowres_first_bottom_ups)
|
|
551
|
+
else None
|
|
552
|
+
)
|
|
553
|
+
|
|
554
|
+
### SET OF FORWARD-LIKE METHODS
|
|
555
|
+
def bottomup_pass(self, inp: torch.Tensor) -> list[torch.Tensor]:
|
|
556
|
+
"""Wrapper of _bottomup_pass()."""
|
|
557
|
+
# TODO Remove wrapper
|
|
558
|
+
return self._bottomup_pass(
|
|
559
|
+
inp,
|
|
560
|
+
self.first_bottom_up,
|
|
561
|
+
self.lowres_first_bottom_ups,
|
|
562
|
+
self.bottom_up_layers,
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
def _bottomup_pass(
|
|
566
|
+
self,
|
|
567
|
+
inp: torch.Tensor,
|
|
568
|
+
first_bottom_up: nn.Sequential,
|
|
569
|
+
lowres_first_bottom_ups: nn.ModuleList,
|
|
570
|
+
bottom_up_layers: nn.ModuleList,
|
|
571
|
+
) -> list[torch.Tensor]:
|
|
572
|
+
"""
|
|
573
|
+
Method defines the forward pass through the LVAE Encoder, the so-called.
|
|
574
|
+
|
|
575
|
+
Bottom-Up pass.
|
|
576
|
+
|
|
577
|
+
Parameters
|
|
578
|
+
----------
|
|
579
|
+
inp: torch.Tensor
|
|
580
|
+
The input tensor to the bottom-up pass of shape (B, 1+n_LC, H, W), where n_LC
|
|
581
|
+
is the number of lateral low-res inputs used in the LC approach.
|
|
582
|
+
In particular, the first channel corresponds to the input patch, while the
|
|
583
|
+
remaining ones are associated to the lateral low-res inputs.
|
|
584
|
+
first_bottom_up: nn.Sequential
|
|
585
|
+
The module defining the first bottom-up layer of the Encoder.
|
|
586
|
+
lowres_first_bottom_ups: nn.ModuleList
|
|
587
|
+
The list of modules defining Lateral Contextualization.
|
|
588
|
+
bottom_up_layers: nn.ModuleList
|
|
589
|
+
The list of modules defining the stack of bottom-up layers of the Encoder.
|
|
590
|
+
"""
|
|
591
|
+
if self._multiscale_count > 1:
|
|
592
|
+
x = first_bottom_up(inp[:, :1])
|
|
593
|
+
else:
|
|
594
|
+
x = first_bottom_up(inp)
|
|
595
|
+
|
|
596
|
+
# Loop from bottom to top layer, store all deterministic nodes we
|
|
597
|
+
# need for the top-down pass in bu_values list
|
|
598
|
+
bu_values = []
|
|
599
|
+
for i in range(self.n_layers):
|
|
600
|
+
lowres_x = None
|
|
601
|
+
if self._multiscale_count > 1 and i + 1 < inp.shape[1]:
|
|
602
|
+
lowres_x = lowres_first_bottom_ups[i](inp[:, i + 1 : i + 2])
|
|
603
|
+
x, bu_value = bottom_up_layers[i](x, lowres_x=lowres_x)
|
|
604
|
+
bu_values.append(bu_value)
|
|
605
|
+
|
|
606
|
+
return bu_values
|
|
607
|
+
|
|
608
|
+
def topdown_pass(
|
|
609
|
+
self,
|
|
610
|
+
bu_values: Union[torch.Tensor, None] = None,
|
|
611
|
+
n_img_prior: Union[torch.Tensor, None] = None,
|
|
612
|
+
constant_layers: Union[Iterable[int], None] = None,
|
|
613
|
+
forced_latent: Union[list[torch.Tensor], None] = None,
|
|
614
|
+
top_down_layers: Union[nn.ModuleList, None] = None,
|
|
615
|
+
final_top_down_layer: Union[nn.Sequential, None] = None,
|
|
616
|
+
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
|
617
|
+
"""
|
|
618
|
+
Method defines the forward pass through the LVAE Decoder, the so-called.
|
|
619
|
+
|
|
620
|
+
Top-Down pass.
|
|
621
|
+
|
|
622
|
+
Parameters
|
|
623
|
+
----------
|
|
624
|
+
bu_values: torch.Tensor, optional
|
|
625
|
+
Output of the bottom-up pass. It will have values from multiple layers of
|
|
626
|
+
the ladder.
|
|
627
|
+
n_img_prior: optional
|
|
628
|
+
When `bu_values` is `None`, `n_img_prior` indicates the number of images to
|
|
629
|
+
generate
|
|
630
|
+
from the prior (so bottom-up pass is not used at all here).
|
|
631
|
+
constant_layers: Iterable[int], optional
|
|
632
|
+
A sequence of indexes associated to the layers in which a single instance's
|
|
633
|
+
z is copied over the entire batch (bottom-up path is not used, so only prior
|
|
634
|
+
is used here). Set to `None` to avoid this behaviour.
|
|
635
|
+
forced_latent: list[torch.Tensor], optional
|
|
636
|
+
A list of tensors that are used as fixed latent variables (hence, sampling
|
|
637
|
+
doesn't take place in this case).
|
|
638
|
+
top_down_layers: nn.ModuleList, optional
|
|
639
|
+
A list of top-down layers to use in the top-down pass. If `None`, the method
|
|
640
|
+
uses the default layers defined in the constructor.
|
|
641
|
+
final_top_down_layer: nn.Sequential, optional
|
|
642
|
+
The last top-down layer of the top-down pass. If `None`, the method uses the
|
|
643
|
+
default layers defined in the constructor.
|
|
644
|
+
"""
|
|
645
|
+
if top_down_layers is None:
|
|
646
|
+
top_down_layers = self.top_down_layers
|
|
647
|
+
if final_top_down_layer is None:
|
|
648
|
+
final_top_down_layer = self.final_top_down
|
|
649
|
+
|
|
650
|
+
# Default: no layer is sampled from the distribution's mode
|
|
651
|
+
if constant_layers is None:
|
|
652
|
+
constant_layers = []
|
|
653
|
+
prior_experiment = len(constant_layers) > 0
|
|
654
|
+
|
|
655
|
+
# If the bottom-up inference values are not given, don't do
|
|
656
|
+
# inference, sample from prior instead
|
|
657
|
+
inference_mode = bu_values is not None
|
|
658
|
+
|
|
659
|
+
# Check consistency of arguments
|
|
660
|
+
if inference_mode != (n_img_prior is None):
|
|
661
|
+
msg = (
|
|
662
|
+
"Number of images for top-down generation has to be given "
|
|
663
|
+
"if and only if we're not doing inference"
|
|
664
|
+
)
|
|
665
|
+
raise RuntimeError(msg)
|
|
666
|
+
if inference_mode and prior_experiment:
|
|
667
|
+
msg = (
|
|
668
|
+
"Prior experiments (e.g. sampling from mode) are not"
|
|
669
|
+
" compatible with inference mode"
|
|
670
|
+
)
|
|
671
|
+
raise RuntimeError(msg)
|
|
672
|
+
|
|
673
|
+
# Sampled latent variables at each layer
|
|
674
|
+
z = [None] * self.n_layers
|
|
675
|
+
# KL divergence of each layer
|
|
676
|
+
kl = [None] * self.n_layers
|
|
677
|
+
# Kl divergence restricted, only for the LC enabled setup denoiSplit.
|
|
678
|
+
kl_restricted = [None] * self.n_layers
|
|
679
|
+
# mean from which z is sampled.
|
|
680
|
+
q_mu = [None] * self.n_layers
|
|
681
|
+
# log(var) from which z is sampled.
|
|
682
|
+
q_lv = [None] * self.n_layers
|
|
683
|
+
# Spatial map of KL divergence for each layer
|
|
684
|
+
kl_spatial = [None] * self.n_layers
|
|
685
|
+
debug_qvar_max = [None] * self.n_layers
|
|
686
|
+
kl_channelwise = [None] * self.n_layers
|
|
687
|
+
if forced_latent is None:
|
|
688
|
+
forced_latent = [None] * self.n_layers
|
|
689
|
+
|
|
690
|
+
# Top-down inference/generation loop
|
|
691
|
+
out = None
|
|
692
|
+
for i in reversed(range(self.n_layers)):
|
|
693
|
+
# If available, get deterministic node from bottom-up inference
|
|
694
|
+
try:
|
|
695
|
+
bu_value = bu_values[i]
|
|
696
|
+
except TypeError:
|
|
697
|
+
bu_value = None
|
|
698
|
+
|
|
699
|
+
# Whether the current layer should be sampled from the mode
|
|
700
|
+
constant_out = i in constant_layers
|
|
701
|
+
|
|
702
|
+
# Input for skip connection
|
|
703
|
+
skip_input = out
|
|
704
|
+
|
|
705
|
+
# Full top-down layer, including sampling and deterministic part
|
|
706
|
+
out, aux = top_down_layers[i](
|
|
707
|
+
input_=out,
|
|
708
|
+
skip_connection_input=skip_input,
|
|
709
|
+
inference_mode=inference_mode,
|
|
710
|
+
bu_value=bu_value,
|
|
711
|
+
n_img_prior=n_img_prior,
|
|
712
|
+
force_constant_output=constant_out,
|
|
713
|
+
forced_latent=forced_latent[i],
|
|
714
|
+
mode_pred=self.mode_pred,
|
|
715
|
+
var_clip_max=self._var_clip_max,
|
|
716
|
+
)
|
|
717
|
+
# Save useful variables
|
|
718
|
+
z[i] = aux["z"] # sampled variable at this layer (batch, ch, h, w)
|
|
719
|
+
kl[i] = aux["kl_samplewise"] # (batch, )
|
|
720
|
+
kl_restricted[i] = aux["kl_samplewise_restricted"]
|
|
721
|
+
kl_spatial[i] = aux["kl_spatial"] # (batch, h, w)
|
|
722
|
+
q_mu[i] = aux["q_mu"]
|
|
723
|
+
q_lv[i] = aux["q_lv"]
|
|
724
|
+
|
|
725
|
+
kl_channelwise[i] = aux["kl_channelwise"]
|
|
726
|
+
debug_qvar_max[i] = aux["qvar_max"]
|
|
727
|
+
# if self.mode_pred is False:
|
|
728
|
+
# logprob_p += aux['logprob_p'].mean() # mean over batch
|
|
729
|
+
# else:
|
|
730
|
+
# logprob_p = None
|
|
731
|
+
|
|
732
|
+
# Final top-down layer
|
|
733
|
+
out = final_top_down_layer(out)
|
|
734
|
+
|
|
735
|
+
# Store useful variables in a dict to return them
|
|
736
|
+
data = {
|
|
737
|
+
"z": z, # list of tensors with shape (batch, ch[i], h[i], w[i])
|
|
738
|
+
"kl": kl, # list of tensors with shape (batch, )
|
|
739
|
+
"kl_restricted": kl_restricted, # list of tensors with shape (batch, )
|
|
740
|
+
"kl_spatial": kl_spatial, # list of tensors w shape (batch, h[i], w[i])
|
|
741
|
+
"kl_channelwise": kl_channelwise, # list of tensors with shape (batch, ch[i])
|
|
742
|
+
# 'logprob_p': logprob_p, # scalar, mean over batch
|
|
743
|
+
"q_mu": q_mu,
|
|
744
|
+
"q_lv": q_lv,
|
|
745
|
+
"debug_qvar_max": debug_qvar_max,
|
|
746
|
+
}
|
|
747
|
+
return out, data
|
|
748
|
+
|
|
749
|
+
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
|
750
|
+
"""
|
|
751
|
+
Forward pass through the LVAE model.
|
|
752
|
+
|
|
753
|
+
Parameters
|
|
754
|
+
----------
|
|
755
|
+
x: torch.Tensor
|
|
756
|
+
The input tensor of shape (B, C, H, W).
|
|
757
|
+
"""
|
|
758
|
+
img_size = x.size()[2:]
|
|
759
|
+
|
|
760
|
+
# Bottom-up inference: return list of length n_layers (bottom to top)
|
|
761
|
+
bu_values = self.bottomup_pass(x)
|
|
762
|
+
for i in range(0, self.skip_bottomk_buvalues):
|
|
763
|
+
bu_values[i] = None
|
|
764
|
+
|
|
765
|
+
if self._squish3d:
|
|
766
|
+
bu_values = [
|
|
767
|
+
torch.mean(self._3D_squisher[k](bu_value), dim=2)
|
|
768
|
+
for k, bu_value in enumerate(bu_values)
|
|
769
|
+
]
|
|
770
|
+
|
|
771
|
+
# Top-down inference/generation
|
|
772
|
+
out, td_data = self.topdown_pass(bu_values)
|
|
773
|
+
|
|
774
|
+
if out.shape[-1] > img_size[-1]:
|
|
775
|
+
# Restore original image size
|
|
776
|
+
out = crop_img_tensor(out, img_size)
|
|
777
|
+
|
|
778
|
+
out = self.output_layer(out)
|
|
779
|
+
|
|
780
|
+
return out, td_data
|
|
781
|
+
|
|
782
|
+
### SET OF GETTERS
|
|
783
|
+
def get_padded_size(self, size):
|
|
784
|
+
"""
|
|
785
|
+
Returns the smallest size (H, W) of the image with actual size given
|
|
786
|
+
as input, such that H and W are powers of 2.
|
|
787
|
+
:param size: input size, tuple either (N, C, H, W) or (H, W)
|
|
788
|
+
:return: 2-tuple (H, W)
|
|
789
|
+
"""
|
|
790
|
+
# Make size argument into (heigth, width)
|
|
791
|
+
# assert len(size) in [2, 4, 5] # TODO commented out cuz it's weird
|
|
792
|
+
# We're only interested in the Y,X dimensions
|
|
793
|
+
size = size[-2:]
|
|
794
|
+
|
|
795
|
+
if self.multiscale_decoder_retain_spatial_dims is True:
|
|
796
|
+
# In this case, we can go much more deeper and so this is not required
|
|
797
|
+
# (in the way it is. ;). More work would be needed if this was to be correctly implemented )
|
|
798
|
+
return list(size)
|
|
799
|
+
|
|
800
|
+
# Overall downscale factor from input to top layer (power of 2)
|
|
801
|
+
dwnsc = self.overall_downscale_factor
|
|
802
|
+
|
|
803
|
+
# Output smallest powers of 2 that are larger than current sizes
|
|
804
|
+
padded_size = [((s - 1) // dwnsc + 1) * dwnsc for s in size]
|
|
805
|
+
# TODO Needed for pad/crop odd sizes. Move to dataset?
|
|
806
|
+
return padded_size
|
|
807
|
+
|
|
808
|
+
def get_latent_spatial_size(self, level_idx: int):
|
|
809
|
+
"""Level_idx: 0 is the bottommost layer, the highest resolution one."""
|
|
810
|
+
actual_downsampling = level_idx + 1
|
|
811
|
+
dwnsc = 2**actual_downsampling
|
|
812
|
+
sz = self.get_padded_size(self.image_size)
|
|
813
|
+
h = sz[0] // dwnsc
|
|
814
|
+
w = sz[1] // dwnsc
|
|
815
|
+
assert h == w
|
|
816
|
+
return h
|
|
817
|
+
|
|
818
|
+
def get_top_prior_param_shape(self, n_imgs: int = 1):
|
|
819
|
+
|
|
820
|
+
# Compute the total downscaling performed in the Encoder
|
|
821
|
+
if self.multiscale_decoder_retain_spatial_dims is False:
|
|
822
|
+
dwnsc = self.overall_downscale_factor
|
|
823
|
+
else:
|
|
824
|
+
# LC allow the encoder latents to keep the same (H, W) size at different levels
|
|
825
|
+
actual_downsampling = self.n_layers + 1 - self._multiscale_count
|
|
826
|
+
dwnsc = 2**actual_downsampling
|
|
827
|
+
|
|
828
|
+
h = self.image_size[-2] // dwnsc
|
|
829
|
+
w = self.image_size[-1] // dwnsc
|
|
830
|
+
mu_logvar = self.z_dims[-1] * 2 # mu and logvar
|
|
831
|
+
top_layer_shape = (n_imgs, mu_logvar, h, w)
|
|
832
|
+
# TODO refactor!
|
|
833
|
+
if self._model_3D_depth > 1 and self._decoder_mode_3D is True:
|
|
834
|
+
# TODO check if model_3D_depth is needed ?
|
|
835
|
+
top_layer_shape = (n_imgs, mu_logvar, self._model_3D_depth, h, w)
|
|
836
|
+
return top_layer_shape
|
|
837
|
+
|
|
838
|
+
def reset_for_inference(self, tile_size: tuple[int, int] | None = None):
|
|
839
|
+
"""Should be called if we want to predict for a different input/output size."""
|
|
840
|
+
self.mode_pred = True
|
|
841
|
+
if tile_size is None:
|
|
842
|
+
tile_size = self.image_size
|
|
843
|
+
self.image_size = tile_size
|
|
844
|
+
for i in range(self.n_layers):
|
|
845
|
+
self.bottom_up_layers[i].output_expected_shape = (
|
|
846
|
+
ts // 2 ** (i + 1) for ts in tile_size
|
|
847
|
+
)
|
|
848
|
+
self.top_down_layers[i].latent_shape = tile_size
|