careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +6 -1
- careamics/careamist.py +729 -0
- careamics/config/__init__.py +39 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +37 -0
- careamics/config/architectures/custom_model.py +162 -0
- careamics/config/architectures/lvae_model.py +174 -0
- careamics/config/architectures/register_model.py +103 -0
- careamics/config/architectures/unet_model.py +118 -0
- careamics/config/callback_model.py +123 -0
- careamics/config/configuration_factory.py +583 -0
- careamics/config/configuration_model.py +604 -0
- careamics/config/data_model.py +527 -0
- careamics/config/fcn_algorithm_model.py +147 -0
- careamics/config/inference_model.py +239 -0
- careamics/config/likelihood_model.py +43 -0
- careamics/config/nm_model.py +101 -0
- careamics/config/optimizer_models.py +187 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +132 -0
- careamics/config/references/references.py +39 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +33 -0
- careamics/config/support/supported_architectures.py +17 -0
- careamics/config/support/supported_data.py +109 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +29 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +11 -0
- careamics/config/tile_information.py +65 -0
- careamics/config/training_model.py +72 -0
- careamics/config/transformations/__init__.py +15 -0
- careamics/config/transformations/n2v_manipulate_model.py +64 -0
- careamics/config/transformations/normalize_model.py +60 -0
- careamics/config/transformations/transform_model.py +45 -0
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +35 -0
- careamics/config/vae_algorithm_model.py +171 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +101 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +101 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +310 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +295 -0
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +299 -0
- careamics/dataset/patching/random_patching.py +201 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
- careamics/dataset/tiling/tiled_patching.py +164 -0
- careamics/dataset/zarr_dataset.py +151 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +12 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/read/zarr.py +60 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +18 -0
- careamics/lightning/callbacks/__init__.py +11 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/lightning_module.py +632 -0
- careamics/lightning/predict_data_module.py +333 -0
- careamics/lightning/train_data_module.py +680 -0
- careamics/losses/__init__.py +15 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +98 -0
- careamics/losses/loss_factory.py +155 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +445 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/dataset/__init__.py +0 -0
- careamics/lvae_training/dataset/data_utils.py +701 -0
- careamics/lvae_training/dataset/lc_dataset.py +259 -0
- careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
- careamics/lvae_training/dataset/vae_data_config.py +179 -0
- careamics/lvae_training/dataset/vae_dataset.py +1054 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +121 -0
- careamics/model_io/bioimage/bioimage_utils.py +52 -0
- careamics/model_io/bioimage/model_description.py +327 -0
- careamics/model_io/bmz_io.py +246 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +39 -0
- careamics/models/layers.py +493 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +364 -0
- careamics/models/lvae/lvae.py +901 -0
- careamics/models/lvae/noise_models.py +541 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +67 -0
- careamics/models/unet.py +443 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +135 -0
- careamics/prediction_utils/stitch_prediction.py +112 -0
- careamics/transforms/__init__.py +20 -0
- careamics/transforms/compose.py +107 -0
- careamics/transforms/n2v_manipulate.py +146 -0
- careamics/transforms/normalize.py +243 -0
- careamics/transforms/pixel_manipulation.py +407 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +101 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +66 -0
- careamics/utils/logging.py +322 -0
- careamics/utils/metrics.py +188 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/torch_utils.py +127 -0
- careamics-0.0.3.dist-info/METADATA +78 -0
- careamics-0.0.3.dist-info/RECORD +154 -0
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
- careamics-0.0.1.dist-info/METADATA +0 -46
- careamics-0.0.1.dist-info/RECORD +0 -6
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Loss submodule.
|
|
3
|
+
|
|
4
|
+
This submodule contains the various losses used in CAREamics.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torch.nn import L1Loss, MSELoss
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def mse_loss(source: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
12
|
+
"""
|
|
13
|
+
Mean squared error loss.
|
|
14
|
+
|
|
15
|
+
Parameters
|
|
16
|
+
----------
|
|
17
|
+
source : torch.Tensor
|
|
18
|
+
Source patches.
|
|
19
|
+
target : torch.Tensor
|
|
20
|
+
Target patches.
|
|
21
|
+
|
|
22
|
+
Returns
|
|
23
|
+
-------
|
|
24
|
+
torch.Tensor
|
|
25
|
+
Loss value.
|
|
26
|
+
"""
|
|
27
|
+
loss = MSELoss()
|
|
28
|
+
return loss(source, target)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def n2v_loss(
|
|
32
|
+
manipulated_patches: torch.Tensor,
|
|
33
|
+
original_patches: torch.Tensor,
|
|
34
|
+
masks: torch.Tensor,
|
|
35
|
+
) -> torch.Tensor:
|
|
36
|
+
"""
|
|
37
|
+
N2V Loss function described in A Krull et al 2018.
|
|
38
|
+
|
|
39
|
+
Parameters
|
|
40
|
+
----------
|
|
41
|
+
manipulated_patches : torch.Tensor
|
|
42
|
+
Patches with manipulated pixels.
|
|
43
|
+
original_patches : torch.Tensor
|
|
44
|
+
Noisy patches.
|
|
45
|
+
masks : torch.Tensor
|
|
46
|
+
Array containing masked pixel locations.
|
|
47
|
+
|
|
48
|
+
Returns
|
|
49
|
+
-------
|
|
50
|
+
torch.Tensor
|
|
51
|
+
Loss value.
|
|
52
|
+
"""
|
|
53
|
+
errors = (original_patches - manipulated_patches) ** 2
|
|
54
|
+
# Average over pixels and batch
|
|
55
|
+
loss = torch.sum(errors * masks) / torch.sum(masks)
|
|
56
|
+
return loss # TODO change output to dict ?
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def mae_loss(samples: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
|
60
|
+
"""
|
|
61
|
+
N2N Loss function described in to J Lehtinen et al 2018.
|
|
62
|
+
|
|
63
|
+
Parameters
|
|
64
|
+
----------
|
|
65
|
+
samples : torch.Tensor
|
|
66
|
+
Raw patches.
|
|
67
|
+
labels : torch.Tensor
|
|
68
|
+
Different subset of noisy patches.
|
|
69
|
+
|
|
70
|
+
Returns
|
|
71
|
+
-------
|
|
72
|
+
torch.Tensor
|
|
73
|
+
Loss value.
|
|
74
|
+
"""
|
|
75
|
+
loss = L1Loss()
|
|
76
|
+
return loss(samples, labels)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
# def pn2v_loss(
|
|
80
|
+
# samples: torch.Tensor,
|
|
81
|
+
# labels: torch.Tensor,
|
|
82
|
+
# masks: torch.Tensor,
|
|
83
|
+
# noise_model: HistogramNoiseModel,
|
|
84
|
+
# ) -> torch.Tensor:
|
|
85
|
+
# """Probabilistic N2V loss function described in A Krull et al., CVF (2019)."""
|
|
86
|
+
# likelihoods = noise_model.likelihood(labels, samples)
|
|
87
|
+
# likelihoods_avg = torch.log(torch.mean(likelihoods, dim=0, keepdim=True)[0, ...])
|
|
88
|
+
|
|
89
|
+
# # Average over pixels and batch
|
|
90
|
+
# loss = -torch.sum(likelihoods_avg * masks) / torch.sum(masks)
|
|
91
|
+
# return loss
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
# def dice_loss(
|
|
95
|
+
# samples: torch.Tensor, labels: torch.Tensor, mode: str = "multiclass"
|
|
96
|
+
# ) -> torch.Tensor:
|
|
97
|
+
# """Dice loss function."""
|
|
98
|
+
# return DiceLoss(mode=mode)(samples, labels.long())
|
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Loss factory module.
|
|
3
|
+
|
|
4
|
+
This module contains a factory function for creating loss functions.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from typing import TYPE_CHECKING, Callable, Literal, Optional, Union
|
|
11
|
+
|
|
12
|
+
from torch import Tensor as tensor
|
|
13
|
+
|
|
14
|
+
from ..config.support import SupportedLoss
|
|
15
|
+
from .fcn.losses import mae_loss, mse_loss, n2v_loss
|
|
16
|
+
from .lvae.losses import denoisplit_loss, denoisplit_musplit_loss, musplit_loss
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from careamics.models.lvae.likelihoods import (
|
|
20
|
+
GaussianLikelihood,
|
|
21
|
+
NoiseModelLikelihood,
|
|
22
|
+
)
|
|
23
|
+
from careamics.models.lvae.noise_models import (
|
|
24
|
+
GaussianMixtureNoiseModel,
|
|
25
|
+
MultiChannelNoiseModel,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
NoiseModel = Union[GaussianMixtureNoiseModel, MultiChannelNoiseModel]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class FCNLossParameters:
|
|
33
|
+
"""Dataclass for FCN loss."""
|
|
34
|
+
|
|
35
|
+
# TODO check
|
|
36
|
+
prediction: tensor
|
|
37
|
+
targets: tensor
|
|
38
|
+
mask: tensor
|
|
39
|
+
current_epoch: int
|
|
40
|
+
loss_weight: float
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass # TODO why not pydantic?
|
|
44
|
+
class LVAELossParameters:
|
|
45
|
+
"""Dataclass for LVAE loss."""
|
|
46
|
+
|
|
47
|
+
# TODO: refactor in more modular blocks (otherwise it gets messy very easily)
|
|
48
|
+
# e.g., - weights, - kl_params, ...
|
|
49
|
+
|
|
50
|
+
noise_model_likelihood: Optional[NoiseModelLikelihood] = None
|
|
51
|
+
"""Noise model likelihood instance."""
|
|
52
|
+
gaussian_likelihood: Optional[GaussianLikelihood] = None
|
|
53
|
+
"""Gaussian likelihood instance."""
|
|
54
|
+
current_epoch: int = 0
|
|
55
|
+
"""Current epoch in the training loop."""
|
|
56
|
+
reconstruction_weight: float = 1.0
|
|
57
|
+
"""Weight for the reconstruction loss in the total net loss
|
|
58
|
+
(i.e., `net_loss = reconstruction_weight * rec_loss + kl_weight * kl_loss`)."""
|
|
59
|
+
musplit_weight: float = 0.0
|
|
60
|
+
"""Weight for the muSplit loss (used in the muSplit-deonoiSplit loss)."""
|
|
61
|
+
denoisplit_weight: float = 1.0
|
|
62
|
+
"""Weight for the denoiSplit loss (used in the muSplit-deonoiSplit loss)."""
|
|
63
|
+
kl_type: Literal["kl", "kl_restricted", "kl_spatial", "kl_channelwise"] = "kl"
|
|
64
|
+
"""Type of KL divergence used as KL loss."""
|
|
65
|
+
kl_weight: float = 1.0
|
|
66
|
+
"""Weight for the KL loss in the total net loss.
|
|
67
|
+
(i.e., `net_loss = reconstruction_weight * rec_loss + kl_weight * kl_loss`)."""
|
|
68
|
+
kl_annealing: bool = False
|
|
69
|
+
"""Whether to apply KL loss annealing."""
|
|
70
|
+
kl_start: int = -1
|
|
71
|
+
"""Epoch at which KL loss annealing starts."""
|
|
72
|
+
kl_annealtime: int = 10
|
|
73
|
+
"""Number of epochs for which KL loss annealing is applied."""
|
|
74
|
+
non_stochastic: bool = False
|
|
75
|
+
"""Whether to sample latents and compute KL."""
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
# TODO: really needed?
|
|
79
|
+
# like it is now, it is difficult to use, we need a way to specify the
|
|
80
|
+
# loss parameters in a more user-friendly way.
|
|
81
|
+
def loss_parameters_factory(
|
|
82
|
+
type: SupportedLoss,
|
|
83
|
+
) -> Union[FCNLossParameters, LVAELossParameters]:
|
|
84
|
+
"""Return loss parameters.
|
|
85
|
+
|
|
86
|
+
Parameters
|
|
87
|
+
----------
|
|
88
|
+
type : SupportedLoss
|
|
89
|
+
Requested loss.
|
|
90
|
+
|
|
91
|
+
Returns
|
|
92
|
+
-------
|
|
93
|
+
Union[FCNLossParameters, LVAELossParameters]
|
|
94
|
+
Loss parameters.
|
|
95
|
+
|
|
96
|
+
Raises
|
|
97
|
+
------
|
|
98
|
+
NotImplementedError
|
|
99
|
+
If the loss is unknown.
|
|
100
|
+
"""
|
|
101
|
+
if type in [SupportedLoss.N2V, SupportedLoss.MSE, SupportedLoss.MAE]:
|
|
102
|
+
return FCNLossParameters
|
|
103
|
+
|
|
104
|
+
elif type in [
|
|
105
|
+
SupportedLoss.MUSPLIT,
|
|
106
|
+
SupportedLoss.DENOISPLIT,
|
|
107
|
+
SupportedLoss.DENOISPLIT_MUSPLIT,
|
|
108
|
+
]:
|
|
109
|
+
return LVAELossParameters # it returns the class, not an instance
|
|
110
|
+
|
|
111
|
+
else:
|
|
112
|
+
raise NotImplementedError(f"Loss {type} is not yet supported.")
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def loss_factory(loss: Union[SupportedLoss, str]) -> Callable:
|
|
116
|
+
"""Return loss function.
|
|
117
|
+
|
|
118
|
+
Parameters
|
|
119
|
+
----------
|
|
120
|
+
loss : Union[SupportedLoss, str]
|
|
121
|
+
Requested loss.
|
|
122
|
+
|
|
123
|
+
Returns
|
|
124
|
+
-------
|
|
125
|
+
Callable
|
|
126
|
+
Loss function.
|
|
127
|
+
|
|
128
|
+
Raises
|
|
129
|
+
------
|
|
130
|
+
NotImplementedError
|
|
131
|
+
If the loss is unknown.
|
|
132
|
+
"""
|
|
133
|
+
if loss == SupportedLoss.N2V:
|
|
134
|
+
return n2v_loss
|
|
135
|
+
|
|
136
|
+
# elif loss_type == SupportedLoss.PN2V:
|
|
137
|
+
# return pn2v_loss
|
|
138
|
+
|
|
139
|
+
elif loss == SupportedLoss.MAE:
|
|
140
|
+
return mae_loss
|
|
141
|
+
|
|
142
|
+
elif loss == SupportedLoss.MSE:
|
|
143
|
+
return mse_loss
|
|
144
|
+
|
|
145
|
+
elif loss == SupportedLoss.MUSPLIT:
|
|
146
|
+
return musplit_loss
|
|
147
|
+
|
|
148
|
+
elif loss == SupportedLoss.DENOISPLIT:
|
|
149
|
+
return denoisplit_loss
|
|
150
|
+
|
|
151
|
+
elif loss == SupportedLoss.DENOISPLIT_MUSPLIT:
|
|
152
|
+
return denoisplit_musplit_loss
|
|
153
|
+
|
|
154
|
+
else:
|
|
155
|
+
raise NotImplementedError(f"Loss {loss} is not yet supported.")
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""LVAE losses."""
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def free_bits_kl(
|
|
5
|
+
kl: torch.Tensor, free_bits: float, batch_average: bool = False, eps: float = 1e-6
|
|
6
|
+
) -> torch.Tensor:
|
|
7
|
+
"""Compute free-bits version of KL divergence.
|
|
8
|
+
|
|
9
|
+
This function ensures that the KL doesn't go to zero for any latent dimension.
|
|
10
|
+
Hence, it contributes to use latent variables more efficiently, leading to
|
|
11
|
+
better representation learning.
|
|
12
|
+
|
|
13
|
+
NOTE:
|
|
14
|
+
Takes in the KL with shape (batch size, layers), returns the KL with
|
|
15
|
+
free bits (for optimization) with shape (layers,), which is the average
|
|
16
|
+
free-bits KL per layer in the current batch.
|
|
17
|
+
If batch_average is False (default), the free bits are per layer and
|
|
18
|
+
per batch element. Otherwise, the free bits are still per layer, but
|
|
19
|
+
are assigned on average to the whole batch. In both cases, the batch
|
|
20
|
+
average is returned, so it's simply a matter of doing mean(clamp(KL))
|
|
21
|
+
or clamp(mean(KL)).
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
kl : torch.Tensor
|
|
26
|
+
The KL divergence tensor with shape (batch size, layers).
|
|
27
|
+
free_bits : float
|
|
28
|
+
The free bits value. Set to 0.0 to disable free bits.
|
|
29
|
+
batch_average : bool
|
|
30
|
+
Whether to average over the batch before clamping to `free_bits`.
|
|
31
|
+
eps : float
|
|
32
|
+
A small value to avoid numerical instability.
|
|
33
|
+
|
|
34
|
+
Returns
|
|
35
|
+
-------
|
|
36
|
+
torch.Tensor
|
|
37
|
+
The free-bits version of the KL divergence with shape (layers,).
|
|
38
|
+
"""
|
|
39
|
+
assert kl.dim() == 2
|
|
40
|
+
if free_bits < eps:
|
|
41
|
+
return kl.mean(0)
|
|
42
|
+
if batch_average:
|
|
43
|
+
return kl.mean(0).clamp(min=free_bits)
|
|
44
|
+
return kl.clamp(min=free_bits).mean(0)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def get_kl_weight(
|
|
48
|
+
kl_annealing: bool,
|
|
49
|
+
kl_start: int,
|
|
50
|
+
kl_annealtime: int,
|
|
51
|
+
kl_weight: float,
|
|
52
|
+
current_epoch: int,
|
|
53
|
+
) -> float:
|
|
54
|
+
"""Compute the weight of the KL loss in case of annealing.
|
|
55
|
+
|
|
56
|
+
Parameters
|
|
57
|
+
----------
|
|
58
|
+
kl_annealing : bool
|
|
59
|
+
Whether to use KL annealing.
|
|
60
|
+
kl_start : int
|
|
61
|
+
The epoch at which to start
|
|
62
|
+
kl_annealtime : int
|
|
63
|
+
The number of epochs for which annealing is applied.
|
|
64
|
+
kl_weight : float
|
|
65
|
+
The weight for the KL loss. If `None`, the weight is computed
|
|
66
|
+
using annealing, else it is set to a default of 1.
|
|
67
|
+
current_epoch : int
|
|
68
|
+
The current epoch.
|
|
69
|
+
"""
|
|
70
|
+
if kl_annealing:
|
|
71
|
+
# calculate relative weight
|
|
72
|
+
kl_weight = (current_epoch - kl_start) * (1.0 / kl_annealtime)
|
|
73
|
+
# clamp to [0,1]
|
|
74
|
+
kl_weight = min(max(0.0, kl_weight), 1.0)
|
|
75
|
+
|
|
76
|
+
# if the final weight is given, then apply that weight on top of it
|
|
77
|
+
if kl_weight is not None:
|
|
78
|
+
kl_weight = kl_weight * kl_weight
|
|
79
|
+
elif kl_weight is not None:
|
|
80
|
+
return kl_weight
|
|
81
|
+
else:
|
|
82
|
+
kl_weight = 1.0
|
|
83
|
+
return kl_weight
|