careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,589 @@
|
|
|
1
|
+
"""Methods for Loss Computation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Literal, Union
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from careamics.losses.lvae.loss_utils import free_bits_kl, get_kl_weight
|
|
11
|
+
from careamics.models.lvae.likelihoods import (
|
|
12
|
+
GaussianLikelihood,
|
|
13
|
+
LikelihoodModule,
|
|
14
|
+
NoiseModelLikelihood,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from careamics.config import LVAELossConfig
|
|
19
|
+
|
|
20
|
+
Likelihood = Union[LikelihoodModule, GaussianLikelihood, NoiseModelLikelihood]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def get_reconstruction_loss(
|
|
24
|
+
reconstruction: torch.Tensor,
|
|
25
|
+
target: torch.Tensor,
|
|
26
|
+
likelihood_obj: Likelihood,
|
|
27
|
+
) -> dict[str, torch.Tensor]:
|
|
28
|
+
"""Compute the reconstruction loss (negative log-likelihood).
|
|
29
|
+
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
reconstruction: torch.Tensor
|
|
33
|
+
The output of the LVAE decoder. Shape is (B, C, [Z], Y, X), where C is the
|
|
34
|
+
number of output channels (e.g., 1 in HDN, >1 in muSplit/denoiSplit).
|
|
35
|
+
target: torch.Tensor
|
|
36
|
+
The target image used to compute the reconstruction loss. Shape is
|
|
37
|
+
(B, C, [Z], Y, X), where C is the number of output channels
|
|
38
|
+
(e.g., 1 in HDN, >1 in muSplit/denoiSplit).
|
|
39
|
+
likelihood_obj: Likelihood
|
|
40
|
+
The likelihood object used to compute the reconstruction loss.
|
|
41
|
+
|
|
42
|
+
Returns
|
|
43
|
+
-------
|
|
44
|
+
torch.Tensor
|
|
45
|
+
The recontruction loss (negative log-likelihood).
|
|
46
|
+
"""
|
|
47
|
+
# Compute Log likelihood
|
|
48
|
+
ll, _ = likelihood_obj(reconstruction, target) # shape: (B, C, [Z], Y, X)
|
|
49
|
+
return -1 * ll.mean()
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _reconstruction_loss_musplit_denoisplit(
|
|
53
|
+
predictions: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
|
54
|
+
targets: torch.Tensor,
|
|
55
|
+
nm_likelihood: NoiseModelLikelihood,
|
|
56
|
+
gaussian_likelihood: GaussianLikelihood,
|
|
57
|
+
nm_weight: float,
|
|
58
|
+
gaussian_weight: float,
|
|
59
|
+
) -> torch.Tensor:
|
|
60
|
+
"""Compute the reconstruction loss for muSplit-denoiSplit loss.
|
|
61
|
+
|
|
62
|
+
The resulting loss is a weighted mean of the noise model likelihood and the
|
|
63
|
+
Gaussian likelihood.
|
|
64
|
+
|
|
65
|
+
Parameters
|
|
66
|
+
----------
|
|
67
|
+
predictions : torch.Tensor
|
|
68
|
+
The output of the LVAE decoder. Shape is (B, C, [Z], Y, X), or
|
|
69
|
+
(B, 2*C, [Z], Y, X), where C is the number of output channels,
|
|
70
|
+
and the factor of 2 is for the case of predicted log-variance.
|
|
71
|
+
targets : torch.Tensor
|
|
72
|
+
The target image used to compute the reconstruction loss. Shape is
|
|
73
|
+
(B, C, [Z], Y, X), where C is the number of output channels
|
|
74
|
+
(e.g., 1 in HDN, >1 in muSplit/denoiSplit).
|
|
75
|
+
nm_likelihood : NoiseModelLikelihood
|
|
76
|
+
A `NoiseModelLikelihood` object used to compute the noise model likelihood.
|
|
77
|
+
gaussian_likelihood : GaussianLikelihood
|
|
78
|
+
A `GaussianLikelihood` object used to compute the Gaussian likelihood.
|
|
79
|
+
nm_weight : float
|
|
80
|
+
The weight for the noise model likelihood.
|
|
81
|
+
gaussian_weight : float
|
|
82
|
+
The weight for the Gaussian likelihood.
|
|
83
|
+
|
|
84
|
+
Returns
|
|
85
|
+
-------
|
|
86
|
+
recons_loss : torch.Tensor
|
|
87
|
+
The reconstruction loss. Shape is (1, ).
|
|
88
|
+
"""
|
|
89
|
+
if predictions.shape[1] == 2 * targets.shape[1]:
|
|
90
|
+
# predictions contain both mean and log-variance
|
|
91
|
+
pred_mean, _ = predictions.chunk(2, dim=1)
|
|
92
|
+
# TODO if this condition does not hold, everything breaks later!
|
|
93
|
+
else:
|
|
94
|
+
pred_mean = predictions
|
|
95
|
+
|
|
96
|
+
recons_loss_nm = get_reconstruction_loss(
|
|
97
|
+
reconstruction=pred_mean, target=targets, likelihood_obj=nm_likelihood
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
recons_loss_gm = get_reconstruction_loss(
|
|
101
|
+
reconstruction=predictions,
|
|
102
|
+
target=targets,
|
|
103
|
+
likelihood_obj=gaussian_likelihood,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
recons_loss = nm_weight * recons_loss_nm + gaussian_weight * recons_loss_gm
|
|
107
|
+
return recons_loss
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def get_kl_divergence_loss(
|
|
111
|
+
kl_type: Literal["kl", "kl_restricted"],
|
|
112
|
+
topdown_data: dict[str, torch.Tensor],
|
|
113
|
+
rescaling: Literal["latent_dim", "image_dim"],
|
|
114
|
+
aggregation: Literal["mean", "sum"],
|
|
115
|
+
free_bits_coeff: float,
|
|
116
|
+
img_shape: tuple[int] | None = None,
|
|
117
|
+
) -> torch.Tensor:
|
|
118
|
+
"""Compute the KL divergence loss.
|
|
119
|
+
|
|
120
|
+
NOTE: Description of `rescaling` methods:
|
|
121
|
+
- If "latent_dim", the KL-loss values are rescaled w.r.t. the latent space
|
|
122
|
+
dimensions (spatial + number of channels, i.e., (C, [Z], Y, X)). In this way they
|
|
123
|
+
have the same magnitude across layers.
|
|
124
|
+
- If "image_dim", the KL-loss values are rescaled w.r.t. the input image spatial
|
|
125
|
+
dimensions. In this way, the lower layers have a larger KL-loss value compared to
|
|
126
|
+
the higher layers, since the latent space and hence the KL tensor has more entries.
|
|
127
|
+
Specifically, at hierarchy `i`, the total KL loss is larger by a factor (128/i**2).
|
|
128
|
+
|
|
129
|
+
NOTE: the type of `aggregation` determines the magnitude of the KL-loss. Clearly,
|
|
130
|
+
"sum" aggregation results in a larger KL-loss value compared to "mean" by a factor
|
|
131
|
+
of `n_layers`.
|
|
132
|
+
|
|
133
|
+
NOTE: recall that sample-wise KL is obtained by summing over all dimensions,
|
|
134
|
+
including Z. Also recall that in current 3D implementation of LVAE, no downsampling
|
|
135
|
+
is done on Z. Therefore, to avoid emphasizing KL loss too much, we divide it
|
|
136
|
+
by the Z dimension of input image in every case.
|
|
137
|
+
|
|
138
|
+
Parameters
|
|
139
|
+
----------
|
|
140
|
+
kl_type : Literal["kl", "kl_restricted"]
|
|
141
|
+
The type of KL divergence loss to compute.
|
|
142
|
+
topdown_data : dict[str, torch.Tensor]
|
|
143
|
+
A dictionary containing information computed for each layer during the top-down
|
|
144
|
+
pass. The dictionary must include the following keys:
|
|
145
|
+
- "kl": The KL-loss values for each layer. Shape of each tensor is (B,).
|
|
146
|
+
- "z": The sampled latents for each layer. Shape of each tensor is
|
|
147
|
+
(B, layers, `z_dims[i]`, H, W).
|
|
148
|
+
rescaling : Literal["latent_dim", "image_dim"]
|
|
149
|
+
The rescaling method used for the KL-loss values. If "latent_dim", the KL-loss
|
|
150
|
+
values are rescaled w.r.t. the latent space dimensions (spatial + number of
|
|
151
|
+
channels, i.e., (C, [Z], Y, X)). If "image_dim", the KL-loss values are
|
|
152
|
+
rescaled w.r.t. the input image spatial dimensions.
|
|
153
|
+
aggregation : Literal["mean", "sum"]
|
|
154
|
+
The aggregation method used to combine the KL-loss values across layers. If
|
|
155
|
+
"mean", the KL-loss values are averaged across layers. If "sum", the KL-loss
|
|
156
|
+
values are summed across layers.
|
|
157
|
+
free_bits_coeff : float
|
|
158
|
+
The free bits coefficient used for the KL-loss computation.
|
|
159
|
+
img_shape : Optional[tuple[int]]
|
|
160
|
+
The shape of the input image to the LVAE model. Shape is ([Z], Y, X).
|
|
161
|
+
|
|
162
|
+
Returns
|
|
163
|
+
-------
|
|
164
|
+
kl_loss : torch.Tensor
|
|
165
|
+
The KL divergence loss. Shape is (1, ).
|
|
166
|
+
"""
|
|
167
|
+
kl = torch.cat(
|
|
168
|
+
[kl_layer.unsqueeze(1) for kl_layer in topdown_data[kl_type]],
|
|
169
|
+
dim=1,
|
|
170
|
+
) # shape: (B, n_layers)
|
|
171
|
+
|
|
172
|
+
# Apply free bits (& batch average)
|
|
173
|
+
kl = free_bits_kl(kl, free_bits_coeff) # shape: (n_layers,)
|
|
174
|
+
|
|
175
|
+
# In 3D case, rescale by Z dim
|
|
176
|
+
# TODO If we have downsampling in Z dimension, then this needs to change.
|
|
177
|
+
if len(img_shape) == 3:
|
|
178
|
+
kl = kl / img_shape[0]
|
|
179
|
+
|
|
180
|
+
# Rescaling
|
|
181
|
+
if rescaling == "latent_dim":
|
|
182
|
+
for i in range(len(kl)):
|
|
183
|
+
latent_dim = topdown_data["z"][i].shape[1:]
|
|
184
|
+
norm_factor = np.prod(latent_dim)
|
|
185
|
+
kl[i] = kl[i] / norm_factor
|
|
186
|
+
elif rescaling == "image_dim":
|
|
187
|
+
kl = kl / np.prod(img_shape[-2:])
|
|
188
|
+
|
|
189
|
+
# Aggregation
|
|
190
|
+
if aggregation == "mean":
|
|
191
|
+
kl = kl.mean() # shape: (1,)
|
|
192
|
+
elif aggregation == "sum":
|
|
193
|
+
kl = kl.sum() # shape: (1,)
|
|
194
|
+
|
|
195
|
+
return kl
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def _get_kl_divergence_loss_musplit(
|
|
199
|
+
topdown_data: dict[str, torch.Tensor],
|
|
200
|
+
img_shape: tuple[int],
|
|
201
|
+
kl_type: Literal["kl", "kl_restricted"],
|
|
202
|
+
) -> torch.Tensor:
|
|
203
|
+
"""Compute the KL divergence loss for muSplit.
|
|
204
|
+
|
|
205
|
+
Parameters
|
|
206
|
+
----------
|
|
207
|
+
topdown_data : dict[str, torch.Tensor]
|
|
208
|
+
A dictionary containing information computed for each layer during the top-down
|
|
209
|
+
pass. The dictionary must include the following keys:
|
|
210
|
+
- "kl": The KL-loss values for each layer. Shape of each tensor is (B,).
|
|
211
|
+
- "z": The sampled latents for each layer. Shape of each tensor is
|
|
212
|
+
(B, layers, `z_dims[i]`, H, W).
|
|
213
|
+
img_shape : tuple[int]
|
|
214
|
+
The shape of the input image to the LVAE model. Shape is ([Z], Y, X).
|
|
215
|
+
kl_type : Literal["kl", "kl_restricted"]
|
|
216
|
+
The type of KL divergence loss to compute.
|
|
217
|
+
|
|
218
|
+
Returns
|
|
219
|
+
-------
|
|
220
|
+
kl_loss : torch.Tensor
|
|
221
|
+
The KL divergence loss for the muSplit case. Shape is (1, ).
|
|
222
|
+
"""
|
|
223
|
+
return get_kl_divergence_loss(
|
|
224
|
+
kl_type="kl", # TODO: hardcoded, deal in future PR
|
|
225
|
+
topdown_data=topdown_data,
|
|
226
|
+
rescaling="latent_dim",
|
|
227
|
+
aggregation="mean",
|
|
228
|
+
free_bits_coeff=0.0,
|
|
229
|
+
img_shape=img_shape,
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def _get_kl_divergence_loss_denoisplit(
|
|
234
|
+
topdown_data: dict[str, torch.Tensor],
|
|
235
|
+
img_shape: tuple[int],
|
|
236
|
+
kl_type: Literal["kl", "kl_restricted"],
|
|
237
|
+
) -> torch.Tensor:
|
|
238
|
+
"""Compute the KL divergence loss for denoiSplit.
|
|
239
|
+
|
|
240
|
+
Parameters
|
|
241
|
+
----------
|
|
242
|
+
topdown_data : dict[str, torch.Tensor]
|
|
243
|
+
A dictionary containing information computed for each layer during the top-down
|
|
244
|
+
pass. The dictionary must include the following keys:
|
|
245
|
+
- "kl": The KL-loss values for each layer. Shape of each tensor is (B,).
|
|
246
|
+
- "z": The sampled latents for each layer. Shape of each tensor is
|
|
247
|
+
(B, layers, `z_dims[i]`, H, W).
|
|
248
|
+
img_shape : tuple[int]
|
|
249
|
+
The shape of the input image to the LVAE model. Shape is ([Z], Y, X).
|
|
250
|
+
kl_type : Literal["kl", "kl_restricted"]
|
|
251
|
+
The type of KL divergence loss to compute.
|
|
252
|
+
|
|
253
|
+
Returns
|
|
254
|
+
-------
|
|
255
|
+
kl_loss : torch.Tensor
|
|
256
|
+
The KL divergence loss for the denoiSplit case. Shape is (1, ).
|
|
257
|
+
"""
|
|
258
|
+
return get_kl_divergence_loss(
|
|
259
|
+
kl_type=kl_type,
|
|
260
|
+
topdown_data=topdown_data,
|
|
261
|
+
rescaling="image_dim",
|
|
262
|
+
aggregation="sum",
|
|
263
|
+
free_bits_coeff=1.0,
|
|
264
|
+
img_shape=img_shape,
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
# TODO: @melisande-c suggested to refactor this as a class (see PR #208)
|
|
269
|
+
# - loss computation happens by calling the `__call__` method
|
|
270
|
+
# - `__init__` method initializes the loss parameters now contained in
|
|
271
|
+
# the `LVAELossParameters` class
|
|
272
|
+
# NOTE: same for the other loss functions
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def hdn_loss(
|
|
276
|
+
model_outputs: tuple[torch.Tensor, dict[str, Any]],
|
|
277
|
+
targets: torch.Tensor,
|
|
278
|
+
config: LVAELossConfig,
|
|
279
|
+
gaussian_likelihood: GaussianLikelihood | None,
|
|
280
|
+
noise_model_likelihood: NoiseModelLikelihood | None,
|
|
281
|
+
) -> dict[str, torch.Tensor] | None:
|
|
282
|
+
"""Loss function for HDN.
|
|
283
|
+
|
|
284
|
+
Parameters
|
|
285
|
+
----------
|
|
286
|
+
model_outputs : tuple[torch.Tensor, dict[str, Any]]
|
|
287
|
+
Tuple containing the model predictions (shape is (B, `target_ch`, [Z], Y, X))
|
|
288
|
+
and the top-down layer data (e.g., sampled latents, KL-loss values, etc.).
|
|
289
|
+
targets : torch.Tensor
|
|
290
|
+
The target image used to compute the reconstruction loss. In this case we use
|
|
291
|
+
the input patch itself as target. Shape is (B, `target_ch`, [Z], Y, X).
|
|
292
|
+
config : LVAELossConfig
|
|
293
|
+
The config for loss function containing all loss hyperparameters.
|
|
294
|
+
gaussian_likelihood : GaussianLikelihood
|
|
295
|
+
The Gaussian likelihood object.
|
|
296
|
+
noise_model_likelihood : NoiseModelLikelihood
|
|
297
|
+
The noise model likelihood object.
|
|
298
|
+
|
|
299
|
+
Returns
|
|
300
|
+
-------
|
|
301
|
+
output : Optional[dict[str, torch.Tensor]]
|
|
302
|
+
A dictionary containing the overall loss `["loss"]`, the reconstruction loss
|
|
303
|
+
`["reconstruction_loss"]`, and the KL divergence loss `["kl_loss"]`.
|
|
304
|
+
"""
|
|
305
|
+
if gaussian_likelihood is not None:
|
|
306
|
+
likelihood = gaussian_likelihood
|
|
307
|
+
elif noise_model_likelihood is not None:
|
|
308
|
+
likelihood = noise_model_likelihood
|
|
309
|
+
else:
|
|
310
|
+
raise ValueError("Invalid likelihood object.")
|
|
311
|
+
# TODO refactor loss signature
|
|
312
|
+
predictions, td_data = model_outputs
|
|
313
|
+
|
|
314
|
+
# Reconstruction loss computation
|
|
315
|
+
recons_loss = config.reconstruction_weight * get_reconstruction_loss(
|
|
316
|
+
reconstruction=predictions,
|
|
317
|
+
target=targets,
|
|
318
|
+
likelihood_obj=likelihood,
|
|
319
|
+
)
|
|
320
|
+
if torch.isnan(recons_loss).any():
|
|
321
|
+
recons_loss = 0.0
|
|
322
|
+
|
|
323
|
+
# KL loss computation
|
|
324
|
+
kl_weight = get_kl_weight(
|
|
325
|
+
config.kl_params.annealing,
|
|
326
|
+
config.kl_params.start,
|
|
327
|
+
config.kl_params.annealtime,
|
|
328
|
+
config.kl_weight,
|
|
329
|
+
config.kl_params.current_epoch,
|
|
330
|
+
)
|
|
331
|
+
kl_loss = (
|
|
332
|
+
_get_kl_divergence_loss_denoisplit(
|
|
333
|
+
topdown_data=td_data,
|
|
334
|
+
img_shape=targets.shape[2:],
|
|
335
|
+
kl_type=config.kl_params.loss_type,
|
|
336
|
+
)
|
|
337
|
+
* kl_weight
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
net_loss = recons_loss + kl_loss # TODO add check that losses coefs sum to 1
|
|
341
|
+
output = {
|
|
342
|
+
"loss": net_loss,
|
|
343
|
+
"reconstruction_loss": (
|
|
344
|
+
recons_loss.detach()
|
|
345
|
+
if isinstance(recons_loss, torch.Tensor)
|
|
346
|
+
else recons_loss
|
|
347
|
+
),
|
|
348
|
+
"kl_loss": kl_loss.detach(),
|
|
349
|
+
}
|
|
350
|
+
# https://github.com/openai/vdvae/blob/main/train.py#L26
|
|
351
|
+
if torch.isnan(net_loss).any():
|
|
352
|
+
return None
|
|
353
|
+
|
|
354
|
+
return output
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
def musplit_loss(
|
|
358
|
+
model_outputs: tuple[torch.Tensor, dict[str, Any]],
|
|
359
|
+
targets: torch.Tensor,
|
|
360
|
+
config: LVAELossConfig,
|
|
361
|
+
gaussian_likelihood: GaussianLikelihood | None,
|
|
362
|
+
noise_model_likelihood: NoiseModelLikelihood | None = None, # TODO: ugly
|
|
363
|
+
) -> dict[str, torch.Tensor] | None:
|
|
364
|
+
"""Loss function for muSplit.
|
|
365
|
+
|
|
366
|
+
Parameters
|
|
367
|
+
----------
|
|
368
|
+
model_outputs : tuple[torch.Tensor, dict[str, Any]]
|
|
369
|
+
Tuple containing the model predictions (shape is (B, `target_ch`, [Z], Y, X))
|
|
370
|
+
and the top-down layer data (e.g., sampled latents, KL-loss values, etc.).
|
|
371
|
+
targets : torch.Tensor
|
|
372
|
+
The target image used to compute the reconstruction loss. Shape is
|
|
373
|
+
(B, `target_ch`, [Z], Y, X).
|
|
374
|
+
config : LVAELossConfig
|
|
375
|
+
The config for loss function (e.g., KL hyperparameters, likelihood module,
|
|
376
|
+
noise model, etc.).
|
|
377
|
+
gaussian_likelihood : GaussianLikelihood
|
|
378
|
+
The Gaussian likelihood object.
|
|
379
|
+
noise_model_likelihood : Optional[NoiseModelLikelihood]
|
|
380
|
+
The noise model likelihood object. Not used here.
|
|
381
|
+
|
|
382
|
+
Returns
|
|
383
|
+
-------
|
|
384
|
+
output : Optional[dict[str, torch.Tensor]]
|
|
385
|
+
A dictionary containing the overall loss `["loss"]`, the reconstruction loss
|
|
386
|
+
`["reconstruction_loss"]`, and the KL divergence loss `["kl_loss"]`.
|
|
387
|
+
"""
|
|
388
|
+
assert gaussian_likelihood is not None
|
|
389
|
+
|
|
390
|
+
predictions, td_data = model_outputs
|
|
391
|
+
|
|
392
|
+
# Reconstruction loss computation
|
|
393
|
+
recons_loss = config.reconstruction_weight * get_reconstruction_loss(
|
|
394
|
+
reconstruction=predictions,
|
|
395
|
+
target=targets,
|
|
396
|
+
likelihood_obj=gaussian_likelihood,
|
|
397
|
+
)
|
|
398
|
+
if torch.isnan(recons_loss).any():
|
|
399
|
+
recons_loss = 0.0
|
|
400
|
+
|
|
401
|
+
# KL loss computation
|
|
402
|
+
kl_weight = get_kl_weight(
|
|
403
|
+
config.kl_params.annealing,
|
|
404
|
+
config.kl_params.start,
|
|
405
|
+
config.kl_params.annealtime,
|
|
406
|
+
config.kl_weight,
|
|
407
|
+
config.kl_params.current_epoch,
|
|
408
|
+
)
|
|
409
|
+
kl_loss = (
|
|
410
|
+
_get_kl_divergence_loss_musplit(
|
|
411
|
+
topdown_data=td_data,
|
|
412
|
+
img_shape=targets.shape[2:],
|
|
413
|
+
kl_type=config.kl_params.loss_type,
|
|
414
|
+
)
|
|
415
|
+
* kl_weight
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
net_loss = recons_loss + kl_loss
|
|
419
|
+
output = {
|
|
420
|
+
"loss": net_loss,
|
|
421
|
+
"reconstruction_loss": (
|
|
422
|
+
recons_loss.detach()
|
|
423
|
+
if isinstance(recons_loss, torch.Tensor)
|
|
424
|
+
else recons_loss
|
|
425
|
+
),
|
|
426
|
+
"kl_loss": kl_loss.detach(),
|
|
427
|
+
}
|
|
428
|
+
# https://github.com/openai/vdvae/blob/main/train.py#L26
|
|
429
|
+
if torch.isnan(net_loss).any():
|
|
430
|
+
return None
|
|
431
|
+
|
|
432
|
+
return output
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
def denoisplit_loss(
|
|
436
|
+
model_outputs: tuple[torch.Tensor, dict[str, Any]],
|
|
437
|
+
targets: torch.Tensor,
|
|
438
|
+
config: LVAELossConfig,
|
|
439
|
+
gaussian_likelihood: GaussianLikelihood | None = None,
|
|
440
|
+
noise_model_likelihood: NoiseModelLikelihood | None = None,
|
|
441
|
+
) -> dict[str, torch.Tensor] | None:
|
|
442
|
+
"""Loss function for DenoiSplit.
|
|
443
|
+
|
|
444
|
+
Parameters
|
|
445
|
+
----------
|
|
446
|
+
model_outputs : tuple[torch.Tensor, dict[str, Any]]
|
|
447
|
+
Tuple containing the model predictions (shape is (B, `target_ch`, [Z], Y, X))
|
|
448
|
+
and the top-down layer data (e.g., sampled latents, KL-loss values, etc.).
|
|
449
|
+
targets : torch.Tensor
|
|
450
|
+
The target image used to compute the reconstruction loss. Shape is
|
|
451
|
+
(B, `target_ch`, [Z], Y, X).
|
|
452
|
+
config : LVAELossConfig
|
|
453
|
+
The config for loss function containing all loss hyperparameters.
|
|
454
|
+
gaussian_likelihood : GaussianLikelihood
|
|
455
|
+
The Gaussian likelihood object.
|
|
456
|
+
noise_model_likelihood : NoiseModelLikelihood
|
|
457
|
+
The noise model likelihood object.
|
|
458
|
+
|
|
459
|
+
Returns
|
|
460
|
+
-------
|
|
461
|
+
output : Optional[dict[str, torch.Tensor]]
|
|
462
|
+
A dictionary containing the overall loss `["loss"]`, the reconstruction loss
|
|
463
|
+
`["reconstruction_loss"]`, and the KL divergence loss `["kl_loss"]`.
|
|
464
|
+
"""
|
|
465
|
+
assert noise_model_likelihood is not None
|
|
466
|
+
|
|
467
|
+
predictions, td_data = model_outputs
|
|
468
|
+
|
|
469
|
+
# Reconstruction loss computation
|
|
470
|
+
recons_loss = config.reconstruction_weight * get_reconstruction_loss(
|
|
471
|
+
reconstruction=predictions,
|
|
472
|
+
target=targets,
|
|
473
|
+
likelihood_obj=noise_model_likelihood,
|
|
474
|
+
)
|
|
475
|
+
if torch.isnan(recons_loss).any():
|
|
476
|
+
recons_loss = 0.0
|
|
477
|
+
|
|
478
|
+
# KL loss computation
|
|
479
|
+
kl_weight = get_kl_weight(
|
|
480
|
+
config.kl_params.annealing,
|
|
481
|
+
config.kl_params.start,
|
|
482
|
+
config.kl_params.annealtime,
|
|
483
|
+
config.kl_weight,
|
|
484
|
+
config.kl_params.current_epoch,
|
|
485
|
+
)
|
|
486
|
+
kl_loss = (
|
|
487
|
+
_get_kl_divergence_loss_denoisplit(
|
|
488
|
+
topdown_data=td_data,
|
|
489
|
+
img_shape=targets.shape[2:],
|
|
490
|
+
kl_type=config.kl_params.loss_type,
|
|
491
|
+
)
|
|
492
|
+
* kl_weight
|
|
493
|
+
)
|
|
494
|
+
|
|
495
|
+
net_loss = recons_loss + kl_loss
|
|
496
|
+
output = {
|
|
497
|
+
"loss": net_loss,
|
|
498
|
+
"reconstruction_loss": (
|
|
499
|
+
recons_loss.detach()
|
|
500
|
+
if isinstance(recons_loss, torch.Tensor)
|
|
501
|
+
else recons_loss
|
|
502
|
+
),
|
|
503
|
+
"kl_loss": kl_loss.detach(),
|
|
504
|
+
}
|
|
505
|
+
# https://github.com/openai/vdvae/blob/main/train.py#L26
|
|
506
|
+
if torch.isnan(net_loss).any():
|
|
507
|
+
return None
|
|
508
|
+
|
|
509
|
+
return output
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
def denoisplit_musplit_loss(
|
|
513
|
+
model_outputs: tuple[torch.Tensor, dict[str, Any]],
|
|
514
|
+
targets: torch.Tensor,
|
|
515
|
+
config: LVAELossConfig,
|
|
516
|
+
gaussian_likelihood: GaussianLikelihood,
|
|
517
|
+
noise_model_likelihood: NoiseModelLikelihood,
|
|
518
|
+
) -> dict[str, torch.Tensor] | None:
|
|
519
|
+
"""Loss function for DenoiSplit.
|
|
520
|
+
|
|
521
|
+
Parameters
|
|
522
|
+
----------
|
|
523
|
+
model_outputs : tuple[torch.Tensor, dict[str, Any]]
|
|
524
|
+
Tuple containing the model predictions (shape is (B, `target_ch`, [Z], Y, X))
|
|
525
|
+
and the top-down layer data (e.g., sampled latents, KL-loss values, etc.).
|
|
526
|
+
targets : torch.Tensor
|
|
527
|
+
The target image used to compute the reconstruction loss. Shape is
|
|
528
|
+
(B, `target_ch`, [Z], Y, X).
|
|
529
|
+
config : LVAELossConfig
|
|
530
|
+
The config for loss function containing all loss hyperparameters.
|
|
531
|
+
gaussian_likelihood : GaussianLikelihood
|
|
532
|
+
The Gaussian likelihood object.
|
|
533
|
+
noise_model_likelihood : NoiseModelLikelihood
|
|
534
|
+
The noise model likelihood object.
|
|
535
|
+
|
|
536
|
+
Returns
|
|
537
|
+
-------
|
|
538
|
+
output : Optional[dict[str, torch.Tensor]]
|
|
539
|
+
A dictionary containing the overall loss `["loss"]`, the reconstruction loss
|
|
540
|
+
`["reconstruction_loss"]`, and the KL divergence loss `["kl_loss"]`.
|
|
541
|
+
"""
|
|
542
|
+
predictions, td_data = model_outputs
|
|
543
|
+
|
|
544
|
+
# Reconstruction loss computation
|
|
545
|
+
recons_loss = _reconstruction_loss_musplit_denoisplit(
|
|
546
|
+
predictions=predictions,
|
|
547
|
+
targets=targets,
|
|
548
|
+
nm_likelihood=noise_model_likelihood,
|
|
549
|
+
gaussian_likelihood=gaussian_likelihood,
|
|
550
|
+
nm_weight=config.denoisplit_weight,
|
|
551
|
+
gaussian_weight=config.musplit_weight,
|
|
552
|
+
)
|
|
553
|
+
if torch.isnan(recons_loss).any():
|
|
554
|
+
recons_loss = 0.0
|
|
555
|
+
|
|
556
|
+
# KL loss computation
|
|
557
|
+
# NOTE: 'kl' key stands for the 'kl_samplewise' key in the TopDownLayer class.
|
|
558
|
+
# The different naming comes from `top_down_pass()` method in the LadderVAE.
|
|
559
|
+
denoisplit_kl = _get_kl_divergence_loss_denoisplit(
|
|
560
|
+
topdown_data=td_data,
|
|
561
|
+
img_shape=targets.shape[2:],
|
|
562
|
+
kl_type=config.kl_params.loss_type,
|
|
563
|
+
)
|
|
564
|
+
musplit_kl = _get_kl_divergence_loss_musplit(
|
|
565
|
+
topdown_data=td_data,
|
|
566
|
+
img_shape=targets.shape[2:],
|
|
567
|
+
kl_type=config.kl_params.loss_type,
|
|
568
|
+
)
|
|
569
|
+
kl_loss = (
|
|
570
|
+
config.denoisplit_weight * denoisplit_kl + config.musplit_weight * musplit_kl
|
|
571
|
+
)
|
|
572
|
+
# TODO `kl_weight` is hardcoded (???)
|
|
573
|
+
kl_loss = config.kl_weight * kl_loss
|
|
574
|
+
|
|
575
|
+
net_loss = recons_loss + kl_loss
|
|
576
|
+
output = {
|
|
577
|
+
"loss": net_loss,
|
|
578
|
+
"reconstruction_loss": (
|
|
579
|
+
recons_loss.detach()
|
|
580
|
+
if isinstance(recons_loss, torch.Tensor)
|
|
581
|
+
else recons_loss
|
|
582
|
+
),
|
|
583
|
+
"kl_loss": kl_loss.detach(),
|
|
584
|
+
}
|
|
585
|
+
# https://github.com/openai/vdvae/blob/main/train.py#L26
|
|
586
|
+
if torch.isnan(net_loss).any():
|
|
587
|
+
return None
|
|
588
|
+
|
|
589
|
+
return output
|
|
File without changes
|