careamics 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +39 -28
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +391 -0
- careamics/cli/main.py +134 -0
- careamics/config/__init__.py +7 -3
- careamics/config/architectures/__init__.py +2 -2
- careamics/config/architectures/architecture_model.py +1 -1
- careamics/config/architectures/custom_model.py +11 -8
- careamics/config/architectures/lvae_model.py +170 -0
- careamics/config/configuration_factory.py +481 -170
- careamics/config/configuration_model.py +6 -3
- careamics/config/data_model.py +31 -20
- careamics/config/{algorithm_model.py → fcn_algorithm_model.py} +35 -45
- careamics/config/likelihood_model.py +60 -0
- careamics/config/nm_model.py +127 -0
- careamics/config/optimizer_models.py +3 -1
- careamics/config/support/supported_activations.py +1 -0
- careamics/config/support/supported_algorithms.py +17 -4
- careamics/config/support/supported_architectures.py +8 -11
- careamics/config/support/supported_losses.py +3 -1
- careamics/config/support/supported_optimizers.py +1 -1
- careamics/config/support/supported_transforms.py +1 -0
- careamics/config/training_model.py +35 -6
- careamics/config/transformations/__init__.py +4 -1
- careamics/config/transformations/n2v_manipulate_model.py +1 -1
- careamics/config/transformations/transform_union.py +20 -0
- careamics/config/vae_algorithm_model.py +137 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +364 -0
- careamics/file_io/read/tiff.py +1 -1
- careamics/lightning/__init__.py +3 -2
- careamics/lightning/callbacks/hyperparameters_callback.py +1 -1
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +1 -1
- careamics/lightning/lightning_module.py +367 -9
- careamics/lightning/predict_data_module.py +2 -2
- careamics/lightning/train_data_module.py +4 -4
- careamics/losses/__init__.py +11 -1
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/{losses.py → fcn/losses.py} +1 -1
- careamics/losses/loss_factory.py +112 -6
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +445 -0
- careamics/lvae_training/dataset/__init__.py +15 -0
- careamics/lvae_training/dataset/config.py +123 -0
- careamics/lvae_training/dataset/lc_dataset.py +267 -0
- careamics/lvae_training/{data_modules.py → dataset/multich_dataset.py} +375 -501
- careamics/lvae_training/dataset/multifile_dataset.py +334 -0
- careamics/lvae_training/dataset/types.py +43 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +232 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +109 -64
- careamics/lvae_training/get_config.py +1 -1
- careamics/lvae_training/train_lvae.py +6 -3
- careamics/model_io/bioimage/bioimage_utils.py +1 -1
- careamics/model_io/bioimage/model_description.py +2 -2
- careamics/model_io/bmz_io.py +20 -7
- careamics/model_io/model_io_utils.py +16 -4
- careamics/models/__init__.py +1 -3
- careamics/models/activation.py +2 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +21 -21
- careamics/models/lvae/likelihoods.py +190 -129
- careamics/models/lvae/lvae.py +60 -148
- careamics/models/lvae/noise_models.py +318 -186
- careamics/models/lvae/utils.py +2 -2
- careamics/models/model_factory.py +22 -7
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/stitch_prediction.py +16 -2
- careamics/transforms/compose.py +90 -15
- careamics/transforms/n2v_manipulate.py +6 -2
- careamics/transforms/normalize.py +14 -3
- careamics/transforms/pixel_manipulation.py +1 -1
- careamics/transforms/xy_flip.py +16 -6
- careamics/transforms/xy_random_rotate90.py +16 -7
- careamics/utils/metrics.py +277 -24
- careamics/utils/serializers.py +60 -0
- {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/METADATA +5 -4
- {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/RECORD +85 -60
- careamics-0.0.4.dist-info/entry_points.txt +2 -0
- careamics/config/architectures/vae_model.py +0 -42
- careamics/lvae_training/data_utils.py +0 -618
- {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/WHEEL +0 -0
- {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def free_bits_kl(
|
|
5
|
+
kl: torch.Tensor, free_bits: float, batch_average: bool = False, eps: float = 1e-6
|
|
6
|
+
) -> torch.Tensor:
|
|
7
|
+
"""Compute free-bits version of KL divergence.
|
|
8
|
+
|
|
9
|
+
This function ensures that the KL doesn't go to zero for any latent dimension.
|
|
10
|
+
Hence, it contributes to use latent variables more efficiently, leading to
|
|
11
|
+
better representation learning.
|
|
12
|
+
|
|
13
|
+
NOTE:
|
|
14
|
+
Takes in the KL with shape (batch size, layers), returns the KL with
|
|
15
|
+
free bits (for optimization) with shape (layers,), which is the average
|
|
16
|
+
free-bits KL per layer in the current batch.
|
|
17
|
+
If batch_average is False (default), the free bits are per layer and
|
|
18
|
+
per batch element. Otherwise, the free bits are still per layer, but
|
|
19
|
+
are assigned on average to the whole batch. In both cases, the batch
|
|
20
|
+
average is returned, so it's simply a matter of doing mean(clamp(KL))
|
|
21
|
+
or clamp(mean(KL)).
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
kl : torch.Tensor
|
|
26
|
+
The KL divergence tensor with shape (batch size, layers).
|
|
27
|
+
free_bits : float
|
|
28
|
+
The free bits value. Set to 0.0 to disable free bits.
|
|
29
|
+
batch_average : bool
|
|
30
|
+
Whether to average over the batch before clamping to `free_bits`.
|
|
31
|
+
eps : float
|
|
32
|
+
A small value to avoid numerical instability.
|
|
33
|
+
|
|
34
|
+
Returns
|
|
35
|
+
-------
|
|
36
|
+
torch.Tensor
|
|
37
|
+
The free-bits version of the KL divergence with shape (layers,).
|
|
38
|
+
"""
|
|
39
|
+
assert kl.dim() == 2
|
|
40
|
+
if free_bits < eps:
|
|
41
|
+
return kl.mean(0)
|
|
42
|
+
if batch_average:
|
|
43
|
+
return kl.mean(0).clamp(min=free_bits)
|
|
44
|
+
return kl.clamp(min=free_bits).mean(0)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def get_kl_weight(
|
|
48
|
+
kl_annealing: bool,
|
|
49
|
+
kl_start: int,
|
|
50
|
+
kl_annealtime: int,
|
|
51
|
+
kl_weight: float,
|
|
52
|
+
current_epoch: int,
|
|
53
|
+
) -> float:
|
|
54
|
+
"""Compute the weight of the KL loss in case of annealing.
|
|
55
|
+
|
|
56
|
+
Parameters
|
|
57
|
+
----------
|
|
58
|
+
kl_annealing : bool
|
|
59
|
+
Whether to use KL annealing.
|
|
60
|
+
kl_start : int
|
|
61
|
+
The epoch at which to start
|
|
62
|
+
kl_annealtime : int
|
|
63
|
+
The number of epochs for which annealing is applied.
|
|
64
|
+
kl_weight : float
|
|
65
|
+
The weight for the KL loss. If `None`, the weight is computed
|
|
66
|
+
using annealing, else it is set to a default of 1.
|
|
67
|
+
current_epoch : int
|
|
68
|
+
The current epoch.
|
|
69
|
+
"""
|
|
70
|
+
if kl_annealing:
|
|
71
|
+
# calculate relative weight
|
|
72
|
+
kl_weight = (current_epoch - kl_start) * (1.0 / kl_annealtime)
|
|
73
|
+
# clamp to [0,1]
|
|
74
|
+
kl_weight = min(max(0.0, kl_weight), 1.0)
|
|
75
|
+
|
|
76
|
+
# if the final weight is given, then apply that weight on top of it
|
|
77
|
+
if kl_weight is not None:
|
|
78
|
+
kl_weight = kl_weight * kl_weight
|
|
79
|
+
elif kl_weight is not None:
|
|
80
|
+
return kl_weight
|
|
81
|
+
else:
|
|
82
|
+
kl_weight = 1.0
|
|
83
|
+
return kl_weight
|
|
@@ -0,0 +1,445 @@
|
|
|
1
|
+
"""Methods for Loss Computation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Optional, Union
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from careamics.losses.lvae.loss_utils import free_bits_kl, get_kl_weight
|
|
11
|
+
from careamics.models.lvae.likelihoods import (
|
|
12
|
+
GaussianLikelihood,
|
|
13
|
+
LikelihoodModule,
|
|
14
|
+
NoiseModelLikelihood,
|
|
15
|
+
)
|
|
16
|
+
from careamics.models.lvae.utils import compute_batch_mean
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from careamics.losses.loss_factory import LVAELossParameters
|
|
20
|
+
|
|
21
|
+
Likelihood = Union[LikelihoodModule, GaussianLikelihood, NoiseModelLikelihood]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def get_reconstruction_loss(
|
|
25
|
+
reconstruction: torch.Tensor, # TODO: naming -> predictions?
|
|
26
|
+
target: torch.Tensor,
|
|
27
|
+
likelihood_obj: Likelihood,
|
|
28
|
+
) -> dict[str, torch.Tensor]:
|
|
29
|
+
"""Compute the reconstruction loss.
|
|
30
|
+
|
|
31
|
+
Parameters
|
|
32
|
+
----------
|
|
33
|
+
reconstruction: torch.Tensor
|
|
34
|
+
The output of the LVAE decoder. Shape is (B, C, [Z], Y, X), where C is the
|
|
35
|
+
number of output channels (e.g., 1 in HDN, >1 in muSplit/denoiSplit).
|
|
36
|
+
target: torch.Tensor
|
|
37
|
+
The target image used to compute the reconstruction loss. Shape is
|
|
38
|
+
(B, C, [Z], Y, X), where C is the number of output channels
|
|
39
|
+
(e.g., 1 in HDN, >1 in muSplit/denoiSplit).
|
|
40
|
+
likelihood_obj: Likelihood
|
|
41
|
+
The likelihood object used to compute the reconstruction loss.
|
|
42
|
+
|
|
43
|
+
Returns
|
|
44
|
+
-------
|
|
45
|
+
dict[str, torch.Tensor]
|
|
46
|
+
A dictionary containing the overall loss `["loss"]` and the loss for
|
|
47
|
+
individual output channels `["ch{i}_loss"]`.
|
|
48
|
+
"""
|
|
49
|
+
loss_dict = _get_reconstruction_loss_vector(
|
|
50
|
+
reconstruction=reconstruction,
|
|
51
|
+
target=target,
|
|
52
|
+
likelihood_obj=likelihood_obj,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
loss_dict["loss"] = loss_dict["loss"].sum() / len(reconstruction)
|
|
56
|
+
for i in range(1, 1 + target.shape[1]):
|
|
57
|
+
key = f"ch{i}_loss"
|
|
58
|
+
loss_dict[key] = loss_dict[key].sum() / len(reconstruction)
|
|
59
|
+
|
|
60
|
+
return loss_dict
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _get_reconstruction_loss_vector(
|
|
64
|
+
reconstruction: torch.Tensor, # TODO: naming -> predictions?
|
|
65
|
+
target: torch.Tensor,
|
|
66
|
+
likelihood_obj: LikelihoodModule,
|
|
67
|
+
) -> dict[str, torch.Tensor]:
|
|
68
|
+
"""Compute the reconstruction loss.
|
|
69
|
+
|
|
70
|
+
Parameters
|
|
71
|
+
----------
|
|
72
|
+
return_predicted_img: bool
|
|
73
|
+
If set to `True`, the besides the loss, the reconstructed image is returned.
|
|
74
|
+
Default is `False`.
|
|
75
|
+
|
|
76
|
+
Returns
|
|
77
|
+
-------
|
|
78
|
+
dict[str, torch.Tensor]
|
|
79
|
+
A dictionary containing the overall loss `["loss"]` and the loss for
|
|
80
|
+
individual output channels `["ch{i}_loss"]`. Shape of individual
|
|
81
|
+
tensors is (B, ).
|
|
82
|
+
"""
|
|
83
|
+
output = {"loss": None}
|
|
84
|
+
for i in range(1, 1 + target.shape[1]):
|
|
85
|
+
output[f"ch{i}_loss"] = None
|
|
86
|
+
|
|
87
|
+
# Compute Log likelihood
|
|
88
|
+
ll, _ = likelihood_obj(reconstruction, target) # shape: (B, C, [Z], Y, X)
|
|
89
|
+
|
|
90
|
+
output = {"loss": compute_batch_mean(-1 * ll)} # shape: (B, )
|
|
91
|
+
if ll.shape[1] > 1: # target_ch > 1
|
|
92
|
+
for i in range(1, 1 + target.shape[1]):
|
|
93
|
+
output[f"ch{i}_loss"] = compute_batch_mean(-ll[:, i - 1]) # shape: (B, )
|
|
94
|
+
else: # target_ch == 1
|
|
95
|
+
# TODO: hacky!!! Refactor this
|
|
96
|
+
assert ll.shape[1] == 1
|
|
97
|
+
output["ch1_loss"] = output["loss"]
|
|
98
|
+
output["ch2_loss"] = output["loss"]
|
|
99
|
+
|
|
100
|
+
return output
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def reconstruction_loss_musplit_denoisplit(
|
|
104
|
+
predictions: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
|
105
|
+
targets: torch.Tensor,
|
|
106
|
+
nm_likelihood: NoiseModelLikelihood,
|
|
107
|
+
gaussian_likelihood: GaussianLikelihood,
|
|
108
|
+
nm_weight: float,
|
|
109
|
+
gaussian_weight: float,
|
|
110
|
+
) -> torch.Tensor:
|
|
111
|
+
"""Compute the reconstruction loss for muSplit-denoiSplit loss.
|
|
112
|
+
|
|
113
|
+
The resulting loss is a weighted mean of the noise model likelihood and the
|
|
114
|
+
Gaussian likelihood.
|
|
115
|
+
|
|
116
|
+
Parameters
|
|
117
|
+
----------
|
|
118
|
+
predictions : torch.Tensor
|
|
119
|
+
The output of the LVAE decoder. Shape is (B, C, [Z], Y, X), or
|
|
120
|
+
(B, 2*C, [Z], Y, X), where C is the number of output channels,
|
|
121
|
+
and the factor of 2 is for the case of predicted log-variance.
|
|
122
|
+
targets : torch.Tensor
|
|
123
|
+
The target image used to compute the reconstruction loss. Shape is
|
|
124
|
+
(B, C, [Z], Y, X), where C is the number of output channels
|
|
125
|
+
(e.g., 1 in HDN, >1 in muSplit/denoiSplit).
|
|
126
|
+
nm_likelihood : NoiseModelLikelihood
|
|
127
|
+
A `NoiseModelLikelihood` object used to compute the noise model likelihood.
|
|
128
|
+
gaussian_likelihood : GaussianLikelihood
|
|
129
|
+
A `GaussianLikelihood` object used to compute the Gaussian likelihood.
|
|
130
|
+
nm_weight : float
|
|
131
|
+
The weight for the noise model likelihood.
|
|
132
|
+
gaussian_weight : float
|
|
133
|
+
The weight for the Gaussian likelihood.
|
|
134
|
+
|
|
135
|
+
Returns
|
|
136
|
+
-------
|
|
137
|
+
recons_loss : torch.Tensor
|
|
138
|
+
The reconstruction loss. Shape is (1, ).
|
|
139
|
+
"""
|
|
140
|
+
# TODO: refactor this function to make it closer to `get_reconstruction_loss`
|
|
141
|
+
# (or viceversa)
|
|
142
|
+
if predictions.shape[1] == 2 * targets.shape[1]:
|
|
143
|
+
# predictions contain both mean and log-variance
|
|
144
|
+
out_mean, _ = predictions.chunk(2, dim=1)
|
|
145
|
+
else:
|
|
146
|
+
out_mean = predictions
|
|
147
|
+
|
|
148
|
+
recons_loss_nm = -1 * nm_likelihood(out_mean, targets)[0].mean()
|
|
149
|
+
recons_loss_gm = -1 * gaussian_likelihood(predictions, targets)[0].mean()
|
|
150
|
+
recons_loss = nm_weight * recons_loss_nm + gaussian_weight * recons_loss_gm
|
|
151
|
+
return recons_loss
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def get_kl_divergence_loss_usplit(
|
|
155
|
+
topdown_data: dict[str, list[torch.Tensor]], kl_key: str = "kl"
|
|
156
|
+
) -> torch.Tensor:
|
|
157
|
+
"""Compute the KL divergence loss for muSplit.
|
|
158
|
+
|
|
159
|
+
Parameters
|
|
160
|
+
----------
|
|
161
|
+
topdown_data : dict[str, list[torch.Tensor]]
|
|
162
|
+
A dictionary containing information computed for each layer during the top-down
|
|
163
|
+
pass. The dictionary must include the following keys:
|
|
164
|
+
- "kl": The KL-loss values for each layer. Shape of each tensor is (B,).
|
|
165
|
+
- "z": The sampled latents for each layer. Shape of each tensor is
|
|
166
|
+
(B, layers, `z_dims[i]`, H, W).
|
|
167
|
+
kl_key : str
|
|
168
|
+
The key for the KL-loss values in the top-down layer data dictionary.
|
|
169
|
+
To choose among ["kl", "kl_restricted", "kl_spatial", "kl_channelwise"]
|
|
170
|
+
Default is "kl".
|
|
171
|
+
"""
|
|
172
|
+
kl = torch.cat(
|
|
173
|
+
[kl_layer.unsqueeze(1) for kl_layer in topdown_data[kl_key]], dim=1
|
|
174
|
+
) # shape: (B, n_layers)
|
|
175
|
+
# NOTE: Values are sum() and so are of the order 30000
|
|
176
|
+
|
|
177
|
+
nlayers = kl.shape[1]
|
|
178
|
+
for i in range(nlayers):
|
|
179
|
+
# NOTE: we want to normalize the KL-loss w.r.t. the latent space dimensions,
|
|
180
|
+
# i.e., the number of entries in the latent space tensors (C, [Z], Y, X).
|
|
181
|
+
# We assume z has shape (B, C, [Z], Y, X), where `C = z_dims[i]`.
|
|
182
|
+
norm_factor = np.prod(topdown_data["z"][i].shape[1:])
|
|
183
|
+
kl[:, i] = kl[:, i] / norm_factor
|
|
184
|
+
|
|
185
|
+
kl_loss = free_bits_kl(kl, 0.0).mean() # shape: (1, )
|
|
186
|
+
# NOTE: free_bits disabled!
|
|
187
|
+
return kl_loss
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def get_kl_divergence_loss_denoisplit(
|
|
191
|
+
topdown_data: dict[str, torch.Tensor],
|
|
192
|
+
img_shape: tuple[int],
|
|
193
|
+
kl_key: str = "kl",
|
|
194
|
+
) -> torch.Tensor:
|
|
195
|
+
"""Compute the KL divergence loss for denoiSplit.
|
|
196
|
+
|
|
197
|
+
Parameters
|
|
198
|
+
----------
|
|
199
|
+
topdown_data : dict[str, torch.Tensor]
|
|
200
|
+
A dictionary containing information computed for each layer during the top-down
|
|
201
|
+
pass. The dictionary must include the following keys:
|
|
202
|
+
- "kl": The KL-loss values for each layer. Shape of each tensor is (B,).
|
|
203
|
+
- "z": The sampled latents for each layer. Shape of each tensor is
|
|
204
|
+
(B, layers, `z_dims[i]`, H, W).
|
|
205
|
+
img_shape : tuple[int]
|
|
206
|
+
The shape of the input image to the LVAE model. Shape is ([Z], Y, X).
|
|
207
|
+
kl_key : str
|
|
208
|
+
The key for the KL-loss values in the top-down layer data dictionary.
|
|
209
|
+
To choose among ["kl", "kl_restricted", "kl_spatial", "kl_channelwise"]
|
|
210
|
+
Default is "kl"
|
|
211
|
+
|
|
212
|
+
kl[i] for each i has length batch_size resulting kl shape: (bs, layers).
|
|
213
|
+
"""
|
|
214
|
+
kl = torch.cat(
|
|
215
|
+
[kl_layer.unsqueeze(1) for kl_layer in topdown_data[kl_key]],
|
|
216
|
+
dim=1,
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
kl_loss = free_bits_kl(kl, 1.0).sum()
|
|
220
|
+
# NOTE: as compared to uSplit kl divergence, this KL loss is larger by a factor of
|
|
221
|
+
# `n_layers` since we sum KL contributions from different layers instead of taking
|
|
222
|
+
# the mean.
|
|
223
|
+
|
|
224
|
+
# NOTE: at each hierarchy, the KL loss is larger by a factor of (128/i**2).
|
|
225
|
+
# 128/(2*2) = 32 (bottommost layer)
|
|
226
|
+
# 128/(4*4) = 8
|
|
227
|
+
# 128/(8*8) = 2
|
|
228
|
+
# 128/(16*16) = 0.5 (topmost layer)
|
|
229
|
+
|
|
230
|
+
# Normalize the KL-loss w.r.t. the input image spatial dimensions (e.g., 64x64)
|
|
231
|
+
kl_loss = kl_loss / np.prod(img_shape)
|
|
232
|
+
return kl_loss
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
# TODO: @melisande-c suggested to refactor this as a class (see PR #208)
|
|
236
|
+
# - loss computation happens by calling the `__call__` method
|
|
237
|
+
# - `__init__` method initializes the loss parameters now contained in
|
|
238
|
+
# the `LVAELossParameters` class
|
|
239
|
+
# NOTE: same for the other loss functions
|
|
240
|
+
def musplit_loss(
|
|
241
|
+
model_outputs: tuple[torch.Tensor, dict[str, Any]],
|
|
242
|
+
targets: torch.Tensor,
|
|
243
|
+
loss_parameters: LVAELossParameters,
|
|
244
|
+
) -> Optional[dict[str, torch.Tensor]]:
|
|
245
|
+
"""Loss function for muSplit.
|
|
246
|
+
|
|
247
|
+
Parameters
|
|
248
|
+
----------
|
|
249
|
+
model_outputs : tuple[torch.Tensor, dict[str, Any]]
|
|
250
|
+
Tuple containing the model predictions (shape is (B, `target_ch`, [Z], Y, X))
|
|
251
|
+
and the top-down layer data (e.g., sampled latents, KL-loss values, etc.).
|
|
252
|
+
targets : torch.Tensor
|
|
253
|
+
The target image used to compute the reconstruction loss. Shape is
|
|
254
|
+
(B, `target_ch`, [Z], Y, X).
|
|
255
|
+
loss_parameters : LVAELossParameters
|
|
256
|
+
The loss parameters for muSplit (e.g., KL hyperparameters, likelihood module,
|
|
257
|
+
noise model, etc.).
|
|
258
|
+
|
|
259
|
+
Returns
|
|
260
|
+
-------
|
|
261
|
+
output : Optional[dict[str, torch.Tensor]]
|
|
262
|
+
A dictionary containing the overall loss `["loss"]`, the reconstruction loss
|
|
263
|
+
`["reconstruction_loss"]`, and the KL divergence loss `["kl_loss"]`.
|
|
264
|
+
"""
|
|
265
|
+
predictions, td_data = model_outputs
|
|
266
|
+
|
|
267
|
+
# Reconstruction loss computation
|
|
268
|
+
recons_loss_dict = get_reconstruction_loss(
|
|
269
|
+
reconstruction=predictions,
|
|
270
|
+
target=targets,
|
|
271
|
+
likelihood_obj=loss_parameters.gaussian_likelihood,
|
|
272
|
+
)
|
|
273
|
+
recons_loss = recons_loss_dict["loss"] * loss_parameters.reconstruction_weight
|
|
274
|
+
if torch.isnan(recons_loss).any():
|
|
275
|
+
recons_loss = 0.0
|
|
276
|
+
|
|
277
|
+
# KL loss computation
|
|
278
|
+
kl_weight = get_kl_weight(
|
|
279
|
+
loss_parameters.kl_annealing,
|
|
280
|
+
loss_parameters.kl_start,
|
|
281
|
+
loss_parameters.kl_annealtime,
|
|
282
|
+
loss_parameters.kl_weight,
|
|
283
|
+
loss_parameters.current_epoch,
|
|
284
|
+
)
|
|
285
|
+
kl_loss = kl_weight * get_kl_divergence_loss_usplit(td_data)
|
|
286
|
+
|
|
287
|
+
net_loss = recons_loss + kl_loss
|
|
288
|
+
output = {
|
|
289
|
+
"loss": net_loss,
|
|
290
|
+
"reconstruction_loss": (
|
|
291
|
+
recons_loss.detach()
|
|
292
|
+
if isinstance(recons_loss, torch.Tensor)
|
|
293
|
+
else recons_loss
|
|
294
|
+
),
|
|
295
|
+
"kl_loss": kl_loss.detach(),
|
|
296
|
+
}
|
|
297
|
+
# https://github.com/openai/vdvae/blob/main/train.py#L26
|
|
298
|
+
if torch.isnan(net_loss).any():
|
|
299
|
+
return None
|
|
300
|
+
|
|
301
|
+
return output
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def denoisplit_loss(
|
|
305
|
+
model_outputs: tuple[torch.Tensor, dict[str, Any]],
|
|
306
|
+
targets: torch.Tensor,
|
|
307
|
+
loss_parameters: LVAELossParameters,
|
|
308
|
+
) -> Optional[dict[str, torch.Tensor]]:
|
|
309
|
+
"""Loss function for DenoiSplit.
|
|
310
|
+
|
|
311
|
+
Parameters
|
|
312
|
+
----------
|
|
313
|
+
model_outputs : tuple[torch.Tensor, dict[str, Any]]
|
|
314
|
+
Tuple containing the model predictions (shape is (B, `target_ch`, [Z], Y, X))
|
|
315
|
+
and the top-down layer data (e.g., sampled latents, KL-loss values, etc.).
|
|
316
|
+
targets : torch.Tensor
|
|
317
|
+
The target image used to compute the reconstruction loss. Shape is
|
|
318
|
+
(B, `target_ch`, [Z], Y, X).
|
|
319
|
+
loss_parameters : LVAELossParameters
|
|
320
|
+
The loss parameters for muSplit (e.g., KL hyperparameters, likelihood module,
|
|
321
|
+
noise model, etc.).
|
|
322
|
+
|
|
323
|
+
Returns
|
|
324
|
+
-------
|
|
325
|
+
output : Optional[dict[str, torch.Tensor]]
|
|
326
|
+
A dictionary containing the overall loss `["loss"]`, the reconstruction loss
|
|
327
|
+
`["reconstruction_loss"]`, and the KL divergence loss `["kl_loss"]`.
|
|
328
|
+
"""
|
|
329
|
+
predictions, td_data = model_outputs
|
|
330
|
+
|
|
331
|
+
# Reconstruction loss computation
|
|
332
|
+
recons_loss_dict = get_reconstruction_loss(
|
|
333
|
+
reconstruction=predictions,
|
|
334
|
+
target=targets,
|
|
335
|
+
likelihood_obj=loss_parameters.noise_model_likelihood,
|
|
336
|
+
)
|
|
337
|
+
recons_loss = recons_loss_dict["loss"] * loss_parameters.reconstruction_weight
|
|
338
|
+
if torch.isnan(recons_loss).any():
|
|
339
|
+
recons_loss = 0.0
|
|
340
|
+
|
|
341
|
+
# KL loss computation
|
|
342
|
+
if loss_parameters.non_stochastic: # TODO always false ?
|
|
343
|
+
kl_loss = torch.Tensor([0.0]).cuda()
|
|
344
|
+
else:
|
|
345
|
+
kl_weight = get_kl_weight(
|
|
346
|
+
loss_parameters.kl_annealing,
|
|
347
|
+
loss_parameters.kl_start,
|
|
348
|
+
loss_parameters.kl_annealtime,
|
|
349
|
+
loss_parameters.kl_weight,
|
|
350
|
+
loss_parameters.current_epoch,
|
|
351
|
+
)
|
|
352
|
+
kl_loss = kl_weight * get_kl_divergence_loss_denoisplit(
|
|
353
|
+
topdown_data=td_data,
|
|
354
|
+
img_shape=targets.shape[2:], # input img spatial dims
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
net_loss = recons_loss + kl_loss
|
|
358
|
+
output = {
|
|
359
|
+
"loss": net_loss,
|
|
360
|
+
"reconstruction_loss": (
|
|
361
|
+
recons_loss.detach()
|
|
362
|
+
if isinstance(recons_loss, torch.Tensor)
|
|
363
|
+
else recons_loss
|
|
364
|
+
),
|
|
365
|
+
"kl_loss": kl_loss.detach(),
|
|
366
|
+
}
|
|
367
|
+
# https://github.com/openai/vdvae/blob/main/train.py#L26
|
|
368
|
+
if torch.isnan(net_loss).any():
|
|
369
|
+
return None
|
|
370
|
+
|
|
371
|
+
return output
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
def denoisplit_musplit_loss(
|
|
375
|
+
model_outputs: tuple[torch.Tensor, dict[str, Any]],
|
|
376
|
+
targets: torch.Tensor,
|
|
377
|
+
loss_parameters: LVAELossParameters,
|
|
378
|
+
) -> Optional[dict[str, torch.Tensor]]:
|
|
379
|
+
"""Loss function for DenoiSplit.
|
|
380
|
+
|
|
381
|
+
Parameters
|
|
382
|
+
----------
|
|
383
|
+
model_outputs : tuple[torch.Tensor, dict[str, Any]]
|
|
384
|
+
Tuple containing the model predictions (shape is (B, `target_ch`, [Z], Y, X))
|
|
385
|
+
and the top-down layer data (e.g., sampled latents, KL-loss values, etc.).
|
|
386
|
+
targets : torch.Tensor
|
|
387
|
+
The target image used to compute the reconstruction loss. Shape is
|
|
388
|
+
(B, `target_ch`, [Z], Y, X).
|
|
389
|
+
loss_parameters : LVAELossParameters
|
|
390
|
+
The loss parameters for muSplit (e.g., KL hyperparameters, likelihood module,
|
|
391
|
+
noise model, etc.).
|
|
392
|
+
|
|
393
|
+
Returns
|
|
394
|
+
-------
|
|
395
|
+
output : Optional[dict[str, torch.Tensor]]
|
|
396
|
+
A dictionary containing the overall loss `["loss"]`, the reconstruction loss
|
|
397
|
+
`["reconstruction_loss"]`, and the KL divergence loss `["kl_loss"]`.
|
|
398
|
+
"""
|
|
399
|
+
predictions, td_data = model_outputs
|
|
400
|
+
|
|
401
|
+
# Reconstruction loss computation
|
|
402
|
+
recons_loss = reconstruction_loss_musplit_denoisplit(
|
|
403
|
+
predictions=predictions,
|
|
404
|
+
targets=targets,
|
|
405
|
+
nm_likelihood=loss_parameters.noise_model_likelihood,
|
|
406
|
+
gaussian_likelihood=loss_parameters.gaussian_likelihood,
|
|
407
|
+
nm_weight=loss_parameters.denoisplit_weight,
|
|
408
|
+
gaussian_weight=loss_parameters.musplit_weight,
|
|
409
|
+
)
|
|
410
|
+
if torch.isnan(recons_loss).any():
|
|
411
|
+
recons_loss = 0.0
|
|
412
|
+
|
|
413
|
+
# KL loss computation
|
|
414
|
+
if loss_parameters.non_stochastic: # TODO always false ?
|
|
415
|
+
kl_loss = torch.Tensor([0.0]).cuda()
|
|
416
|
+
else:
|
|
417
|
+
# NOTE: 'kl' key stands for the 'kl_samplewise' key in the TopDownLayer class.
|
|
418
|
+
# The different naming comes from `top_down_pass()` method in the LadderVAE.
|
|
419
|
+
denoisplit_kl = get_kl_divergence_loss_denoisplit(
|
|
420
|
+
topdown_data=td_data,
|
|
421
|
+
img_shape=targets.shape[2:], # input img spatial dims
|
|
422
|
+
)
|
|
423
|
+
musplit_kl = get_kl_divergence_loss_usplit(td_data)
|
|
424
|
+
kl_loss = (
|
|
425
|
+
loss_parameters.denoisplit_weight * denoisplit_kl
|
|
426
|
+
+ loss_parameters.musplit_weight * musplit_kl
|
|
427
|
+
)
|
|
428
|
+
# TODO `kl_weight` is hardcoded (???)
|
|
429
|
+
kl_loss = loss_parameters.kl_weight * kl_loss
|
|
430
|
+
|
|
431
|
+
net_loss = recons_loss + kl_loss
|
|
432
|
+
output = {
|
|
433
|
+
"loss": net_loss,
|
|
434
|
+
"reconstruction_loss": (
|
|
435
|
+
recons_loss.detach()
|
|
436
|
+
if isinstance(recons_loss, torch.Tensor)
|
|
437
|
+
else recons_loss
|
|
438
|
+
),
|
|
439
|
+
"kl_loss": kl_loss.detach(),
|
|
440
|
+
}
|
|
441
|
+
# https://github.com/openai/vdvae/blob/main/train.py#L26
|
|
442
|
+
if torch.isnan(net_loss).any():
|
|
443
|
+
return None
|
|
444
|
+
|
|
445
|
+
return output
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from .multich_dataset import MultiChDloader
|
|
2
|
+
from .lc_dataset import LCMultiChDloader
|
|
3
|
+
from .multifile_dataset import MultiFileDset
|
|
4
|
+
from .config import DatasetConfig
|
|
5
|
+
from .types import DataType, DataSplitType, TilingMode
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"DatasetConfig",
|
|
9
|
+
"MultiChDloader",
|
|
10
|
+
"LCMultiChDloader",
|
|
11
|
+
"MultiFileDset",
|
|
12
|
+
"DataType",
|
|
13
|
+
"DataSplitType",
|
|
14
|
+
"TilingMode",
|
|
15
|
+
]
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
from typing import Any, Optional
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, ConfigDict
|
|
4
|
+
|
|
5
|
+
from .types import DataType, DataSplitType, TilingMode
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
# TODO: check if any bool logic can be removed
|
|
9
|
+
class DatasetConfig(BaseModel):
|
|
10
|
+
model_config = ConfigDict(validate_assignment=True, extra="forbid")
|
|
11
|
+
|
|
12
|
+
data_type: Optional[DataType]
|
|
13
|
+
"""Type of the dataset, should be one of DataType"""
|
|
14
|
+
|
|
15
|
+
depth3D: Optional[int] = 1
|
|
16
|
+
"""Number of slices in 3D. If data is 2D depth3D is equal to 1"""
|
|
17
|
+
|
|
18
|
+
datasplit_type: Optional[DataSplitType] = None
|
|
19
|
+
"""Whether to return training, validation or test split, should be one of
|
|
20
|
+
DataSplitType"""
|
|
21
|
+
|
|
22
|
+
num_channels: Optional[int] = 2
|
|
23
|
+
"""Number of channels in the input"""
|
|
24
|
+
|
|
25
|
+
# TODO: remove ch*_fname parameters, should be parsed automatically from a name list
|
|
26
|
+
ch1_fname: Optional[str] = None
|
|
27
|
+
ch2_fname: Optional[str] = None
|
|
28
|
+
ch_input_fname: Optional[str] = None
|
|
29
|
+
|
|
30
|
+
input_is_sum: Optional[bool] = False
|
|
31
|
+
"""Whether the input is the sum or average of channels"""
|
|
32
|
+
|
|
33
|
+
input_idx: Optional[int] = None
|
|
34
|
+
"""Index of the channel where the input is stored in the data"""
|
|
35
|
+
|
|
36
|
+
target_idx_list: Optional[list[int]] = None
|
|
37
|
+
"""Indices of the channels where the targets are stored in the data"""
|
|
38
|
+
|
|
39
|
+
# TODO: where are there used?
|
|
40
|
+
start_alpha: Optional[Any] = None
|
|
41
|
+
end_alpha: Optional[Any] = None
|
|
42
|
+
|
|
43
|
+
image_size: int
|
|
44
|
+
"""Size of one patch of data"""
|
|
45
|
+
|
|
46
|
+
grid_size: Optional[int] = None
|
|
47
|
+
"""Frame is divided into square grids of this size. A patch centered on a grid
|
|
48
|
+
having size `image_size` is returned. Grid size not used in training,
|
|
49
|
+
used only during val / test, grid size controls the overlap of the patches"""
|
|
50
|
+
|
|
51
|
+
empty_patch_replacement_enabled: Optional[bool] = False
|
|
52
|
+
"""Whether to replace the content of one of the channels
|
|
53
|
+
with background with given probability"""
|
|
54
|
+
empty_patch_replacement_channel_idx: Optional[Any] = None
|
|
55
|
+
empty_patch_replacement_probab: Optional[Any] = None
|
|
56
|
+
empty_patch_max_val_threshold: Optional[Any] = None
|
|
57
|
+
|
|
58
|
+
uncorrelated_channels: Optional[bool] = False
|
|
59
|
+
"""Replace the content in one of the channels with given probability to make
|
|
60
|
+
channel content 'uncorrelated'"""
|
|
61
|
+
uncorrelated_channel_probab: Optional[float] = 0.5
|
|
62
|
+
|
|
63
|
+
poisson_noise_factor: Optional[float] = -1
|
|
64
|
+
"""The added poisson noise factor"""
|
|
65
|
+
|
|
66
|
+
synthetic_gaussian_scale: Optional[float] = 0.1
|
|
67
|
+
|
|
68
|
+
# TODO: set to True in training code, recheck
|
|
69
|
+
input_has_dependant_noise: Optional[bool] = False
|
|
70
|
+
|
|
71
|
+
# TODO: sometimes max_val differs between runs with fixed seeds with noise enabled
|
|
72
|
+
enable_gaussian_noise: Optional[bool] = False
|
|
73
|
+
"""Whether to enable gaussian noise"""
|
|
74
|
+
|
|
75
|
+
# TODO: is this parameter used?
|
|
76
|
+
allow_generation: bool = False
|
|
77
|
+
|
|
78
|
+
# TODO: both used in IndexSwitcher, insure correct passing
|
|
79
|
+
training_validtarget_fraction: Any = None
|
|
80
|
+
deterministic_grid: Any = None
|
|
81
|
+
|
|
82
|
+
# TODO: why is this not used?
|
|
83
|
+
enable_rotation_aug: Optional[bool] = False
|
|
84
|
+
|
|
85
|
+
max_val: Optional[float] = None
|
|
86
|
+
"""Maximum data in the dataset. Is calculated for train split, and should be
|
|
87
|
+
externally set for val and test splits."""
|
|
88
|
+
|
|
89
|
+
overlapping_padding_kwargs: Any = None
|
|
90
|
+
"""Parameters for np.pad method"""
|
|
91
|
+
|
|
92
|
+
# TODO: remove this parameter, controls debug print
|
|
93
|
+
print_vars: Optional[bool] = False
|
|
94
|
+
|
|
95
|
+
# Hard-coded parameters (used to be in the config file)
|
|
96
|
+
normalized_input: bool = True
|
|
97
|
+
"""If this is set to true, then one mean and stdev is used
|
|
98
|
+
for both channels. Otherwise, two different mean and stdev are used."""
|
|
99
|
+
use_one_mu_std: Optional[bool] = True
|
|
100
|
+
|
|
101
|
+
# TODO: is this parameter used?
|
|
102
|
+
train_aug_rotate: Optional[bool] = False
|
|
103
|
+
enable_random_cropping: Optional[bool] = True
|
|
104
|
+
|
|
105
|
+
multiscale_lowres_count: Optional[int] = None
|
|
106
|
+
"""Number of LC scales"""
|
|
107
|
+
|
|
108
|
+
tiling_mode: Optional[TilingMode] = TilingMode.ShiftBoundary
|
|
109
|
+
|
|
110
|
+
target_separate_normalization: Optional[bool] = True
|
|
111
|
+
|
|
112
|
+
mode_3D: Optional[bool] = False
|
|
113
|
+
"""If training in 3D mode or not"""
|
|
114
|
+
|
|
115
|
+
trainig_datausage_fraction: Optional[float] = 1.0
|
|
116
|
+
|
|
117
|
+
validtarget_random_fraction: Optional[float] = None
|
|
118
|
+
|
|
119
|
+
validation_datausage_fraction: Optional[float] = 1.0
|
|
120
|
+
|
|
121
|
+
random_flip_z_3D: Optional[bool] = False
|
|
122
|
+
|
|
123
|
+
padding_kwargs: Optional[dict] = None
|