careamics 0.0.4.1__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (43) hide show
  1. careamics/careamist.py +235 -25
  2. careamics/cli/conf.py +19 -30
  3. careamics/cli/main.py +111 -10
  4. careamics/cli/utils.py +29 -0
  5. careamics/config/__init__.py +2 -0
  6. careamics/config/architectures/lvae_model.py +104 -21
  7. careamics/config/configuration_factory.py +49 -45
  8. careamics/config/configuration_model.py +2 -2
  9. careamics/config/likelihood_model.py +7 -6
  10. careamics/config/loss_model.py +56 -0
  11. careamics/config/nm_model.py +24 -24
  12. careamics/config/vae_algorithm_model.py +14 -13
  13. careamics/dataset/dataset_utils/running_stats.py +22 -23
  14. careamics/lightning/lightning_module.py +58 -27
  15. careamics/lightning/train_data_module.py +15 -1
  16. careamics/losses/loss_factory.py +1 -85
  17. careamics/losses/lvae/losses.py +223 -164
  18. careamics/lvae_training/calibration.py +184 -0
  19. careamics/lvae_training/dataset/config.py +2 -2
  20. careamics/lvae_training/dataset/multich_dataset.py +11 -19
  21. careamics/lvae_training/dataset/multifile_dataset.py +3 -2
  22. careamics/lvae_training/dataset/types.py +15 -26
  23. careamics/lvae_training/dataset/utils/index_manager.py +4 -4
  24. careamics/lvae_training/eval_utils.py +125 -213
  25. careamics/model_io/bioimage/_readme_factory.py +25 -33
  26. careamics/model_io/bioimage/cover_factory.py +171 -0
  27. careamics/model_io/bioimage/model_description.py +35 -22
  28. careamics/model_io/bmz_io.py +36 -25
  29. careamics/models/layers.py +6 -4
  30. careamics/models/lvae/layers.py +348 -975
  31. careamics/models/lvae/likelihoods.py +10 -8
  32. careamics/models/lvae/lvae.py +214 -272
  33. careamics/models/lvae/noise_models.py +179 -112
  34. careamics/models/lvae/stochastic.py +393 -0
  35. careamics/models/lvae/utils.py +82 -73
  36. careamics/utils/lightning_utils.py +57 -0
  37. careamics/utils/serializers.py +2 -0
  38. careamics/utils/torch_utils.py +1 -1
  39. {careamics-0.0.4.1.dist-info → careamics-0.0.5.dist-info}/METADATA +12 -9
  40. {careamics-0.0.4.1.dist-info → careamics-0.0.5.dist-info}/RECORD +43 -37
  41. {careamics-0.0.4.1.dist-info → careamics-0.0.5.dist-info}/WHEEL +1 -1
  42. {careamics-0.0.4.1.dist-info → careamics-0.0.5.dist-info}/entry_points.txt +0 -0
  43. {careamics-0.0.4.1.dist-info → careamics-0.0.5.dist-info}/licenses/LICENSE +0 -0
@@ -2,7 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from typing import TYPE_CHECKING, Any, Optional, Union
5
+ from typing import TYPE_CHECKING, Any, Literal, Optional, Union
6
6
 
7
7
  import numpy as np
8
8
  import torch
@@ -13,20 +13,19 @@ from careamics.models.lvae.likelihoods import (
13
13
  LikelihoodModule,
14
14
  NoiseModelLikelihood,
15
15
  )
16
- from careamics.models.lvae.utils import compute_batch_mean
17
16
 
18
17
  if TYPE_CHECKING:
19
- from careamics.losses.loss_factory import LVAELossParameters
18
+ from careamics.config import LVAELossConfig
20
19
 
21
20
  Likelihood = Union[LikelihoodModule, GaussianLikelihood, NoiseModelLikelihood]
22
21
 
23
22
 
24
23
  def get_reconstruction_loss(
25
- reconstruction: torch.Tensor, # TODO: naming -> predictions?
24
+ reconstruction: torch.Tensor,
26
25
  target: torch.Tensor,
27
26
  likelihood_obj: Likelihood,
28
27
  ) -> dict[str, torch.Tensor]:
29
- """Compute the reconstruction loss.
28
+ """Compute the reconstruction loss (negative log-likelihood).
30
29
 
31
30
  Parameters
32
31
  ----------
@@ -42,65 +41,15 @@ def get_reconstruction_loss(
42
41
 
43
42
  Returns
44
43
  -------
45
- dict[str, torch.Tensor]
46
- A dictionary containing the overall loss `["loss"]` and the loss for
47
- individual output channels `["ch{i}_loss"]`.
44
+ torch.Tensor
45
+ The recontruction loss (negative log-likelihood).
48
46
  """
49
- loss_dict = _get_reconstruction_loss_vector(
50
- reconstruction=reconstruction,
51
- target=target,
52
- likelihood_obj=likelihood_obj,
53
- )
54
-
55
- loss_dict["loss"] = loss_dict["loss"].sum() / len(reconstruction)
56
- for i in range(1, 1 + target.shape[1]):
57
- key = f"ch{i}_loss"
58
- loss_dict[key] = loss_dict[key].sum() / len(reconstruction)
59
-
60
- return loss_dict
61
-
62
-
63
- def _get_reconstruction_loss_vector(
64
- reconstruction: torch.Tensor, # TODO: naming -> predictions?
65
- target: torch.Tensor,
66
- likelihood_obj: LikelihoodModule,
67
- ) -> dict[str, torch.Tensor]:
68
- """Compute the reconstruction loss.
69
-
70
- Parameters
71
- ----------
72
- return_predicted_img: bool
73
- If set to `True`, the besides the loss, the reconstructed image is returned.
74
- Default is `False`.
75
-
76
- Returns
77
- -------
78
- dict[str, torch.Tensor]
79
- A dictionary containing the overall loss `["loss"]` and the loss for
80
- individual output channels `["ch{i}_loss"]`. Shape of individual
81
- tensors is (B, ).
82
- """
83
- output = {"loss": None}
84
- for i in range(1, 1 + target.shape[1]):
85
- output[f"ch{i}_loss"] = None
86
-
87
47
  # Compute Log likelihood
88
48
  ll, _ = likelihood_obj(reconstruction, target) # shape: (B, C, [Z], Y, X)
89
-
90
- output = {"loss": compute_batch_mean(-1 * ll)} # shape: (B, )
91
- if ll.shape[1] > 1: # target_ch > 1
92
- for i in range(1, 1 + target.shape[1]):
93
- output[f"ch{i}_loss"] = compute_batch_mean(-ll[:, i - 1]) # shape: (B, )
94
- else: # target_ch == 1
95
- # TODO: hacky!!! Refactor this
96
- assert ll.shape[1] == 1
97
- output["ch1_loss"] = output["loss"]
98
- output["ch2_loss"] = output["loss"]
99
-
100
- return output
49
+ return -1 * ll.mean()
101
50
 
102
51
 
103
- def reconstruction_loss_musplit_denoisplit(
52
+ def _reconstruction_loss_musplit_denoisplit(
104
53
  predictions: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
105
54
  targets: torch.Tensor,
106
55
  nm_likelihood: NoiseModelLikelihood,
@@ -137,62 +86,120 @@ def reconstruction_loss_musplit_denoisplit(
137
86
  recons_loss : torch.Tensor
138
87
  The reconstruction loss. Shape is (1, ).
139
88
  """
140
- # TODO: refactor this function to make it closer to `get_reconstruction_loss`
141
- # (or viceversa)
142
89
  if predictions.shape[1] == 2 * targets.shape[1]:
143
90
  # predictions contain both mean and log-variance
144
- out_mean, _ = predictions.chunk(2, dim=1)
91
+ pred_mean, _ = predictions.chunk(2, dim=1)
145
92
  else:
146
- out_mean = predictions
93
+ pred_mean = predictions
94
+
95
+ recons_loss_nm = get_reconstruction_loss(
96
+ reconstruction=pred_mean, target=targets, likelihood_obj=nm_likelihood
97
+ )
98
+
99
+ recons_loss_gm = get_reconstruction_loss(
100
+ reconstruction=predictions,
101
+ target=targets,
102
+ likelihood_obj=gaussian_likelihood,
103
+ )
147
104
 
148
- recons_loss_nm = -1 * nm_likelihood(out_mean, targets)[0].mean()
149
- recons_loss_gm = -1 * gaussian_likelihood(predictions, targets)[0].mean()
150
105
  recons_loss = nm_weight * recons_loss_nm + gaussian_weight * recons_loss_gm
151
106
  return recons_loss
152
107
 
153
108
 
154
- def get_kl_divergence_loss_usplit(
155
- topdown_data: dict[str, list[torch.Tensor]], kl_key: str = "kl"
109
+ def get_kl_divergence_loss(
110
+ kl_type: Literal["kl", "kl_restricted"],
111
+ topdown_data: dict[str, torch.Tensor],
112
+ rescaling: Literal["latent_dim", "image_dim"],
113
+ aggregation: Literal["mean", "sum"],
114
+ free_bits_coeff: float,
115
+ img_shape: Optional[tuple[int]] = None,
156
116
  ) -> torch.Tensor:
157
- """Compute the KL divergence loss for muSplit.
117
+ """Compute the KL divergence loss.
118
+
119
+ NOTE: Description of `rescaling` methods:
120
+ - If "latent_dim", the KL-loss values are rescaled w.r.t. the latent space
121
+ dimensions (spatial + number of channels, i.e., (C, [Z], Y, X)). In this way they
122
+ have the same magnitude across layers.
123
+ - If "image_dim", the KL-loss values are rescaled w.r.t. the input image spatial
124
+ dimensions. In this way, the lower layers have a larger KL-loss value compared to
125
+ the higher layers, since the latent space and hence the KL tensor has more entries.
126
+ Specifically, at hierarchy `i`, the total KL loss is larger by a factor (128/i**2).
127
+
128
+ NOTE: the type of `aggregation` determines the magnitude of the KL-loss. Clearly,
129
+ "sum" aggregation results in a larger KL-loss value compared to "mean" by a factor
130
+ of `n_layers`.
131
+
132
+ NOTE: recall that sample-wise KL is obtained by summing over all dimensions,
133
+ including Z. Also recall that in current 3D implementation of LVAE, no downsampling
134
+ is done on Z. Therefore, to avoid emphasizing KL loss too much, we divide it
135
+ by the Z dimension of input image in every case.
158
136
 
159
137
  Parameters
160
138
  ----------
161
- topdown_data : dict[str, list[torch.Tensor]]
139
+ kl_type : Literal["kl", "kl_restricted"]
140
+ The type of KL divergence loss to compute.
141
+ topdown_data : dict[str, torch.Tensor]
162
142
  A dictionary containing information computed for each layer during the top-down
163
143
  pass. The dictionary must include the following keys:
164
144
  - "kl": The KL-loss values for each layer. Shape of each tensor is (B,).
165
145
  - "z": The sampled latents for each layer. Shape of each tensor is
166
146
  (B, layers, `z_dims[i]`, H, W).
167
- kl_key : str
168
- The key for the KL-loss values in the top-down layer data dictionary.
169
- To choose among ["kl", "kl_restricted", "kl_spatial", "kl_channelwise"]
170
- Default is "kl".
147
+ rescaling : Literal["latent_dim", "image_dim"]
148
+ The rescaling method used for the KL-loss values. If "latent_dim", the KL-loss
149
+ values are rescaled w.r.t. the latent space dimensions (spatial + number of
150
+ channels, i.e., (C, [Z], Y, X)). If "image_dim", the KL-loss values are
151
+ rescaled w.r.t. the input image spatial dimensions.
152
+ aggregation : Literal["mean", "sum"]
153
+ The aggregation method used to combine the KL-loss values across layers. If
154
+ "mean", the KL-loss values are averaged across layers. If "sum", the KL-loss
155
+ values are summed across layers.
156
+ free_bits_coeff : float
157
+ The free bits coefficient used for the KL-loss computation.
158
+ img_shape : Optional[tuple[int]]
159
+ The shape of the input image to the LVAE model. Shape is ([Z], Y, X).
160
+
161
+ Returns
162
+ -------
163
+ kl_loss : torch.Tensor
164
+ The KL divergence loss. Shape is (1, ).
171
165
  """
172
166
  kl = torch.cat(
173
- [kl_layer.unsqueeze(1) for kl_layer in topdown_data[kl_key]], dim=1
167
+ [kl_layer.unsqueeze(1) for kl_layer in topdown_data[kl_type]],
168
+ dim=1,
174
169
  ) # shape: (B, n_layers)
175
- # NOTE: Values are sum() and so are of the order 30000
176
170
 
177
- nlayers = kl.shape[1]
178
- for i in range(nlayers):
179
- # NOTE: we want to normalize the KL-loss w.r.t. the latent space dimensions,
180
- # i.e., the number of entries in the latent space tensors (C, [Z], Y, X).
181
- # We assume z has shape (B, C, [Z], Y, X), where `C = z_dims[i]`.
182
- norm_factor = np.prod(topdown_data["z"][i].shape[1:])
183
- kl[:, i] = kl[:, i] / norm_factor
171
+ # Apply free bits (& batch average)
172
+ kl = free_bits_kl(kl, free_bits_coeff) # shape: (n_layers,)
184
173
 
185
- kl_loss = free_bits_kl(kl, 0.0).mean() # shape: (1, )
186
- # NOTE: free_bits disabled!
187
- return kl_loss
174
+ # In 3D case, rescale by Z dim
175
+ # TODO If we have downsampling in Z dimension, then this needs to change.
176
+ if len(img_shape) == 3:
177
+ kl = kl / img_shape[0]
188
178
 
179
+ # Rescaling
180
+ if rescaling == "latent_dim":
181
+ for i in range(len(kl)):
182
+ latent_dim = topdown_data["z"][i].shape[1:]
183
+ norm_factor = np.prod(latent_dim)
184
+ kl[i] = kl[i] / norm_factor
185
+ elif rescaling == "image_dim":
186
+ kl = kl / np.prod(img_shape[-2:])
189
187
 
190
- def get_kl_divergence_loss_denoisplit(
188
+ # Aggregation
189
+ if aggregation == "mean":
190
+ kl = kl.mean() # shape: (1,)
191
+ elif aggregation == "sum":
192
+ kl = kl.sum() # shape: (1,)
193
+
194
+ return kl
195
+
196
+
197
+ def _get_kl_divergence_loss_musplit(
191
198
  topdown_data: dict[str, torch.Tensor],
192
199
  img_shape: tuple[int],
193
- kl_key: str = "kl",
200
+ kl_type: Literal["kl", "kl_restricted"],
194
201
  ) -> torch.Tensor:
195
- """Compute the KL divergence loss for denoiSplit.
202
+ """Compute the KL divergence loss for muSplit.
196
203
 
197
204
  Parameters
198
205
  ----------
@@ -204,32 +211,57 @@ def get_kl_divergence_loss_denoisplit(
204
211
  (B, layers, `z_dims[i]`, H, W).
205
212
  img_shape : tuple[int]
206
213
  The shape of the input image to the LVAE model. Shape is ([Z], Y, X).
207
- kl_key : str
208
- The key for the KL-loss values in the top-down layer data dictionary.
209
- To choose among ["kl", "kl_restricted", "kl_spatial", "kl_channelwise"]
210
- Default is "kl"
214
+ kl_type : Literal["kl", "kl_restricted"]
215
+ The type of KL divergence loss to compute.
211
216
 
212
- kl[i] for each i has length batch_size resulting kl shape: (bs, layers).
217
+ Returns
218
+ -------
219
+ kl_loss : torch.Tensor
220
+ The KL divergence loss for the muSplit case. Shape is (1, ).
213
221
  """
214
- kl = torch.cat(
215
- [kl_layer.unsqueeze(1) for kl_layer in topdown_data[kl_key]],
216
- dim=1,
222
+ return get_kl_divergence_loss(
223
+ kl_type="kl", # TODO: hardcoded, deal in future PR
224
+ topdown_data=topdown_data,
225
+ rescaling="latent_dim",
226
+ aggregation="mean",
227
+ free_bits_coeff=0.0,
228
+ img_shape=img_shape,
217
229
  )
218
230
 
219
- kl_loss = free_bits_kl(kl, 1.0).sum()
220
- # NOTE: as compared to uSplit kl divergence, this KL loss is larger by a factor of
221
- # `n_layers` since we sum KL contributions from different layers instead of taking
222
- # the mean.
223
231
 
224
- # NOTE: at each hierarchy, the KL loss is larger by a factor of (128/i**2).
225
- # 128/(2*2) = 32 (bottommost layer)
226
- # 128/(4*4) = 8
227
- # 128/(8*8) = 2
228
- # 128/(16*16) = 0.5 (topmost layer)
232
+ def _get_kl_divergence_loss_denoisplit(
233
+ topdown_data: dict[str, torch.Tensor],
234
+ img_shape: tuple[int],
235
+ kl_type: Literal["kl", "kl_restricted"],
236
+ ) -> torch.Tensor:
237
+ """Compute the KL divergence loss for denoiSplit.
229
238
 
230
- # Normalize the KL-loss w.r.t. the input image spatial dimensions (e.g., 64x64)
231
- kl_loss = kl_loss / np.prod(img_shape)
232
- return kl_loss
239
+ Parameters
240
+ ----------
241
+ topdown_data : dict[str, torch.Tensor]
242
+ A dictionary containing information computed for each layer during the top-down
243
+ pass. The dictionary must include the following keys:
244
+ - "kl": The KL-loss values for each layer. Shape of each tensor is (B,).
245
+ - "z": The sampled latents for each layer. Shape of each tensor is
246
+ (B, layers, `z_dims[i]`, H, W).
247
+ img_shape : tuple[int]
248
+ The shape of the input image to the LVAE model. Shape is ([Z], Y, X).
249
+ kl_type : Literal["kl", "kl_restricted"]
250
+ The type of KL divergence loss to compute.
251
+
252
+ Returns
253
+ -------
254
+ kl_loss : torch.Tensor
255
+ The KL divergence loss for the denoiSplit case. Shape is (1, ).
256
+ """
257
+ return get_kl_divergence_loss(
258
+ kl_type=kl_type,
259
+ topdown_data=topdown_data,
260
+ rescaling="image_dim",
261
+ aggregation="sum",
262
+ free_bits_coeff=1.0,
263
+ img_shape=img_shape,
264
+ )
233
265
 
234
266
 
235
267
  # TODO: @melisande-c suggested to refactor this as a class (see PR #208)
@@ -240,7 +272,9 @@ def get_kl_divergence_loss_denoisplit(
240
272
  def musplit_loss(
241
273
  model_outputs: tuple[torch.Tensor, dict[str, Any]],
242
274
  targets: torch.Tensor,
243
- loss_parameters: LVAELossParameters,
275
+ config: LVAELossConfig,
276
+ gaussian_likelihood: Optional[GaussianLikelihood],
277
+ noise_model_likelihood: Optional[NoiseModelLikelihood] = None, # TODO: ugly
244
278
  ) -> Optional[dict[str, torch.Tensor]]:
245
279
  """Loss function for muSplit.
246
280
 
@@ -252,9 +286,13 @@ def musplit_loss(
252
286
  targets : torch.Tensor
253
287
  The target image used to compute the reconstruction loss. Shape is
254
288
  (B, `target_ch`, [Z], Y, X).
255
- loss_parameters : LVAELossParameters
256
- The loss parameters for muSplit (e.g., KL hyperparameters, likelihood module,
289
+ config : LVAELossConfig
290
+ The config for loss function (e.g., KL hyperparameters, likelihood module,
257
291
  noise model, etc.).
292
+ gaussian_likelihood : GaussianLikelihood
293
+ The Gaussian likelihood object.
294
+ noise_model_likelihood : Optional[NoiseModelLikelihood]
295
+ The noise model likelihood object. Not used here.
258
296
 
259
297
  Returns
260
298
  -------
@@ -262,27 +300,35 @@ def musplit_loss(
262
300
  A dictionary containing the overall loss `["loss"]`, the reconstruction loss
263
301
  `["reconstruction_loss"]`, and the KL divergence loss `["kl_loss"]`.
264
302
  """
303
+ assert gaussian_likelihood is not None
304
+
265
305
  predictions, td_data = model_outputs
266
306
 
267
307
  # Reconstruction loss computation
268
- recons_loss_dict = get_reconstruction_loss(
308
+ recons_loss = config.reconstruction_weight * get_reconstruction_loss(
269
309
  reconstruction=predictions,
270
310
  target=targets,
271
- likelihood_obj=loss_parameters.gaussian_likelihood,
311
+ likelihood_obj=gaussian_likelihood,
272
312
  )
273
- recons_loss = recons_loss_dict["loss"] * loss_parameters.reconstruction_weight
274
313
  if torch.isnan(recons_loss).any():
275
314
  recons_loss = 0.0
276
315
 
277
316
  # KL loss computation
278
317
  kl_weight = get_kl_weight(
279
- loss_parameters.kl_annealing,
280
- loss_parameters.kl_start,
281
- loss_parameters.kl_annealtime,
282
- loss_parameters.kl_weight,
283
- loss_parameters.current_epoch,
318
+ config.kl_params.annealing,
319
+ config.kl_params.start,
320
+ config.kl_params.annealtime,
321
+ config.kl_weight,
322
+ config.kl_params.current_epoch,
323
+ )
324
+ kl_loss = (
325
+ _get_kl_divergence_loss_musplit(
326
+ topdown_data=td_data,
327
+ img_shape=targets.shape[2:],
328
+ kl_type=config.kl_params.loss_type,
329
+ )
330
+ * kl_weight
284
331
  )
285
- kl_loss = kl_weight * get_kl_divergence_loss_usplit(td_data)
286
332
 
287
333
  net_loss = recons_loss + kl_loss
288
334
  output = {
@@ -304,7 +350,9 @@ def musplit_loss(
304
350
  def denoisplit_loss(
305
351
  model_outputs: tuple[torch.Tensor, dict[str, Any]],
306
352
  targets: torch.Tensor,
307
- loss_parameters: LVAELossParameters,
353
+ config: LVAELossConfig,
354
+ gaussian_likelihood: Optional[GaussianLikelihood] = None,
355
+ noise_model_likelihood: Optional[NoiseModelLikelihood] = None,
308
356
  ) -> Optional[dict[str, torch.Tensor]]:
309
357
  """Loss function for DenoiSplit.
310
358
 
@@ -316,9 +364,12 @@ def denoisplit_loss(
316
364
  targets : torch.Tensor
317
365
  The target image used to compute the reconstruction loss. Shape is
318
366
  (B, `target_ch`, [Z], Y, X).
319
- loss_parameters : LVAELossParameters
320
- The loss parameters for muSplit (e.g., KL hyperparameters, likelihood module,
321
- noise model, etc.).
367
+ config : LVAELossConfig
368
+ The config for loss function containing all loss hyperparameters.
369
+ gaussian_likelihood : GaussianLikelihood
370
+ The Gaussian likelihood object.
371
+ noise_model_likelihood : NoiseModelLikelihood
372
+ The noise model likelihood object.
322
373
 
323
374
  Returns
324
375
  -------
@@ -326,33 +377,35 @@ def denoisplit_loss(
326
377
  A dictionary containing the overall loss `["loss"]`, the reconstruction loss
327
378
  `["reconstruction_loss"]`, and the KL divergence loss `["kl_loss"]`.
328
379
  """
380
+ assert noise_model_likelihood is not None
381
+
329
382
  predictions, td_data = model_outputs
330
383
 
331
384
  # Reconstruction loss computation
332
- recons_loss_dict = get_reconstruction_loss(
385
+ recons_loss = config.reconstruction_weight * get_reconstruction_loss(
333
386
  reconstruction=predictions,
334
387
  target=targets,
335
- likelihood_obj=loss_parameters.noise_model_likelihood,
388
+ likelihood_obj=noise_model_likelihood,
336
389
  )
337
- recons_loss = recons_loss_dict["loss"] * loss_parameters.reconstruction_weight
338
390
  if torch.isnan(recons_loss).any():
339
391
  recons_loss = 0.0
340
392
 
341
393
  # KL loss computation
342
- if loss_parameters.non_stochastic: # TODO always false ?
343
- kl_loss = torch.Tensor([0.0]).cuda()
344
- else:
345
- kl_weight = get_kl_weight(
346
- loss_parameters.kl_annealing,
347
- loss_parameters.kl_start,
348
- loss_parameters.kl_annealtime,
349
- loss_parameters.kl_weight,
350
- loss_parameters.current_epoch,
351
- )
352
- kl_loss = kl_weight * get_kl_divergence_loss_denoisplit(
394
+ kl_weight = get_kl_weight(
395
+ config.kl_params.annealing,
396
+ config.kl_params.start,
397
+ config.kl_params.annealtime,
398
+ config.kl_weight,
399
+ config.kl_params.current_epoch,
400
+ )
401
+ kl_loss = (
402
+ _get_kl_divergence_loss_denoisplit(
353
403
  topdown_data=td_data,
354
- img_shape=targets.shape[2:], # input img spatial dims
404
+ img_shape=targets.shape[2:],
405
+ kl_type=config.kl_params.loss_type,
355
406
  )
407
+ * kl_weight
408
+ )
356
409
 
357
410
  net_loss = recons_loss + kl_loss
358
411
  output = {
@@ -374,7 +427,9 @@ def denoisplit_loss(
374
427
  def denoisplit_musplit_loss(
375
428
  model_outputs: tuple[torch.Tensor, dict[str, Any]],
376
429
  targets: torch.Tensor,
377
- loss_parameters: LVAELossParameters,
430
+ config: LVAELossConfig,
431
+ gaussian_likelihood: GaussianLikelihood,
432
+ noise_model_likelihood: NoiseModelLikelihood,
378
433
  ) -> Optional[dict[str, torch.Tensor]]:
379
434
  """Loss function for DenoiSplit.
380
435
 
@@ -386,9 +441,12 @@ def denoisplit_musplit_loss(
386
441
  targets : torch.Tensor
387
442
  The target image used to compute the reconstruction loss. Shape is
388
443
  (B, `target_ch`, [Z], Y, X).
389
- loss_parameters : LVAELossParameters
390
- The loss parameters for muSplit (e.g., KL hyperparameters, likelihood module,
391
- noise model, etc.).
444
+ config : LVAELossConfig
445
+ The config for loss function containing all loss hyperparameters.
446
+ gaussian_likelihood : GaussianLikelihood
447
+ The Gaussian likelihood object.
448
+ noise_model_likelihood : NoiseModelLikelihood
449
+ The noise model likelihood object.
392
450
 
393
451
  Returns
394
452
  -------
@@ -399,34 +457,35 @@ def denoisplit_musplit_loss(
399
457
  predictions, td_data = model_outputs
400
458
 
401
459
  # Reconstruction loss computation
402
- recons_loss = reconstruction_loss_musplit_denoisplit(
460
+ recons_loss = _reconstruction_loss_musplit_denoisplit(
403
461
  predictions=predictions,
404
462
  targets=targets,
405
- nm_likelihood=loss_parameters.noise_model_likelihood,
406
- gaussian_likelihood=loss_parameters.gaussian_likelihood,
407
- nm_weight=loss_parameters.denoisplit_weight,
408
- gaussian_weight=loss_parameters.musplit_weight,
463
+ nm_likelihood=noise_model_likelihood,
464
+ gaussian_likelihood=gaussian_likelihood,
465
+ nm_weight=config.denoisplit_weight,
466
+ gaussian_weight=config.musplit_weight,
409
467
  )
410
468
  if torch.isnan(recons_loss).any():
411
469
  recons_loss = 0.0
412
470
 
413
471
  # KL loss computation
414
- if loss_parameters.non_stochastic: # TODO always false ?
415
- kl_loss = torch.Tensor([0.0]).cuda()
416
- else:
417
- # NOTE: 'kl' key stands for the 'kl_samplewise' key in the TopDownLayer class.
418
- # The different naming comes from `top_down_pass()` method in the LadderVAE.
419
- denoisplit_kl = get_kl_divergence_loss_denoisplit(
420
- topdown_data=td_data,
421
- img_shape=targets.shape[2:], # input img spatial dims
422
- )
423
- musplit_kl = get_kl_divergence_loss_usplit(td_data)
424
- kl_loss = (
425
- loss_parameters.denoisplit_weight * denoisplit_kl
426
- + loss_parameters.musplit_weight * musplit_kl
427
- )
428
- # TODO `kl_weight` is hardcoded (???)
429
- kl_loss = loss_parameters.kl_weight * kl_loss
472
+ # NOTE: 'kl' key stands for the 'kl_samplewise' key in the TopDownLayer class.
473
+ # The different naming comes from `top_down_pass()` method in the LadderVAE.
474
+ denoisplit_kl = _get_kl_divergence_loss_denoisplit(
475
+ topdown_data=td_data,
476
+ img_shape=targets.shape[2:],
477
+ kl_type=config.kl_params.loss_type,
478
+ )
479
+ musplit_kl = _get_kl_divergence_loss_musplit(
480
+ topdown_data=td_data,
481
+ img_shape=targets.shape[2:],
482
+ kl_type=config.kl_params.loss_type,
483
+ )
484
+ kl_loss = (
485
+ config.denoisplit_weight * denoisplit_kl + config.musplit_weight * musplit_kl
486
+ )
487
+ # TODO `kl_weight` is hardcoded (???)
488
+ kl_loss = config.kl_weight * kl_loss
430
489
 
431
490
  net_loss = recons_loss + kl_loss
432
491
  output = {