careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (69) hide show
  1. careamics/careamist.py +163 -266
  2. careamics/config/algorithm_model.py +0 -15
  3. careamics/config/architectures/custom_model.py +3 -3
  4. careamics/config/configuration_example.py +0 -3
  5. careamics/config/configuration_factory.py +23 -25
  6. careamics/config/configuration_model.py +11 -11
  7. careamics/config/data_model.py +80 -50
  8. careamics/config/inference_model.py +29 -17
  9. careamics/config/optimizer_models.py +7 -7
  10. careamics/config/support/supported_transforms.py +0 -1
  11. careamics/config/tile_information.py +26 -58
  12. careamics/config/transformations/normalize_model.py +32 -4
  13. careamics/config/validators/validator_utils.py +1 -1
  14. careamics/dataset/__init__.py +12 -1
  15. careamics/dataset/dataset_utils/__init__.py +8 -1
  16. careamics/dataset/dataset_utils/file_utils.py +1 -1
  17. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  18. careamics/dataset/dataset_utils/read_tiff.py +0 -9
  19. careamics/dataset/dataset_utils/running_stats.py +186 -0
  20. careamics/dataset/in_memory_dataset.py +66 -171
  21. careamics/dataset/in_memory_pred_dataset.py +88 -0
  22. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  23. careamics/dataset/iterable_dataset.py +92 -249
  24. careamics/dataset/iterable_pred_dataset.py +121 -0
  25. careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
  26. careamics/dataset/patching/patching.py +54 -25
  27. careamics/dataset/patching/random_patching.py +9 -4
  28. careamics/dataset/patching/validate_patch_dimension.py +5 -3
  29. careamics/dataset/tiling/__init__.py +10 -0
  30. careamics/dataset/tiling/collate_tiles.py +33 -0
  31. careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
  32. careamics/lightning_datamodule.py +1 -6
  33. careamics/lightning_module.py +11 -7
  34. careamics/lightning_prediction_datamodule.py +52 -72
  35. careamics/lvae_training/__init__.py +0 -0
  36. careamics/lvae_training/data_modules.py +1220 -0
  37. careamics/lvae_training/data_utils.py +618 -0
  38. careamics/lvae_training/eval_utils.py +905 -0
  39. careamics/lvae_training/get_config.py +84 -0
  40. careamics/lvae_training/lightning_module.py +701 -0
  41. careamics/lvae_training/metrics.py +214 -0
  42. careamics/lvae_training/train_lvae.py +339 -0
  43. careamics/lvae_training/train_utils.py +121 -0
  44. careamics/model_io/bioimage/model_description.py +40 -32
  45. careamics/model_io/bmz_io.py +1 -1
  46. careamics/model_io/model_io_utils.py +5 -2
  47. careamics/models/lvae/__init__.py +0 -0
  48. careamics/models/lvae/layers.py +1998 -0
  49. careamics/models/lvae/likelihoods.py +312 -0
  50. careamics/models/lvae/lvae.py +985 -0
  51. careamics/models/lvae/noise_models.py +409 -0
  52. careamics/models/lvae/utils.py +395 -0
  53. careamics/prediction_utils/__init__.py +12 -0
  54. careamics/prediction_utils/create_pred_datamodule.py +185 -0
  55. careamics/prediction_utils/prediction_outputs.py +165 -0
  56. careamics/prediction_utils/stitch_prediction.py +100 -0
  57. careamics/transforms/n2v_manipulate.py +3 -1
  58. careamics/transforms/normalize.py +139 -68
  59. careamics/transforms/pixel_manipulation.py +33 -9
  60. careamics/transforms/tta.py +43 -29
  61. careamics/utils/ram.py +2 -2
  62. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +7 -6
  63. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/RECORD +65 -42
  64. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
  65. careamics/lightning_prediction_loop.py +0 -118
  66. careamics/prediction/__init__.py +0 -7
  67. careamics/prediction/stitch_prediction.py +0 -70
  68. careamics/utils/running_stats.py +0 -43
  69. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,701 @@
1
+ """
2
+ Lightning Module for LadderVAE.
3
+ """
4
+
5
+ from typing import Any, Dict
6
+
7
+ import ml_collections
8
+ import numpy as np
9
+ import pytorch_lightning as L
10
+ import torch
11
+ import torchvision.transforms.functional as F
12
+
13
+ from careamics.models.lvae.likelihoods import LikelihoodModule
14
+ from careamics.models.lvae.lvae import LadderVAE
15
+ from careamics.models.lvae.utils import (
16
+ LossType,
17
+ compute_batch_mean,
18
+ free_bits_kl,
19
+ torch_nanmean,
20
+ )
21
+
22
+ from .metrics import RangeInvariantPsnr, RunningPSNR
23
+ from .train_utils import MetricMonitor
24
+
25
+
26
+ class LadderVAELight(L.LightningModule):
27
+
28
+ def __init__(
29
+ self,
30
+ config: ml_collections.ConfigDict,
31
+ data_mean: Dict[str, torch.Tensor],
32
+ data_std: Dict[str, torch.Tensor],
33
+ target_ch: int,
34
+ ):
35
+ """
36
+ Here we will do the following:
37
+ - initialize the model (from LadderVAE class)
38
+ - initialize the parameters related to the training and loss.
39
+
40
+ NOTE:
41
+ Some of the model attributes are defined in the model object itself, while some others will be defined here.
42
+ Note that all the attributes related to the training and loss that were already defined in the model object
43
+ are redefined here as Lightning module attributes (e.g., self.some_attr = model.some_attr).
44
+ The attributes related to the model itself are treated as model attributes (e.g., self.model.some_attr).
45
+
46
+ NOTE: HC stands for Hard Coded attribute.
47
+ """
48
+ super().__init__()
49
+
50
+ self.data_mean = data_mean
51
+ self.data_std = data_std
52
+ self.target_ch = target_ch
53
+
54
+ # Initialize LVAE model
55
+ self.model = LadderVAE(
56
+ data_mean=data_mean, data_std=data_std, config=config, target_ch=target_ch
57
+ )
58
+
59
+ ##### Define attributes from config #####
60
+ self.workdir = config.workdir
61
+ self._input_is_sum = False
62
+ self.kl_loss_formulation = config.loss.kl_loss_formulation
63
+ assert self.kl_loss_formulation in [
64
+ None,
65
+ "",
66
+ "usplit",
67
+ "denoisplit",
68
+ "denoisplit_usplit",
69
+ ], f"""
70
+ Invalid kl_loss_formulation. {self.kl_loss_formulation}"""
71
+
72
+ ##### Define loss attributes #####
73
+ # Parameters already defined in the model object
74
+ self.loss_type = self.model.loss_type
75
+ self._denoisplit_w = self._usplit_w = None
76
+ if self.loss_type == LossType.DenoiSplitMuSplit:
77
+ self._usplit_w = 0
78
+ self._denoisplit_w = 1 - self._usplit_w
79
+ assert self._denoisplit_w + self._usplit_w == 1
80
+ self._restricted_kl = self.model._restricted_kl
81
+
82
+ # General loss parameters
83
+ self.channel_1_w = 1
84
+ self.channel_2_w = 1
85
+
86
+ # About Reconsruction Loss
87
+ self.reconstruction_mode = False
88
+ self.skip_nboundary_pixels_from_loss = None
89
+ self.reconstruction_weight = 1.0
90
+ self._exclusion_loss_weight = 0
91
+ self.ch1_recons_w = 1
92
+ self.ch2_recons_w = 1
93
+ self.enable_mixed_rec = False
94
+ self.mixed_rec_w_step = 0
95
+
96
+ # About KL Loss
97
+ self.kl_weight = 1.0 # HC
98
+ self.usplit_kl_weight = None # HC
99
+ self.free_bits = 1.0 # HC
100
+ self.kl_annealing = False # HC
101
+ self.kl_annealtime = self.kl_start = None
102
+ if self.kl_annealing:
103
+ self.kl_annealtime = 10 # HC
104
+ self.kl_start = -1 # HC
105
+
106
+ ##### Define training attributes #####
107
+ self.lr = config.training.lr
108
+ self.lr_scheduler_patience = config.training.lr_scheduler_patience
109
+ self.lr_scheduler_monitor = config.model.get("monitor", "val_loss")
110
+ self.lr_scheduler_mode = MetricMonitor(self.lr_scheduler_monitor).mode()
111
+
112
+ # Initialize object for keeping track of PSNR for each output channel
113
+ self.channels_psnr = [RunningPSNR() for _ in range(self.model.target_ch)]
114
+
115
+ def forward(self, x: Any) -> Any:
116
+ return self.model(x)
117
+
118
+ def training_step(
119
+ self, batch: torch.Tensor, batch_idx: int, enable_logging: bool = True
120
+ ) -> Dict[str, torch.Tensor]:
121
+
122
+ if self.current_epoch == 0 and batch_idx == 0:
123
+ self.log("val_psnr", 1.0, on_epoch=True)
124
+
125
+ # Pre-processing of inputs
126
+ x, target = batch[:2]
127
+ self.set_params_to_same_device_as(x)
128
+ x_normalized = self.normalize_input(x)
129
+ if self.reconstruction_mode: # just for experimental purpose
130
+ target_normalized = x_normalized[:, :1].repeat(1, 2, 1, 1)
131
+ target = None
132
+ mask = None
133
+ else:
134
+ target_normalized = self.normalize_target(target)
135
+ mask = ~((target == 0).reshape(len(target), -1).all(dim=1))
136
+
137
+ # Forward pass
138
+ out, td_data = self.forward(x_normalized)
139
+
140
+ if (
141
+ self.model.encoder_no_padding_mode
142
+ and out.shape[-2:] != target_normalized.shape[-2:]
143
+ ):
144
+ target_normalized = F.center_crop(target_normalized, out.shape[-2:])
145
+
146
+ # Loss Computations
147
+ # mask = torch.isnan(target.reshape(len(x), -1)).all(dim=1)
148
+ recons_loss_dict, imgs = self.get_reconstruction_loss(
149
+ reconstruction=out,
150
+ target=target_normalized,
151
+ input=x_normalized,
152
+ splitting_mask=mask,
153
+ return_predicted_img=True,
154
+ )
155
+
156
+ # This `if` is not used by default config
157
+ if self.skip_nboundary_pixels_from_loss:
158
+ pad = self.skip_nboundary_pixels_from_loss
159
+ target_normalized = target_normalized[:, :, pad:-pad, pad:-pad]
160
+
161
+ recons_loss = recons_loss_dict["loss"] * self.reconstruction_weight
162
+
163
+ if torch.isnan(recons_loss).any():
164
+ recons_loss = 0.0
165
+
166
+ if self.model.non_stochastic_version:
167
+ kl_loss = torch.Tensor([0.0]).cuda()
168
+ net_loss = recons_loss
169
+ else:
170
+ if self.loss_type == LossType.DenoiSplitMuSplit:
171
+ msg = f"For the loss type {LossType.name(self.loss_type)}, kl_loss_formulation must be denoisplit_usplit"
172
+ assert self.kl_loss_formulation == "denoisplit_usplit", msg
173
+ assert self._denoisplit_w is not None and self._usplit_w is not None
174
+
175
+ kl_key_denoisplit = "kl_restricted" if self._restricted_kl else "kl"
176
+ # NOTE: 'kl' key stands for the 'kl_samplewise' key in the TopDownLayer class.
177
+ # The different naming comes from `top_down_pass()` method in the LadderVAE class.
178
+ denoisplit_kl = self.get_kl_divergence_loss(
179
+ topdown_layer_data_dict=td_data, kl_key=kl_key_denoisplit
180
+ )
181
+ usplit_kl = self.get_kl_divergence_loss_usplit(
182
+ topdown_layer_data_dict=td_data
183
+ )
184
+ kl_loss = (
185
+ self._denoisplit_w * denoisplit_kl + self._usplit_w * usplit_kl
186
+ )
187
+ kl_loss = self.kl_weight * kl_loss
188
+
189
+ recons_loss = self.reconstruction_loss_musplit_denoisplit(
190
+ out, target_normalized
191
+ )
192
+ # recons_loss = self._denoisplit_w * recons_loss_nm + self._usplit_w * recons_loss_gm
193
+
194
+ elif self.kl_loss_formulation == "usplit":
195
+ kl_loss = self.get_kl_weight() * self.get_kl_divergence_loss_usplit(
196
+ td_data
197
+ )
198
+ elif self.kl_loss_formulation in ["", "denoisplit"]:
199
+ kl_loss = self.get_kl_weight() * self.get_kl_divergence_loss(td_data)
200
+ net_loss = recons_loss + kl_loss
201
+
202
+ # Logging
203
+ if enable_logging:
204
+ for i, x in enumerate(td_data["debug_qvar_max"]):
205
+ self.log(f"qvar_max:{i}", x.item(), on_epoch=True)
206
+
207
+ self.log("reconstruction_loss", recons_loss_dict["loss"], on_epoch=True)
208
+ self.log("kl_loss", kl_loss, on_epoch=True)
209
+ self.log("training_loss", net_loss, on_epoch=True)
210
+ self.log("lr", self.lr, on_epoch=True)
211
+ if self.model._tethered_ch2_scalar is not None:
212
+ self.log(
213
+ "tethered_ch2_scalar",
214
+ self.model._tethered_ch2_scalar,
215
+ on_epoch=True,
216
+ )
217
+ self.log(
218
+ "tethered_ch1_scalar",
219
+ self.model._tethered_ch1_scalar,
220
+ on_epoch=True,
221
+ )
222
+
223
+ # self.log('grad_norm_bottom_up', self.grad_norm_bottom_up, on_epoch=True)
224
+ # self.log('grad_norm_top_down', self.grad_norm_top_down, on_epoch=True)
225
+
226
+ output = {
227
+ "loss": net_loss,
228
+ "reconstruction_loss": (
229
+ recons_loss.detach()
230
+ if isinstance(recons_loss, torch.Tensor)
231
+ else recons_loss
232
+ ),
233
+ "kl_loss": kl_loss.detach(),
234
+ }
235
+ # https://github.com/openai/vdvae/blob/main/train.py#L26
236
+ if torch.isnan(net_loss).any():
237
+ return None
238
+
239
+ return output
240
+
241
+ def validation_step(self, batch: torch.Tensor, batch_idx: int):
242
+ # Pre-processing of inputs
243
+ x, target = batch[:2]
244
+ self.set_params_to_same_device_as(x)
245
+ x_normalized = self.normalize_input(x)
246
+ if self.reconstruction_mode: # only for experimental purpose
247
+ target_normalized = x_normalized[:, :1].repeat(1, 2, 1, 1)
248
+ target = None
249
+ mask = None
250
+ else:
251
+ target_normalized = self.normalize_target(target)
252
+ mask = ~((target == 0).reshape(len(target), -1).all(dim=1))
253
+
254
+ # Forward pass
255
+ out, _ = self.forward(x_normalized)
256
+
257
+ if self.model.predict_logvar is not None:
258
+ out_mean, _ = out.chunk(2, dim=1)
259
+ else:
260
+ out_mean = out
261
+
262
+ if (
263
+ self.model.encoder_no_padding_mode
264
+ and out.shape[-2:] != target_normalized.shape[-2:]
265
+ ):
266
+ target_normalized = F.center_crop(target_normalized, out.shape[-2:])
267
+
268
+ if self.loss_type == LossType.DenoiSplitMuSplit:
269
+ recons_loss = self.reconstruction_loss_musplit_denoisplit(
270
+ out, target_normalized
271
+ )
272
+ recons_loss_dict = {"loss": recons_loss}
273
+ recons_img = out_mean
274
+ else:
275
+ # Metrics computation
276
+ recons_loss_dict, recons_img = self.get_reconstruction_loss(
277
+ reconstruction=out_mean,
278
+ target=target_normalized,
279
+ input=x_normalized,
280
+ splitting_mask=mask,
281
+ return_predicted_img=True,
282
+ )
283
+
284
+ # This `if` is not used by default config
285
+ if self.skip_nboundary_pixels_from_loss:
286
+ pad = self.skip_nboundary_pixels_from_loss
287
+ target_normalized = target_normalized[:, :, pad:-pad, pad:-pad]
288
+
289
+ channels_rinvpsnr = []
290
+ for i in range(target_normalized.shape[1]):
291
+ self.channels_psnr[i].update(recons_img[:, i], target_normalized[:, i])
292
+ psnr = RangeInvariantPsnr(
293
+ target_normalized[:, i].clone(), recons_img[:, i].clone()
294
+ )
295
+ channels_rinvpsnr.append(psnr)
296
+ psnr = torch_nanmean(psnr).item()
297
+ self.log(f"val_psnr_l{i+1}", psnr, on_epoch=True)
298
+
299
+ recons_loss = recons_loss_dict["loss"]
300
+ if torch.isnan(recons_loss).any():
301
+ return
302
+
303
+ self.log("val_loss", recons_loss, on_epoch=True)
304
+ # self.log('val_psnr', (val_psnr_l1 + val_psnr_l2) / 2, on_epoch=True)
305
+
306
+ # if batch_idx == 0 and self.power_of_2(self.current_epoch):
307
+ # all_samples = []
308
+ # for i in range(20):
309
+ # sample, _ = self(x_normalized[0:1, ...])
310
+ # sample = self.likelihood.get_mean_lv(sample)[0]
311
+ # all_samples.append(sample[None])
312
+
313
+ # all_samples = torch.cat(all_samples, dim=0)
314
+ # all_samples = all_samples * self.data_std + self.data_mean
315
+ # all_samples = all_samples.cpu()
316
+ # img_mmse = torch.mean(all_samples, dim=0)[0]
317
+ # self.log_images_for_tensorboard(all_samples[:, 0, 0, ...], target[0, 0, ...], img_mmse[0], 'label1')
318
+ # self.log_images_for_tensorboard(all_samples[:, 0, 1, ...], target[0, 1, ...], img_mmse[1], 'label2')
319
+
320
+ # return net_loss
321
+
322
+ def on_validation_epoch_end(self):
323
+ psnr_arr = []
324
+ for i in range(len(self.channels_psnr)):
325
+ psnr = self.channels_psnr[i].get()
326
+ if psnr is None:
327
+ psnr_arr = None
328
+ break
329
+ psnr_arr.append(psnr.cpu().numpy())
330
+ self.channels_psnr[i].reset()
331
+
332
+ if psnr_arr is not None:
333
+ psnr = np.mean(psnr_arr)
334
+ self.log("val_psnr", psnr, on_epoch=True)
335
+ else:
336
+ self.log("val_psnr", 0.0, on_epoch=True)
337
+
338
+ if self.mixed_rec_w_step:
339
+ self.mixed_rec_w = max(self.mixed_rec_w - self.mixed_rec_w_step, 0.0)
340
+ self.log("mixed_rec_w", self.mixed_rec_w, on_epoch=True)
341
+
342
+ def predict_step(self, batch: torch.Tensor, batch_idx: Any) -> Any:
343
+ raise NotImplementedError("predict_step is not implemented")
344
+
345
+ def configure_optimizers(self):
346
+ optimizer = torch.optim.Adamax(self.parameters(), lr=self.lr, weight_decay=0)
347
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
348
+ optimizer,
349
+ self.lr_scheduler_mode,
350
+ patience=self.lr_scheduler_patience,
351
+ factor=0.5,
352
+ min_lr=1e-12,
353
+ verbose=True,
354
+ )
355
+
356
+ return {
357
+ "optimizer": optimizer,
358
+ "lr_scheduler": scheduler,
359
+ "monitor": self.lr_scheduler_monitor,
360
+ }
361
+
362
+ ##### REQUIRED Methods for Loss Computation #####
363
+ def get_reconstruction_loss(
364
+ self,
365
+ reconstruction: torch.Tensor,
366
+ target: torch.Tensor,
367
+ input: torch.Tensor,
368
+ splitting_mask: torch.Tensor = None,
369
+ return_predicted_img: bool = False,
370
+ likelihood_obj: LikelihoodModule = None,
371
+ ) -> Dict[str, torch.Tensor]:
372
+ """
373
+ Parameters
374
+ ----------
375
+ reconstruction: torch.Tensor,
376
+ target: torch.Tensor
377
+ input: torch.Tensor
378
+ splitting_mask: torch.Tensor = None
379
+ A boolean tensor that indicates which items to keep for reconstruction loss computation.
380
+ If `None`, all the elements of the items are considered (i.e., the mask is all `True`).
381
+ return_predicted_img: bool = False
382
+ likelihood_obj: LikelihoodModule = None
383
+ """
384
+ output = self._get_reconstruction_loss_vector(
385
+ reconstruction=reconstruction,
386
+ target=target,
387
+ input=input,
388
+ return_predicted_img=return_predicted_img,
389
+ likelihood_obj=likelihood_obj,
390
+ )
391
+ loss_dict = output[0] if return_predicted_img else output
392
+
393
+ if splitting_mask is None:
394
+ splitting_mask = torch.ones_like(loss_dict["loss"]).bool()
395
+
396
+ # print(len(target) - (torch.isnan(loss_dict['loss'])).sum())
397
+
398
+ loss_dict["loss"] = loss_dict["loss"][splitting_mask].sum() / len(
399
+ reconstruction
400
+ )
401
+ for i in range(1, 1 + target.shape[1]):
402
+ key = f"ch{i}_loss"
403
+ loss_dict[key] = loss_dict[key][splitting_mask].sum() / len(reconstruction)
404
+
405
+ if "mixed_loss" in loss_dict:
406
+ loss_dict["mixed_loss"] = torch.mean(loss_dict["mixed_loss"])
407
+ if return_predicted_img:
408
+ assert len(output) == 2
409
+ return loss_dict, output[1]
410
+ else:
411
+ return loss_dict
412
+
413
+ def _get_reconstruction_loss_vector(
414
+ self,
415
+ reconstruction: torch.Tensor,
416
+ target: torch.Tensor,
417
+ input: torch.Tensor,
418
+ return_predicted_img: bool = False,
419
+ likelihood_obj: LikelihoodModule = None,
420
+ ):
421
+ """
422
+ Parameters
423
+ ----------
424
+ return_predicted_img: bool
425
+ If set to `True`, the besides the loss, the reconstructed image is also returned.
426
+ Default is `False`.
427
+ """
428
+ output = {
429
+ "loss": None,
430
+ "mixed_loss": None,
431
+ }
432
+
433
+ for i in range(1, 1 + target.shape[1]):
434
+ output[f"ch{i}_loss"] = None
435
+
436
+ if likelihood_obj is None:
437
+ likelihood_obj = self.model.likelihood
438
+
439
+ # Log likelihood
440
+ ll, like_dict = likelihood_obj(reconstruction, target)
441
+ ll = self._get_weighted_likelihood(ll)
442
+ if (
443
+ self.skip_nboundary_pixels_from_loss is not None
444
+ and self.skip_nboundary_pixels_from_loss > 0
445
+ ):
446
+ pad = self.skip_nboundary_pixels_from_loss
447
+ ll = ll[:, :, pad:-pad, pad:-pad]
448
+ like_dict["params"]["mean"] = like_dict["params"]["mean"][
449
+ :, :, pad:-pad, pad:-pad
450
+ ]
451
+
452
+ # assert ll.shape[1] == 2, f"Change the code below to handle >2 channels first. ll.shape {ll.shape}"
453
+ output = {"loss": compute_batch_mean(-1 * ll)}
454
+ if ll.shape[1] > 1:
455
+ for i in range(1, 1 + target.shape[1]):
456
+ output[f"ch{i}_loss"] = compute_batch_mean(-ll[:, i - 1])
457
+ else:
458
+ assert ll.shape[1] == 1
459
+ output["ch1_loss"] = output["loss"]
460
+ output["ch2_loss"] = output["loss"]
461
+
462
+ if (
463
+ self.channel_1_w is not None
464
+ and self.channel_2_w is not None
465
+ and (self.channel_1_w != 1 or self.channel_2_w != 1)
466
+ ):
467
+ assert ll.shape[1] == 2, "Only 2 channels are supported for now."
468
+ output["loss"] = (
469
+ self.channel_1_w * output["ch1_loss"]
470
+ + self.channel_2_w * output["ch2_loss"]
471
+ ) / (self.channel_1_w + self.channel_2_w)
472
+
473
+ # This `if` is not used by default config
474
+ if self.enable_mixed_rec:
475
+ mixed_pred, mixed_logvar = self.get_mixed_prediction(
476
+ like_dict["params"]["mean"],
477
+ like_dict["params"]["logvar"],
478
+ self.data_mean,
479
+ self.data_std,
480
+ )
481
+ if (
482
+ self.model._multiscale_count is not None
483
+ and self.model._multiscale_count > 1
484
+ ):
485
+ assert input.shape[1] == self.model._multiscale_count
486
+ input = input[:, :1]
487
+
488
+ assert (
489
+ input.shape == mixed_pred.shape
490
+ ), "No fucking room for vectorization induced bugs."
491
+ mixed_recons_ll = self.model.likelihood.log_likelihood(
492
+ input, {"mean": mixed_pred, "logvar": mixed_logvar}
493
+ )
494
+ output["mixed_loss"] = compute_batch_mean(-1 * mixed_recons_ll)
495
+
496
+ # This `if` is not used by default config
497
+ if self._exclusion_loss_weight:
498
+ raise NotImplementedError(
499
+ "Exclusion loss is not well defined here, so it should not be used."
500
+ )
501
+ imgs = like_dict["params"]["mean"]
502
+ exclusion_loss = compute_exclusion_loss(imgs[:, :1], imgs[:, 1:])
503
+ output["exclusion_loss"] = exclusion_loss
504
+
505
+ if return_predicted_img:
506
+ return output, like_dict["params"]["mean"]
507
+
508
+ return output
509
+
510
+ def reconstruction_loss_musplit_denoisplit(self, out, target_normalized):
511
+ if self.model.predict_logvar is not None:
512
+ out_mean, _ = out.chunk(2, dim=1)
513
+ else:
514
+ out_mean = out
515
+
516
+ recons_loss_nm = (
517
+ -1 * self.model.likelihood_NM(out_mean, target_normalized)[0].mean()
518
+ )
519
+ recons_loss_gm = -1 * self.model.likelihood_gm(out, target_normalized)[0].mean()
520
+ recons_loss = (
521
+ self._denoisplit_w * recons_loss_nm + self._usplit_w * recons_loss_gm
522
+ )
523
+ return recons_loss
524
+
525
+ def _get_weighted_likelihood(self, ll):
526
+ """
527
+ Each of the channels gets multiplied with a different weight.
528
+ """
529
+ if self.ch1_recons_w == 1 and self.ch2_recons_w == 1:
530
+ return ll
531
+
532
+ assert ll.shape[1] == 2, "This function is only for 2 channel images"
533
+
534
+ mask1 = torch.zeros((len(ll), ll.shape[1], 1, 1), device=ll.device)
535
+ mask1[:, 0] = 1
536
+ mask2 = torch.zeros((len(ll), ll.shape[1], 1, 1), device=ll.device)
537
+ mask2[:, 1] = 1
538
+
539
+ return ll * mask1 * self.ch1_recons_w + ll * mask2 * self.ch2_recons_w
540
+
541
+ def get_kl_weight(self):
542
+ """
543
+ KL loss can be weighted depending whether any annealing procedure is used.
544
+ This function computes the weight of the KL loss in case of annealing.
545
+ """
546
+ if self.kl_annealing == True:
547
+ # calculate relative weight
548
+ kl_weight = (self.current_epoch - self.kl_start) * (
549
+ 1.0 / self.kl_annealtime
550
+ )
551
+ # clamp to [0,1]
552
+ kl_weight = min(max(0.0, kl_weight), 1.0)
553
+
554
+ # if the final weight is given, then apply that weight on top of it
555
+ if self.kl_weight is not None:
556
+ kl_weight = kl_weight * self.kl_weight
557
+ elif self.kl_weight is not None:
558
+ return self.kl_weight
559
+ else:
560
+ kl_weight = 1.0
561
+ return kl_weight
562
+
563
+ def get_kl_divergence_loss_usplit(
564
+ self, topdown_layer_data_dict: Dict[str, torch.Tensor]
565
+ ) -> torch.Tensor:
566
+ """ """
567
+ kl = torch.cat(
568
+ [kl_layer.unsqueeze(1) for kl_layer in topdown_layer_data_dict["kl"]], dim=1
569
+ )
570
+ # NOTE: kl.shape = (16,4) 16 is batch size. 4 is number of layers.
571
+ # Values are sum() and so are of the order 30000
572
+ # Example values: 30626.6758, 31028.8145, 29509.8809, 29945.4922, 28919.1875, 29075.2988
573
+
574
+ nlayers = kl.shape[1]
575
+ for i in range(nlayers):
576
+ # topdown_layer_data_dict['z'][2].shape[-3:] = 128 * 32 * 32
577
+ norm_factor = np.prod(topdown_layer_data_dict["z"][i].shape[-3:])
578
+ # if self._restricted_kl:
579
+ # pow = np.power(2,min(i + 1, self._multiscale_count-1))
580
+ # norm_factor /= pow * pow
581
+
582
+ kl[:, i] = kl[:, i] / norm_factor
583
+
584
+ kl_loss = free_bits_kl(kl, 0.0).mean()
585
+ return kl_loss
586
+
587
+ def get_kl_divergence_loss(self, topdown_layer_data_dict, kl_key="kl"):
588
+ """
589
+ kl[i] for each i has length batch_size
590
+ resulting kl shape: (batch_size, layers)
591
+ """
592
+ kl = torch.cat(
593
+ [kl_layer.unsqueeze(1) for kl_layer in topdown_layer_data_dict[kl_key]],
594
+ dim=1,
595
+ )
596
+
597
+ # As compared to uSplit kl divergence,
598
+ # more by a factor of 4 just because we do sum and not mean.
599
+ kl_loss = free_bits_kl(kl, self.free_bits).sum()
600
+ # NOTE: at each hierarchy, it is more by a factor of 128/i**2).
601
+ # 128/(2*2) = 32 (bottommost layer)
602
+ # 128/(4*4) = 8
603
+ # 128/(8*8) = 2
604
+ # 128/(16*16) = 0.5 (topmost layer)
605
+
606
+ # Normalize the KL-loss w.r.t. the latent space
607
+ kl_loss = kl_loss / np.prod(self.model.img_shape)
608
+ return kl_loss
609
+
610
+ ##### UTILS Methods #####
611
+ def normalize_input(self, x):
612
+ if self.model.normalized_input:
613
+ return x
614
+ return (x - self.data_mean["input"].mean()) / self.data_std["input"].mean()
615
+
616
+ def normalize_target(self, target, batch=None):
617
+ return (target - self.data_mean["target"]) / self.data_std["target"]
618
+
619
+ def unnormalize_target(self, target_normalized):
620
+ return target_normalized * self.data_std["target"] + self.data_mean["target"]
621
+
622
+ ##### ADDITIONAL Methods #####
623
+ # def log_images_for_tensorboard(self, pred, target, img_mmse, label):
624
+ # clamped_pred = torch.clamp((pred - pred.min()) / (pred.max() - pred.min()), 0, 1)
625
+ # clamped_mmse = torch.clamp((img_mmse - img_mmse.min()) / (img_mmse.max() - img_mmse.min()), 0, 1)
626
+ # if target is not None:
627
+ # clamped_input = torch.clamp((target - target.min()) / (target.max() - target.min()), 0, 1)
628
+ # img = wandb.Image(clamped_input[None].cpu().numpy())
629
+ # self.logger.experiment.log({f'target_for{label}': img})
630
+ # # self.trainer.logger.experiment.add_image(f'target_for{label}', clamped_input[None], self.current_epoch)
631
+ # for i in range(3):
632
+ # # self.trainer.logger.experiment.add_image(f'{label}/sample_{i}', clamped_pred[i:i + 1], self.current_epoch)
633
+ # img = wandb.Image(clamped_pred[i:i + 1].cpu().numpy())
634
+ # self.logger.experiment.log({f'{label}/sample_{i}': img})
635
+
636
+ # img = wandb.Image(clamped_mmse[None].cpu().numpy())
637
+ # self.trainer.logger.experiment.log({f'{label}/mmse (100 samples)': img})
638
+
639
+ @property
640
+ def global_step(self) -> int:
641
+ """Global step."""
642
+ return self._global_step
643
+
644
+ def increment_global_step(self):
645
+ """Increments global step by 1."""
646
+ self._global_step += 1
647
+
648
+ def set_params_to_same_device_as(self, correct_device_tensor: torch.Tensor):
649
+
650
+ self.model.likelihood.set_params_to_same_device_as(correct_device_tensor)
651
+ if isinstance(self.data_mean, torch.Tensor):
652
+ if self.data_mean.device != correct_device_tensor.device:
653
+ self.data_mean = self.data_mean.to(correct_device_tensor.device)
654
+ self.data_std = self.data_std.to(correct_device_tensor.device)
655
+ elif isinstance(self.data_mean, dict):
656
+ for k, v in self.data_mean.items():
657
+ if v.device != correct_device_tensor.device:
658
+ self.data_mean[k] = v.to(correct_device_tensor.device)
659
+ self.data_std[k] = self.data_std[k].to(correct_device_tensor.device)
660
+
661
+ def get_mixed_prediction(
662
+ self, prediction, prediction_logvar, data_mean, data_std, channel_weights=None
663
+ ):
664
+ pred_unorm = prediction * data_std["target"] + data_mean["target"]
665
+ if channel_weights is None:
666
+ channel_weights = 1
667
+
668
+ if self._input_is_sum:
669
+ mixed_prediction = torch.sum(
670
+ pred_unorm * channel_weights, dim=1, keepdim=True
671
+ )
672
+ else:
673
+ mixed_prediction = torch.mean(
674
+ pred_unorm * channel_weights, dim=1, keepdim=True
675
+ )
676
+
677
+ mixed_prediction = (mixed_prediction - data_mean["input"].mean()) / data_std[
678
+ "input"
679
+ ].mean()
680
+
681
+ if prediction_logvar is not None:
682
+ if data_std["target"].shape == data_std["input"].shape and torch.all(
683
+ data_std["target"] == data_std["input"]
684
+ ):
685
+ assert channel_weights == 1
686
+ logvar = prediction_logvar
687
+ else:
688
+ var = torch.exp(prediction_logvar)
689
+ var = var * (data_std["target"] / data_std["input"]) ** 2
690
+ if channel_weights != 1:
691
+ var = var * torch.square(channel_weights)
692
+
693
+ # sum of variance.
694
+ mixed_var = 0
695
+ for i in range(var.shape[1]):
696
+ mixed_var += var[:, i : i + 1]
697
+
698
+ logvar = torch.log(mixed_var)
699
+ else:
700
+ logvar = None
701
+ return mixed_prediction, logvar