careamics 0.0.4.2__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +17 -2
- careamics/careamist.py +239 -28
- careamics/cli/conf.py +19 -31
- careamics/cli/main.py +112 -12
- careamics/cli/utils.py +29 -0
- careamics/config/__init__.py +48 -24
- careamics/config/algorithms/__init__.py +15 -0
- careamics/config/algorithms/care_algorithm_model.py +50 -0
- careamics/config/algorithms/n2n_algorithm_model.py +42 -0
- careamics/config/algorithms/n2v_algorithm_model.py +35 -0
- careamics/config/algorithms/unet_algorithm_model.py +88 -0
- careamics/config/{vae_algorithm_model.py → algorithms/vae_algorithm_model.py} +26 -23
- careamics/config/architectures/__init__.py +1 -11
- careamics/config/architectures/architecture_model.py +3 -3
- careamics/config/architectures/lvae_model.py +109 -21
- careamics/config/architectures/unet_model.py +1 -0
- careamics/config/care_configuration.py +100 -0
- careamics/config/configuration.py +354 -0
- careamics/config/{configuration_factory.py → configuration_factories.py} +152 -81
- careamics/config/configuration_io.py +85 -0
- careamics/config/data/__init__.py +10 -0
- careamics/config/{data_model.py → data/data_model.py} +58 -198
- careamics/config/data/n2v_data_model.py +193 -0
- careamics/config/likelihood_model.py +8 -8
- careamics/config/loss_model.py +56 -0
- careamics/config/n2n_configuration.py +101 -0
- careamics/config/n2v_configuration.py +266 -0
- careamics/config/nm_model.py +24 -25
- careamics/config/support/__init__.py +7 -7
- careamics/config/support/supported_algorithms.py +0 -3
- careamics/config/support/supported_architectures.py +0 -4
- careamics/config/transformations/__init__.py +10 -4
- careamics/config/transformations/transform_model.py +3 -3
- careamics/config/transformations/transform_unions.py +42 -0
- careamics/config/validators/validator_utils.py +3 -3
- careamics/dataset/__init__.py +2 -2
- careamics/dataset/dataset_utils/__init__.py +3 -3
- careamics/dataset/dataset_utils/dataset_utils.py +4 -6
- careamics/dataset/dataset_utils/file_utils.py +9 -9
- careamics/dataset/dataset_utils/iterate_over_files.py +4 -3
- careamics/dataset/dataset_utils/running_stats.py +22 -23
- careamics/dataset/in_memory_dataset.py +11 -12
- careamics/dataset/iterable_dataset.py +4 -4
- careamics/dataset/iterable_pred_dataset.py +2 -1
- careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
- careamics/dataset/patching/random_patching.py +11 -10
- careamics/dataset/patching/sequential_patching.py +26 -26
- careamics/dataset/patching/validate_patch_dimension.py +3 -3
- careamics/dataset/tiling/__init__.py +2 -2
- careamics/dataset/tiling/collate_tiles.py +3 -3
- careamics/dataset/tiling/lvae_tiled_patching.py +2 -1
- careamics/dataset/tiling/tiled_patching.py +11 -10
- careamics/file_io/__init__.py +5 -5
- careamics/file_io/read/__init__.py +1 -1
- careamics/file_io/read/get_func.py +2 -2
- careamics/file_io/write/__init__.py +2 -2
- careamics/lightning/__init__.py +5 -5
- careamics/lightning/callbacks/__init__.py +1 -1
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +3 -3
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +2 -1
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +2 -1
- careamics/lightning/callbacks/progress_bar_callback.py +2 -2
- careamics/lightning/lightning_module.py +69 -34
- careamics/lightning/train_data_module.py +41 -27
- careamics/losses/__init__.py +3 -3
- careamics/losses/loss_factory.py +1 -85
- careamics/losses/lvae/losses.py +223 -164
- careamics/lvae_training/calibration.py +184 -0
- careamics/lvae_training/dataset/config.py +2 -2
- careamics/lvae_training/dataset/multich_dataset.py +11 -19
- careamics/lvae_training/dataset/multifile_dataset.py +3 -2
- careamics/lvae_training/dataset/types.py +15 -26
- careamics/lvae_training/dataset/utils/index_manager.py +4 -4
- careamics/lvae_training/eval_utils.py +125 -213
- careamics/model_io/__init__.py +1 -1
- careamics/model_io/bioimage/__init__.py +1 -1
- careamics/model_io/bioimage/_readme_factory.py +26 -34
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +56 -34
- careamics/model_io/bmz_io.py +42 -42
- careamics/model_io/model_io_utils.py +9 -9
- careamics/models/layers.py +22 -20
- careamics/models/lvae/layers.py +348 -975
- careamics/models/lvae/likelihoods.py +10 -8
- careamics/models/lvae/lvae.py +214 -275
- careamics/models/lvae/noise_models.py +179 -112
- careamics/models/lvae/stochastic.py +393 -0
- careamics/models/lvae/utils.py +82 -73
- careamics/models/model_factory.py +2 -15
- careamics/models/unet.py +8 -8
- careamics/prediction_utils/__init__.py +1 -1
- careamics/prediction_utils/prediction_outputs.py +15 -15
- careamics/prediction_utils/stitch_prediction.py +6 -6
- careamics/transforms/__init__.py +5 -5
- careamics/transforms/compose.py +13 -13
- careamics/transforms/n2v_manipulate.py +3 -3
- careamics/transforms/pixel_manipulation.py +9 -9
- careamics/transforms/xy_random_rotate90.py +4 -4
- careamics/utils/__init__.py +5 -5
- careamics/utils/context.py +2 -1
- careamics/utils/lightning_utils.py +57 -0
- careamics/utils/logging.py +11 -10
- careamics/utils/serializers.py +2 -0
- careamics/utils/torch_utils.py +8 -8
- {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/METADATA +16 -13
- careamics-0.0.6.dist-info/RECORD +176 -0
- {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/WHEEL +1 -1
- careamics/config/architectures/custom_model.py +0 -162
- careamics/config/architectures/register_model.py +0 -103
- careamics/config/configuration_model.py +0 -603
- careamics/config/fcn_algorithm_model.py +0 -152
- careamics/config/references/__init__.py +0 -45
- careamics/config/references/algorithm_descriptions.py +0 -132
- careamics/config/references/references.py +0 -39
- careamics/config/transformations/transform_union.py +0 -20
- careamics-0.0.4.2.dist-info/RECORD +0 -165
- {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/licenses/LICENSE +0 -0
careamics/losses/lvae/losses.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
from typing import TYPE_CHECKING, Any, Optional, Union
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
import torch
|
|
@@ -13,20 +13,19 @@ from careamics.models.lvae.likelihoods import (
|
|
|
13
13
|
LikelihoodModule,
|
|
14
14
|
NoiseModelLikelihood,
|
|
15
15
|
)
|
|
16
|
-
from careamics.models.lvae.utils import compute_batch_mean
|
|
17
16
|
|
|
18
17
|
if TYPE_CHECKING:
|
|
19
|
-
from careamics.
|
|
18
|
+
from careamics.config import LVAELossConfig
|
|
20
19
|
|
|
21
20
|
Likelihood = Union[LikelihoodModule, GaussianLikelihood, NoiseModelLikelihood]
|
|
22
21
|
|
|
23
22
|
|
|
24
23
|
def get_reconstruction_loss(
|
|
25
|
-
reconstruction: torch.Tensor,
|
|
24
|
+
reconstruction: torch.Tensor,
|
|
26
25
|
target: torch.Tensor,
|
|
27
26
|
likelihood_obj: Likelihood,
|
|
28
27
|
) -> dict[str, torch.Tensor]:
|
|
29
|
-
"""Compute the reconstruction loss.
|
|
28
|
+
"""Compute the reconstruction loss (negative log-likelihood).
|
|
30
29
|
|
|
31
30
|
Parameters
|
|
32
31
|
----------
|
|
@@ -42,65 +41,15 @@ def get_reconstruction_loss(
|
|
|
42
41
|
|
|
43
42
|
Returns
|
|
44
43
|
-------
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
individual output channels `["ch{i}_loss"]`.
|
|
44
|
+
torch.Tensor
|
|
45
|
+
The recontruction loss (negative log-likelihood).
|
|
48
46
|
"""
|
|
49
|
-
loss_dict = _get_reconstruction_loss_vector(
|
|
50
|
-
reconstruction=reconstruction,
|
|
51
|
-
target=target,
|
|
52
|
-
likelihood_obj=likelihood_obj,
|
|
53
|
-
)
|
|
54
|
-
|
|
55
|
-
loss_dict["loss"] = loss_dict["loss"].sum() / len(reconstruction)
|
|
56
|
-
for i in range(1, 1 + target.shape[1]):
|
|
57
|
-
key = f"ch{i}_loss"
|
|
58
|
-
loss_dict[key] = loss_dict[key].sum() / len(reconstruction)
|
|
59
|
-
|
|
60
|
-
return loss_dict
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
def _get_reconstruction_loss_vector(
|
|
64
|
-
reconstruction: torch.Tensor, # TODO: naming -> predictions?
|
|
65
|
-
target: torch.Tensor,
|
|
66
|
-
likelihood_obj: LikelihoodModule,
|
|
67
|
-
) -> dict[str, torch.Tensor]:
|
|
68
|
-
"""Compute the reconstruction loss.
|
|
69
|
-
|
|
70
|
-
Parameters
|
|
71
|
-
----------
|
|
72
|
-
return_predicted_img: bool
|
|
73
|
-
If set to `True`, the besides the loss, the reconstructed image is returned.
|
|
74
|
-
Default is `False`.
|
|
75
|
-
|
|
76
|
-
Returns
|
|
77
|
-
-------
|
|
78
|
-
dict[str, torch.Tensor]
|
|
79
|
-
A dictionary containing the overall loss `["loss"]` and the loss for
|
|
80
|
-
individual output channels `["ch{i}_loss"]`. Shape of individual
|
|
81
|
-
tensors is (B, ).
|
|
82
|
-
"""
|
|
83
|
-
output = {"loss": None}
|
|
84
|
-
for i in range(1, 1 + target.shape[1]):
|
|
85
|
-
output[f"ch{i}_loss"] = None
|
|
86
|
-
|
|
87
47
|
# Compute Log likelihood
|
|
88
48
|
ll, _ = likelihood_obj(reconstruction, target) # shape: (B, C, [Z], Y, X)
|
|
89
|
-
|
|
90
|
-
output = {"loss": compute_batch_mean(-1 * ll)} # shape: (B, )
|
|
91
|
-
if ll.shape[1] > 1: # target_ch > 1
|
|
92
|
-
for i in range(1, 1 + target.shape[1]):
|
|
93
|
-
output[f"ch{i}_loss"] = compute_batch_mean(-ll[:, i - 1]) # shape: (B, )
|
|
94
|
-
else: # target_ch == 1
|
|
95
|
-
# TODO: hacky!!! Refactor this
|
|
96
|
-
assert ll.shape[1] == 1
|
|
97
|
-
output["ch1_loss"] = output["loss"]
|
|
98
|
-
output["ch2_loss"] = output["loss"]
|
|
99
|
-
|
|
100
|
-
return output
|
|
49
|
+
return -1 * ll.mean()
|
|
101
50
|
|
|
102
51
|
|
|
103
|
-
def
|
|
52
|
+
def _reconstruction_loss_musplit_denoisplit(
|
|
104
53
|
predictions: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
|
105
54
|
targets: torch.Tensor,
|
|
106
55
|
nm_likelihood: NoiseModelLikelihood,
|
|
@@ -137,62 +86,120 @@ def reconstruction_loss_musplit_denoisplit(
|
|
|
137
86
|
recons_loss : torch.Tensor
|
|
138
87
|
The reconstruction loss. Shape is (1, ).
|
|
139
88
|
"""
|
|
140
|
-
# TODO: refactor this function to make it closer to `get_reconstruction_loss`
|
|
141
|
-
# (or viceversa)
|
|
142
89
|
if predictions.shape[1] == 2 * targets.shape[1]:
|
|
143
90
|
# predictions contain both mean and log-variance
|
|
144
|
-
|
|
91
|
+
pred_mean, _ = predictions.chunk(2, dim=1)
|
|
145
92
|
else:
|
|
146
|
-
|
|
93
|
+
pred_mean = predictions
|
|
94
|
+
|
|
95
|
+
recons_loss_nm = get_reconstruction_loss(
|
|
96
|
+
reconstruction=pred_mean, target=targets, likelihood_obj=nm_likelihood
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
recons_loss_gm = get_reconstruction_loss(
|
|
100
|
+
reconstruction=predictions,
|
|
101
|
+
target=targets,
|
|
102
|
+
likelihood_obj=gaussian_likelihood,
|
|
103
|
+
)
|
|
147
104
|
|
|
148
|
-
recons_loss_nm = -1 * nm_likelihood(out_mean, targets)[0].mean()
|
|
149
|
-
recons_loss_gm = -1 * gaussian_likelihood(predictions, targets)[0].mean()
|
|
150
105
|
recons_loss = nm_weight * recons_loss_nm + gaussian_weight * recons_loss_gm
|
|
151
106
|
return recons_loss
|
|
152
107
|
|
|
153
108
|
|
|
154
|
-
def
|
|
155
|
-
|
|
109
|
+
def get_kl_divergence_loss(
|
|
110
|
+
kl_type: Literal["kl", "kl_restricted"],
|
|
111
|
+
topdown_data: dict[str, torch.Tensor],
|
|
112
|
+
rescaling: Literal["latent_dim", "image_dim"],
|
|
113
|
+
aggregation: Literal["mean", "sum"],
|
|
114
|
+
free_bits_coeff: float,
|
|
115
|
+
img_shape: Optional[tuple[int]] = None,
|
|
156
116
|
) -> torch.Tensor:
|
|
157
|
-
"""Compute the KL divergence loss
|
|
117
|
+
"""Compute the KL divergence loss.
|
|
118
|
+
|
|
119
|
+
NOTE: Description of `rescaling` methods:
|
|
120
|
+
- If "latent_dim", the KL-loss values are rescaled w.r.t. the latent space
|
|
121
|
+
dimensions (spatial + number of channels, i.e., (C, [Z], Y, X)). In this way they
|
|
122
|
+
have the same magnitude across layers.
|
|
123
|
+
- If "image_dim", the KL-loss values are rescaled w.r.t. the input image spatial
|
|
124
|
+
dimensions. In this way, the lower layers have a larger KL-loss value compared to
|
|
125
|
+
the higher layers, since the latent space and hence the KL tensor has more entries.
|
|
126
|
+
Specifically, at hierarchy `i`, the total KL loss is larger by a factor (128/i**2).
|
|
127
|
+
|
|
128
|
+
NOTE: the type of `aggregation` determines the magnitude of the KL-loss. Clearly,
|
|
129
|
+
"sum" aggregation results in a larger KL-loss value compared to "mean" by a factor
|
|
130
|
+
of `n_layers`.
|
|
131
|
+
|
|
132
|
+
NOTE: recall that sample-wise KL is obtained by summing over all dimensions,
|
|
133
|
+
including Z. Also recall that in current 3D implementation of LVAE, no downsampling
|
|
134
|
+
is done on Z. Therefore, to avoid emphasizing KL loss too much, we divide it
|
|
135
|
+
by the Z dimension of input image in every case.
|
|
158
136
|
|
|
159
137
|
Parameters
|
|
160
138
|
----------
|
|
161
|
-
|
|
139
|
+
kl_type : Literal["kl", "kl_restricted"]
|
|
140
|
+
The type of KL divergence loss to compute.
|
|
141
|
+
topdown_data : dict[str, torch.Tensor]
|
|
162
142
|
A dictionary containing information computed for each layer during the top-down
|
|
163
143
|
pass. The dictionary must include the following keys:
|
|
164
144
|
- "kl": The KL-loss values for each layer. Shape of each tensor is (B,).
|
|
165
145
|
- "z": The sampled latents for each layer. Shape of each tensor is
|
|
166
146
|
(B, layers, `z_dims[i]`, H, W).
|
|
167
|
-
|
|
168
|
-
The
|
|
169
|
-
|
|
170
|
-
|
|
147
|
+
rescaling : Literal["latent_dim", "image_dim"]
|
|
148
|
+
The rescaling method used for the KL-loss values. If "latent_dim", the KL-loss
|
|
149
|
+
values are rescaled w.r.t. the latent space dimensions (spatial + number of
|
|
150
|
+
channels, i.e., (C, [Z], Y, X)). If "image_dim", the KL-loss values are
|
|
151
|
+
rescaled w.r.t. the input image spatial dimensions.
|
|
152
|
+
aggregation : Literal["mean", "sum"]
|
|
153
|
+
The aggregation method used to combine the KL-loss values across layers. If
|
|
154
|
+
"mean", the KL-loss values are averaged across layers. If "sum", the KL-loss
|
|
155
|
+
values are summed across layers.
|
|
156
|
+
free_bits_coeff : float
|
|
157
|
+
The free bits coefficient used for the KL-loss computation.
|
|
158
|
+
img_shape : Optional[tuple[int]]
|
|
159
|
+
The shape of the input image to the LVAE model. Shape is ([Z], Y, X).
|
|
160
|
+
|
|
161
|
+
Returns
|
|
162
|
+
-------
|
|
163
|
+
kl_loss : torch.Tensor
|
|
164
|
+
The KL divergence loss. Shape is (1, ).
|
|
171
165
|
"""
|
|
172
166
|
kl = torch.cat(
|
|
173
|
-
[kl_layer.unsqueeze(1) for kl_layer in topdown_data[
|
|
167
|
+
[kl_layer.unsqueeze(1) for kl_layer in topdown_data[kl_type]],
|
|
168
|
+
dim=1,
|
|
174
169
|
) # shape: (B, n_layers)
|
|
175
|
-
# NOTE: Values are sum() and so are of the order 30000
|
|
176
170
|
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
# NOTE: we want to normalize the KL-loss w.r.t. the latent space dimensions,
|
|
180
|
-
# i.e., the number of entries in the latent space tensors (C, [Z], Y, X).
|
|
181
|
-
# We assume z has shape (B, C, [Z], Y, X), where `C = z_dims[i]`.
|
|
182
|
-
norm_factor = np.prod(topdown_data["z"][i].shape[1:])
|
|
183
|
-
kl[:, i] = kl[:, i] / norm_factor
|
|
171
|
+
# Apply free bits (& batch average)
|
|
172
|
+
kl = free_bits_kl(kl, free_bits_coeff) # shape: (n_layers,)
|
|
184
173
|
|
|
185
|
-
|
|
186
|
-
#
|
|
187
|
-
|
|
174
|
+
# In 3D case, rescale by Z dim
|
|
175
|
+
# TODO If we have downsampling in Z dimension, then this needs to change.
|
|
176
|
+
if len(img_shape) == 3:
|
|
177
|
+
kl = kl / img_shape[0]
|
|
188
178
|
|
|
179
|
+
# Rescaling
|
|
180
|
+
if rescaling == "latent_dim":
|
|
181
|
+
for i in range(len(kl)):
|
|
182
|
+
latent_dim = topdown_data["z"][i].shape[1:]
|
|
183
|
+
norm_factor = np.prod(latent_dim)
|
|
184
|
+
kl[i] = kl[i] / norm_factor
|
|
185
|
+
elif rescaling == "image_dim":
|
|
186
|
+
kl = kl / np.prod(img_shape[-2:])
|
|
189
187
|
|
|
190
|
-
|
|
188
|
+
# Aggregation
|
|
189
|
+
if aggregation == "mean":
|
|
190
|
+
kl = kl.mean() # shape: (1,)
|
|
191
|
+
elif aggregation == "sum":
|
|
192
|
+
kl = kl.sum() # shape: (1,)
|
|
193
|
+
|
|
194
|
+
return kl
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def _get_kl_divergence_loss_musplit(
|
|
191
198
|
topdown_data: dict[str, torch.Tensor],
|
|
192
199
|
img_shape: tuple[int],
|
|
193
|
-
|
|
200
|
+
kl_type: Literal["kl", "kl_restricted"],
|
|
194
201
|
) -> torch.Tensor:
|
|
195
|
-
"""Compute the KL divergence loss for
|
|
202
|
+
"""Compute the KL divergence loss for muSplit.
|
|
196
203
|
|
|
197
204
|
Parameters
|
|
198
205
|
----------
|
|
@@ -204,32 +211,57 @@ def get_kl_divergence_loss_denoisplit(
|
|
|
204
211
|
(B, layers, `z_dims[i]`, H, W).
|
|
205
212
|
img_shape : tuple[int]
|
|
206
213
|
The shape of the input image to the LVAE model. Shape is ([Z], Y, X).
|
|
207
|
-
|
|
208
|
-
The
|
|
209
|
-
To choose among ["kl", "kl_restricted", "kl_spatial", "kl_channelwise"]
|
|
210
|
-
Default is "kl"
|
|
214
|
+
kl_type : Literal["kl", "kl_restricted"]
|
|
215
|
+
The type of KL divergence loss to compute.
|
|
211
216
|
|
|
212
|
-
|
|
217
|
+
Returns
|
|
218
|
+
-------
|
|
219
|
+
kl_loss : torch.Tensor
|
|
220
|
+
The KL divergence loss for the muSplit case. Shape is (1, ).
|
|
213
221
|
"""
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
222
|
+
return get_kl_divergence_loss(
|
|
223
|
+
kl_type="kl", # TODO: hardcoded, deal in future PR
|
|
224
|
+
topdown_data=topdown_data,
|
|
225
|
+
rescaling="latent_dim",
|
|
226
|
+
aggregation="mean",
|
|
227
|
+
free_bits_coeff=0.0,
|
|
228
|
+
img_shape=img_shape,
|
|
217
229
|
)
|
|
218
230
|
|
|
219
|
-
kl_loss = free_bits_kl(kl, 1.0).sum()
|
|
220
|
-
# NOTE: as compared to uSplit kl divergence, this KL loss is larger by a factor of
|
|
221
|
-
# `n_layers` since we sum KL contributions from different layers instead of taking
|
|
222
|
-
# the mean.
|
|
223
231
|
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
232
|
+
def _get_kl_divergence_loss_denoisplit(
|
|
233
|
+
topdown_data: dict[str, torch.Tensor],
|
|
234
|
+
img_shape: tuple[int],
|
|
235
|
+
kl_type: Literal["kl", "kl_restricted"],
|
|
236
|
+
) -> torch.Tensor:
|
|
237
|
+
"""Compute the KL divergence loss for denoiSplit.
|
|
229
238
|
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
239
|
+
Parameters
|
|
240
|
+
----------
|
|
241
|
+
topdown_data : dict[str, torch.Tensor]
|
|
242
|
+
A dictionary containing information computed for each layer during the top-down
|
|
243
|
+
pass. The dictionary must include the following keys:
|
|
244
|
+
- "kl": The KL-loss values for each layer. Shape of each tensor is (B,).
|
|
245
|
+
- "z": The sampled latents for each layer. Shape of each tensor is
|
|
246
|
+
(B, layers, `z_dims[i]`, H, W).
|
|
247
|
+
img_shape : tuple[int]
|
|
248
|
+
The shape of the input image to the LVAE model. Shape is ([Z], Y, X).
|
|
249
|
+
kl_type : Literal["kl", "kl_restricted"]
|
|
250
|
+
The type of KL divergence loss to compute.
|
|
251
|
+
|
|
252
|
+
Returns
|
|
253
|
+
-------
|
|
254
|
+
kl_loss : torch.Tensor
|
|
255
|
+
The KL divergence loss for the denoiSplit case. Shape is (1, ).
|
|
256
|
+
"""
|
|
257
|
+
return get_kl_divergence_loss(
|
|
258
|
+
kl_type=kl_type,
|
|
259
|
+
topdown_data=topdown_data,
|
|
260
|
+
rescaling="image_dim",
|
|
261
|
+
aggregation="sum",
|
|
262
|
+
free_bits_coeff=1.0,
|
|
263
|
+
img_shape=img_shape,
|
|
264
|
+
)
|
|
233
265
|
|
|
234
266
|
|
|
235
267
|
# TODO: @melisande-c suggested to refactor this as a class (see PR #208)
|
|
@@ -240,7 +272,9 @@ def get_kl_divergence_loss_denoisplit(
|
|
|
240
272
|
def musplit_loss(
|
|
241
273
|
model_outputs: tuple[torch.Tensor, dict[str, Any]],
|
|
242
274
|
targets: torch.Tensor,
|
|
243
|
-
|
|
275
|
+
config: LVAELossConfig,
|
|
276
|
+
gaussian_likelihood: Optional[GaussianLikelihood],
|
|
277
|
+
noise_model_likelihood: Optional[NoiseModelLikelihood] = None, # TODO: ugly
|
|
244
278
|
) -> Optional[dict[str, torch.Tensor]]:
|
|
245
279
|
"""Loss function for muSplit.
|
|
246
280
|
|
|
@@ -252,9 +286,13 @@ def musplit_loss(
|
|
|
252
286
|
targets : torch.Tensor
|
|
253
287
|
The target image used to compute the reconstruction loss. Shape is
|
|
254
288
|
(B, `target_ch`, [Z], Y, X).
|
|
255
|
-
|
|
256
|
-
The
|
|
289
|
+
config : LVAELossConfig
|
|
290
|
+
The config for loss function (e.g., KL hyperparameters, likelihood module,
|
|
257
291
|
noise model, etc.).
|
|
292
|
+
gaussian_likelihood : GaussianLikelihood
|
|
293
|
+
The Gaussian likelihood object.
|
|
294
|
+
noise_model_likelihood : Optional[NoiseModelLikelihood]
|
|
295
|
+
The noise model likelihood object. Not used here.
|
|
258
296
|
|
|
259
297
|
Returns
|
|
260
298
|
-------
|
|
@@ -262,27 +300,35 @@ def musplit_loss(
|
|
|
262
300
|
A dictionary containing the overall loss `["loss"]`, the reconstruction loss
|
|
263
301
|
`["reconstruction_loss"]`, and the KL divergence loss `["kl_loss"]`.
|
|
264
302
|
"""
|
|
303
|
+
assert gaussian_likelihood is not None
|
|
304
|
+
|
|
265
305
|
predictions, td_data = model_outputs
|
|
266
306
|
|
|
267
307
|
# Reconstruction loss computation
|
|
268
|
-
|
|
308
|
+
recons_loss = config.reconstruction_weight * get_reconstruction_loss(
|
|
269
309
|
reconstruction=predictions,
|
|
270
310
|
target=targets,
|
|
271
|
-
likelihood_obj=
|
|
311
|
+
likelihood_obj=gaussian_likelihood,
|
|
272
312
|
)
|
|
273
|
-
recons_loss = recons_loss_dict["loss"] * loss_parameters.reconstruction_weight
|
|
274
313
|
if torch.isnan(recons_loss).any():
|
|
275
314
|
recons_loss = 0.0
|
|
276
315
|
|
|
277
316
|
# KL loss computation
|
|
278
317
|
kl_weight = get_kl_weight(
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
318
|
+
config.kl_params.annealing,
|
|
319
|
+
config.kl_params.start,
|
|
320
|
+
config.kl_params.annealtime,
|
|
321
|
+
config.kl_weight,
|
|
322
|
+
config.kl_params.current_epoch,
|
|
323
|
+
)
|
|
324
|
+
kl_loss = (
|
|
325
|
+
_get_kl_divergence_loss_musplit(
|
|
326
|
+
topdown_data=td_data,
|
|
327
|
+
img_shape=targets.shape[2:],
|
|
328
|
+
kl_type=config.kl_params.loss_type,
|
|
329
|
+
)
|
|
330
|
+
* kl_weight
|
|
284
331
|
)
|
|
285
|
-
kl_loss = kl_weight * get_kl_divergence_loss_usplit(td_data)
|
|
286
332
|
|
|
287
333
|
net_loss = recons_loss + kl_loss
|
|
288
334
|
output = {
|
|
@@ -304,7 +350,9 @@ def musplit_loss(
|
|
|
304
350
|
def denoisplit_loss(
|
|
305
351
|
model_outputs: tuple[torch.Tensor, dict[str, Any]],
|
|
306
352
|
targets: torch.Tensor,
|
|
307
|
-
|
|
353
|
+
config: LVAELossConfig,
|
|
354
|
+
gaussian_likelihood: Optional[GaussianLikelihood] = None,
|
|
355
|
+
noise_model_likelihood: Optional[NoiseModelLikelihood] = None,
|
|
308
356
|
) -> Optional[dict[str, torch.Tensor]]:
|
|
309
357
|
"""Loss function for DenoiSplit.
|
|
310
358
|
|
|
@@ -316,9 +364,12 @@ def denoisplit_loss(
|
|
|
316
364
|
targets : torch.Tensor
|
|
317
365
|
The target image used to compute the reconstruction loss. Shape is
|
|
318
366
|
(B, `target_ch`, [Z], Y, X).
|
|
319
|
-
|
|
320
|
-
The
|
|
321
|
-
|
|
367
|
+
config : LVAELossConfig
|
|
368
|
+
The config for loss function containing all loss hyperparameters.
|
|
369
|
+
gaussian_likelihood : GaussianLikelihood
|
|
370
|
+
The Gaussian likelihood object.
|
|
371
|
+
noise_model_likelihood : NoiseModelLikelihood
|
|
372
|
+
The noise model likelihood object.
|
|
322
373
|
|
|
323
374
|
Returns
|
|
324
375
|
-------
|
|
@@ -326,33 +377,35 @@ def denoisplit_loss(
|
|
|
326
377
|
A dictionary containing the overall loss `["loss"]`, the reconstruction loss
|
|
327
378
|
`["reconstruction_loss"]`, and the KL divergence loss `["kl_loss"]`.
|
|
328
379
|
"""
|
|
380
|
+
assert noise_model_likelihood is not None
|
|
381
|
+
|
|
329
382
|
predictions, td_data = model_outputs
|
|
330
383
|
|
|
331
384
|
# Reconstruction loss computation
|
|
332
|
-
|
|
385
|
+
recons_loss = config.reconstruction_weight * get_reconstruction_loss(
|
|
333
386
|
reconstruction=predictions,
|
|
334
387
|
target=targets,
|
|
335
|
-
likelihood_obj=
|
|
388
|
+
likelihood_obj=noise_model_likelihood,
|
|
336
389
|
)
|
|
337
|
-
recons_loss = recons_loss_dict["loss"] * loss_parameters.reconstruction_weight
|
|
338
390
|
if torch.isnan(recons_loss).any():
|
|
339
391
|
recons_loss = 0.0
|
|
340
392
|
|
|
341
393
|
# KL loss computation
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
)
|
|
352
|
-
kl_loss = kl_weight * get_kl_divergence_loss_denoisplit(
|
|
394
|
+
kl_weight = get_kl_weight(
|
|
395
|
+
config.kl_params.annealing,
|
|
396
|
+
config.kl_params.start,
|
|
397
|
+
config.kl_params.annealtime,
|
|
398
|
+
config.kl_weight,
|
|
399
|
+
config.kl_params.current_epoch,
|
|
400
|
+
)
|
|
401
|
+
kl_loss = (
|
|
402
|
+
_get_kl_divergence_loss_denoisplit(
|
|
353
403
|
topdown_data=td_data,
|
|
354
|
-
img_shape=targets.shape[2:],
|
|
404
|
+
img_shape=targets.shape[2:],
|
|
405
|
+
kl_type=config.kl_params.loss_type,
|
|
355
406
|
)
|
|
407
|
+
* kl_weight
|
|
408
|
+
)
|
|
356
409
|
|
|
357
410
|
net_loss = recons_loss + kl_loss
|
|
358
411
|
output = {
|
|
@@ -374,7 +427,9 @@ def denoisplit_loss(
|
|
|
374
427
|
def denoisplit_musplit_loss(
|
|
375
428
|
model_outputs: tuple[torch.Tensor, dict[str, Any]],
|
|
376
429
|
targets: torch.Tensor,
|
|
377
|
-
|
|
430
|
+
config: LVAELossConfig,
|
|
431
|
+
gaussian_likelihood: GaussianLikelihood,
|
|
432
|
+
noise_model_likelihood: NoiseModelLikelihood,
|
|
378
433
|
) -> Optional[dict[str, torch.Tensor]]:
|
|
379
434
|
"""Loss function for DenoiSplit.
|
|
380
435
|
|
|
@@ -386,9 +441,12 @@ def denoisplit_musplit_loss(
|
|
|
386
441
|
targets : torch.Tensor
|
|
387
442
|
The target image used to compute the reconstruction loss. Shape is
|
|
388
443
|
(B, `target_ch`, [Z], Y, X).
|
|
389
|
-
|
|
390
|
-
The
|
|
391
|
-
|
|
444
|
+
config : LVAELossConfig
|
|
445
|
+
The config for loss function containing all loss hyperparameters.
|
|
446
|
+
gaussian_likelihood : GaussianLikelihood
|
|
447
|
+
The Gaussian likelihood object.
|
|
448
|
+
noise_model_likelihood : NoiseModelLikelihood
|
|
449
|
+
The noise model likelihood object.
|
|
392
450
|
|
|
393
451
|
Returns
|
|
394
452
|
-------
|
|
@@ -399,34 +457,35 @@ def denoisplit_musplit_loss(
|
|
|
399
457
|
predictions, td_data = model_outputs
|
|
400
458
|
|
|
401
459
|
# Reconstruction loss computation
|
|
402
|
-
recons_loss =
|
|
460
|
+
recons_loss = _reconstruction_loss_musplit_denoisplit(
|
|
403
461
|
predictions=predictions,
|
|
404
462
|
targets=targets,
|
|
405
|
-
nm_likelihood=
|
|
406
|
-
gaussian_likelihood=
|
|
407
|
-
nm_weight=
|
|
408
|
-
gaussian_weight=
|
|
463
|
+
nm_likelihood=noise_model_likelihood,
|
|
464
|
+
gaussian_likelihood=gaussian_likelihood,
|
|
465
|
+
nm_weight=config.denoisplit_weight,
|
|
466
|
+
gaussian_weight=config.musplit_weight,
|
|
409
467
|
)
|
|
410
468
|
if torch.isnan(recons_loss).any():
|
|
411
469
|
recons_loss = 0.0
|
|
412
470
|
|
|
413
471
|
# KL loss computation
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
472
|
+
# NOTE: 'kl' key stands for the 'kl_samplewise' key in the TopDownLayer class.
|
|
473
|
+
# The different naming comes from `top_down_pass()` method in the LadderVAE.
|
|
474
|
+
denoisplit_kl = _get_kl_divergence_loss_denoisplit(
|
|
475
|
+
topdown_data=td_data,
|
|
476
|
+
img_shape=targets.shape[2:],
|
|
477
|
+
kl_type=config.kl_params.loss_type,
|
|
478
|
+
)
|
|
479
|
+
musplit_kl = _get_kl_divergence_loss_musplit(
|
|
480
|
+
topdown_data=td_data,
|
|
481
|
+
img_shape=targets.shape[2:],
|
|
482
|
+
kl_type=config.kl_params.loss_type,
|
|
483
|
+
)
|
|
484
|
+
kl_loss = (
|
|
485
|
+
config.denoisplit_weight * denoisplit_kl + config.musplit_weight * musplit_kl
|
|
486
|
+
)
|
|
487
|
+
# TODO `kl_weight` is hardcoded (???)
|
|
488
|
+
kl_loss = config.kl_weight * kl_loss
|
|
430
489
|
|
|
431
490
|
net_loss = recons_loss + kl_loss
|
|
432
491
|
output = {
|