careamics 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (87) hide show
  1. careamics/careamist.py +39 -28
  2. careamics/cli/__init__.py +5 -0
  3. careamics/cli/conf.py +391 -0
  4. careamics/cli/main.py +134 -0
  5. careamics/config/__init__.py +7 -3
  6. careamics/config/architectures/__init__.py +2 -2
  7. careamics/config/architectures/architecture_model.py +1 -1
  8. careamics/config/architectures/custom_model.py +11 -8
  9. careamics/config/architectures/lvae_model.py +170 -0
  10. careamics/config/configuration_factory.py +481 -170
  11. careamics/config/configuration_model.py +6 -3
  12. careamics/config/data_model.py +31 -20
  13. careamics/config/{algorithm_model.py → fcn_algorithm_model.py} +35 -45
  14. careamics/config/likelihood_model.py +60 -0
  15. careamics/config/nm_model.py +127 -0
  16. careamics/config/optimizer_models.py +3 -1
  17. careamics/config/support/supported_activations.py +1 -0
  18. careamics/config/support/supported_algorithms.py +17 -4
  19. careamics/config/support/supported_architectures.py +8 -11
  20. careamics/config/support/supported_losses.py +3 -1
  21. careamics/config/support/supported_optimizers.py +1 -1
  22. careamics/config/support/supported_transforms.py +1 -0
  23. careamics/config/training_model.py +35 -6
  24. careamics/config/transformations/__init__.py +4 -1
  25. careamics/config/transformations/n2v_manipulate_model.py +1 -1
  26. careamics/config/transformations/transform_union.py +20 -0
  27. careamics/config/vae_algorithm_model.py +137 -0
  28. careamics/dataset/tiling/lvae_tiled_patching.py +364 -0
  29. careamics/file_io/read/tiff.py +1 -1
  30. careamics/lightning/__init__.py +3 -2
  31. careamics/lightning/callbacks/hyperparameters_callback.py +1 -1
  32. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +1 -1
  33. careamics/lightning/lightning_module.py +367 -9
  34. careamics/lightning/predict_data_module.py +2 -2
  35. careamics/lightning/train_data_module.py +4 -4
  36. careamics/losses/__init__.py +11 -1
  37. careamics/losses/fcn/__init__.py +1 -0
  38. careamics/losses/{losses.py → fcn/losses.py} +1 -1
  39. careamics/losses/loss_factory.py +112 -6
  40. careamics/losses/lvae/__init__.py +1 -0
  41. careamics/losses/lvae/loss_utils.py +83 -0
  42. careamics/losses/lvae/losses.py +445 -0
  43. careamics/lvae_training/dataset/__init__.py +15 -0
  44. careamics/lvae_training/dataset/config.py +123 -0
  45. careamics/lvae_training/dataset/lc_dataset.py +267 -0
  46. careamics/lvae_training/{data_modules.py → dataset/multich_dataset.py} +375 -501
  47. careamics/lvae_training/dataset/multifile_dataset.py +334 -0
  48. careamics/lvae_training/dataset/types.py +43 -0
  49. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  50. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  51. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  52. careamics/lvae_training/dataset/utils/index_manager.py +232 -0
  53. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  54. careamics/lvae_training/eval_utils.py +109 -64
  55. careamics/lvae_training/get_config.py +1 -1
  56. careamics/lvae_training/train_lvae.py +6 -3
  57. careamics/model_io/bioimage/bioimage_utils.py +1 -1
  58. careamics/model_io/bioimage/model_description.py +2 -2
  59. careamics/model_io/bmz_io.py +20 -7
  60. careamics/model_io/model_io_utils.py +16 -4
  61. careamics/models/__init__.py +1 -3
  62. careamics/models/activation.py +2 -0
  63. careamics/models/lvae/__init__.py +3 -0
  64. careamics/models/lvae/layers.py +21 -21
  65. careamics/models/lvae/likelihoods.py +190 -129
  66. careamics/models/lvae/lvae.py +60 -148
  67. careamics/models/lvae/noise_models.py +318 -186
  68. careamics/models/lvae/utils.py +2 -2
  69. careamics/models/model_factory.py +22 -7
  70. careamics/prediction_utils/lvae_prediction.py +158 -0
  71. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  72. careamics/prediction_utils/stitch_prediction.py +16 -2
  73. careamics/transforms/compose.py +90 -15
  74. careamics/transforms/n2v_manipulate.py +6 -2
  75. careamics/transforms/normalize.py +14 -3
  76. careamics/transforms/pixel_manipulation.py +1 -1
  77. careamics/transforms/xy_flip.py +16 -6
  78. careamics/transforms/xy_random_rotate90.py +16 -7
  79. careamics/utils/metrics.py +277 -24
  80. careamics/utils/serializers.py +60 -0
  81. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/METADATA +5 -4
  82. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/RECORD +85 -60
  83. careamics-0.0.4.dist-info/entry_points.txt +2 -0
  84. careamics/config/architectures/vae_model.py +0 -42
  85. careamics/lvae_training/data_utils.py +0 -618
  86. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/WHEEL +0 -0
  87. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,83 @@
1
+ import torch
2
+
3
+
4
+ def free_bits_kl(
5
+ kl: torch.Tensor, free_bits: float, batch_average: bool = False, eps: float = 1e-6
6
+ ) -> torch.Tensor:
7
+ """Compute free-bits version of KL divergence.
8
+
9
+ This function ensures that the KL doesn't go to zero for any latent dimension.
10
+ Hence, it contributes to use latent variables more efficiently, leading to
11
+ better representation learning.
12
+
13
+ NOTE:
14
+ Takes in the KL with shape (batch size, layers), returns the KL with
15
+ free bits (for optimization) with shape (layers,), which is the average
16
+ free-bits KL per layer in the current batch.
17
+ If batch_average is False (default), the free bits are per layer and
18
+ per batch element. Otherwise, the free bits are still per layer, but
19
+ are assigned on average to the whole batch. In both cases, the batch
20
+ average is returned, so it's simply a matter of doing mean(clamp(KL))
21
+ or clamp(mean(KL)).
22
+
23
+ Parameters
24
+ ----------
25
+ kl : torch.Tensor
26
+ The KL divergence tensor with shape (batch size, layers).
27
+ free_bits : float
28
+ The free bits value. Set to 0.0 to disable free bits.
29
+ batch_average : bool
30
+ Whether to average over the batch before clamping to `free_bits`.
31
+ eps : float
32
+ A small value to avoid numerical instability.
33
+
34
+ Returns
35
+ -------
36
+ torch.Tensor
37
+ The free-bits version of the KL divergence with shape (layers,).
38
+ """
39
+ assert kl.dim() == 2
40
+ if free_bits < eps:
41
+ return kl.mean(0)
42
+ if batch_average:
43
+ return kl.mean(0).clamp(min=free_bits)
44
+ return kl.clamp(min=free_bits).mean(0)
45
+
46
+
47
+ def get_kl_weight(
48
+ kl_annealing: bool,
49
+ kl_start: int,
50
+ kl_annealtime: int,
51
+ kl_weight: float,
52
+ current_epoch: int,
53
+ ) -> float:
54
+ """Compute the weight of the KL loss in case of annealing.
55
+
56
+ Parameters
57
+ ----------
58
+ kl_annealing : bool
59
+ Whether to use KL annealing.
60
+ kl_start : int
61
+ The epoch at which to start
62
+ kl_annealtime : int
63
+ The number of epochs for which annealing is applied.
64
+ kl_weight : float
65
+ The weight for the KL loss. If `None`, the weight is computed
66
+ using annealing, else it is set to a default of 1.
67
+ current_epoch : int
68
+ The current epoch.
69
+ """
70
+ if kl_annealing:
71
+ # calculate relative weight
72
+ kl_weight = (current_epoch - kl_start) * (1.0 / kl_annealtime)
73
+ # clamp to [0,1]
74
+ kl_weight = min(max(0.0, kl_weight), 1.0)
75
+
76
+ # if the final weight is given, then apply that weight on top of it
77
+ if kl_weight is not None:
78
+ kl_weight = kl_weight * kl_weight
79
+ elif kl_weight is not None:
80
+ return kl_weight
81
+ else:
82
+ kl_weight = 1.0
83
+ return kl_weight
@@ -0,0 +1,445 @@
1
+ """Methods for Loss Computation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING, Any, Optional, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ from careamics.losses.lvae.loss_utils import free_bits_kl, get_kl_weight
11
+ from careamics.models.lvae.likelihoods import (
12
+ GaussianLikelihood,
13
+ LikelihoodModule,
14
+ NoiseModelLikelihood,
15
+ )
16
+ from careamics.models.lvae.utils import compute_batch_mean
17
+
18
+ if TYPE_CHECKING:
19
+ from careamics.losses.loss_factory import LVAELossParameters
20
+
21
+ Likelihood = Union[LikelihoodModule, GaussianLikelihood, NoiseModelLikelihood]
22
+
23
+
24
+ def get_reconstruction_loss(
25
+ reconstruction: torch.Tensor, # TODO: naming -> predictions?
26
+ target: torch.Tensor,
27
+ likelihood_obj: Likelihood,
28
+ ) -> dict[str, torch.Tensor]:
29
+ """Compute the reconstruction loss.
30
+
31
+ Parameters
32
+ ----------
33
+ reconstruction: torch.Tensor
34
+ The output of the LVAE decoder. Shape is (B, C, [Z], Y, X), where C is the
35
+ number of output channels (e.g., 1 in HDN, >1 in muSplit/denoiSplit).
36
+ target: torch.Tensor
37
+ The target image used to compute the reconstruction loss. Shape is
38
+ (B, C, [Z], Y, X), where C is the number of output channels
39
+ (e.g., 1 in HDN, >1 in muSplit/denoiSplit).
40
+ likelihood_obj: Likelihood
41
+ The likelihood object used to compute the reconstruction loss.
42
+
43
+ Returns
44
+ -------
45
+ dict[str, torch.Tensor]
46
+ A dictionary containing the overall loss `["loss"]` and the loss for
47
+ individual output channels `["ch{i}_loss"]`.
48
+ """
49
+ loss_dict = _get_reconstruction_loss_vector(
50
+ reconstruction=reconstruction,
51
+ target=target,
52
+ likelihood_obj=likelihood_obj,
53
+ )
54
+
55
+ loss_dict["loss"] = loss_dict["loss"].sum() / len(reconstruction)
56
+ for i in range(1, 1 + target.shape[1]):
57
+ key = f"ch{i}_loss"
58
+ loss_dict[key] = loss_dict[key].sum() / len(reconstruction)
59
+
60
+ return loss_dict
61
+
62
+
63
+ def _get_reconstruction_loss_vector(
64
+ reconstruction: torch.Tensor, # TODO: naming -> predictions?
65
+ target: torch.Tensor,
66
+ likelihood_obj: LikelihoodModule,
67
+ ) -> dict[str, torch.Tensor]:
68
+ """Compute the reconstruction loss.
69
+
70
+ Parameters
71
+ ----------
72
+ return_predicted_img: bool
73
+ If set to `True`, the besides the loss, the reconstructed image is returned.
74
+ Default is `False`.
75
+
76
+ Returns
77
+ -------
78
+ dict[str, torch.Tensor]
79
+ A dictionary containing the overall loss `["loss"]` and the loss for
80
+ individual output channels `["ch{i}_loss"]`. Shape of individual
81
+ tensors is (B, ).
82
+ """
83
+ output = {"loss": None}
84
+ for i in range(1, 1 + target.shape[1]):
85
+ output[f"ch{i}_loss"] = None
86
+
87
+ # Compute Log likelihood
88
+ ll, _ = likelihood_obj(reconstruction, target) # shape: (B, C, [Z], Y, X)
89
+
90
+ output = {"loss": compute_batch_mean(-1 * ll)} # shape: (B, )
91
+ if ll.shape[1] > 1: # target_ch > 1
92
+ for i in range(1, 1 + target.shape[1]):
93
+ output[f"ch{i}_loss"] = compute_batch_mean(-ll[:, i - 1]) # shape: (B, )
94
+ else: # target_ch == 1
95
+ # TODO: hacky!!! Refactor this
96
+ assert ll.shape[1] == 1
97
+ output["ch1_loss"] = output["loss"]
98
+ output["ch2_loss"] = output["loss"]
99
+
100
+ return output
101
+
102
+
103
+ def reconstruction_loss_musplit_denoisplit(
104
+ predictions: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
105
+ targets: torch.Tensor,
106
+ nm_likelihood: NoiseModelLikelihood,
107
+ gaussian_likelihood: GaussianLikelihood,
108
+ nm_weight: float,
109
+ gaussian_weight: float,
110
+ ) -> torch.Tensor:
111
+ """Compute the reconstruction loss for muSplit-denoiSplit loss.
112
+
113
+ The resulting loss is a weighted mean of the noise model likelihood and the
114
+ Gaussian likelihood.
115
+
116
+ Parameters
117
+ ----------
118
+ predictions : torch.Tensor
119
+ The output of the LVAE decoder. Shape is (B, C, [Z], Y, X), or
120
+ (B, 2*C, [Z], Y, X), where C is the number of output channels,
121
+ and the factor of 2 is for the case of predicted log-variance.
122
+ targets : torch.Tensor
123
+ The target image used to compute the reconstruction loss. Shape is
124
+ (B, C, [Z], Y, X), where C is the number of output channels
125
+ (e.g., 1 in HDN, >1 in muSplit/denoiSplit).
126
+ nm_likelihood : NoiseModelLikelihood
127
+ A `NoiseModelLikelihood` object used to compute the noise model likelihood.
128
+ gaussian_likelihood : GaussianLikelihood
129
+ A `GaussianLikelihood` object used to compute the Gaussian likelihood.
130
+ nm_weight : float
131
+ The weight for the noise model likelihood.
132
+ gaussian_weight : float
133
+ The weight for the Gaussian likelihood.
134
+
135
+ Returns
136
+ -------
137
+ recons_loss : torch.Tensor
138
+ The reconstruction loss. Shape is (1, ).
139
+ """
140
+ # TODO: refactor this function to make it closer to `get_reconstruction_loss`
141
+ # (or viceversa)
142
+ if predictions.shape[1] == 2 * targets.shape[1]:
143
+ # predictions contain both mean and log-variance
144
+ out_mean, _ = predictions.chunk(2, dim=1)
145
+ else:
146
+ out_mean = predictions
147
+
148
+ recons_loss_nm = -1 * nm_likelihood(out_mean, targets)[0].mean()
149
+ recons_loss_gm = -1 * gaussian_likelihood(predictions, targets)[0].mean()
150
+ recons_loss = nm_weight * recons_loss_nm + gaussian_weight * recons_loss_gm
151
+ return recons_loss
152
+
153
+
154
+ def get_kl_divergence_loss_usplit(
155
+ topdown_data: dict[str, list[torch.Tensor]], kl_key: str = "kl"
156
+ ) -> torch.Tensor:
157
+ """Compute the KL divergence loss for muSplit.
158
+
159
+ Parameters
160
+ ----------
161
+ topdown_data : dict[str, list[torch.Tensor]]
162
+ A dictionary containing information computed for each layer during the top-down
163
+ pass. The dictionary must include the following keys:
164
+ - "kl": The KL-loss values for each layer. Shape of each tensor is (B,).
165
+ - "z": The sampled latents for each layer. Shape of each tensor is
166
+ (B, layers, `z_dims[i]`, H, W).
167
+ kl_key : str
168
+ The key for the KL-loss values in the top-down layer data dictionary.
169
+ To choose among ["kl", "kl_restricted", "kl_spatial", "kl_channelwise"]
170
+ Default is "kl".
171
+ """
172
+ kl = torch.cat(
173
+ [kl_layer.unsqueeze(1) for kl_layer in topdown_data[kl_key]], dim=1
174
+ ) # shape: (B, n_layers)
175
+ # NOTE: Values are sum() and so are of the order 30000
176
+
177
+ nlayers = kl.shape[1]
178
+ for i in range(nlayers):
179
+ # NOTE: we want to normalize the KL-loss w.r.t. the latent space dimensions,
180
+ # i.e., the number of entries in the latent space tensors (C, [Z], Y, X).
181
+ # We assume z has shape (B, C, [Z], Y, X), where `C = z_dims[i]`.
182
+ norm_factor = np.prod(topdown_data["z"][i].shape[1:])
183
+ kl[:, i] = kl[:, i] / norm_factor
184
+
185
+ kl_loss = free_bits_kl(kl, 0.0).mean() # shape: (1, )
186
+ # NOTE: free_bits disabled!
187
+ return kl_loss
188
+
189
+
190
+ def get_kl_divergence_loss_denoisplit(
191
+ topdown_data: dict[str, torch.Tensor],
192
+ img_shape: tuple[int],
193
+ kl_key: str = "kl",
194
+ ) -> torch.Tensor:
195
+ """Compute the KL divergence loss for denoiSplit.
196
+
197
+ Parameters
198
+ ----------
199
+ topdown_data : dict[str, torch.Tensor]
200
+ A dictionary containing information computed for each layer during the top-down
201
+ pass. The dictionary must include the following keys:
202
+ - "kl": The KL-loss values for each layer. Shape of each tensor is (B,).
203
+ - "z": The sampled latents for each layer. Shape of each tensor is
204
+ (B, layers, `z_dims[i]`, H, W).
205
+ img_shape : tuple[int]
206
+ The shape of the input image to the LVAE model. Shape is ([Z], Y, X).
207
+ kl_key : str
208
+ The key for the KL-loss values in the top-down layer data dictionary.
209
+ To choose among ["kl", "kl_restricted", "kl_spatial", "kl_channelwise"]
210
+ Default is "kl"
211
+
212
+ kl[i] for each i has length batch_size resulting kl shape: (bs, layers).
213
+ """
214
+ kl = torch.cat(
215
+ [kl_layer.unsqueeze(1) for kl_layer in topdown_data[kl_key]],
216
+ dim=1,
217
+ )
218
+
219
+ kl_loss = free_bits_kl(kl, 1.0).sum()
220
+ # NOTE: as compared to uSplit kl divergence, this KL loss is larger by a factor of
221
+ # `n_layers` since we sum KL contributions from different layers instead of taking
222
+ # the mean.
223
+
224
+ # NOTE: at each hierarchy, the KL loss is larger by a factor of (128/i**2).
225
+ # 128/(2*2) = 32 (bottommost layer)
226
+ # 128/(4*4) = 8
227
+ # 128/(8*8) = 2
228
+ # 128/(16*16) = 0.5 (topmost layer)
229
+
230
+ # Normalize the KL-loss w.r.t. the input image spatial dimensions (e.g., 64x64)
231
+ kl_loss = kl_loss / np.prod(img_shape)
232
+ return kl_loss
233
+
234
+
235
+ # TODO: @melisande-c suggested to refactor this as a class (see PR #208)
236
+ # - loss computation happens by calling the `__call__` method
237
+ # - `__init__` method initializes the loss parameters now contained in
238
+ # the `LVAELossParameters` class
239
+ # NOTE: same for the other loss functions
240
+ def musplit_loss(
241
+ model_outputs: tuple[torch.Tensor, dict[str, Any]],
242
+ targets: torch.Tensor,
243
+ loss_parameters: LVAELossParameters,
244
+ ) -> Optional[dict[str, torch.Tensor]]:
245
+ """Loss function for muSplit.
246
+
247
+ Parameters
248
+ ----------
249
+ model_outputs : tuple[torch.Tensor, dict[str, Any]]
250
+ Tuple containing the model predictions (shape is (B, `target_ch`, [Z], Y, X))
251
+ and the top-down layer data (e.g., sampled latents, KL-loss values, etc.).
252
+ targets : torch.Tensor
253
+ The target image used to compute the reconstruction loss. Shape is
254
+ (B, `target_ch`, [Z], Y, X).
255
+ loss_parameters : LVAELossParameters
256
+ The loss parameters for muSplit (e.g., KL hyperparameters, likelihood module,
257
+ noise model, etc.).
258
+
259
+ Returns
260
+ -------
261
+ output : Optional[dict[str, torch.Tensor]]
262
+ A dictionary containing the overall loss `["loss"]`, the reconstruction loss
263
+ `["reconstruction_loss"]`, and the KL divergence loss `["kl_loss"]`.
264
+ """
265
+ predictions, td_data = model_outputs
266
+
267
+ # Reconstruction loss computation
268
+ recons_loss_dict = get_reconstruction_loss(
269
+ reconstruction=predictions,
270
+ target=targets,
271
+ likelihood_obj=loss_parameters.gaussian_likelihood,
272
+ )
273
+ recons_loss = recons_loss_dict["loss"] * loss_parameters.reconstruction_weight
274
+ if torch.isnan(recons_loss).any():
275
+ recons_loss = 0.0
276
+
277
+ # KL loss computation
278
+ kl_weight = get_kl_weight(
279
+ loss_parameters.kl_annealing,
280
+ loss_parameters.kl_start,
281
+ loss_parameters.kl_annealtime,
282
+ loss_parameters.kl_weight,
283
+ loss_parameters.current_epoch,
284
+ )
285
+ kl_loss = kl_weight * get_kl_divergence_loss_usplit(td_data)
286
+
287
+ net_loss = recons_loss + kl_loss
288
+ output = {
289
+ "loss": net_loss,
290
+ "reconstruction_loss": (
291
+ recons_loss.detach()
292
+ if isinstance(recons_loss, torch.Tensor)
293
+ else recons_loss
294
+ ),
295
+ "kl_loss": kl_loss.detach(),
296
+ }
297
+ # https://github.com/openai/vdvae/blob/main/train.py#L26
298
+ if torch.isnan(net_loss).any():
299
+ return None
300
+
301
+ return output
302
+
303
+
304
+ def denoisplit_loss(
305
+ model_outputs: tuple[torch.Tensor, dict[str, Any]],
306
+ targets: torch.Tensor,
307
+ loss_parameters: LVAELossParameters,
308
+ ) -> Optional[dict[str, torch.Tensor]]:
309
+ """Loss function for DenoiSplit.
310
+
311
+ Parameters
312
+ ----------
313
+ model_outputs : tuple[torch.Tensor, dict[str, Any]]
314
+ Tuple containing the model predictions (shape is (B, `target_ch`, [Z], Y, X))
315
+ and the top-down layer data (e.g., sampled latents, KL-loss values, etc.).
316
+ targets : torch.Tensor
317
+ The target image used to compute the reconstruction loss. Shape is
318
+ (B, `target_ch`, [Z], Y, X).
319
+ loss_parameters : LVAELossParameters
320
+ The loss parameters for muSplit (e.g., KL hyperparameters, likelihood module,
321
+ noise model, etc.).
322
+
323
+ Returns
324
+ -------
325
+ output : Optional[dict[str, torch.Tensor]]
326
+ A dictionary containing the overall loss `["loss"]`, the reconstruction loss
327
+ `["reconstruction_loss"]`, and the KL divergence loss `["kl_loss"]`.
328
+ """
329
+ predictions, td_data = model_outputs
330
+
331
+ # Reconstruction loss computation
332
+ recons_loss_dict = get_reconstruction_loss(
333
+ reconstruction=predictions,
334
+ target=targets,
335
+ likelihood_obj=loss_parameters.noise_model_likelihood,
336
+ )
337
+ recons_loss = recons_loss_dict["loss"] * loss_parameters.reconstruction_weight
338
+ if torch.isnan(recons_loss).any():
339
+ recons_loss = 0.0
340
+
341
+ # KL loss computation
342
+ if loss_parameters.non_stochastic: # TODO always false ?
343
+ kl_loss = torch.Tensor([0.0]).cuda()
344
+ else:
345
+ kl_weight = get_kl_weight(
346
+ loss_parameters.kl_annealing,
347
+ loss_parameters.kl_start,
348
+ loss_parameters.kl_annealtime,
349
+ loss_parameters.kl_weight,
350
+ loss_parameters.current_epoch,
351
+ )
352
+ kl_loss = kl_weight * get_kl_divergence_loss_denoisplit(
353
+ topdown_data=td_data,
354
+ img_shape=targets.shape[2:], # input img spatial dims
355
+ )
356
+
357
+ net_loss = recons_loss + kl_loss
358
+ output = {
359
+ "loss": net_loss,
360
+ "reconstruction_loss": (
361
+ recons_loss.detach()
362
+ if isinstance(recons_loss, torch.Tensor)
363
+ else recons_loss
364
+ ),
365
+ "kl_loss": kl_loss.detach(),
366
+ }
367
+ # https://github.com/openai/vdvae/blob/main/train.py#L26
368
+ if torch.isnan(net_loss).any():
369
+ return None
370
+
371
+ return output
372
+
373
+
374
+ def denoisplit_musplit_loss(
375
+ model_outputs: tuple[torch.Tensor, dict[str, Any]],
376
+ targets: torch.Tensor,
377
+ loss_parameters: LVAELossParameters,
378
+ ) -> Optional[dict[str, torch.Tensor]]:
379
+ """Loss function for DenoiSplit.
380
+
381
+ Parameters
382
+ ----------
383
+ model_outputs : tuple[torch.Tensor, dict[str, Any]]
384
+ Tuple containing the model predictions (shape is (B, `target_ch`, [Z], Y, X))
385
+ and the top-down layer data (e.g., sampled latents, KL-loss values, etc.).
386
+ targets : torch.Tensor
387
+ The target image used to compute the reconstruction loss. Shape is
388
+ (B, `target_ch`, [Z], Y, X).
389
+ loss_parameters : LVAELossParameters
390
+ The loss parameters for muSplit (e.g., KL hyperparameters, likelihood module,
391
+ noise model, etc.).
392
+
393
+ Returns
394
+ -------
395
+ output : Optional[dict[str, torch.Tensor]]
396
+ A dictionary containing the overall loss `["loss"]`, the reconstruction loss
397
+ `["reconstruction_loss"]`, and the KL divergence loss `["kl_loss"]`.
398
+ """
399
+ predictions, td_data = model_outputs
400
+
401
+ # Reconstruction loss computation
402
+ recons_loss = reconstruction_loss_musplit_denoisplit(
403
+ predictions=predictions,
404
+ targets=targets,
405
+ nm_likelihood=loss_parameters.noise_model_likelihood,
406
+ gaussian_likelihood=loss_parameters.gaussian_likelihood,
407
+ nm_weight=loss_parameters.denoisplit_weight,
408
+ gaussian_weight=loss_parameters.musplit_weight,
409
+ )
410
+ if torch.isnan(recons_loss).any():
411
+ recons_loss = 0.0
412
+
413
+ # KL loss computation
414
+ if loss_parameters.non_stochastic: # TODO always false ?
415
+ kl_loss = torch.Tensor([0.0]).cuda()
416
+ else:
417
+ # NOTE: 'kl' key stands for the 'kl_samplewise' key in the TopDownLayer class.
418
+ # The different naming comes from `top_down_pass()` method in the LadderVAE.
419
+ denoisplit_kl = get_kl_divergence_loss_denoisplit(
420
+ topdown_data=td_data,
421
+ img_shape=targets.shape[2:], # input img spatial dims
422
+ )
423
+ musplit_kl = get_kl_divergence_loss_usplit(td_data)
424
+ kl_loss = (
425
+ loss_parameters.denoisplit_weight * denoisplit_kl
426
+ + loss_parameters.musplit_weight * musplit_kl
427
+ )
428
+ # TODO `kl_weight` is hardcoded (???)
429
+ kl_loss = loss_parameters.kl_weight * kl_loss
430
+
431
+ net_loss = recons_loss + kl_loss
432
+ output = {
433
+ "loss": net_loss,
434
+ "reconstruction_loss": (
435
+ recons_loss.detach()
436
+ if isinstance(recons_loss, torch.Tensor)
437
+ else recons_loss
438
+ ),
439
+ "kl_loss": kl_loss.detach(),
440
+ }
441
+ # https://github.com/openai/vdvae/blob/main/train.py#L26
442
+ if torch.isnan(net_loss).any():
443
+ return None
444
+
445
+ return output
@@ -0,0 +1,15 @@
1
+ from .multich_dataset import MultiChDloader
2
+ from .lc_dataset import LCMultiChDloader
3
+ from .multifile_dataset import MultiFileDset
4
+ from .config import DatasetConfig
5
+ from .types import DataType, DataSplitType, TilingMode
6
+
7
+ __all__ = [
8
+ "DatasetConfig",
9
+ "MultiChDloader",
10
+ "LCMultiChDloader",
11
+ "MultiFileDset",
12
+ "DataType",
13
+ "DataSplitType",
14
+ "TilingMode",
15
+ ]
@@ -0,0 +1,123 @@
1
+ from typing import Any, Optional
2
+
3
+ from pydantic import BaseModel, ConfigDict
4
+
5
+ from .types import DataType, DataSplitType, TilingMode
6
+
7
+
8
+ # TODO: check if any bool logic can be removed
9
+ class DatasetConfig(BaseModel):
10
+ model_config = ConfigDict(validate_assignment=True, extra="forbid")
11
+
12
+ data_type: Optional[DataType]
13
+ """Type of the dataset, should be one of DataType"""
14
+
15
+ depth3D: Optional[int] = 1
16
+ """Number of slices in 3D. If data is 2D depth3D is equal to 1"""
17
+
18
+ datasplit_type: Optional[DataSplitType] = None
19
+ """Whether to return training, validation or test split, should be one of
20
+ DataSplitType"""
21
+
22
+ num_channels: Optional[int] = 2
23
+ """Number of channels in the input"""
24
+
25
+ # TODO: remove ch*_fname parameters, should be parsed automatically from a name list
26
+ ch1_fname: Optional[str] = None
27
+ ch2_fname: Optional[str] = None
28
+ ch_input_fname: Optional[str] = None
29
+
30
+ input_is_sum: Optional[bool] = False
31
+ """Whether the input is the sum or average of channels"""
32
+
33
+ input_idx: Optional[int] = None
34
+ """Index of the channel where the input is stored in the data"""
35
+
36
+ target_idx_list: Optional[list[int]] = None
37
+ """Indices of the channels where the targets are stored in the data"""
38
+
39
+ # TODO: where are there used?
40
+ start_alpha: Optional[Any] = None
41
+ end_alpha: Optional[Any] = None
42
+
43
+ image_size: int
44
+ """Size of one patch of data"""
45
+
46
+ grid_size: Optional[int] = None
47
+ """Frame is divided into square grids of this size. A patch centered on a grid
48
+ having size `image_size` is returned. Grid size not used in training,
49
+ used only during val / test, grid size controls the overlap of the patches"""
50
+
51
+ empty_patch_replacement_enabled: Optional[bool] = False
52
+ """Whether to replace the content of one of the channels
53
+ with background with given probability"""
54
+ empty_patch_replacement_channel_idx: Optional[Any] = None
55
+ empty_patch_replacement_probab: Optional[Any] = None
56
+ empty_patch_max_val_threshold: Optional[Any] = None
57
+
58
+ uncorrelated_channels: Optional[bool] = False
59
+ """Replace the content in one of the channels with given probability to make
60
+ channel content 'uncorrelated'"""
61
+ uncorrelated_channel_probab: Optional[float] = 0.5
62
+
63
+ poisson_noise_factor: Optional[float] = -1
64
+ """The added poisson noise factor"""
65
+
66
+ synthetic_gaussian_scale: Optional[float] = 0.1
67
+
68
+ # TODO: set to True in training code, recheck
69
+ input_has_dependant_noise: Optional[bool] = False
70
+
71
+ # TODO: sometimes max_val differs between runs with fixed seeds with noise enabled
72
+ enable_gaussian_noise: Optional[bool] = False
73
+ """Whether to enable gaussian noise"""
74
+
75
+ # TODO: is this parameter used?
76
+ allow_generation: bool = False
77
+
78
+ # TODO: both used in IndexSwitcher, insure correct passing
79
+ training_validtarget_fraction: Any = None
80
+ deterministic_grid: Any = None
81
+
82
+ # TODO: why is this not used?
83
+ enable_rotation_aug: Optional[bool] = False
84
+
85
+ max_val: Optional[float] = None
86
+ """Maximum data in the dataset. Is calculated for train split, and should be
87
+ externally set for val and test splits."""
88
+
89
+ overlapping_padding_kwargs: Any = None
90
+ """Parameters for np.pad method"""
91
+
92
+ # TODO: remove this parameter, controls debug print
93
+ print_vars: Optional[bool] = False
94
+
95
+ # Hard-coded parameters (used to be in the config file)
96
+ normalized_input: bool = True
97
+ """If this is set to true, then one mean and stdev is used
98
+ for both channels. Otherwise, two different mean and stdev are used."""
99
+ use_one_mu_std: Optional[bool] = True
100
+
101
+ # TODO: is this parameter used?
102
+ train_aug_rotate: Optional[bool] = False
103
+ enable_random_cropping: Optional[bool] = True
104
+
105
+ multiscale_lowres_count: Optional[int] = None
106
+ """Number of LC scales"""
107
+
108
+ tiling_mode: Optional[TilingMode] = TilingMode.ShiftBoundary
109
+
110
+ target_separate_normalization: Optional[bool] = True
111
+
112
+ mode_3D: Optional[bool] = False
113
+ """If training in 3D mode or not"""
114
+
115
+ trainig_datausage_fraction: Optional[float] = 1.0
116
+
117
+ validtarget_random_fraction: Optional[float] = None
118
+
119
+ validation_datausage_fraction: Optional[float] = 1.0
120
+
121
+ random_flip_z_3D: Optional[bool] = False
122
+
123
+ padding_kwargs: Optional[dict] = None