careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,1371 @@
|
|
|
1
|
+
"""Script containing the common basic blocks (nn.Module) reused by the LadderVAE."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Iterable
|
|
4
|
+
from copy import deepcopy
|
|
5
|
+
from typing import Callable, Literal, Optional, Union
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn as nn
|
|
10
|
+
|
|
11
|
+
from .stochastic import NormalStochasticBlock
|
|
12
|
+
from .utils import (
|
|
13
|
+
crop_img_tensor,
|
|
14
|
+
pad_img_tensor,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
ConvType = Union[nn.Conv2d, nn.Conv3d]
|
|
18
|
+
NormType = Union[nn.BatchNorm2d, nn.BatchNorm3d]
|
|
19
|
+
DropoutType = Union[nn.Dropout2d, nn.Dropout3d]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ResidualBlock(nn.Module):
|
|
23
|
+
"""
|
|
24
|
+
Residual block with 2 convolutional layers.
|
|
25
|
+
|
|
26
|
+
Some architectural notes:
|
|
27
|
+
- The number of input, intermediate, and output channels is the same,
|
|
28
|
+
- Padding is always 'same',
|
|
29
|
+
- The 2 convolutional layers have the same groups,
|
|
30
|
+
- No stride allowed,
|
|
31
|
+
- Kernel sizes must be odd.
|
|
32
|
+
|
|
33
|
+
The output isgiven by: `out = gate(f(x)) + x`.
|
|
34
|
+
The presence of the gating mechanism is optional, and f(x) has different
|
|
35
|
+
structures depending on the `block_type` argument.
|
|
36
|
+
Specifically, `block_type` is a string specifying the block's structure, with:
|
|
37
|
+
a = activation
|
|
38
|
+
b = batch norm
|
|
39
|
+
c = conv layer
|
|
40
|
+
d = dropout.
|
|
41
|
+
For example, "bacdbacd" defines a block with 2x[batchnorm, activation, conv, dropout].
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
default_kernel_size = (3, 3)
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
channels: int,
|
|
49
|
+
nonlin: Callable,
|
|
50
|
+
conv_strides: tuple[int] = (2, 2),
|
|
51
|
+
kernel: Union[int, Iterable[int], None] = None,
|
|
52
|
+
groups: int = 1,
|
|
53
|
+
batchnorm: bool = True,
|
|
54
|
+
block_type: str = None,
|
|
55
|
+
dropout: float = None,
|
|
56
|
+
gated: bool = None,
|
|
57
|
+
conv2d_bias: bool = True,
|
|
58
|
+
):
|
|
59
|
+
"""
|
|
60
|
+
Constructor.
|
|
61
|
+
|
|
62
|
+
Parameters
|
|
63
|
+
----------
|
|
64
|
+
channels: int
|
|
65
|
+
The number of input and output channels (they are the same).
|
|
66
|
+
nonlin: Callable
|
|
67
|
+
The non-linearity function used in the block (e.g., `nn.ReLU`).
|
|
68
|
+
kernel: Union[int, Iterable[int]], optional
|
|
69
|
+
The kernel size used in the convolutions of the block.
|
|
70
|
+
It can be either a single integer or a pair of integers defining the squared kernel.
|
|
71
|
+
Default is `None`.
|
|
72
|
+
groups: int, optional
|
|
73
|
+
The number of groups to consider in the convolutions. Default is 1.
|
|
74
|
+
batchnorm: bool, optional
|
|
75
|
+
Whether to use batchnorm layers. Default is `True`.
|
|
76
|
+
block_type: str, optional
|
|
77
|
+
A string specifying the block structure, check class docstring for more info.
|
|
78
|
+
Default is `None`.
|
|
79
|
+
dropout: float, optional
|
|
80
|
+
The dropout probability in dropout layers. If `None` dropout is not used.
|
|
81
|
+
Default is `None`.
|
|
82
|
+
gated: bool, optional
|
|
83
|
+
Whether to use gated layer. Default is `None`.
|
|
84
|
+
conv2d_bias: bool, optional
|
|
85
|
+
Whether to use bias term in convolutions. Default is `True`.
|
|
86
|
+
"""
|
|
87
|
+
super().__init__()
|
|
88
|
+
|
|
89
|
+
# Set kernel size & padding
|
|
90
|
+
if kernel is None:
|
|
91
|
+
kernel = self.default_kernel_size
|
|
92
|
+
elif isinstance(kernel, int):
|
|
93
|
+
kernel = (kernel, kernel)
|
|
94
|
+
elif len(kernel) != 2:
|
|
95
|
+
raise ValueError("kernel has to be None, int, or an iterable of length 2")
|
|
96
|
+
assert all(k % 2 == 1 for k in kernel), "kernel sizes have to be odd"
|
|
97
|
+
kernel = list(kernel)
|
|
98
|
+
|
|
99
|
+
# Define modules
|
|
100
|
+
conv_layer: ConvType = getattr(nn, f"Conv{len(conv_strides)}d")
|
|
101
|
+
norm_layer: NormType = getattr(nn, f"BatchNorm{len(conv_strides)}d")
|
|
102
|
+
dropout_layer: DropoutType = getattr(nn, f"Dropout{len(conv_strides)}d")
|
|
103
|
+
# TODO: same comment as in lvae.py, would be more readable to have `conv_dims`
|
|
104
|
+
|
|
105
|
+
modules = []
|
|
106
|
+
if block_type == "cabdcabd":
|
|
107
|
+
for i in range(2):
|
|
108
|
+
conv = conv_layer(
|
|
109
|
+
channels,
|
|
110
|
+
channels,
|
|
111
|
+
kernel[i],
|
|
112
|
+
padding="same",
|
|
113
|
+
groups=groups,
|
|
114
|
+
bias=conv2d_bias,
|
|
115
|
+
)
|
|
116
|
+
modules.append(conv)
|
|
117
|
+
modules.append(nonlin)
|
|
118
|
+
if batchnorm:
|
|
119
|
+
modules.append(norm_layer(channels))
|
|
120
|
+
if dropout is not None:
|
|
121
|
+
modules.append(dropout_layer(dropout))
|
|
122
|
+
elif block_type == "bacdbac":
|
|
123
|
+
for i in range(2):
|
|
124
|
+
if batchnorm:
|
|
125
|
+
modules.append(norm_layer(channels))
|
|
126
|
+
modules.append(nonlin)
|
|
127
|
+
conv = conv_layer(
|
|
128
|
+
channels,
|
|
129
|
+
channels,
|
|
130
|
+
kernel[i],
|
|
131
|
+
padding="same",
|
|
132
|
+
groups=groups,
|
|
133
|
+
bias=conv2d_bias,
|
|
134
|
+
)
|
|
135
|
+
modules.append(conv)
|
|
136
|
+
if dropout is not None and i == 0:
|
|
137
|
+
modules.append(dropout_layer(dropout))
|
|
138
|
+
elif block_type == "bacdbacd":
|
|
139
|
+
for i in range(2):
|
|
140
|
+
if batchnorm:
|
|
141
|
+
modules.append(norm_layer(channels))
|
|
142
|
+
modules.append(nonlin)
|
|
143
|
+
conv = conv_layer(
|
|
144
|
+
channels,
|
|
145
|
+
channels,
|
|
146
|
+
kernel[i],
|
|
147
|
+
padding="same",
|
|
148
|
+
groups=groups,
|
|
149
|
+
bias=conv2d_bias,
|
|
150
|
+
)
|
|
151
|
+
modules.append(conv)
|
|
152
|
+
modules.append(dropout_layer(dropout))
|
|
153
|
+
|
|
154
|
+
else:
|
|
155
|
+
raise ValueError(f"unrecognized block type '{block_type}'")
|
|
156
|
+
|
|
157
|
+
self.gated = gated
|
|
158
|
+
if gated:
|
|
159
|
+
modules.append(
|
|
160
|
+
GateLayer(
|
|
161
|
+
channels=channels,
|
|
162
|
+
conv_strides=conv_strides,
|
|
163
|
+
kernel_size=1,
|
|
164
|
+
nonlin=nonlin,
|
|
165
|
+
)
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
self.block = nn.Sequential(*modules)
|
|
169
|
+
|
|
170
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
171
|
+
"""Forward pass.
|
|
172
|
+
|
|
173
|
+
Parameters
|
|
174
|
+
----------
|
|
175
|
+
x : torch.Tensor
|
|
176
|
+
input tensor # TODO add shape
|
|
177
|
+
|
|
178
|
+
Returns
|
|
179
|
+
-------
|
|
180
|
+
torch.Tensor
|
|
181
|
+
output tensor # TODO add shape
|
|
182
|
+
"""
|
|
183
|
+
out = self.block(x)
|
|
184
|
+
assert (
|
|
185
|
+
out.shape == x.shape
|
|
186
|
+
), f"output shape: {out.shape} != input shape: {x.shape}"
|
|
187
|
+
return out + x
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class ResidualGatedBlock(ResidualBlock):
|
|
191
|
+
"""Layer class that implements a residual block with a gating mechanism."""
|
|
192
|
+
|
|
193
|
+
def __init__(self, *args, **kwargs):
|
|
194
|
+
super().__init__(*args, **kwargs, gated=True)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
class GateLayer(nn.Module):
|
|
198
|
+
"""
|
|
199
|
+
Layer class that implements a gating mechanism.
|
|
200
|
+
|
|
201
|
+
Double the number of channels through a convolutional layer, then use
|
|
202
|
+
half the channels as gate for the other half.
|
|
203
|
+
"""
|
|
204
|
+
|
|
205
|
+
def __init__(
|
|
206
|
+
self,
|
|
207
|
+
channels: int,
|
|
208
|
+
conv_strides: tuple[int] = (2, 2),
|
|
209
|
+
kernel_size: int = 3,
|
|
210
|
+
nonlin: Callable = nn.LeakyReLU(),
|
|
211
|
+
):
|
|
212
|
+
super().__init__()
|
|
213
|
+
assert kernel_size % 2 == 1
|
|
214
|
+
pad = kernel_size // 2
|
|
215
|
+
conv_layer: ConvType = getattr(nn, f"Conv{len(conv_strides)}d")
|
|
216
|
+
self.conv = conv_layer(channels, 2 * channels, kernel_size, padding=pad)
|
|
217
|
+
self.nonlin = nonlin
|
|
218
|
+
|
|
219
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
220
|
+
"""Forward pass.
|
|
221
|
+
|
|
222
|
+
Parameters
|
|
223
|
+
----------
|
|
224
|
+
x : torch.Tensor
|
|
225
|
+
input # TODO add shape
|
|
226
|
+
|
|
227
|
+
Returns
|
|
228
|
+
-------
|
|
229
|
+
torch.Tensor
|
|
230
|
+
output # TODO add shape
|
|
231
|
+
"""
|
|
232
|
+
x = self.conv(x)
|
|
233
|
+
x, gate = torch.chunk(x, 2, dim=1)
|
|
234
|
+
x = self.nonlin(x) # TODO remove this?
|
|
235
|
+
gate = torch.sigmoid(gate)
|
|
236
|
+
return x * gate
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
class ResBlockWithResampling(nn.Module):
|
|
240
|
+
"""
|
|
241
|
+
Residual block with resampling.
|
|
242
|
+
|
|
243
|
+
Residual block that takes care of resampling (i.e. downsampling or upsampling) steps (by a factor 2).
|
|
244
|
+
It is structured as follows:
|
|
245
|
+
1. `pre_conv`: a downsampling or upsampling strided convolutional layer in case of resampling, or
|
|
246
|
+
a 1x1 convolutional layer that maps the number of channels of the input to `inner_channels`.
|
|
247
|
+
2. `ResidualBlock`
|
|
248
|
+
3. `post_conv`: a 1x1 convolutional layer that maps the number of channels to `c_out`.
|
|
249
|
+
|
|
250
|
+
Some implementation notes:
|
|
251
|
+
- Resampling is performed through a strided convolution layer at the beginning of the block.
|
|
252
|
+
- The strided convolution block has fixed kernel size of 3x3 and 1 layer of padding with zeros.
|
|
253
|
+
- The number of channels is adjusted at the beginning and end of the block through 1x1 convolutional layers.
|
|
254
|
+
- The number of internal channels is by default the same as the number of output channels, but
|
|
255
|
+
min_inner_channels can override the behaviour.
|
|
256
|
+
"""
|
|
257
|
+
|
|
258
|
+
def __init__(
|
|
259
|
+
self,
|
|
260
|
+
mode: Literal["top-down", "bottom-up"],
|
|
261
|
+
c_in: int,
|
|
262
|
+
c_out: int,
|
|
263
|
+
conv_strides: tuple[int],
|
|
264
|
+
min_inner_channels: Union[int, None] = None,
|
|
265
|
+
nonlin: Callable = nn.LeakyReLU(),
|
|
266
|
+
resample: bool = False,
|
|
267
|
+
res_block_kernel: Optional[Union[int, Iterable[int]]] = None,
|
|
268
|
+
groups: int = 1,
|
|
269
|
+
batchnorm: bool = True,
|
|
270
|
+
res_block_type: Union[str, None] = None,
|
|
271
|
+
dropout: Union[float, None] = None,
|
|
272
|
+
gated: Union[bool, None] = None,
|
|
273
|
+
conv2d_bias: bool = True,
|
|
274
|
+
# lowres_input: bool = False,
|
|
275
|
+
):
|
|
276
|
+
"""
|
|
277
|
+
Constructor.
|
|
278
|
+
|
|
279
|
+
Parameters
|
|
280
|
+
----------
|
|
281
|
+
mode: Literal["top-down", "bottom-up"]
|
|
282
|
+
The type of resampling performed in the initial strided convolution of the block.
|
|
283
|
+
If "bottom-up" downsampling of a factor 2 is done.
|
|
284
|
+
If "top-down" upsampling of a factor 2 is done.
|
|
285
|
+
c_in: int
|
|
286
|
+
The number of input channels.
|
|
287
|
+
c_out: int
|
|
288
|
+
The number of output channels.
|
|
289
|
+
min_inner_channels: int, optional
|
|
290
|
+
The number of channels used in the inner layer of this module.
|
|
291
|
+
Default is `None`, meaning that the number of inner channels is set to `c_out`.
|
|
292
|
+
nonlin: Callable, optional
|
|
293
|
+
The non-linearity function used in the block. Default is `nn.LeakyReLU`.
|
|
294
|
+
resample: bool, optional
|
|
295
|
+
Whether to perform resampling in the first convolutional layer.
|
|
296
|
+
If `False`, the first convolutional layer just maps the input to a tensor with
|
|
297
|
+
`inner_channels` channels through 1x1 convolution. Default is `False`.
|
|
298
|
+
res_block_kernel: Union[int, Iterable[int]], optional
|
|
299
|
+
The kernel size used in the convolutions of the residual block.
|
|
300
|
+
It can be either a single integer or a pair of integers defining the squared kernel.
|
|
301
|
+
Default is `None`.
|
|
302
|
+
groups: int, optional
|
|
303
|
+
The number of groups to consider in the convolutions. Default is 1.
|
|
304
|
+
batchnorm: bool, optional
|
|
305
|
+
Whether to use batchnorm layers. Default is `True`.
|
|
306
|
+
res_block_type: str, optional
|
|
307
|
+
A string specifying the structure of residual block.
|
|
308
|
+
Check `ResidualBlock` doscstring for more information.
|
|
309
|
+
Default is `None`.
|
|
310
|
+
dropout: float, optional
|
|
311
|
+
The dropout probability in dropout layers. If `None` dropout is not used.
|
|
312
|
+
Default is `None`.
|
|
313
|
+
gated: bool, optional
|
|
314
|
+
Whether to use gated layer. Default is `None`.
|
|
315
|
+
conv2d_bias: bool, optional
|
|
316
|
+
Whether to use bias term in convolutions. Default is `True`.
|
|
317
|
+
"""
|
|
318
|
+
super().__init__()
|
|
319
|
+
assert mode in ["top-down", "bottom-up"]
|
|
320
|
+
|
|
321
|
+
conv_layer: ConvType = getattr(nn, f"Conv{len(conv_strides)}d")
|
|
322
|
+
transp_conv_layer: ConvType = getattr(nn, f"ConvTranspose{len(conv_strides)}d")
|
|
323
|
+
|
|
324
|
+
if min_inner_channels is None:
|
|
325
|
+
min_inner_channels = 0
|
|
326
|
+
# inner_channels is the number of channels used in the inner layers
|
|
327
|
+
# of ResBlockWithResampling
|
|
328
|
+
inner_channels = max(c_out, min_inner_channels)
|
|
329
|
+
|
|
330
|
+
# Define first conv layer to change num channels and/or up/downsample
|
|
331
|
+
if resample:
|
|
332
|
+
if mode == "bottom-up": # downsample
|
|
333
|
+
self.pre_conv = conv_layer(
|
|
334
|
+
in_channels=c_in,
|
|
335
|
+
out_channels=inner_channels,
|
|
336
|
+
kernel_size=3,
|
|
337
|
+
padding=1,
|
|
338
|
+
stride=conv_strides,
|
|
339
|
+
groups=groups,
|
|
340
|
+
bias=conv2d_bias,
|
|
341
|
+
)
|
|
342
|
+
elif mode == "top-down": # upsample
|
|
343
|
+
self.pre_conv = transp_conv_layer(
|
|
344
|
+
in_channels=c_in,
|
|
345
|
+
kernel_size=3,
|
|
346
|
+
out_channels=inner_channels,
|
|
347
|
+
padding=1, # TODO maybe don't hardcode this?
|
|
348
|
+
stride=conv_strides,
|
|
349
|
+
groups=groups,
|
|
350
|
+
output_padding=1 if len(conv_strides) == 2 else (0, 1, 1),
|
|
351
|
+
bias=conv2d_bias,
|
|
352
|
+
)
|
|
353
|
+
elif c_in != inner_channels:
|
|
354
|
+
self.pre_conv = conv_layer(
|
|
355
|
+
c_in, inner_channels, 1, groups=groups, bias=conv2d_bias
|
|
356
|
+
)
|
|
357
|
+
else:
|
|
358
|
+
self.pre_conv = None
|
|
359
|
+
|
|
360
|
+
# Residual block
|
|
361
|
+
self.res = ResidualBlock(
|
|
362
|
+
channels=inner_channels,
|
|
363
|
+
conv_strides=conv_strides,
|
|
364
|
+
nonlin=nonlin,
|
|
365
|
+
kernel=res_block_kernel,
|
|
366
|
+
groups=groups,
|
|
367
|
+
batchnorm=batchnorm,
|
|
368
|
+
dropout=dropout,
|
|
369
|
+
gated=gated,
|
|
370
|
+
block_type=res_block_type,
|
|
371
|
+
conv2d_bias=conv2d_bias,
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
# Define last conv layer to get correct num output channels
|
|
375
|
+
if inner_channels != c_out:
|
|
376
|
+
self.post_conv = conv_layer(
|
|
377
|
+
inner_channels, c_out, 1, groups=groups, bias=conv2d_bias
|
|
378
|
+
)
|
|
379
|
+
else:
|
|
380
|
+
self.post_conv = None
|
|
381
|
+
|
|
382
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
383
|
+
"""Forward pass.
|
|
384
|
+
|
|
385
|
+
Parameters
|
|
386
|
+
----------
|
|
387
|
+
x : torch.Tensor
|
|
388
|
+
input # TODO add shape
|
|
389
|
+
|
|
390
|
+
Returns
|
|
391
|
+
-------
|
|
392
|
+
torch.Tensor
|
|
393
|
+
output # TODO add shape
|
|
394
|
+
"""
|
|
395
|
+
if self.pre_conv is not None:
|
|
396
|
+
x = self.pre_conv(x)
|
|
397
|
+
|
|
398
|
+
x = self.res(x)
|
|
399
|
+
|
|
400
|
+
if self.post_conv is not None:
|
|
401
|
+
x = self.post_conv(x)
|
|
402
|
+
return x
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
class TopDownDeterministicResBlock(ResBlockWithResampling):
|
|
406
|
+
"""Resnet block for top-down deterministic layers."""
|
|
407
|
+
|
|
408
|
+
def __init__(self, *args, upsample: bool = False, **kwargs):
|
|
409
|
+
kwargs["resample"] = upsample
|
|
410
|
+
super().__init__("top-down", *args, **kwargs)
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
class BottomUpDeterministicResBlock(ResBlockWithResampling):
|
|
414
|
+
"""Resnet block for bottom-up deterministic layers."""
|
|
415
|
+
|
|
416
|
+
def __init__(self, *args, downsample: bool = False, **kwargs):
|
|
417
|
+
kwargs["resample"] = downsample
|
|
418
|
+
super().__init__("bottom-up", *args, **kwargs)
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
class BottomUpLayer(nn.Module):
|
|
422
|
+
"""
|
|
423
|
+
Bottom-up deterministic layer.
|
|
424
|
+
|
|
425
|
+
It consists of one or a stack of `BottomUpDeterministicResBlock`'s.
|
|
426
|
+
The outputs are the so-called `bu_values` that are later used in the Decoder to update the
|
|
427
|
+
generative distributions.
|
|
428
|
+
|
|
429
|
+
NOTE: When Lateral Contextualization is Enabled (i.e., `enable_multiscale=True`),
|
|
430
|
+
the low-res lateral input is first fed through a BottomUpDeterministicBlock (BUDB)
|
|
431
|
+
(without downsampling), and then merged to the latent tensor produced by the primary flow
|
|
432
|
+
of the `BottomUpLayer` through the `MergeLowRes` layer. It is meaningful to remark that
|
|
433
|
+
the BUDB that takes care of encoding the low-res input can be either shared with the
|
|
434
|
+
primary flow (and in that case it is the "same_size" BUDB (or stack of BUDBs) -> see `self.net`),
|
|
435
|
+
or can be a deep-copy of the primary flow's BUDB.
|
|
436
|
+
This behaviour is controlled by `lowres_separate_branch` parameter.
|
|
437
|
+
"""
|
|
438
|
+
|
|
439
|
+
def __init__(
|
|
440
|
+
self,
|
|
441
|
+
n_res_blocks: int,
|
|
442
|
+
n_filters: int,
|
|
443
|
+
conv_strides: tuple[int] = (2, 2),
|
|
444
|
+
downsampling_steps: int = 0,
|
|
445
|
+
nonlin: Optional[Callable] = None,
|
|
446
|
+
batchnorm: bool = True,
|
|
447
|
+
dropout: Optional[float] = None,
|
|
448
|
+
res_block_type: Optional[str] = None,
|
|
449
|
+
res_block_kernel: Optional[int] = None,
|
|
450
|
+
gated: Optional[bool] = None,
|
|
451
|
+
enable_multiscale: bool = False,
|
|
452
|
+
multiscale_lowres_size_factor: Optional[int] = None,
|
|
453
|
+
lowres_separate_branch: bool = False,
|
|
454
|
+
multiscale_retain_spatial_dims: bool = False,
|
|
455
|
+
decoder_retain_spatial_dims: bool = False,
|
|
456
|
+
output_expected_shape: Optional[Iterable[int]] = None,
|
|
457
|
+
):
|
|
458
|
+
"""
|
|
459
|
+
Constructor.
|
|
460
|
+
|
|
461
|
+
Parameters
|
|
462
|
+
----------
|
|
463
|
+
n_res_blocks: int
|
|
464
|
+
Number of `BottomUpDeterministicResBlock` modules stacked in this layer.
|
|
465
|
+
n_filters: int
|
|
466
|
+
Number of channels present through out the layers of this block.
|
|
467
|
+
downsampling_steps: int, optional
|
|
468
|
+
Number of downsampling steps that has to be done in this layer (typically 1).
|
|
469
|
+
Default is 0.
|
|
470
|
+
nonlin: Callable, optional
|
|
471
|
+
The non-linearity function used in the block. Default is `None`.
|
|
472
|
+
batchnorm: bool, optional
|
|
473
|
+
Whether to use batchnorm layers. Default is `True`.
|
|
474
|
+
dropout: float, optional
|
|
475
|
+
The dropout probability in dropout layers. If `None` dropout is not used.
|
|
476
|
+
Default is `None`.
|
|
477
|
+
res_block_type: str, optional
|
|
478
|
+
A string specifying the structure of residual block.
|
|
479
|
+
Check `ResidualBlock` doscstring for more information.
|
|
480
|
+
Default is `None`.
|
|
481
|
+
res_block_kernel: Union[int, Iterable[int]], optional
|
|
482
|
+
The kernel size used in the convolutions of the residual block.
|
|
483
|
+
It can be either a single integer or a pair of integers defining the squared kernel.
|
|
484
|
+
Default is `None`.
|
|
485
|
+
gated: bool, optional
|
|
486
|
+
Whether to use gated layer. Default is `None`.
|
|
487
|
+
enable_multiscale: bool, optional
|
|
488
|
+
Whether to enable multiscale (Lateral Contextualization) or not. Default is `False`.
|
|
489
|
+
multiscale_lowres_size_factor: int, optional
|
|
490
|
+
A factor the expresses the relative size of the primary flow tensor with respect to the
|
|
491
|
+
lower-resolution lateral input tensor. Default in `None`.
|
|
492
|
+
lowres_separate_branch: bool, optional
|
|
493
|
+
Whether the residual block(s) encoding the low-res input should be shared (`False`) or
|
|
494
|
+
not (`True`) with the primary flow "same-size" residual block(s). Default is `False`.
|
|
495
|
+
multiscale_retain_spatial_dims: bool, optional
|
|
496
|
+
Whether to pad the latent tensor resulting from the bottom-up layer's primary flow
|
|
497
|
+
to match the size of the low-res input. Default is `False`.
|
|
498
|
+
decoder_retain_spatial_dims: bool, optional
|
|
499
|
+
Whether in the corresponding top-down layer the shape of tensor is retained between
|
|
500
|
+
input and output. Default is `False`.
|
|
501
|
+
output_expected_shape: Iterable[int], optional
|
|
502
|
+
The expected shape of the layer output (only used if `enable_multiscale == True`).
|
|
503
|
+
Default is `None`.
|
|
504
|
+
"""
|
|
505
|
+
super().__init__()
|
|
506
|
+
|
|
507
|
+
# Define attributes for Lateral Contextualization
|
|
508
|
+
self.enable_multiscale = enable_multiscale
|
|
509
|
+
self.lowres_separate_branch = lowres_separate_branch
|
|
510
|
+
self.multiscale_retain_spatial_dims = multiscale_retain_spatial_dims
|
|
511
|
+
self.multiscale_lowres_size_factor = multiscale_lowres_size_factor
|
|
512
|
+
self.decoder_retain_spatial_dims = decoder_retain_spatial_dims
|
|
513
|
+
self.output_expected_shape = output_expected_shape
|
|
514
|
+
assert self.output_expected_shape is None or self.enable_multiscale is True
|
|
515
|
+
|
|
516
|
+
bu_blocks_downsized = []
|
|
517
|
+
bu_blocks_samesize = []
|
|
518
|
+
for _ in range(n_res_blocks):
|
|
519
|
+
do_resample = False
|
|
520
|
+
if downsampling_steps > 0:
|
|
521
|
+
do_resample = True
|
|
522
|
+
downsampling_steps -= 1
|
|
523
|
+
block = BottomUpDeterministicResBlock(
|
|
524
|
+
conv_strides=conv_strides,
|
|
525
|
+
c_in=n_filters,
|
|
526
|
+
c_out=n_filters,
|
|
527
|
+
nonlin=nonlin,
|
|
528
|
+
downsample=do_resample,
|
|
529
|
+
batchnorm=batchnorm,
|
|
530
|
+
dropout=dropout,
|
|
531
|
+
res_block_type=res_block_type,
|
|
532
|
+
res_block_kernel=res_block_kernel,
|
|
533
|
+
gated=gated,
|
|
534
|
+
)
|
|
535
|
+
if do_resample:
|
|
536
|
+
bu_blocks_downsized.append(block)
|
|
537
|
+
else:
|
|
538
|
+
bu_blocks_samesize.append(block)
|
|
539
|
+
|
|
540
|
+
self.net_downsized = nn.Sequential(*bu_blocks_downsized)
|
|
541
|
+
self.net = nn.Sequential(*bu_blocks_samesize)
|
|
542
|
+
|
|
543
|
+
# Using the same net for the low resolution (and larger sized image)
|
|
544
|
+
self.lowres_net = self.lowres_merge = None
|
|
545
|
+
if self.enable_multiscale:
|
|
546
|
+
self._init_multiscale(
|
|
547
|
+
n_filters=n_filters,
|
|
548
|
+
conv_strides=conv_strides,
|
|
549
|
+
nonlin=nonlin,
|
|
550
|
+
batchnorm=batchnorm,
|
|
551
|
+
dropout=dropout,
|
|
552
|
+
res_block_type=res_block_type,
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
# msg = f'[{self.__class__.__name__}] McEnabled:{int(enable_multiscale)} '
|
|
556
|
+
# if enable_multiscale:
|
|
557
|
+
# msg += f'McParallelBeam:{int(multiscale_retain_spatial_dims)} McFactor{multiscale_lowres_size_factor}'
|
|
558
|
+
# print(msg)
|
|
559
|
+
|
|
560
|
+
def _init_multiscale(
|
|
561
|
+
self,
|
|
562
|
+
nonlin: Callable = None,
|
|
563
|
+
n_filters: int = None,
|
|
564
|
+
conv_strides: tuple[int] = (2, 2),
|
|
565
|
+
batchnorm: bool = None,
|
|
566
|
+
dropout: float = None,
|
|
567
|
+
res_block_type: str = None,
|
|
568
|
+
) -> None:
|
|
569
|
+
"""
|
|
570
|
+
Bottom-up layer's method that initializes the LC modules.
|
|
571
|
+
|
|
572
|
+
Defines the modules responsible of merging compressed lateral inputs to the
|
|
573
|
+
outputs of the primary flow at different hierarchical levels in the
|
|
574
|
+
multiresolution approach (LC). Specifically, the method initializes `lowres_net`
|
|
575
|
+
, which is a stack of `BottomUpDeterministicBlock`'s (w/out downsampling) that
|
|
576
|
+
takes care of additionally processing the low-res input, and `lowres_merge`,
|
|
577
|
+
which is the module responsible of merging the compressed lateral input to the
|
|
578
|
+
main flow.
|
|
579
|
+
|
|
580
|
+
NOTE: The merge modality is set by default to "residual", meaning that the
|
|
581
|
+
merge layer performs concatenation on dim=1, followed by 1x1 convolution and
|
|
582
|
+
a Residual Gated block.
|
|
583
|
+
|
|
584
|
+
Parameters
|
|
585
|
+
----------
|
|
586
|
+
nonlin: Callable, optional
|
|
587
|
+
The non-linearity function used in the block. Default is `None`.
|
|
588
|
+
n_filters: int
|
|
589
|
+
Number of channels present through out the layers of this block.
|
|
590
|
+
batchnorm: bool, optional
|
|
591
|
+
Whether to use batchnorm layers. Default is `True`.
|
|
592
|
+
dropout: float, optional
|
|
593
|
+
The dropout probability in dropout layers. If `None` dropout is not used.
|
|
594
|
+
Default is `None`.
|
|
595
|
+
res_block_type: str, optional
|
|
596
|
+
A string specifying the structure of residual block.
|
|
597
|
+
Check `ResidualBlock` doscstring for more information.
|
|
598
|
+
Default is `None`.
|
|
599
|
+
"""
|
|
600
|
+
self.lowres_net = self.net
|
|
601
|
+
if self.lowres_separate_branch:
|
|
602
|
+
self.lowres_net = deepcopy(self.net)
|
|
603
|
+
|
|
604
|
+
self.lowres_merge = MergeLowRes(
|
|
605
|
+
channels=n_filters,
|
|
606
|
+
conv_strides=conv_strides,
|
|
607
|
+
merge_type="residual",
|
|
608
|
+
nonlin=nonlin,
|
|
609
|
+
batchnorm=batchnorm,
|
|
610
|
+
dropout=dropout,
|
|
611
|
+
res_block_type=res_block_type,
|
|
612
|
+
multiscale_retain_spatial_dims=self.multiscale_retain_spatial_dims,
|
|
613
|
+
multiscale_lowres_size_factor=self.multiscale_lowres_size_factor,
|
|
614
|
+
)
|
|
615
|
+
|
|
616
|
+
def forward(
|
|
617
|
+
self, x: torch.Tensor, lowres_x: Union[torch.Tensor, None] = None
|
|
618
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
619
|
+
"""Forward pass.
|
|
620
|
+
|
|
621
|
+
Parameters
|
|
622
|
+
----------
|
|
623
|
+
x: torch.Tensor
|
|
624
|
+
The input of the `BottomUpLayer`, i.e., the input image or the output of the
|
|
625
|
+
previous layer.
|
|
626
|
+
lowres_x: torch.Tensor, optional
|
|
627
|
+
The low-res input used for Lateral Contextualization (LC). Default is `None`.
|
|
628
|
+
|
|
629
|
+
NOTE: first returned tensor is used as input for the next BU layer, while the second
|
|
630
|
+
tensor is the bu_value passed to the top-down layer.
|
|
631
|
+
"""
|
|
632
|
+
# The input is fed through the residual downsampling block(s)
|
|
633
|
+
primary_flow = self.net_downsized(x)
|
|
634
|
+
# The downsampling output is fed through additional residual block(s)
|
|
635
|
+
primary_flow = self.net(primary_flow)
|
|
636
|
+
|
|
637
|
+
# If LC is not used, simply return output of primary-flow
|
|
638
|
+
if self.enable_multiscale is False:
|
|
639
|
+
assert lowres_x is None
|
|
640
|
+
return primary_flow, primary_flow
|
|
641
|
+
|
|
642
|
+
if lowres_x is not None:
|
|
643
|
+
# First encode the low-res lateral input
|
|
644
|
+
lowres_flow = self.lowres_net(lowres_x)
|
|
645
|
+
# Then pass the result through the MergeLowRes layer
|
|
646
|
+
merged = self.lowres_merge(primary_flow, lowres_flow)
|
|
647
|
+
else:
|
|
648
|
+
merged = primary_flow
|
|
649
|
+
|
|
650
|
+
# NOTE: Explanation of possible cases for the conditionals:
|
|
651
|
+
# - if both are `True` -> `merged` has the same spatial dims as the input (`x`) since
|
|
652
|
+
# spatial dims are retained by padding `primary_flow` in `MergeLowRes`. This is
|
|
653
|
+
# OK for the corresp TopDown layer, as it also retains spatial dims.
|
|
654
|
+
# - if both are `False` -> `merged`'s spatial dims are equal to `self.net_downsized(x)`,
|
|
655
|
+
# since no padding is done in `MergeLowRes` and, instead, the lowres input is cropped.
|
|
656
|
+
# This is OK for the corresp TopDown layer, as it also halves the spatial dims.
|
|
657
|
+
# - if 1st is `False` and 2nd is `True` -> not a concern, it cannot happen
|
|
658
|
+
# (see lvae.py, line 111, intialization of `multiscale_decoder_retain_spatial_dims`).
|
|
659
|
+
if (
|
|
660
|
+
self.multiscale_retain_spatial_dims is False
|
|
661
|
+
or self.decoder_retain_spatial_dims is True
|
|
662
|
+
):
|
|
663
|
+
return merged, merged
|
|
664
|
+
|
|
665
|
+
# NOTE: if we reach here, it means that `multiscale_retain_spatial_dims` is `True`,
|
|
666
|
+
# but `decoder_retain_spatial_dims` is `False`, meaning that merging LC preserves
|
|
667
|
+
# the spatial dimensions, but at the same time we don't want to retain the spatial
|
|
668
|
+
# dims in the corresponding top-down layer. Therefore, we need to crop the tensor.
|
|
669
|
+
if self.output_expected_shape is not None:
|
|
670
|
+
expected_shape = self.output_expected_shape
|
|
671
|
+
else:
|
|
672
|
+
fac = self.multiscale_lowres_size_factor
|
|
673
|
+
expected_shape = (merged.shape[-2] // fac, merged.shape[-1] // fac)
|
|
674
|
+
assert merged.shape[-2:] != expected_shape
|
|
675
|
+
|
|
676
|
+
# Crop the resulting tensor so that it matches with the Decoder
|
|
677
|
+
value_to_use_in_topdown = crop_img_tensor(merged, expected_shape)
|
|
678
|
+
return merged, value_to_use_in_topdown
|
|
679
|
+
|
|
680
|
+
|
|
681
|
+
class MergeLayer(nn.Module):
|
|
682
|
+
"""
|
|
683
|
+
Layer class that merges two or more input tensors.
|
|
684
|
+
|
|
685
|
+
Merges two or more (B, C, [Z], Y, X) input tensors by concatenating
|
|
686
|
+
them along dim=1 and passes the result through:
|
|
687
|
+
a) a convolutional 1x1 layer (`merge_type == "linear"`), or
|
|
688
|
+
b) a convolutional 1x1 layer and then a gated residual block (`merge_type == "residual"`), or
|
|
689
|
+
c) a convolutional 1x1 layer and then an ungated residual block (`merge_type == "residual_ungated"`).
|
|
690
|
+
"""
|
|
691
|
+
|
|
692
|
+
def __init__(
|
|
693
|
+
self,
|
|
694
|
+
merge_type: Literal["linear", "residual", "residual_ungated"],
|
|
695
|
+
channels: Union[int, Iterable[int]],
|
|
696
|
+
conv_strides: tuple[int] = (2, 2),
|
|
697
|
+
nonlin: Callable = nn.LeakyReLU(),
|
|
698
|
+
batchnorm: bool = True,
|
|
699
|
+
dropout: Optional[float] = None,
|
|
700
|
+
res_block_type: Optional[str] = None,
|
|
701
|
+
res_block_kernel: Optional[int] = None,
|
|
702
|
+
conv2d_bias: Optional[bool] = True,
|
|
703
|
+
):
|
|
704
|
+
"""
|
|
705
|
+
Constructor.
|
|
706
|
+
|
|
707
|
+
Parameters
|
|
708
|
+
----------
|
|
709
|
+
merge_type: Literal["linear", "residual", "residual_ungated"]
|
|
710
|
+
The type of merge done in the layer. It can be chosen between "linear",
|
|
711
|
+
"residual", and "residual_ungated". Check the class docstring for more
|
|
712
|
+
information about the behaviour of different merge modalities.
|
|
713
|
+
channels: Union[int, Iterable[int]]
|
|
714
|
+
The number of channels used in the convolutional blocks of this layer.
|
|
715
|
+
If it is an `int`:
|
|
716
|
+
- 1st 1x1 Conv2d: in_channels=2*channels, out_channels=channels
|
|
717
|
+
- (Optional) ResBlock: in_channels=channels, out_channels=channels
|
|
718
|
+
If it is an Iterable (must have `len(channels)==3`):
|
|
719
|
+
- 1st 1x1 Conv2d: in_channels=sum(channels[:-1]),
|
|
720
|
+
out_channels=channels[-1]
|
|
721
|
+
- (Optional) ResBlock: in_channels=channels[-1],
|
|
722
|
+
out_channels=channels[-1]
|
|
723
|
+
conv_strides: tuple, optional
|
|
724
|
+
The strides used in the convolutions. Default is `(2, 2)`.
|
|
725
|
+
nonlin: Callable, optional
|
|
726
|
+
The non-linearity function used in the block. Default is `nn.LeakyReLU`.
|
|
727
|
+
batchnorm: bool, optional
|
|
728
|
+
Whether to use batchnorm layers. Default is `True`.
|
|
729
|
+
dropout: float, optional
|
|
730
|
+
The dropout probability in dropout layers. If `None` dropout is not used.
|
|
731
|
+
Default is `None`.
|
|
732
|
+
res_block_type: str, optional
|
|
733
|
+
A string specifying the structure of residual block.
|
|
734
|
+
Check `ResidualBlock` doscstring for more information.
|
|
735
|
+
Default is `None`.
|
|
736
|
+
res_block_kernel: Union[int, Iterable[int]], optional
|
|
737
|
+
The kernel size used in the convolutions of the residual block.
|
|
738
|
+
It can be either a single integer or a pair of integers defining the squared
|
|
739
|
+
kernel.
|
|
740
|
+
Default is `None`.
|
|
741
|
+
conv2d_bias: bool, optional
|
|
742
|
+
Whether to use bias term in convolutions. Default is `True`.
|
|
743
|
+
"""
|
|
744
|
+
super().__init__()
|
|
745
|
+
try:
|
|
746
|
+
iter(channels)
|
|
747
|
+
except TypeError: # it is not iterable
|
|
748
|
+
channels = [channels] * 3
|
|
749
|
+
else: # it is iterable
|
|
750
|
+
if len(channels) == 1:
|
|
751
|
+
channels = [channels[0]] * 3
|
|
752
|
+
|
|
753
|
+
self.conv_layer: ConvType = getattr(nn, f"Conv{len(conv_strides)}d")
|
|
754
|
+
|
|
755
|
+
if merge_type == "linear":
|
|
756
|
+
self.layer = self.conv_layer(
|
|
757
|
+
sum(channels[:-1]), channels[-1], 1, bias=conv2d_bias
|
|
758
|
+
)
|
|
759
|
+
elif merge_type == "residual":
|
|
760
|
+
self.layer = nn.Sequential(
|
|
761
|
+
self.conv_layer(
|
|
762
|
+
sum(channels[:-1]), channels[-1], 1, padding=0, bias=conv2d_bias
|
|
763
|
+
),
|
|
764
|
+
ResidualGatedBlock(
|
|
765
|
+
conv_strides=conv_strides,
|
|
766
|
+
channels=channels[-1],
|
|
767
|
+
nonlin=nonlin,
|
|
768
|
+
batchnorm=batchnorm,
|
|
769
|
+
dropout=dropout,
|
|
770
|
+
block_type=res_block_type,
|
|
771
|
+
kernel=res_block_kernel,
|
|
772
|
+
conv2d_bias=conv2d_bias,
|
|
773
|
+
),
|
|
774
|
+
)
|
|
775
|
+
elif merge_type == "residual_ungated":
|
|
776
|
+
self.layer = nn.Sequential(
|
|
777
|
+
self.conv_layer(
|
|
778
|
+
sum(channels[:-1]), channels[-1], 1, padding=0, bias=conv2d_bias
|
|
779
|
+
),
|
|
780
|
+
ResidualBlock(
|
|
781
|
+
conv_strides=conv_strides,
|
|
782
|
+
channels=channels[-1],
|
|
783
|
+
nonlin=nonlin,
|
|
784
|
+
batchnorm=batchnorm,
|
|
785
|
+
dropout=dropout,
|
|
786
|
+
block_type=res_block_type,
|
|
787
|
+
kernel=res_block_kernel,
|
|
788
|
+
conv2d_bias=conv2d_bias,
|
|
789
|
+
),
|
|
790
|
+
)
|
|
791
|
+
|
|
792
|
+
def forward(self, *args) -> torch.Tensor:
|
|
793
|
+
|
|
794
|
+
# Concatenate the input tensors along dim=1
|
|
795
|
+
x = torch.cat(args, dim=1)
|
|
796
|
+
|
|
797
|
+
# Pass the concatenated tensor through the conv layer
|
|
798
|
+
x = self.layer(x)
|
|
799
|
+
|
|
800
|
+
return x
|
|
801
|
+
|
|
802
|
+
|
|
803
|
+
class MergeLowRes(MergeLayer):
|
|
804
|
+
"""
|
|
805
|
+
Child class of `MergeLayer`.
|
|
806
|
+
|
|
807
|
+
Specifically designed to merge the low-resolution patches
|
|
808
|
+
that are used in Lateral Contextualization approach.
|
|
809
|
+
"""
|
|
810
|
+
|
|
811
|
+
def __init__(self, *args, **kwargs):
|
|
812
|
+
self.retain_spatial_dims = kwargs.pop("multiscale_retain_spatial_dims")
|
|
813
|
+
self.multiscale_lowres_size_factor = kwargs.pop("multiscale_lowres_size_factor")
|
|
814
|
+
super().__init__(*args, **kwargs)
|
|
815
|
+
|
|
816
|
+
def forward(self, latent: torch.Tensor, lowres: torch.Tensor) -> torch.Tensor:
|
|
817
|
+
"""Forward pass.
|
|
818
|
+
|
|
819
|
+
Parameters
|
|
820
|
+
----------
|
|
821
|
+
latent: torch.Tensor
|
|
822
|
+
The output latent tensor from previous layer in the LVAE hierarchy.
|
|
823
|
+
lowres: torch.Tensor
|
|
824
|
+
The low-res patch image to be merged to increase the context.
|
|
825
|
+
"""
|
|
826
|
+
# TODO: treat (X, Y) and Z differently (e.g., line 762)
|
|
827
|
+
if self.retain_spatial_dims:
|
|
828
|
+
# Pad latent tensor to match lowres tensor's shape
|
|
829
|
+
# Output.shape == Lowres.shape (== Input.shape),
|
|
830
|
+
# where Input is the input to the BU layer
|
|
831
|
+
latent = pad_img_tensor(latent, lowres.shape[2:])
|
|
832
|
+
else:
|
|
833
|
+
# Crop lowres tensor to match latent tensor's shape
|
|
834
|
+
lz, ly, lx = lowres.shape[2:]
|
|
835
|
+
z = lz // self.multiscale_lowres_size_factor
|
|
836
|
+
y = ly // self.multiscale_lowres_size_factor
|
|
837
|
+
x = lx // self.multiscale_lowres_size_factor
|
|
838
|
+
z_pad = (lz - z) // 2
|
|
839
|
+
y_pad = (ly - y) // 2
|
|
840
|
+
x_pad = (lx - x) // 2
|
|
841
|
+
lowres = lowres[:, :, z_pad:-z_pad, y_pad:-y_pad, x_pad:-x_pad]
|
|
842
|
+
|
|
843
|
+
return super().forward(latent, lowres)
|
|
844
|
+
|
|
845
|
+
|
|
846
|
+
class SkipConnectionMerger(MergeLayer):
|
|
847
|
+
"""Specialized `MergeLayer` module, handles skip connections in the model."""
|
|
848
|
+
|
|
849
|
+
def __init__(
|
|
850
|
+
self,
|
|
851
|
+
nonlin: Callable,
|
|
852
|
+
channels: Union[int, Iterable[int]],
|
|
853
|
+
batchnorm: bool,
|
|
854
|
+
dropout: float,
|
|
855
|
+
res_block_type: str,
|
|
856
|
+
conv_strides: tuple[int] = (2, 2),
|
|
857
|
+
merge_type: Literal["linear", "residual", "residual_ungated"] = "residual",
|
|
858
|
+
conv2d_bias: bool = True,
|
|
859
|
+
res_block_kernel: Optional[int] = None,
|
|
860
|
+
):
|
|
861
|
+
"""
|
|
862
|
+
Constructor.
|
|
863
|
+
|
|
864
|
+
nonlin: Callable, optional
|
|
865
|
+
The non-linearity function used in the block. Default is `nn.LeakyReLU`.
|
|
866
|
+
channels: Union[int, Iterable[int]]
|
|
867
|
+
The number of channels used in the convolutional blocks of this layer.
|
|
868
|
+
If it is an `int`:
|
|
869
|
+
- 1st 1x1 Conv2d: in_channels=2*channels, out_channels=channels
|
|
870
|
+
- (Optional) ResBlock: in_channels=channels, out_channels=channels
|
|
871
|
+
If it is an Iterable (must have `len(channels)==3`):
|
|
872
|
+
- 1st 1x1 Conv2d: in_channels=sum(channels[:-1]), out_channels=channels[-1]
|
|
873
|
+
- (Optional) ResBlock: in_channels=channels[-1], out_channels=channels[-1]
|
|
874
|
+
batchnorm: bool
|
|
875
|
+
Whether to use batchnorm layers.
|
|
876
|
+
dropout: float
|
|
877
|
+
The dropout probability in dropout layers. If `None` dropout is not used.
|
|
878
|
+
res_block_type: str
|
|
879
|
+
A string specifying the structure of residual block.
|
|
880
|
+
Check `ResidualBlock` doscstring for more information.
|
|
881
|
+
conv_strides: tuple, optional
|
|
882
|
+
The strides used in the convolutions. Default is `(2, 2)`.
|
|
883
|
+
merge_type: Literal["linear", "residual", "residual_ungated"]
|
|
884
|
+
The type of merge done in the layer. It can be chosen between "linear", "residual", and "residual_ungated".
|
|
885
|
+
Check the class docstring for more information about the behaviour of different merge modalities.
|
|
886
|
+
conv2d_bias: bool, optional
|
|
887
|
+
Whether to use bias term in convolutions. Default is `True`.
|
|
888
|
+
res_block_kernel: Union[int, Iterable[int]], optional
|
|
889
|
+
The kernel size used in the convolutions of the residual block.
|
|
890
|
+
It can be either a single integer or a pair of integers defining the squared kernel.
|
|
891
|
+
Default is `None`.
|
|
892
|
+
"""
|
|
893
|
+
super().__init__(
|
|
894
|
+
conv_strides=conv_strides,
|
|
895
|
+
channels=channels,
|
|
896
|
+
nonlin=nonlin,
|
|
897
|
+
merge_type=merge_type,
|
|
898
|
+
batchnorm=batchnorm,
|
|
899
|
+
dropout=dropout,
|
|
900
|
+
res_block_type=res_block_type,
|
|
901
|
+
res_block_kernel=res_block_kernel,
|
|
902
|
+
conv2d_bias=conv2d_bias,
|
|
903
|
+
)
|
|
904
|
+
|
|
905
|
+
|
|
906
|
+
class TopDownLayer(nn.Module):
|
|
907
|
+
"""Top-down inference layer.
|
|
908
|
+
|
|
909
|
+
It includes:
|
|
910
|
+
- Stochastic sampling,
|
|
911
|
+
- Computation of KL divergence,
|
|
912
|
+
- A small deterministic ResNet that performs upsampling.
|
|
913
|
+
|
|
914
|
+
NOTE 1:
|
|
915
|
+
The algorithm for generative inference approximately works as follows:
|
|
916
|
+
- p_params = output of top-down layer above
|
|
917
|
+
- bu = inferred bottom-up value at this layer
|
|
918
|
+
- q_params = merge(bu, p_params)
|
|
919
|
+
- z = stochastic_layer(q_params)
|
|
920
|
+
- (optional) get and merge skip connection from prev top-down layer
|
|
921
|
+
- top-down deterministic ResNet
|
|
922
|
+
|
|
923
|
+
NOTE 2:
|
|
924
|
+
The Top-Down layer can work in two modes: inference and prediction/generative.
|
|
925
|
+
Depending on the particular mode, it follows distinct behaviours:
|
|
926
|
+
- In inference mode, parameters of q(z_i|z_i+1) are obtained from the inference path,
|
|
927
|
+
by merging outcomes of bottom-up and top-down passes. The exception is the top layer,
|
|
928
|
+
in which the parameters of q(z_L|x) are set as the output of the topmost bottom-up layer.
|
|
929
|
+
- On the contrary in predicition/generative mode, parameters of q(z_i|z_i+1) can be obtained
|
|
930
|
+
once again by merging bottom-up and top-down outputs (CONDITIONAL GENERATION), or it is
|
|
931
|
+
possible to directly sample from the prior p(z_i|z_i+1) (UNCONDITIONAL GENERATION).
|
|
932
|
+
|
|
933
|
+
NOTE 3:
|
|
934
|
+
When doing unconditional generation, bu_value is not available. Hence the
|
|
935
|
+
merge layer is not used, and z is sampled directly from p_params.
|
|
936
|
+
|
|
937
|
+
NOTE 4:
|
|
938
|
+
If this is the top layer, at inference time, the uppermost bottom-up value
|
|
939
|
+
is used directly as q_params, and p_params are defined in this layer
|
|
940
|
+
(while they are usually taken from the previous layer), and can be learned.
|
|
941
|
+
"""
|
|
942
|
+
|
|
943
|
+
def __init__(
|
|
944
|
+
self,
|
|
945
|
+
z_dim: int,
|
|
946
|
+
n_res_blocks: int,
|
|
947
|
+
n_filters: int,
|
|
948
|
+
conv_strides: tuple[int],
|
|
949
|
+
is_top_layer: bool = False,
|
|
950
|
+
upsampling_steps: Union[int, None] = None,
|
|
951
|
+
nonlin: Union[Callable, None] = None,
|
|
952
|
+
merge_type: Union[
|
|
953
|
+
Literal["linear", "residual", "residual_ungated"], None
|
|
954
|
+
] = None,
|
|
955
|
+
batchnorm: bool = True,
|
|
956
|
+
dropout: Union[float, None] = None,
|
|
957
|
+
stochastic_skip: bool = False,
|
|
958
|
+
res_block_type: Union[str, None] = None,
|
|
959
|
+
res_block_kernel: Union[int, None] = None,
|
|
960
|
+
groups: int = 1,
|
|
961
|
+
gated: Union[bool, None] = None,
|
|
962
|
+
learn_top_prior: bool = False,
|
|
963
|
+
top_prior_param_shape: Union[Iterable[int], None] = None,
|
|
964
|
+
analytical_kl: bool = False,
|
|
965
|
+
retain_spatial_dims: bool = False,
|
|
966
|
+
vanilla_latent_hw: Union[Iterable[int], None] = None,
|
|
967
|
+
input_image_shape: Union[tuple[int, int], None] = None,
|
|
968
|
+
normalize_latent_factor: float = 1.0,
|
|
969
|
+
conv2d_bias: bool = True,
|
|
970
|
+
stochastic_use_naive_exponential: bool = False,
|
|
971
|
+
):
|
|
972
|
+
"""
|
|
973
|
+
Constructor.
|
|
974
|
+
|
|
975
|
+
Parameters
|
|
976
|
+
----------
|
|
977
|
+
z_dim: int
|
|
978
|
+
The size of the latent space.
|
|
979
|
+
n_res_blocks: int
|
|
980
|
+
The number of TopDownDeterministicResBlock blocks
|
|
981
|
+
n_filters: int
|
|
982
|
+
The number of channels present through out the layers of this block.
|
|
983
|
+
conv_strides: tuple, optional
|
|
984
|
+
The strides used in the convolutions. Default is `(2, 2)`.
|
|
985
|
+
is_top_layer: bool, optional
|
|
986
|
+
Whether the current layer is at the top of the Decoder hierarchy. Default is `False`.
|
|
987
|
+
upsampling_steps: int, optional
|
|
988
|
+
The number of upsampling steps that has to be done in this layer (typically 1).
|
|
989
|
+
Default is `None`.
|
|
990
|
+
nonlin: Callable, optional
|
|
991
|
+
The non-linearity function used in the block (e.g., `nn.ReLU`). Default is `None`.
|
|
992
|
+
merge_type: Literal["linear", "residual", "residual_ungated"], optional
|
|
993
|
+
The type of merge done in the layer. It can be chosen between "linear", "residual",
|
|
994
|
+
and "residual_ungated". Check the `MergeLayer` class docstring for more information
|
|
995
|
+
about the behaviour of different merging modalities. Default is `None`.
|
|
996
|
+
batchnorm: bool, optional
|
|
997
|
+
Whether to use batchnorm layers. Default is `True`.
|
|
998
|
+
dropout: float, optional
|
|
999
|
+
The dropout probability in dropout layers. If `None` dropout is not used.
|
|
1000
|
+
Default is `None`.
|
|
1001
|
+
stochastic_skip: bool, optional
|
|
1002
|
+
Whether to use skip connections between previous top-down layer's output and this layer's stochastic output.
|
|
1003
|
+
Stochastic skip connection allows the previous layer's output has a way to directly reach this hierarchical
|
|
1004
|
+
level, hence facilitating the gradient flow during backpropagation. Default is `False`.
|
|
1005
|
+
res_block_type: str, optional
|
|
1006
|
+
A string specifying the structure of residual block.
|
|
1007
|
+
Check `ResidualBlock` documentation for more information.
|
|
1008
|
+
Default is `None`.
|
|
1009
|
+
res_block_kernel: Union[int, Iterable[int]], optional
|
|
1010
|
+
The kernel size used in the convolutions of the residual block.
|
|
1011
|
+
It can be either a single integer or a pair of integers defining the squared kernel.
|
|
1012
|
+
Default is `None`.
|
|
1013
|
+
groups: int, optional
|
|
1014
|
+
The number of groups to consider in the convolutions. Default is 1.
|
|
1015
|
+
gated: bool, optional
|
|
1016
|
+
Whether to use gated layer in `ResidualBlock`. Default is `None`.
|
|
1017
|
+
learn_top_prior:
|
|
1018
|
+
Whether to set the top prior as learnable.
|
|
1019
|
+
If this is set to `False`, in the top-most layer the prior will be N(0,1).
|
|
1020
|
+
Otherwise, we will still have a normal distribution whose parameters will be learnt.
|
|
1021
|
+
Default is `False`.
|
|
1022
|
+
top_prior_param_shape: Iterable[int], optional
|
|
1023
|
+
The size of the tensor which expresses the mean and the variance
|
|
1024
|
+
of the prior for the top most layer. Default is `None`.
|
|
1025
|
+
analytical_kl: bool, optional
|
|
1026
|
+
If True, KL divergence is calculated according to the analytical formula.
|
|
1027
|
+
Otherwise, an MC approximation using sampled latents is calculated.
|
|
1028
|
+
Default is `False`.
|
|
1029
|
+
retain_spatial_dims: bool, optional
|
|
1030
|
+
If `True`, the size of Encoder's latent space is kept to `input_image_shape` within the topdown layer.
|
|
1031
|
+
This implies that the oput spatial size equals the input spatial size.
|
|
1032
|
+
To achieve this, we centercrop the intermediate representation.
|
|
1033
|
+
Default is `False`.
|
|
1034
|
+
vanilla_latent_hw: Iterable[int], optional
|
|
1035
|
+
The shape of the latent tensor used for prediction (i.e., it influences the computation of restricted KL).
|
|
1036
|
+
Default is `None`.
|
|
1037
|
+
input_image_shape: Tuple[int, int], optionalut
|
|
1038
|
+
The shape of the input image tensor.
|
|
1039
|
+
When `retain_spatial_dims` is set to `True`, this is used to ensure that the shape of this layer
|
|
1040
|
+
output has the same shape as the input. Default is `None`.
|
|
1041
|
+
normalize_latent_factor: float, optional
|
|
1042
|
+
A factor used to normalize the latent tensors `q_params`.
|
|
1043
|
+
Specifically, normalization is done by dividing the latent tensor by this factor.
|
|
1044
|
+
Default is 1.0.
|
|
1045
|
+
conv2d_bias: bool, optional
|
|
1046
|
+
Whether to use bias term is the convolutional blocks of this layer.
|
|
1047
|
+
Default is `True`.
|
|
1048
|
+
stochastic_use_naive_exponential: bool, optional
|
|
1049
|
+
If `False`, in the NormalStochasticBlock2d exponentials are computed according
|
|
1050
|
+
to the alternative definition provided by `StableExponential` class.
|
|
1051
|
+
This should improve numerical stability in the training process.
|
|
1052
|
+
Default is `False`.
|
|
1053
|
+
"""
|
|
1054
|
+
super().__init__()
|
|
1055
|
+
|
|
1056
|
+
self.is_top_layer = is_top_layer
|
|
1057
|
+
self.z_dim = z_dim
|
|
1058
|
+
self.stochastic_skip = stochastic_skip
|
|
1059
|
+
self.learn_top_prior = learn_top_prior
|
|
1060
|
+
self.analytical_kl = analytical_kl
|
|
1061
|
+
self.retain_spatial_dims = retain_spatial_dims
|
|
1062
|
+
self.input_image_shape = (
|
|
1063
|
+
input_image_shape if len(conv_strides) == 3 else input_image_shape[1:]
|
|
1064
|
+
)
|
|
1065
|
+
self.latent_shape = self.input_image_shape if self.retain_spatial_dims else None
|
|
1066
|
+
self.normalize_latent_factor = normalize_latent_factor
|
|
1067
|
+
self._vanilla_latent_hw = vanilla_latent_hw # TODO: check this, it is not used
|
|
1068
|
+
|
|
1069
|
+
# Define top layer prior parameters, possibly learnable
|
|
1070
|
+
if is_top_layer:
|
|
1071
|
+
self.top_prior_params = nn.Parameter(
|
|
1072
|
+
torch.zeros(top_prior_param_shape), requires_grad=learn_top_prior
|
|
1073
|
+
)
|
|
1074
|
+
|
|
1075
|
+
# Upsampling steps left to do in this layer
|
|
1076
|
+
ups_left = upsampling_steps
|
|
1077
|
+
|
|
1078
|
+
# Define deterministic top-down block, which is a sequence of deterministic
|
|
1079
|
+
# residual blocks with (optional) upsampling.
|
|
1080
|
+
block_list = []
|
|
1081
|
+
for _ in range(n_res_blocks):
|
|
1082
|
+
do_resample = False
|
|
1083
|
+
if ups_left > 0:
|
|
1084
|
+
do_resample = True
|
|
1085
|
+
ups_left -= 1
|
|
1086
|
+
block_list.append(
|
|
1087
|
+
TopDownDeterministicResBlock(
|
|
1088
|
+
c_in=n_filters,
|
|
1089
|
+
c_out=n_filters,
|
|
1090
|
+
conv_strides=conv_strides,
|
|
1091
|
+
nonlin=nonlin,
|
|
1092
|
+
upsample=do_resample,
|
|
1093
|
+
batchnorm=batchnorm,
|
|
1094
|
+
dropout=dropout,
|
|
1095
|
+
res_block_type=res_block_type,
|
|
1096
|
+
res_block_kernel=res_block_kernel,
|
|
1097
|
+
gated=gated,
|
|
1098
|
+
conv2d_bias=conv2d_bias,
|
|
1099
|
+
groups=groups,
|
|
1100
|
+
)
|
|
1101
|
+
)
|
|
1102
|
+
self.deterministic_block = nn.Sequential(*block_list)
|
|
1103
|
+
|
|
1104
|
+
# Define stochastic block with convolutions
|
|
1105
|
+
|
|
1106
|
+
self.stochastic = NormalStochasticBlock(
|
|
1107
|
+
c_in=n_filters,
|
|
1108
|
+
c_vars=z_dim,
|
|
1109
|
+
c_out=n_filters,
|
|
1110
|
+
conv_dims=len(conv_strides),
|
|
1111
|
+
transform_p_params=(not is_top_layer),
|
|
1112
|
+
vanilla_latent_hw=vanilla_latent_hw,
|
|
1113
|
+
use_naive_exponential=stochastic_use_naive_exponential,
|
|
1114
|
+
)
|
|
1115
|
+
|
|
1116
|
+
if not is_top_layer:
|
|
1117
|
+
# Merge layer: it combines bottom-up inference and top-down
|
|
1118
|
+
# generative outcomes to give posterior parameters
|
|
1119
|
+
self.merge = MergeLayer(
|
|
1120
|
+
channels=n_filters,
|
|
1121
|
+
conv_strides=conv_strides,
|
|
1122
|
+
merge_type=merge_type,
|
|
1123
|
+
nonlin=nonlin,
|
|
1124
|
+
batchnorm=batchnorm,
|
|
1125
|
+
dropout=dropout,
|
|
1126
|
+
res_block_type=res_block_type,
|
|
1127
|
+
res_block_kernel=res_block_kernel,
|
|
1128
|
+
conv2d_bias=conv2d_bias,
|
|
1129
|
+
)
|
|
1130
|
+
|
|
1131
|
+
# Skip connection that goes around the stochastic top-down layer
|
|
1132
|
+
if stochastic_skip:
|
|
1133
|
+
self.skip_connection_merger = SkipConnectionMerger(
|
|
1134
|
+
channels=n_filters,
|
|
1135
|
+
conv_strides=conv_strides,
|
|
1136
|
+
nonlin=nonlin,
|
|
1137
|
+
batchnorm=batchnorm,
|
|
1138
|
+
dropout=dropout,
|
|
1139
|
+
res_block_type=res_block_type,
|
|
1140
|
+
merge_type=merge_type,
|
|
1141
|
+
conv2d_bias=conv2d_bias,
|
|
1142
|
+
res_block_kernel=res_block_kernel,
|
|
1143
|
+
)
|
|
1144
|
+
|
|
1145
|
+
def sample_from_q(
|
|
1146
|
+
self,
|
|
1147
|
+
input_: torch.Tensor,
|
|
1148
|
+
bu_value: torch.Tensor,
|
|
1149
|
+
var_clip_max: Optional[float] = None,
|
|
1150
|
+
mask: torch.Tensor = None,
|
|
1151
|
+
) -> torch.Tensor:
|
|
1152
|
+
"""
|
|
1153
|
+
Method computes the latent inference distribution q(z_i|z_{i+1}).
|
|
1154
|
+
|
|
1155
|
+
Used for sampling a latent tensor from it.
|
|
1156
|
+
|
|
1157
|
+
Parameters
|
|
1158
|
+
----------
|
|
1159
|
+
input_: torch.Tensor
|
|
1160
|
+
The input tensor to the layer, which is the output of the top-down layer.
|
|
1161
|
+
bu_value: torch.Tensor
|
|
1162
|
+
The tensor defining the parameters /mu_q and /sigma_q computed during the
|
|
1163
|
+
bottom-up deterministic pass at the correspondent hierarchical layer.
|
|
1164
|
+
var_clip_max: float, optional
|
|
1165
|
+
The maximum value reachable by the log-variance of the latent distribution.
|
|
1166
|
+
Values exceeding this threshold are clipped. Default is `None`.
|
|
1167
|
+
mask: Union[None, torch.Tensor], optional
|
|
1168
|
+
A tensor that is used to mask the sampled latent tensor. Default is `None`.
|
|
1169
|
+
"""
|
|
1170
|
+
if self.is_top_layer: # In top layer, we don't merge bu_value with p_params
|
|
1171
|
+
q_params = bu_value
|
|
1172
|
+
else:
|
|
1173
|
+
# NOTE: Here the assumption is that the vampprior is only applied on the top layer.
|
|
1174
|
+
n_img_prior = None
|
|
1175
|
+
p_params = self.get_p_params(input_, n_img_prior)
|
|
1176
|
+
q_params = self.merge(bu_value, p_params)
|
|
1177
|
+
|
|
1178
|
+
sample = self.stochastic.sample_from_q(q_params, var_clip_max)
|
|
1179
|
+
|
|
1180
|
+
if mask:
|
|
1181
|
+
return sample[mask]
|
|
1182
|
+
|
|
1183
|
+
return sample
|
|
1184
|
+
|
|
1185
|
+
def get_p_params(
|
|
1186
|
+
self,
|
|
1187
|
+
input_: torch.Tensor,
|
|
1188
|
+
n_img_prior: int,
|
|
1189
|
+
) -> torch.Tensor:
|
|
1190
|
+
"""Return the parameters of the prior distribution p(z_i|z_{i+1}).
|
|
1191
|
+
|
|
1192
|
+
The parameters depend on the hierarchical level of the layer:
|
|
1193
|
+
- if it is the topmost level, parameters are the ones of the prior.
|
|
1194
|
+
- else, the input from the layer above is the parameters itself.
|
|
1195
|
+
|
|
1196
|
+
Parameters
|
|
1197
|
+
----------
|
|
1198
|
+
input_: torch.Tensor
|
|
1199
|
+
The input tensor to the layer, which is the output of the top-down layer above.
|
|
1200
|
+
n_img_prior: int
|
|
1201
|
+
The number of images to be generated from the unconditional prior distribution p(z_L).
|
|
1202
|
+
"""
|
|
1203
|
+
p_params = None
|
|
1204
|
+
|
|
1205
|
+
# If top layer, define p_params as the ones of the prior p(z_L)
|
|
1206
|
+
if self.is_top_layer:
|
|
1207
|
+
p_params = self.top_prior_params
|
|
1208
|
+
|
|
1209
|
+
# Sample specific number of images by expanding the prior
|
|
1210
|
+
if n_img_prior is not None:
|
|
1211
|
+
p_params = p_params.expand(n_img_prior, -1, -1, -1)
|
|
1212
|
+
|
|
1213
|
+
# Else the input from the layer above is p_params itself
|
|
1214
|
+
else:
|
|
1215
|
+
p_params = input_
|
|
1216
|
+
|
|
1217
|
+
return p_params
|
|
1218
|
+
|
|
1219
|
+
def forward(
|
|
1220
|
+
self,
|
|
1221
|
+
input_: Union[torch.Tensor, None] = None,
|
|
1222
|
+
skip_connection_input: Union[torch.Tensor, None] = None,
|
|
1223
|
+
inference_mode: bool = False,
|
|
1224
|
+
bu_value: Union[torch.Tensor, None] = None,
|
|
1225
|
+
n_img_prior: Union[int, None] = None,
|
|
1226
|
+
forced_latent: Union[torch.Tensor, None] = None,
|
|
1227
|
+
force_constant_output: bool = False,
|
|
1228
|
+
mode_pred: bool = False,
|
|
1229
|
+
use_uncond_mode: bool = False,
|
|
1230
|
+
var_clip_max: Union[float, None] = None,
|
|
1231
|
+
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
|
1232
|
+
"""Forward pass.
|
|
1233
|
+
|
|
1234
|
+
Parameters
|
|
1235
|
+
----------
|
|
1236
|
+
input_: torch.Tensor, optional
|
|
1237
|
+
The input tensor to the layer, which is the output of the top-down layer.
|
|
1238
|
+
Default is `None`.
|
|
1239
|
+
skip_connection_input: torch.Tensor, optional
|
|
1240
|
+
The tensor brought by the skip connection between the current and the
|
|
1241
|
+
previous top-down layer.
|
|
1242
|
+
Default is `None`.
|
|
1243
|
+
inference_mode: bool, optional
|
|
1244
|
+
Whether the layer is in inference mode. See NOTE 2 in class description
|
|
1245
|
+
for more info.
|
|
1246
|
+
Default is `False`.
|
|
1247
|
+
bu_value: torch.Tensor, optional
|
|
1248
|
+
The tensor defining the parameters /mu_q and /sigma_q computed during the
|
|
1249
|
+
bottom-up deterministic pass
|
|
1250
|
+
at the correspondent hierarchical layer. Default is `None`.
|
|
1251
|
+
n_img_prior: int, optional
|
|
1252
|
+
The number of images to be generated from the unconditional prior
|
|
1253
|
+
distribution p(z_L).
|
|
1254
|
+
Default is `None`.
|
|
1255
|
+
forced_latent: torch.Tensor, optional
|
|
1256
|
+
A pre-defined latent tensor. If it is not `None`, than it is used as the
|
|
1257
|
+
actual latent tensor and,
|
|
1258
|
+
hence, sampling does not happen. Default is `None`.
|
|
1259
|
+
force_constant_output: bool, optional
|
|
1260
|
+
Whether to copy the first sample (and rel. distrib parameters) over the
|
|
1261
|
+
whole batch.
|
|
1262
|
+
This is used when doing experiment from the prior - q is not used.
|
|
1263
|
+
Default is `False`.
|
|
1264
|
+
mode_pred: bool, optional
|
|
1265
|
+
Whether the model is in prediction mode. Default is `False`.
|
|
1266
|
+
use_uncond_mode: bool, optional
|
|
1267
|
+
Whether to use the uncoditional distribution p(z) to sample latents in
|
|
1268
|
+
prediction mode.
|
|
1269
|
+
var_clip_max: float
|
|
1270
|
+
The maximum value reachable by the log-variance of the latent distribution.
|
|
1271
|
+
Values exceeding this threshold are clipped.
|
|
1272
|
+
"""
|
|
1273
|
+
# Check consistency of arguments
|
|
1274
|
+
inputs_none = input_ is None and skip_connection_input is None
|
|
1275
|
+
if self.is_top_layer and not inputs_none:
|
|
1276
|
+
raise ValueError("In top layer, inputs should be None")
|
|
1277
|
+
|
|
1278
|
+
p_params = self.get_p_params(input_, n_img_prior)
|
|
1279
|
+
|
|
1280
|
+
# Get the parameters for the latent distribution to sample from
|
|
1281
|
+
if inference_mode: # TODO What's this ? reuse Fede's code?
|
|
1282
|
+
if self.is_top_layer:
|
|
1283
|
+
q_params = bu_value
|
|
1284
|
+
if mode_pred is False:
|
|
1285
|
+
assert p_params.shape[2:] == bu_value.shape[2:], (
|
|
1286
|
+
"Spatial dimensions of p_params and bu_value should match. "
|
|
1287
|
+
f"Instead, we got p_params={p_params.shape[2:]} and "
|
|
1288
|
+
f"bu_value={bu_value.shape[2:]}."
|
|
1289
|
+
)
|
|
1290
|
+
else:
|
|
1291
|
+
if use_uncond_mode:
|
|
1292
|
+
q_params = p_params
|
|
1293
|
+
else:
|
|
1294
|
+
assert p_params.shape[2:] == bu_value.shape[2:], (
|
|
1295
|
+
"Spatial dimensions of p_params and bu_value should match. "
|
|
1296
|
+
f"Instead, we got p_params={p_params.shape[2:]} and "
|
|
1297
|
+
f"bu_value={bu_value.shape[2:]}."
|
|
1298
|
+
)
|
|
1299
|
+
q_params = self.merge(bu_value, p_params)
|
|
1300
|
+
else: # generative mode, q is not used, we sample from p(z_i | z_{i+1})
|
|
1301
|
+
q_params = None
|
|
1302
|
+
|
|
1303
|
+
# NOTE: Sampling is done either from q(z_i | z_{i+1}, x) or p(z_i | z_{i+1})
|
|
1304
|
+
# depending on the mode (hence, in practice, by checking whether q_params is None).
|
|
1305
|
+
|
|
1306
|
+
# Normalization of latent space parameters for stablity.
|
|
1307
|
+
# See Very deep VAEs generalize autoregressive models.
|
|
1308
|
+
if self.normalize_latent_factor:
|
|
1309
|
+
q_params = q_params / self.normalize_latent_factor
|
|
1310
|
+
|
|
1311
|
+
# Sample (and process) a latent tensor in the stochastic layer
|
|
1312
|
+
x, data_stoch = self.stochastic(
|
|
1313
|
+
p_params=p_params,
|
|
1314
|
+
q_params=q_params,
|
|
1315
|
+
forced_latent=forced_latent,
|
|
1316
|
+
force_constant_output=force_constant_output,
|
|
1317
|
+
analytical_kl=self.analytical_kl,
|
|
1318
|
+
mode_pred=mode_pred,
|
|
1319
|
+
use_uncond_mode=use_uncond_mode,
|
|
1320
|
+
var_clip_max=var_clip_max,
|
|
1321
|
+
)
|
|
1322
|
+
# Merge skip connection from previous layer
|
|
1323
|
+
if self.stochastic_skip and not self.is_top_layer:
|
|
1324
|
+
x = self.skip_connection_merger(x, skip_connection_input)
|
|
1325
|
+
if self.retain_spatial_dims:
|
|
1326
|
+
# NOTE: we assume that one topdown layer will have exactly one upscaling layer.
|
|
1327
|
+
|
|
1328
|
+
# NOTE: in case, in the Bottom-Up layer, LC retains spatial dimensions,
|
|
1329
|
+
# we have the following (see `MergeLowRes`):
|
|
1330
|
+
# - the "primary-flow" tensor is padded to match the low-res patch size
|
|
1331
|
+
# (e.g., from 32x32 to 64x64)
|
|
1332
|
+
# - padded tensor is then merged with the low-res patch (concatenation
|
|
1333
|
+
# along dim=1 + convolution)
|
|
1334
|
+
# Therefore, we need to do the symmetric operation here, that is to
|
|
1335
|
+
# crop `x` for the same amount we padded it in the correspondent BU layer.
|
|
1336
|
+
|
|
1337
|
+
# NOTE: cropping is done to retain the shape of the input in the output.
|
|
1338
|
+
# Therefore we need it only in the case `x` is the same shape of the input,
|
|
1339
|
+
# because that's the only case in which we need to retain the shape.
|
|
1340
|
+
# Here, it must be strictly greater than half the input shape, which is
|
|
1341
|
+
# the case if and only if `x.shape == self.latent_shape`.
|
|
1342
|
+
rescale = (
|
|
1343
|
+
np.array((1, 2, 2)) if len(self.latent_shape) == 3 else np.array((2, 2))
|
|
1344
|
+
) # TODO better way?
|
|
1345
|
+
new_latent_shape = tuple(np.array(self.latent_shape) // rescale)
|
|
1346
|
+
if x.shape[-1] > new_latent_shape[-1]:
|
|
1347
|
+
x = crop_img_tensor(x, new_latent_shape)
|
|
1348
|
+
# TODO: `retain_spatial_dims` is the same for all the TD layers.
|
|
1349
|
+
# How to handle the case in which we do not have LC for all layers?
|
|
1350
|
+
# The answer is in `self.latent_shape`, which is equal to `input_image_shape`
|
|
1351
|
+
# (e.g., (64, 64)) if `retain_spatial_dims` is `True`, else it is `None`.
|
|
1352
|
+
# Last top-down block (sequence of residual blocks w\ upsampling)
|
|
1353
|
+
x = self.deterministic_block(x)
|
|
1354
|
+
# Save some metrics that will be used in the loss computation
|
|
1355
|
+
keys = [
|
|
1356
|
+
"z",
|
|
1357
|
+
"kl_samplewise",
|
|
1358
|
+
"kl_samplewise_restricted",
|
|
1359
|
+
"kl_spatial",
|
|
1360
|
+
"kl_channelwise",
|
|
1361
|
+
"logprob_q",
|
|
1362
|
+
"qvar_max",
|
|
1363
|
+
]
|
|
1364
|
+
data = {k: data_stoch.get(k, None) for k in keys}
|
|
1365
|
+
data["q_mu"] = None
|
|
1366
|
+
data["q_lv"] = None
|
|
1367
|
+
if data_stoch["q_params"] is not None:
|
|
1368
|
+
q_mu, q_lv = data_stoch["q_params"]
|
|
1369
|
+
data["q_mu"] = q_mu
|
|
1370
|
+
data["q_lv"] = q_lv
|
|
1371
|
+
return x, data
|