careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,701 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Lightning Module for LadderVAE.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Any, Dict
|
|
6
|
+
|
|
7
|
+
import ml_collections
|
|
8
|
+
import numpy as np
|
|
9
|
+
import pytorch_lightning as L
|
|
10
|
+
import torch
|
|
11
|
+
import torchvision.transforms.functional as F
|
|
12
|
+
|
|
13
|
+
from careamics.models.lvae.likelihoods import LikelihoodModule
|
|
14
|
+
from careamics.models.lvae.lvae import LadderVAE
|
|
15
|
+
from careamics.models.lvae.utils import (
|
|
16
|
+
LossType,
|
|
17
|
+
compute_batch_mean,
|
|
18
|
+
free_bits_kl,
|
|
19
|
+
torch_nanmean,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
from .metrics import RangeInvariantPsnr, RunningPSNR
|
|
23
|
+
from .train_utils import MetricMonitor
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class LadderVAELight(L.LightningModule):
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
config: ml_collections.ConfigDict,
|
|
31
|
+
data_mean: Dict[str, torch.Tensor],
|
|
32
|
+
data_std: Dict[str, torch.Tensor],
|
|
33
|
+
target_ch: int,
|
|
34
|
+
):
|
|
35
|
+
"""
|
|
36
|
+
Here we will do the following:
|
|
37
|
+
- initialize the model (from LadderVAE class)
|
|
38
|
+
- initialize the parameters related to the training and loss.
|
|
39
|
+
|
|
40
|
+
NOTE:
|
|
41
|
+
Some of the model attributes are defined in the model object itself, while some others will be defined here.
|
|
42
|
+
Note that all the attributes related to the training and loss that were already defined in the model object
|
|
43
|
+
are redefined here as Lightning module attributes (e.g., self.some_attr = model.some_attr).
|
|
44
|
+
The attributes related to the model itself are treated as model attributes (e.g., self.model.some_attr).
|
|
45
|
+
|
|
46
|
+
NOTE: HC stands for Hard Coded attribute.
|
|
47
|
+
"""
|
|
48
|
+
super().__init__()
|
|
49
|
+
|
|
50
|
+
self.data_mean = data_mean
|
|
51
|
+
self.data_std = data_std
|
|
52
|
+
self.target_ch = target_ch
|
|
53
|
+
|
|
54
|
+
# Initialize LVAE model
|
|
55
|
+
self.model = LadderVAE(
|
|
56
|
+
data_mean=data_mean, data_std=data_std, config=config, target_ch=target_ch
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
##### Define attributes from config #####
|
|
60
|
+
self.workdir = config.workdir
|
|
61
|
+
self._input_is_sum = False
|
|
62
|
+
self.kl_loss_formulation = config.loss.kl_loss_formulation
|
|
63
|
+
assert self.kl_loss_formulation in [
|
|
64
|
+
None,
|
|
65
|
+
"",
|
|
66
|
+
"usplit",
|
|
67
|
+
"denoisplit",
|
|
68
|
+
"denoisplit_usplit",
|
|
69
|
+
], f"""
|
|
70
|
+
Invalid kl_loss_formulation. {self.kl_loss_formulation}"""
|
|
71
|
+
|
|
72
|
+
##### Define loss attributes #####
|
|
73
|
+
# Parameters already defined in the model object
|
|
74
|
+
self.loss_type = self.model.loss_type
|
|
75
|
+
self._denoisplit_w = self._usplit_w = None
|
|
76
|
+
if self.loss_type == LossType.DenoiSplitMuSplit:
|
|
77
|
+
self._usplit_w = 0
|
|
78
|
+
self._denoisplit_w = 1 - self._usplit_w
|
|
79
|
+
assert self._denoisplit_w + self._usplit_w == 1
|
|
80
|
+
self._restricted_kl = self.model._restricted_kl
|
|
81
|
+
|
|
82
|
+
# General loss parameters
|
|
83
|
+
self.channel_1_w = 1
|
|
84
|
+
self.channel_2_w = 1
|
|
85
|
+
|
|
86
|
+
# About Reconsruction Loss
|
|
87
|
+
self.reconstruction_mode = False
|
|
88
|
+
self.skip_nboundary_pixels_from_loss = None
|
|
89
|
+
self.reconstruction_weight = 1.0
|
|
90
|
+
self._exclusion_loss_weight = 0
|
|
91
|
+
self.ch1_recons_w = 1
|
|
92
|
+
self.ch2_recons_w = 1
|
|
93
|
+
self.enable_mixed_rec = False
|
|
94
|
+
self.mixed_rec_w_step = 0
|
|
95
|
+
|
|
96
|
+
# About KL Loss
|
|
97
|
+
self.kl_weight = 1.0 # HC
|
|
98
|
+
self.usplit_kl_weight = None # HC
|
|
99
|
+
self.free_bits = 1.0 # HC
|
|
100
|
+
self.kl_annealing = False # HC
|
|
101
|
+
self.kl_annealtime = self.kl_start = None
|
|
102
|
+
if self.kl_annealing:
|
|
103
|
+
self.kl_annealtime = 10 # HC
|
|
104
|
+
self.kl_start = -1 # HC
|
|
105
|
+
|
|
106
|
+
##### Define training attributes #####
|
|
107
|
+
self.lr = config.training.lr
|
|
108
|
+
self.lr_scheduler_patience = config.training.lr_scheduler_patience
|
|
109
|
+
self.lr_scheduler_monitor = config.model.get("monitor", "val_loss")
|
|
110
|
+
self.lr_scheduler_mode = MetricMonitor(self.lr_scheduler_monitor).mode()
|
|
111
|
+
|
|
112
|
+
# Initialize object for keeping track of PSNR for each output channel
|
|
113
|
+
self.channels_psnr = [RunningPSNR() for _ in range(self.model.target_ch)]
|
|
114
|
+
|
|
115
|
+
def forward(self, x: Any) -> Any:
|
|
116
|
+
return self.model(x)
|
|
117
|
+
|
|
118
|
+
def training_step(
|
|
119
|
+
self, batch: torch.Tensor, batch_idx: int, enable_logging: bool = True
|
|
120
|
+
) -> Dict[str, torch.Tensor]:
|
|
121
|
+
|
|
122
|
+
if self.current_epoch == 0 and batch_idx == 0:
|
|
123
|
+
self.log("val_psnr", 1.0, on_epoch=True)
|
|
124
|
+
|
|
125
|
+
# Pre-processing of inputs
|
|
126
|
+
x, target = batch[:2]
|
|
127
|
+
self.set_params_to_same_device_as(x)
|
|
128
|
+
x_normalized = self.normalize_input(x)
|
|
129
|
+
if self.reconstruction_mode: # just for experimental purpose
|
|
130
|
+
target_normalized = x_normalized[:, :1].repeat(1, 2, 1, 1)
|
|
131
|
+
target = None
|
|
132
|
+
mask = None
|
|
133
|
+
else:
|
|
134
|
+
target_normalized = self.normalize_target(target)
|
|
135
|
+
mask = ~((target == 0).reshape(len(target), -1).all(dim=1))
|
|
136
|
+
|
|
137
|
+
# Forward pass
|
|
138
|
+
out, td_data = self.forward(x_normalized)
|
|
139
|
+
|
|
140
|
+
if (
|
|
141
|
+
self.model.encoder_no_padding_mode
|
|
142
|
+
and out.shape[-2:] != target_normalized.shape[-2:]
|
|
143
|
+
):
|
|
144
|
+
target_normalized = F.center_crop(target_normalized, out.shape[-2:])
|
|
145
|
+
|
|
146
|
+
# Loss Computations
|
|
147
|
+
# mask = torch.isnan(target.reshape(len(x), -1)).all(dim=1)
|
|
148
|
+
recons_loss_dict, imgs = self.get_reconstruction_loss(
|
|
149
|
+
reconstruction=out,
|
|
150
|
+
target=target_normalized,
|
|
151
|
+
input=x_normalized,
|
|
152
|
+
splitting_mask=mask,
|
|
153
|
+
return_predicted_img=True,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
# This `if` is not used by default config
|
|
157
|
+
if self.skip_nboundary_pixels_from_loss:
|
|
158
|
+
pad = self.skip_nboundary_pixels_from_loss
|
|
159
|
+
target_normalized = target_normalized[:, :, pad:-pad, pad:-pad]
|
|
160
|
+
|
|
161
|
+
recons_loss = recons_loss_dict["loss"] * self.reconstruction_weight
|
|
162
|
+
|
|
163
|
+
if torch.isnan(recons_loss).any():
|
|
164
|
+
recons_loss = 0.0
|
|
165
|
+
|
|
166
|
+
if self.model.non_stochastic_version:
|
|
167
|
+
kl_loss = torch.Tensor([0.0]).cuda()
|
|
168
|
+
net_loss = recons_loss
|
|
169
|
+
else:
|
|
170
|
+
if self.loss_type == LossType.DenoiSplitMuSplit:
|
|
171
|
+
msg = f"For the loss type {LossType.name(self.loss_type)}, kl_loss_formulation must be denoisplit_usplit"
|
|
172
|
+
assert self.kl_loss_formulation == "denoisplit_usplit", msg
|
|
173
|
+
assert self._denoisplit_w is not None and self._usplit_w is not None
|
|
174
|
+
|
|
175
|
+
kl_key_denoisplit = "kl_restricted" if self._restricted_kl else "kl"
|
|
176
|
+
# NOTE: 'kl' key stands for the 'kl_samplewise' key in the TopDownLayer class.
|
|
177
|
+
# The different naming comes from `top_down_pass()` method in the LadderVAE class.
|
|
178
|
+
denoisplit_kl = self.get_kl_divergence_loss(
|
|
179
|
+
topdown_layer_data_dict=td_data, kl_key=kl_key_denoisplit
|
|
180
|
+
)
|
|
181
|
+
usplit_kl = self.get_kl_divergence_loss_usplit(
|
|
182
|
+
topdown_layer_data_dict=td_data
|
|
183
|
+
)
|
|
184
|
+
kl_loss = (
|
|
185
|
+
self._denoisplit_w * denoisplit_kl + self._usplit_w * usplit_kl
|
|
186
|
+
)
|
|
187
|
+
kl_loss = self.kl_weight * kl_loss
|
|
188
|
+
|
|
189
|
+
recons_loss = self.reconstruction_loss_musplit_denoisplit(
|
|
190
|
+
out, target_normalized
|
|
191
|
+
)
|
|
192
|
+
# recons_loss = self._denoisplit_w * recons_loss_nm + self._usplit_w * recons_loss_gm
|
|
193
|
+
|
|
194
|
+
elif self.kl_loss_formulation == "usplit":
|
|
195
|
+
kl_loss = self.get_kl_weight() * self.get_kl_divergence_loss_usplit(
|
|
196
|
+
td_data
|
|
197
|
+
)
|
|
198
|
+
elif self.kl_loss_formulation in ["", "denoisplit"]:
|
|
199
|
+
kl_loss = self.get_kl_weight() * self.get_kl_divergence_loss(td_data)
|
|
200
|
+
net_loss = recons_loss + kl_loss
|
|
201
|
+
|
|
202
|
+
# Logging
|
|
203
|
+
if enable_logging:
|
|
204
|
+
for i, x in enumerate(td_data["debug_qvar_max"]):
|
|
205
|
+
self.log(f"qvar_max:{i}", x.item(), on_epoch=True)
|
|
206
|
+
|
|
207
|
+
self.log("reconstruction_loss", recons_loss_dict["loss"], on_epoch=True)
|
|
208
|
+
self.log("kl_loss", kl_loss, on_epoch=True)
|
|
209
|
+
self.log("training_loss", net_loss, on_epoch=True)
|
|
210
|
+
self.log("lr", self.lr, on_epoch=True)
|
|
211
|
+
if self.model._tethered_ch2_scalar is not None:
|
|
212
|
+
self.log(
|
|
213
|
+
"tethered_ch2_scalar",
|
|
214
|
+
self.model._tethered_ch2_scalar,
|
|
215
|
+
on_epoch=True,
|
|
216
|
+
)
|
|
217
|
+
self.log(
|
|
218
|
+
"tethered_ch1_scalar",
|
|
219
|
+
self.model._tethered_ch1_scalar,
|
|
220
|
+
on_epoch=True,
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
# self.log('grad_norm_bottom_up', self.grad_norm_bottom_up, on_epoch=True)
|
|
224
|
+
# self.log('grad_norm_top_down', self.grad_norm_top_down, on_epoch=True)
|
|
225
|
+
|
|
226
|
+
output = {
|
|
227
|
+
"loss": net_loss,
|
|
228
|
+
"reconstruction_loss": (
|
|
229
|
+
recons_loss.detach()
|
|
230
|
+
if isinstance(recons_loss, torch.Tensor)
|
|
231
|
+
else recons_loss
|
|
232
|
+
),
|
|
233
|
+
"kl_loss": kl_loss.detach(),
|
|
234
|
+
}
|
|
235
|
+
# https://github.com/openai/vdvae/blob/main/train.py#L26
|
|
236
|
+
if torch.isnan(net_loss).any():
|
|
237
|
+
return None
|
|
238
|
+
|
|
239
|
+
return output
|
|
240
|
+
|
|
241
|
+
def validation_step(self, batch: torch.Tensor, batch_idx: int):
|
|
242
|
+
# Pre-processing of inputs
|
|
243
|
+
x, target = batch[:2]
|
|
244
|
+
self.set_params_to_same_device_as(x)
|
|
245
|
+
x_normalized = self.normalize_input(x)
|
|
246
|
+
if self.reconstruction_mode: # only for experimental purpose
|
|
247
|
+
target_normalized = x_normalized[:, :1].repeat(1, 2, 1, 1)
|
|
248
|
+
target = None
|
|
249
|
+
mask = None
|
|
250
|
+
else:
|
|
251
|
+
target_normalized = self.normalize_target(target)
|
|
252
|
+
mask = ~((target == 0).reshape(len(target), -1).all(dim=1))
|
|
253
|
+
|
|
254
|
+
# Forward pass
|
|
255
|
+
out, _ = self.forward(x_normalized)
|
|
256
|
+
|
|
257
|
+
if self.model.predict_logvar is not None:
|
|
258
|
+
out_mean, _ = out.chunk(2, dim=1)
|
|
259
|
+
else:
|
|
260
|
+
out_mean = out
|
|
261
|
+
|
|
262
|
+
if (
|
|
263
|
+
self.model.encoder_no_padding_mode
|
|
264
|
+
and out.shape[-2:] != target_normalized.shape[-2:]
|
|
265
|
+
):
|
|
266
|
+
target_normalized = F.center_crop(target_normalized, out.shape[-2:])
|
|
267
|
+
|
|
268
|
+
if self.loss_type == LossType.DenoiSplitMuSplit:
|
|
269
|
+
recons_loss = self.reconstruction_loss_musplit_denoisplit(
|
|
270
|
+
out, target_normalized
|
|
271
|
+
)
|
|
272
|
+
recons_loss_dict = {"loss": recons_loss}
|
|
273
|
+
recons_img = out_mean
|
|
274
|
+
else:
|
|
275
|
+
# Metrics computation
|
|
276
|
+
recons_loss_dict, recons_img = self.get_reconstruction_loss(
|
|
277
|
+
reconstruction=out_mean,
|
|
278
|
+
target=target_normalized,
|
|
279
|
+
input=x_normalized,
|
|
280
|
+
splitting_mask=mask,
|
|
281
|
+
return_predicted_img=True,
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
# This `if` is not used by default config
|
|
285
|
+
if self.skip_nboundary_pixels_from_loss:
|
|
286
|
+
pad = self.skip_nboundary_pixels_from_loss
|
|
287
|
+
target_normalized = target_normalized[:, :, pad:-pad, pad:-pad]
|
|
288
|
+
|
|
289
|
+
channels_rinvpsnr = []
|
|
290
|
+
for i in range(target_normalized.shape[1]):
|
|
291
|
+
self.channels_psnr[i].update(recons_img[:, i], target_normalized[:, i])
|
|
292
|
+
psnr = RangeInvariantPsnr(
|
|
293
|
+
target_normalized[:, i].clone(), recons_img[:, i].clone()
|
|
294
|
+
)
|
|
295
|
+
channels_rinvpsnr.append(psnr)
|
|
296
|
+
psnr = torch_nanmean(psnr).item()
|
|
297
|
+
self.log(f"val_psnr_l{i+1}", psnr, on_epoch=True)
|
|
298
|
+
|
|
299
|
+
recons_loss = recons_loss_dict["loss"]
|
|
300
|
+
if torch.isnan(recons_loss).any():
|
|
301
|
+
return
|
|
302
|
+
|
|
303
|
+
self.log("val_loss", recons_loss, on_epoch=True)
|
|
304
|
+
# self.log('val_psnr', (val_psnr_l1 + val_psnr_l2) / 2, on_epoch=True)
|
|
305
|
+
|
|
306
|
+
# if batch_idx == 0 and self.power_of_2(self.current_epoch):
|
|
307
|
+
# all_samples = []
|
|
308
|
+
# for i in range(20):
|
|
309
|
+
# sample, _ = self(x_normalized[0:1, ...])
|
|
310
|
+
# sample = self.likelihood.get_mean_lv(sample)[0]
|
|
311
|
+
# all_samples.append(sample[None])
|
|
312
|
+
|
|
313
|
+
# all_samples = torch.cat(all_samples, dim=0)
|
|
314
|
+
# all_samples = all_samples * self.data_std + self.data_mean
|
|
315
|
+
# all_samples = all_samples.cpu()
|
|
316
|
+
# img_mmse = torch.mean(all_samples, dim=0)[0]
|
|
317
|
+
# self.log_images_for_tensorboard(all_samples[:, 0, 0, ...], target[0, 0, ...], img_mmse[0], 'label1')
|
|
318
|
+
# self.log_images_for_tensorboard(all_samples[:, 0, 1, ...], target[0, 1, ...], img_mmse[1], 'label2')
|
|
319
|
+
|
|
320
|
+
# return net_loss
|
|
321
|
+
|
|
322
|
+
def on_validation_epoch_end(self):
|
|
323
|
+
psnr_arr = []
|
|
324
|
+
for i in range(len(self.channels_psnr)):
|
|
325
|
+
psnr = self.channels_psnr[i].get()
|
|
326
|
+
if psnr is None:
|
|
327
|
+
psnr_arr = None
|
|
328
|
+
break
|
|
329
|
+
psnr_arr.append(psnr.cpu().numpy())
|
|
330
|
+
self.channels_psnr[i].reset()
|
|
331
|
+
|
|
332
|
+
if psnr_arr is not None:
|
|
333
|
+
psnr = np.mean(psnr_arr)
|
|
334
|
+
self.log("val_psnr", psnr, on_epoch=True)
|
|
335
|
+
else:
|
|
336
|
+
self.log("val_psnr", 0.0, on_epoch=True)
|
|
337
|
+
|
|
338
|
+
if self.mixed_rec_w_step:
|
|
339
|
+
self.mixed_rec_w = max(self.mixed_rec_w - self.mixed_rec_w_step, 0.0)
|
|
340
|
+
self.log("mixed_rec_w", self.mixed_rec_w, on_epoch=True)
|
|
341
|
+
|
|
342
|
+
def predict_step(self, batch: torch.Tensor, batch_idx: Any) -> Any:
|
|
343
|
+
raise NotImplementedError("predict_step is not implemented")
|
|
344
|
+
|
|
345
|
+
def configure_optimizers(self):
|
|
346
|
+
optimizer = torch.optim.Adamax(self.parameters(), lr=self.lr, weight_decay=0)
|
|
347
|
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
|
348
|
+
optimizer,
|
|
349
|
+
self.lr_scheduler_mode,
|
|
350
|
+
patience=self.lr_scheduler_patience,
|
|
351
|
+
factor=0.5,
|
|
352
|
+
min_lr=1e-12,
|
|
353
|
+
verbose=True,
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
return {
|
|
357
|
+
"optimizer": optimizer,
|
|
358
|
+
"lr_scheduler": scheduler,
|
|
359
|
+
"monitor": self.lr_scheduler_monitor,
|
|
360
|
+
}
|
|
361
|
+
|
|
362
|
+
##### REQUIRED Methods for Loss Computation #####
|
|
363
|
+
def get_reconstruction_loss(
|
|
364
|
+
self,
|
|
365
|
+
reconstruction: torch.Tensor,
|
|
366
|
+
target: torch.Tensor,
|
|
367
|
+
input: torch.Tensor,
|
|
368
|
+
splitting_mask: torch.Tensor = None,
|
|
369
|
+
return_predicted_img: bool = False,
|
|
370
|
+
likelihood_obj: LikelihoodModule = None,
|
|
371
|
+
) -> Dict[str, torch.Tensor]:
|
|
372
|
+
"""
|
|
373
|
+
Parameters
|
|
374
|
+
----------
|
|
375
|
+
reconstruction: torch.Tensor,
|
|
376
|
+
target: torch.Tensor
|
|
377
|
+
input: torch.Tensor
|
|
378
|
+
splitting_mask: torch.Tensor = None
|
|
379
|
+
A boolean tensor that indicates which items to keep for reconstruction loss computation.
|
|
380
|
+
If `None`, all the elements of the items are considered (i.e., the mask is all `True`).
|
|
381
|
+
return_predicted_img: bool = False
|
|
382
|
+
likelihood_obj: LikelihoodModule = None
|
|
383
|
+
"""
|
|
384
|
+
output = self._get_reconstruction_loss_vector(
|
|
385
|
+
reconstruction=reconstruction,
|
|
386
|
+
target=target,
|
|
387
|
+
input=input,
|
|
388
|
+
return_predicted_img=return_predicted_img,
|
|
389
|
+
likelihood_obj=likelihood_obj,
|
|
390
|
+
)
|
|
391
|
+
loss_dict = output[0] if return_predicted_img else output
|
|
392
|
+
|
|
393
|
+
if splitting_mask is None:
|
|
394
|
+
splitting_mask = torch.ones_like(loss_dict["loss"]).bool()
|
|
395
|
+
|
|
396
|
+
# print(len(target) - (torch.isnan(loss_dict['loss'])).sum())
|
|
397
|
+
|
|
398
|
+
loss_dict["loss"] = loss_dict["loss"][splitting_mask].sum() / len(
|
|
399
|
+
reconstruction
|
|
400
|
+
)
|
|
401
|
+
for i in range(1, 1 + target.shape[1]):
|
|
402
|
+
key = f"ch{i}_loss"
|
|
403
|
+
loss_dict[key] = loss_dict[key][splitting_mask].sum() / len(reconstruction)
|
|
404
|
+
|
|
405
|
+
if "mixed_loss" in loss_dict:
|
|
406
|
+
loss_dict["mixed_loss"] = torch.mean(loss_dict["mixed_loss"])
|
|
407
|
+
if return_predicted_img:
|
|
408
|
+
assert len(output) == 2
|
|
409
|
+
return loss_dict, output[1]
|
|
410
|
+
else:
|
|
411
|
+
return loss_dict
|
|
412
|
+
|
|
413
|
+
def _get_reconstruction_loss_vector(
|
|
414
|
+
self,
|
|
415
|
+
reconstruction: torch.Tensor,
|
|
416
|
+
target: torch.Tensor,
|
|
417
|
+
input: torch.Tensor,
|
|
418
|
+
return_predicted_img: bool = False,
|
|
419
|
+
likelihood_obj: LikelihoodModule = None,
|
|
420
|
+
):
|
|
421
|
+
"""
|
|
422
|
+
Parameters
|
|
423
|
+
----------
|
|
424
|
+
return_predicted_img: bool
|
|
425
|
+
If set to `True`, the besides the loss, the reconstructed image is also returned.
|
|
426
|
+
Default is `False`.
|
|
427
|
+
"""
|
|
428
|
+
output = {
|
|
429
|
+
"loss": None,
|
|
430
|
+
"mixed_loss": None,
|
|
431
|
+
}
|
|
432
|
+
|
|
433
|
+
for i in range(1, 1 + target.shape[1]):
|
|
434
|
+
output[f"ch{i}_loss"] = None
|
|
435
|
+
|
|
436
|
+
if likelihood_obj is None:
|
|
437
|
+
likelihood_obj = self.model.likelihood
|
|
438
|
+
|
|
439
|
+
# Log likelihood
|
|
440
|
+
ll, like_dict = likelihood_obj(reconstruction, target)
|
|
441
|
+
ll = self._get_weighted_likelihood(ll)
|
|
442
|
+
if (
|
|
443
|
+
self.skip_nboundary_pixels_from_loss is not None
|
|
444
|
+
and self.skip_nboundary_pixels_from_loss > 0
|
|
445
|
+
):
|
|
446
|
+
pad = self.skip_nboundary_pixels_from_loss
|
|
447
|
+
ll = ll[:, :, pad:-pad, pad:-pad]
|
|
448
|
+
like_dict["params"]["mean"] = like_dict["params"]["mean"][
|
|
449
|
+
:, :, pad:-pad, pad:-pad
|
|
450
|
+
]
|
|
451
|
+
|
|
452
|
+
# assert ll.shape[1] == 2, f"Change the code below to handle >2 channels first. ll.shape {ll.shape}"
|
|
453
|
+
output = {"loss": compute_batch_mean(-1 * ll)}
|
|
454
|
+
if ll.shape[1] > 1:
|
|
455
|
+
for i in range(1, 1 + target.shape[1]):
|
|
456
|
+
output[f"ch{i}_loss"] = compute_batch_mean(-ll[:, i - 1])
|
|
457
|
+
else:
|
|
458
|
+
assert ll.shape[1] == 1
|
|
459
|
+
output["ch1_loss"] = output["loss"]
|
|
460
|
+
output["ch2_loss"] = output["loss"]
|
|
461
|
+
|
|
462
|
+
if (
|
|
463
|
+
self.channel_1_w is not None
|
|
464
|
+
and self.channel_2_w is not None
|
|
465
|
+
and (self.channel_1_w != 1 or self.channel_2_w != 1)
|
|
466
|
+
):
|
|
467
|
+
assert ll.shape[1] == 2, "Only 2 channels are supported for now."
|
|
468
|
+
output["loss"] = (
|
|
469
|
+
self.channel_1_w * output["ch1_loss"]
|
|
470
|
+
+ self.channel_2_w * output["ch2_loss"]
|
|
471
|
+
) / (self.channel_1_w + self.channel_2_w)
|
|
472
|
+
|
|
473
|
+
# This `if` is not used by default config
|
|
474
|
+
if self.enable_mixed_rec:
|
|
475
|
+
mixed_pred, mixed_logvar = self.get_mixed_prediction(
|
|
476
|
+
like_dict["params"]["mean"],
|
|
477
|
+
like_dict["params"]["logvar"],
|
|
478
|
+
self.data_mean,
|
|
479
|
+
self.data_std,
|
|
480
|
+
)
|
|
481
|
+
if (
|
|
482
|
+
self.model._multiscale_count is not None
|
|
483
|
+
and self.model._multiscale_count > 1
|
|
484
|
+
):
|
|
485
|
+
assert input.shape[1] == self.model._multiscale_count
|
|
486
|
+
input = input[:, :1]
|
|
487
|
+
|
|
488
|
+
assert (
|
|
489
|
+
input.shape == mixed_pred.shape
|
|
490
|
+
), "No fucking room for vectorization induced bugs."
|
|
491
|
+
mixed_recons_ll = self.model.likelihood.log_likelihood(
|
|
492
|
+
input, {"mean": mixed_pred, "logvar": mixed_logvar}
|
|
493
|
+
)
|
|
494
|
+
output["mixed_loss"] = compute_batch_mean(-1 * mixed_recons_ll)
|
|
495
|
+
|
|
496
|
+
# This `if` is not used by default config
|
|
497
|
+
if self._exclusion_loss_weight:
|
|
498
|
+
raise NotImplementedError(
|
|
499
|
+
"Exclusion loss is not well defined here, so it should not be used."
|
|
500
|
+
)
|
|
501
|
+
imgs = like_dict["params"]["mean"]
|
|
502
|
+
exclusion_loss = compute_exclusion_loss(imgs[:, :1], imgs[:, 1:])
|
|
503
|
+
output["exclusion_loss"] = exclusion_loss
|
|
504
|
+
|
|
505
|
+
if return_predicted_img:
|
|
506
|
+
return output, like_dict["params"]["mean"]
|
|
507
|
+
|
|
508
|
+
return output
|
|
509
|
+
|
|
510
|
+
def reconstruction_loss_musplit_denoisplit(self, out, target_normalized):
|
|
511
|
+
if self.model.predict_logvar is not None:
|
|
512
|
+
out_mean, _ = out.chunk(2, dim=1)
|
|
513
|
+
else:
|
|
514
|
+
out_mean = out
|
|
515
|
+
|
|
516
|
+
recons_loss_nm = (
|
|
517
|
+
-1 * self.model.likelihood_NM(out_mean, target_normalized)[0].mean()
|
|
518
|
+
)
|
|
519
|
+
recons_loss_gm = -1 * self.model.likelihood_gm(out, target_normalized)[0].mean()
|
|
520
|
+
recons_loss = (
|
|
521
|
+
self._denoisplit_w * recons_loss_nm + self._usplit_w * recons_loss_gm
|
|
522
|
+
)
|
|
523
|
+
return recons_loss
|
|
524
|
+
|
|
525
|
+
def _get_weighted_likelihood(self, ll):
|
|
526
|
+
"""
|
|
527
|
+
Each of the channels gets multiplied with a different weight.
|
|
528
|
+
"""
|
|
529
|
+
if self.ch1_recons_w == 1 and self.ch2_recons_w == 1:
|
|
530
|
+
return ll
|
|
531
|
+
|
|
532
|
+
assert ll.shape[1] == 2, "This function is only for 2 channel images"
|
|
533
|
+
|
|
534
|
+
mask1 = torch.zeros((len(ll), ll.shape[1], 1, 1), device=ll.device)
|
|
535
|
+
mask1[:, 0] = 1
|
|
536
|
+
mask2 = torch.zeros((len(ll), ll.shape[1], 1, 1), device=ll.device)
|
|
537
|
+
mask2[:, 1] = 1
|
|
538
|
+
|
|
539
|
+
return ll * mask1 * self.ch1_recons_w + ll * mask2 * self.ch2_recons_w
|
|
540
|
+
|
|
541
|
+
def get_kl_weight(self):
|
|
542
|
+
"""
|
|
543
|
+
KL loss can be weighted depending whether any annealing procedure is used.
|
|
544
|
+
This function computes the weight of the KL loss in case of annealing.
|
|
545
|
+
"""
|
|
546
|
+
if self.kl_annealing == True:
|
|
547
|
+
# calculate relative weight
|
|
548
|
+
kl_weight = (self.current_epoch - self.kl_start) * (
|
|
549
|
+
1.0 / self.kl_annealtime
|
|
550
|
+
)
|
|
551
|
+
# clamp to [0,1]
|
|
552
|
+
kl_weight = min(max(0.0, kl_weight), 1.0)
|
|
553
|
+
|
|
554
|
+
# if the final weight is given, then apply that weight on top of it
|
|
555
|
+
if self.kl_weight is not None:
|
|
556
|
+
kl_weight = kl_weight * self.kl_weight
|
|
557
|
+
elif self.kl_weight is not None:
|
|
558
|
+
return self.kl_weight
|
|
559
|
+
else:
|
|
560
|
+
kl_weight = 1.0
|
|
561
|
+
return kl_weight
|
|
562
|
+
|
|
563
|
+
def get_kl_divergence_loss_usplit(
|
|
564
|
+
self, topdown_layer_data_dict: Dict[str, torch.Tensor]
|
|
565
|
+
) -> torch.Tensor:
|
|
566
|
+
""" """
|
|
567
|
+
kl = torch.cat(
|
|
568
|
+
[kl_layer.unsqueeze(1) for kl_layer in topdown_layer_data_dict["kl"]], dim=1
|
|
569
|
+
)
|
|
570
|
+
# NOTE: kl.shape = (16,4) 16 is batch size. 4 is number of layers.
|
|
571
|
+
# Values are sum() and so are of the order 30000
|
|
572
|
+
# Example values: 30626.6758, 31028.8145, 29509.8809, 29945.4922, 28919.1875, 29075.2988
|
|
573
|
+
|
|
574
|
+
nlayers = kl.shape[1]
|
|
575
|
+
for i in range(nlayers):
|
|
576
|
+
# topdown_layer_data_dict['z'][2].shape[-3:] = 128 * 32 * 32
|
|
577
|
+
norm_factor = np.prod(topdown_layer_data_dict["z"][i].shape[-3:])
|
|
578
|
+
# if self._restricted_kl:
|
|
579
|
+
# pow = np.power(2,min(i + 1, self._multiscale_count-1))
|
|
580
|
+
# norm_factor /= pow * pow
|
|
581
|
+
|
|
582
|
+
kl[:, i] = kl[:, i] / norm_factor
|
|
583
|
+
|
|
584
|
+
kl_loss = free_bits_kl(kl, 0.0).mean()
|
|
585
|
+
return kl_loss
|
|
586
|
+
|
|
587
|
+
def get_kl_divergence_loss(self, topdown_layer_data_dict, kl_key="kl"):
|
|
588
|
+
"""
|
|
589
|
+
kl[i] for each i has length batch_size
|
|
590
|
+
resulting kl shape: (batch_size, layers)
|
|
591
|
+
"""
|
|
592
|
+
kl = torch.cat(
|
|
593
|
+
[kl_layer.unsqueeze(1) for kl_layer in topdown_layer_data_dict[kl_key]],
|
|
594
|
+
dim=1,
|
|
595
|
+
)
|
|
596
|
+
|
|
597
|
+
# As compared to uSplit kl divergence,
|
|
598
|
+
# more by a factor of 4 just because we do sum and not mean.
|
|
599
|
+
kl_loss = free_bits_kl(kl, self.free_bits).sum()
|
|
600
|
+
# NOTE: at each hierarchy, it is more by a factor of 128/i**2).
|
|
601
|
+
# 128/(2*2) = 32 (bottommost layer)
|
|
602
|
+
# 128/(4*4) = 8
|
|
603
|
+
# 128/(8*8) = 2
|
|
604
|
+
# 128/(16*16) = 0.5 (topmost layer)
|
|
605
|
+
|
|
606
|
+
# Normalize the KL-loss w.r.t. the latent space
|
|
607
|
+
kl_loss = kl_loss / np.prod(self.model.img_shape)
|
|
608
|
+
return kl_loss
|
|
609
|
+
|
|
610
|
+
##### UTILS Methods #####
|
|
611
|
+
def normalize_input(self, x):
|
|
612
|
+
if self.model.normalized_input:
|
|
613
|
+
return x
|
|
614
|
+
return (x - self.data_mean["input"].mean()) / self.data_std["input"].mean()
|
|
615
|
+
|
|
616
|
+
def normalize_target(self, target, batch=None):
|
|
617
|
+
return (target - self.data_mean["target"]) / self.data_std["target"]
|
|
618
|
+
|
|
619
|
+
def unnormalize_target(self, target_normalized):
|
|
620
|
+
return target_normalized * self.data_std["target"] + self.data_mean["target"]
|
|
621
|
+
|
|
622
|
+
##### ADDITIONAL Methods #####
|
|
623
|
+
# def log_images_for_tensorboard(self, pred, target, img_mmse, label):
|
|
624
|
+
# clamped_pred = torch.clamp((pred - pred.min()) / (pred.max() - pred.min()), 0, 1)
|
|
625
|
+
# clamped_mmse = torch.clamp((img_mmse - img_mmse.min()) / (img_mmse.max() - img_mmse.min()), 0, 1)
|
|
626
|
+
# if target is not None:
|
|
627
|
+
# clamped_input = torch.clamp((target - target.min()) / (target.max() - target.min()), 0, 1)
|
|
628
|
+
# img = wandb.Image(clamped_input[None].cpu().numpy())
|
|
629
|
+
# self.logger.experiment.log({f'target_for{label}': img})
|
|
630
|
+
# # self.trainer.logger.experiment.add_image(f'target_for{label}', clamped_input[None], self.current_epoch)
|
|
631
|
+
# for i in range(3):
|
|
632
|
+
# # self.trainer.logger.experiment.add_image(f'{label}/sample_{i}', clamped_pred[i:i + 1], self.current_epoch)
|
|
633
|
+
# img = wandb.Image(clamped_pred[i:i + 1].cpu().numpy())
|
|
634
|
+
# self.logger.experiment.log({f'{label}/sample_{i}': img})
|
|
635
|
+
|
|
636
|
+
# img = wandb.Image(clamped_mmse[None].cpu().numpy())
|
|
637
|
+
# self.trainer.logger.experiment.log({f'{label}/mmse (100 samples)': img})
|
|
638
|
+
|
|
639
|
+
@property
|
|
640
|
+
def global_step(self) -> int:
|
|
641
|
+
"""Global step."""
|
|
642
|
+
return self._global_step
|
|
643
|
+
|
|
644
|
+
def increment_global_step(self):
|
|
645
|
+
"""Increments global step by 1."""
|
|
646
|
+
self._global_step += 1
|
|
647
|
+
|
|
648
|
+
def set_params_to_same_device_as(self, correct_device_tensor: torch.Tensor):
|
|
649
|
+
|
|
650
|
+
self.model.likelihood.set_params_to_same_device_as(correct_device_tensor)
|
|
651
|
+
if isinstance(self.data_mean, torch.Tensor):
|
|
652
|
+
if self.data_mean.device != correct_device_tensor.device:
|
|
653
|
+
self.data_mean = self.data_mean.to(correct_device_tensor.device)
|
|
654
|
+
self.data_std = self.data_std.to(correct_device_tensor.device)
|
|
655
|
+
elif isinstance(self.data_mean, dict):
|
|
656
|
+
for k, v in self.data_mean.items():
|
|
657
|
+
if v.device != correct_device_tensor.device:
|
|
658
|
+
self.data_mean[k] = v.to(correct_device_tensor.device)
|
|
659
|
+
self.data_std[k] = self.data_std[k].to(correct_device_tensor.device)
|
|
660
|
+
|
|
661
|
+
def get_mixed_prediction(
|
|
662
|
+
self, prediction, prediction_logvar, data_mean, data_std, channel_weights=None
|
|
663
|
+
):
|
|
664
|
+
pred_unorm = prediction * data_std["target"] + data_mean["target"]
|
|
665
|
+
if channel_weights is None:
|
|
666
|
+
channel_weights = 1
|
|
667
|
+
|
|
668
|
+
if self._input_is_sum:
|
|
669
|
+
mixed_prediction = torch.sum(
|
|
670
|
+
pred_unorm * channel_weights, dim=1, keepdim=True
|
|
671
|
+
)
|
|
672
|
+
else:
|
|
673
|
+
mixed_prediction = torch.mean(
|
|
674
|
+
pred_unorm * channel_weights, dim=1, keepdim=True
|
|
675
|
+
)
|
|
676
|
+
|
|
677
|
+
mixed_prediction = (mixed_prediction - data_mean["input"].mean()) / data_std[
|
|
678
|
+
"input"
|
|
679
|
+
].mean()
|
|
680
|
+
|
|
681
|
+
if prediction_logvar is not None:
|
|
682
|
+
if data_std["target"].shape == data_std["input"].shape and torch.all(
|
|
683
|
+
data_std["target"] == data_std["input"]
|
|
684
|
+
):
|
|
685
|
+
assert channel_weights == 1
|
|
686
|
+
logvar = prediction_logvar
|
|
687
|
+
else:
|
|
688
|
+
var = torch.exp(prediction_logvar)
|
|
689
|
+
var = var * (data_std["target"] / data_std["input"]) ** 2
|
|
690
|
+
if channel_weights != 1:
|
|
691
|
+
var = var * torch.square(channel_weights)
|
|
692
|
+
|
|
693
|
+
# sum of variance.
|
|
694
|
+
mixed_var = 0
|
|
695
|
+
for i in range(var.shape[1]):
|
|
696
|
+
mixed_var += var[:, i : i + 1]
|
|
697
|
+
|
|
698
|
+
logvar = torch.log(mixed_var)
|
|
699
|
+
else:
|
|
700
|
+
logvar = None
|
|
701
|
+
return mixed_prediction, logvar
|