careamics 0.0.15__py3-none-any.whl → 0.0.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +6 -12
- careamics/cli/conf.py +18 -3
- careamics/config/__init__.py +8 -0
- careamics/config/algorithms/__init__.py +4 -0
- careamics/config/algorithms/hdn_algorithm_model.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_model.py +103 -0
- careamics/config/algorithms/n2v_algorithm_model.py +1 -2
- careamics/config/algorithms/vae_algorithm_model.py +51 -16
- careamics/config/architectures/lvae_model.py +12 -8
- careamics/config/callback_model.py +7 -3
- careamics/config/configuration.py +9 -8
- careamics/config/configuration_factories.py +843 -29
- careamics/config/data/data_model.py +1 -2
- careamics/config/data/ng_data_model.py +1 -2
- careamics/config/inference_model.py +1 -2
- careamics/config/likelihood_model.py +2 -2
- careamics/config/loss_model.py +6 -2
- careamics/config/nm_model.py +26 -1
- careamics/config/optimizer_models.py +1 -2
- careamics/config/support/supported_algorithms.py +5 -3
- careamics/config/support/supported_losses.py +5 -2
- careamics/config/training_model.py +6 -36
- careamics/config/transformations/normalize_model.py +1 -2
- careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +4 -4
- careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -2
- careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +33 -7
- careamics/dataset_ng/patch_extractor/image_stack_loader.py +2 -2
- careamics/file_io/read/__init__.py +0 -1
- careamics/lightning/__init__.py +16 -2
- careamics/lightning/callbacks/__init__.py +2 -0
- careamics/lightning/callbacks/data_stats_callback.py +23 -0
- careamics/lightning/lightning_module.py +161 -61
- careamics/lightning/microsplit_data_module.py +631 -0
- careamics/lightning/predict_data_module.py +8 -1
- careamics/lightning/train_data_module.py +19 -8
- careamics/losses/__init__.py +7 -1
- careamics/losses/loss_factory.py +9 -1
- careamics/losses/lvae/losses.py +85 -0
- careamics/lvae_training/dataset/__init__.py +8 -8
- careamics/lvae_training/dataset/config.py +56 -44
- careamics/lvae_training/dataset/lc_dataset.py +18 -12
- careamics/lvae_training/dataset/ms_dataset_ref.py +5 -5
- careamics/lvae_training/dataset/multich_dataset.py +24 -18
- careamics/lvae_training/dataset/multifile_dataset.py +6 -6
- careamics/model_io/bmz_io.py +9 -5
- careamics/models/lvae/likelihoods.py +30 -14
- careamics/models/lvae/lvae.py +2 -2
- careamics/models/lvae/noise_models.py +20 -14
- careamics/prediction_utils/__init__.py +8 -2
- careamics/prediction_utils/prediction_outputs.py +48 -3
- careamics/prediction_utils/stitch_prediction.py +71 -0
- careamics/transforms/xy_random_rotate90.py +1 -1
- {careamics-0.0.15.dist-info → careamics-0.0.16.dist-info}/METADATA +18 -15
- {careamics-0.0.15.dist-info → careamics-0.0.16.dist-info}/RECORD +57 -55
- careamics/dataset/zarr_dataset.py +0 -151
- careamics/file_io/read/zarr.py +0 -60
- {careamics-0.0.15.dist-info → careamics-0.0.16.dist-info}/WHEEL +0 -0
- {careamics-0.0.15.dist-info → careamics-0.0.16.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.15.dist-info → careamics-0.0.16.dist-info}/licenses/LICENSE +0 -0
|
@@ -5,7 +5,7 @@ from typing import Any, Literal, Union
|
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
7
|
import pytorch_lightning as L
|
|
8
|
-
|
|
8
|
+
import torch
|
|
9
9
|
|
|
10
10
|
from careamics.config import (
|
|
11
11
|
N2VAlgorithm,
|
|
@@ -71,7 +71,9 @@ class FCNModule(L.LightningModule):
|
|
|
71
71
|
Learning rate scheduler name.
|
|
72
72
|
"""
|
|
73
73
|
|
|
74
|
-
def __init__(
|
|
74
|
+
def __init__(
|
|
75
|
+
self, algorithm_config: Union[UNetBasedAlgorithm, VAEBasedAlgorithm, dict]
|
|
76
|
+
) -> None:
|
|
75
77
|
"""Lightning module for CAREamics.
|
|
76
78
|
|
|
77
79
|
This class encapsulates the a PyTorch model along with the training, validation,
|
|
@@ -98,7 +100,7 @@ class FCNModule(L.LightningModule):
|
|
|
98
100
|
self.n2v_preprocess = None
|
|
99
101
|
|
|
100
102
|
self.algorithm = algorithm_config.algorithm
|
|
101
|
-
self.model: nn.Module = model_factory(algorithm_config.model)
|
|
103
|
+
self.model: torch.nn.Module = model_factory(algorithm_config.model)
|
|
102
104
|
self.loss_func = loss_factory(algorithm_config.loss)
|
|
103
105
|
|
|
104
106
|
# save optimizer and lr_scheduler names and parameters
|
|
@@ -122,12 +124,12 @@ class FCNModule(L.LightningModule):
|
|
|
122
124
|
"""
|
|
123
125
|
return self.model(x)
|
|
124
126
|
|
|
125
|
-
def training_step(self, batch: Tensor, batch_idx: Any) -> Any:
|
|
127
|
+
def training_step(self, batch: torch.Tensor, batch_idx: Any) -> Any:
|
|
126
128
|
"""Training step.
|
|
127
129
|
|
|
128
130
|
Parameters
|
|
129
131
|
----------
|
|
130
|
-
batch : torch.Tensor
|
|
132
|
+
batch : torch.torch.Tensor
|
|
131
133
|
Input batch.
|
|
132
134
|
batch_idx : Any
|
|
133
135
|
Batch index.
|
|
@@ -154,12 +156,12 @@ class FCNModule(L.LightningModule):
|
|
|
154
156
|
self.log("learning_rate", current_lr, on_step=False, on_epoch=True, logger=True)
|
|
155
157
|
return loss
|
|
156
158
|
|
|
157
|
-
def validation_step(self, batch: Tensor, batch_idx: Any) -> None:
|
|
159
|
+
def validation_step(self, batch: torch.Tensor, batch_idx: Any) -> None:
|
|
158
160
|
"""Validation step.
|
|
159
161
|
|
|
160
162
|
Parameters
|
|
161
163
|
----------
|
|
162
|
-
batch : torch.Tensor
|
|
164
|
+
batch : torch.torch.Tensor
|
|
163
165
|
Input batch.
|
|
164
166
|
batch_idx : Any
|
|
165
167
|
Batch index.
|
|
@@ -184,12 +186,12 @@ class FCNModule(L.LightningModule):
|
|
|
184
186
|
logger=True,
|
|
185
187
|
)
|
|
186
188
|
|
|
187
|
-
def predict_step(self, batch: Tensor, batch_idx: Any) -> Any:
|
|
189
|
+
def predict_step(self, batch: torch.Tensor, batch_idx: Any) -> Any:
|
|
188
190
|
"""Prediction step.
|
|
189
191
|
|
|
190
192
|
Parameters
|
|
191
193
|
----------
|
|
192
|
-
batch : torch.Tensor
|
|
194
|
+
batch : torch.torch.torch.Tensor
|
|
193
195
|
Input batch.
|
|
194
196
|
batch_idx : Any
|
|
195
197
|
Batch index.
|
|
@@ -330,17 +332,21 @@ class VAEModule(L.LightningModule):
|
|
|
330
332
|
# self.save_hyperparameters(self.algorithm_config.model_dump())
|
|
331
333
|
|
|
332
334
|
# create model
|
|
333
|
-
self.model: nn.Module = model_factory(self.algorithm_config.model)
|
|
335
|
+
self.model: torch.nn.Module = model_factory(self.algorithm_config.model)
|
|
334
336
|
|
|
337
|
+
# supervised_mode
|
|
338
|
+
self.supervised_mode = self.algorithm_config.is_supervised
|
|
335
339
|
# create loss function
|
|
336
340
|
self.noise_model: NoiseModel | None = noise_model_factory(
|
|
337
341
|
self.algorithm_config.noise_model
|
|
338
342
|
)
|
|
339
343
|
|
|
340
|
-
self.noise_model_likelihood: NoiseModelLikelihood | None =
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
+
self.noise_model_likelihood: NoiseModelLikelihood | None = None
|
|
345
|
+
if self.algorithm_config.noise_model_likelihood is not None:
|
|
346
|
+
self.noise_model_likelihood = likelihood_factory(
|
|
347
|
+
config=self.algorithm_config.noise_model_likelihood,
|
|
348
|
+
noise_model=self.noise_model,
|
|
349
|
+
)
|
|
344
350
|
|
|
345
351
|
self.gaussian_likelihood: GaussianLikelihood | None = likelihood_factory(
|
|
346
352
|
self.algorithm_config.gaussian_likelihood
|
|
@@ -360,30 +366,43 @@ class VAEModule(L.LightningModule):
|
|
|
360
366
|
RunningPSNR() for _ in range(self.algorithm_config.model.output_channels)
|
|
361
367
|
]
|
|
362
368
|
|
|
363
|
-
def forward(self, x: Tensor) -> tuple[Tensor, dict[str, Any]]:
|
|
369
|
+
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, dict[str, Any]]:
|
|
364
370
|
"""Forward pass.
|
|
365
371
|
|
|
366
372
|
Parameters
|
|
367
373
|
----------
|
|
368
|
-
x : Tensor
|
|
374
|
+
x : torch.Tensor
|
|
369
375
|
Input tensor of shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the
|
|
370
376
|
number of lateral inputs.
|
|
371
377
|
|
|
372
378
|
Returns
|
|
373
379
|
-------
|
|
374
|
-
tuple[Tensor, dict[str, Any]]
|
|
380
|
+
tuple[torch.Tensor, dict[str, Any]]
|
|
375
381
|
A tuple with the output tensor and additional data from the top-down pass.
|
|
376
382
|
"""
|
|
377
383
|
return self.model(x) # TODO Different model can have more than one output
|
|
378
384
|
|
|
385
|
+
def set_data_stats(self, data_mean, data_std):
|
|
386
|
+
"""Set data mean and std for the noise model likelihood.
|
|
387
|
+
|
|
388
|
+
Parameters
|
|
389
|
+
----------
|
|
390
|
+
data_mean : float
|
|
391
|
+
Mean of the data.
|
|
392
|
+
data_std : float
|
|
393
|
+
Standard deviation of the data.
|
|
394
|
+
"""
|
|
395
|
+
if self.noise_model_likelihood is not None:
|
|
396
|
+
self.noise_model_likelihood.set_data_stats(data_mean, data_std)
|
|
397
|
+
|
|
379
398
|
def training_step(
|
|
380
|
-
self, batch: tuple[Tensor, Tensor], batch_idx: Any
|
|
381
|
-
) -> dict[str, Tensor] | None:
|
|
399
|
+
self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: Any
|
|
400
|
+
) -> dict[str, torch.Tensor] | None:
|
|
382
401
|
"""Training step.
|
|
383
402
|
|
|
384
403
|
Parameters
|
|
385
404
|
----------
|
|
386
|
-
batch : tuple[Tensor, Tensor]
|
|
405
|
+
batch : tuple[torch.Tensor, torch.Tensor]
|
|
387
406
|
Input batch. It is a tuple with the input tensor and the target tensor.
|
|
388
407
|
The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the
|
|
389
408
|
number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X),
|
|
@@ -397,15 +416,29 @@ class VAEModule(L.LightningModule):
|
|
|
397
416
|
Any
|
|
398
417
|
Loss value.
|
|
399
418
|
"""
|
|
400
|
-
x, target = batch
|
|
419
|
+
x, *target = batch
|
|
401
420
|
|
|
402
421
|
# Forward pass
|
|
403
422
|
out = self.model(x)
|
|
423
|
+
if not self.supervised_mode:
|
|
424
|
+
target = x
|
|
425
|
+
else:
|
|
426
|
+
target = target[
|
|
427
|
+
0
|
|
428
|
+
] # hacky way to unpack. #TODO maybe should be fixed on the dataset level
|
|
404
429
|
|
|
405
430
|
# Update loss parameters
|
|
406
431
|
self.loss_parameters.kl_params.current_epoch = self.current_epoch
|
|
407
432
|
|
|
408
433
|
# Compute loss
|
|
434
|
+
if self.noise_model_likelihood is not None:
|
|
435
|
+
if (
|
|
436
|
+
self.noise_model_likelihood.data_mean is None
|
|
437
|
+
or self.noise_model_likelihood.data_std is None
|
|
438
|
+
):
|
|
439
|
+
raise RuntimeError(
|
|
440
|
+
"NoiseModelLikelihood: data_mean and data_std must be set before training."
|
|
441
|
+
)
|
|
409
442
|
loss = self.loss_func(
|
|
410
443
|
model_outputs=out,
|
|
411
444
|
targets=target,
|
|
@@ -417,15 +450,26 @@ class VAEModule(L.LightningModule):
|
|
|
417
450
|
# Logging
|
|
418
451
|
# TODO: implement a separate logging method?
|
|
419
452
|
self.log_dict(loss, on_step=True, on_epoch=True)
|
|
420
|
-
|
|
453
|
+
|
|
454
|
+
try:
|
|
455
|
+
optimizer = self.optimizers()
|
|
456
|
+
current_lr = optimizer.param_groups[0]["lr"]
|
|
457
|
+
self.log(
|
|
458
|
+
"learning_rate", current_lr, on_step=False, on_epoch=True, logger=True
|
|
459
|
+
)
|
|
460
|
+
except RuntimeError:
|
|
461
|
+
# This happens when the module is not attached to a trainer, e.g., in tests
|
|
462
|
+
pass
|
|
421
463
|
return loss
|
|
422
464
|
|
|
423
|
-
def validation_step(
|
|
465
|
+
def validation_step(
|
|
466
|
+
self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: Any
|
|
467
|
+
) -> None:
|
|
424
468
|
"""Validation step.
|
|
425
469
|
|
|
426
470
|
Parameters
|
|
427
471
|
----------
|
|
428
|
-
batch : tuple[Tensor, Tensor]
|
|
472
|
+
batch : tuple[torch.Tensor, torch.Tensor]
|
|
429
473
|
Input batch. It is a tuple with the input tensor and the target tensor.
|
|
430
474
|
The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the
|
|
431
475
|
number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X),
|
|
@@ -434,11 +478,16 @@ class VAEModule(L.LightningModule):
|
|
|
434
478
|
batch_idx : Any
|
|
435
479
|
Batch index.
|
|
436
480
|
"""
|
|
437
|
-
x, target = batch
|
|
481
|
+
x, *target = batch
|
|
438
482
|
|
|
439
483
|
# Forward pass
|
|
440
484
|
out = self.model(x)
|
|
441
|
-
|
|
485
|
+
if not self.supervised_mode:
|
|
486
|
+
target = x
|
|
487
|
+
else:
|
|
488
|
+
target = target[
|
|
489
|
+
0
|
|
490
|
+
] # hacky way to unpack. #TODO maybe should be fixed on the datasel level
|
|
442
491
|
# Compute loss
|
|
443
492
|
loss = self.loss_func(
|
|
444
493
|
model_outputs=out,
|
|
@@ -464,12 +513,12 @@ class VAEModule(L.LightningModule):
|
|
|
464
513
|
else:
|
|
465
514
|
self.log("val_psnr", 0.0, on_epoch=True, prog_bar=True)
|
|
466
515
|
|
|
467
|
-
def predict_step(self, batch: Tensor, batch_idx: Any) -> Any:
|
|
516
|
+
def predict_step(self, batch: torch.Tensor, batch_idx: Any) -> Any:
|
|
468
517
|
"""Prediction step.
|
|
469
518
|
|
|
470
519
|
Parameters
|
|
471
520
|
----------
|
|
472
|
-
batch : Tensor
|
|
521
|
+
batch : torch.Tensor
|
|
473
522
|
Input batch.
|
|
474
523
|
batch_idx : Any
|
|
475
524
|
Batch index.
|
|
@@ -479,36 +528,86 @@ class VAEModule(L.LightningModule):
|
|
|
479
528
|
Any
|
|
480
529
|
Model output.
|
|
481
530
|
"""
|
|
482
|
-
if self.
|
|
531
|
+
if self.algorithm_config.algorithm == "microsplit":
|
|
483
532
|
x, *aux = batch
|
|
484
|
-
|
|
485
|
-
x
|
|
486
|
-
aux = []
|
|
533
|
+
# Reset model for inference with spatial dimensions only (H, W)
|
|
534
|
+
self.model.reset_for_inference(x.shape[-2:])
|
|
487
535
|
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
augmented_batch = tta.forward(x) # list of augmented tensors
|
|
493
|
-
augmented_output = []
|
|
494
|
-
for augmented in augmented_batch:
|
|
495
|
-
augmented_pred = self.model(augmented)
|
|
496
|
-
augmented_output.append(augmented_pred)
|
|
497
|
-
output = tta.backward(augmented_output)
|
|
498
|
-
else:
|
|
499
|
-
output = self.model(x)
|
|
536
|
+
rec_img_list = []
|
|
537
|
+
for _ in range(self.algorithm_config.mmse_count):
|
|
538
|
+
# get model output
|
|
539
|
+
rec, _ = self.model(x)
|
|
500
540
|
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
541
|
+
# get reconstructed img
|
|
542
|
+
if self.model.predict_logvar is None:
|
|
543
|
+
rec_img = rec
|
|
544
|
+
logvar = torch.tensor([-1])
|
|
545
|
+
else:
|
|
546
|
+
rec_img, logvar = torch.chunk(rec, chunks=2, dim=1)
|
|
547
|
+
rec_img_list.append(rec_img.cpu().unsqueeze(0)) # add MMSE dim
|
|
548
|
+
|
|
549
|
+
# aggregate results
|
|
550
|
+
samples = torch.cat(rec_img_list, dim=0)
|
|
551
|
+
mmse_imgs = torch.mean(samples, dim=0) # avg over MMSE dim
|
|
552
|
+
std_imgs = torch.std(samples, dim=0) # std over MMSE dim
|
|
553
|
+
|
|
554
|
+
tile_prediction = mmse_imgs.cpu().numpy()
|
|
555
|
+
tile_std = std_imgs.cpu().numpy()
|
|
556
|
+
|
|
557
|
+
return tile_prediction, tile_std
|
|
507
558
|
|
|
508
|
-
if len(aux) > 0: # aux can be tiling information
|
|
509
|
-
return denormalized_output, *aux
|
|
510
559
|
else:
|
|
511
|
-
|
|
560
|
+
# Regular prediction logic
|
|
561
|
+
if self._trainer.datamodule.tiled:
|
|
562
|
+
# TODO tile_size should match model input size
|
|
563
|
+
x, *aux = batch
|
|
564
|
+
x = (
|
|
565
|
+
x[0] if isinstance(x, list | tuple) else x
|
|
566
|
+
) # TODO ugly, so far i don't know why x might be a list
|
|
567
|
+
self.model.reset_for_inference(x.shape) # TODO should it be here ?
|
|
568
|
+
else:
|
|
569
|
+
x = batch[0] if isinstance(batch, list | tuple) else batch
|
|
570
|
+
aux = []
|
|
571
|
+
self.model.reset_for_inference(x.shape)
|
|
572
|
+
|
|
573
|
+
mmse_list = []
|
|
574
|
+
for _ in range(self.algorithm_config.mmse_count):
|
|
575
|
+
# apply test-time augmentation if available
|
|
576
|
+
if self._trainer.datamodule.prediction_config.tta_transforms:
|
|
577
|
+
tta = ImageRestorationTTA()
|
|
578
|
+
augmented_batch = tta.forward(x) # list of augmented tensors
|
|
579
|
+
augmented_output = []
|
|
580
|
+
for augmented in augmented_batch:
|
|
581
|
+
augmented_pred = self.model(augmented)
|
|
582
|
+
augmented_output.append(augmented_pred)
|
|
583
|
+
output = tta.backward(augmented_output)
|
|
584
|
+
else:
|
|
585
|
+
output = self.model(x)
|
|
586
|
+
|
|
587
|
+
# taking the 1st element of the output, 2nd is std if
|
|
588
|
+
# predict_logvar=="pixelwise"
|
|
589
|
+
output = (
|
|
590
|
+
output[0]
|
|
591
|
+
if self.model.predict_logvar is None
|
|
592
|
+
else output[0][:, 0:1, ...]
|
|
593
|
+
)
|
|
594
|
+
mmse_list.append(output)
|
|
595
|
+
|
|
596
|
+
mmse = torch.stack(mmse_list).mean(0)
|
|
597
|
+
std = torch.stack(mmse_list).std(0) # TODO why?
|
|
598
|
+
# TODO better way to unpack if pred logvar
|
|
599
|
+
# Denormalize the output
|
|
600
|
+
denorm = Denormalize(
|
|
601
|
+
image_means=self._trainer.datamodule.predict_dataset.image_means,
|
|
602
|
+
image_stds=self._trainer.datamodule.predict_dataset.image_stds,
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
denormalized_output = denorm(patch=mmse.cpu().numpy())
|
|
606
|
+
|
|
607
|
+
if len(aux) > 0: # aux can be tiling information
|
|
608
|
+
return denormalized_output, std, *aux
|
|
609
|
+
else:
|
|
610
|
+
return denormalized_output, std
|
|
512
611
|
|
|
513
612
|
def configure_optimizers(self) -> Any:
|
|
514
613
|
"""Configure optimizers and learning rate schedulers.
|
|
@@ -537,19 +636,19 @@ class VAEModule(L.LightningModule):
|
|
|
537
636
|
# should we refactor LadderVAE so that it already outputs
|
|
538
637
|
# tuple(`mean`, `logvar`, `td_data`)?
|
|
539
638
|
def get_reconstructed_tensor(
|
|
540
|
-
self, model_outputs: tuple[Tensor, dict[str, Any]]
|
|
541
|
-
) -> Tensor:
|
|
639
|
+
self, model_outputs: tuple[torch.Tensor, dict[str, Any]]
|
|
640
|
+
) -> torch.Tensor:
|
|
542
641
|
"""Get the reconstructed tensor from the LVAE model outputs.
|
|
543
642
|
|
|
544
643
|
Parameters
|
|
545
644
|
----------
|
|
546
|
-
model_outputs : tuple[Tensor, dict[str, Any]]
|
|
645
|
+
model_outputs : tuple[torch.Tensor, dict[str, Any]]
|
|
547
646
|
Model outputs. It is a tuple with a tensor representing the predicted mean
|
|
548
647
|
and (optionally) logvar, and the top-down data dictionary.
|
|
549
648
|
|
|
550
649
|
Returns
|
|
551
650
|
-------
|
|
552
|
-
Tensor
|
|
651
|
+
torch.Tensor
|
|
553
652
|
Reconstructed tensor, i.e., the predicted mean.
|
|
554
653
|
"""
|
|
555
654
|
predictions, _ = model_outputs
|
|
@@ -560,18 +659,18 @@ class VAEModule(L.LightningModule):
|
|
|
560
659
|
|
|
561
660
|
def compute_val_psnr(
|
|
562
661
|
self,
|
|
563
|
-
model_output: tuple[Tensor, dict[str, Any]],
|
|
564
|
-
target: Tensor,
|
|
662
|
+
model_output: tuple[torch.Tensor, dict[str, Any]],
|
|
663
|
+
target: torch.Tensor,
|
|
565
664
|
psnr_func: Callable = scale_invariant_psnr,
|
|
566
665
|
) -> list[float]:
|
|
567
666
|
"""Compute the PSNR for the current validation batch.
|
|
568
667
|
|
|
569
668
|
Parameters
|
|
570
669
|
----------
|
|
571
|
-
model_output : tuple[Tensor, dict[str, Any]]
|
|
670
|
+
model_output : tuple[torch.Tensor, dict[str, Any]]
|
|
572
671
|
Model output, a tuple with the predicted mean and (optionally) logvar,
|
|
573
672
|
and the top-down data dictionary.
|
|
574
|
-
target : Tensor
|
|
673
|
+
target : torch.Tensor
|
|
575
674
|
Target tensor.
|
|
576
675
|
psnr_func : Callable, optional
|
|
577
676
|
PSNR function to use, by default `scale_invariant_psnr`.
|
|
@@ -581,6 +680,7 @@ class VAEModule(L.LightningModule):
|
|
|
581
680
|
list[float]
|
|
582
681
|
PSNR for each channel in the current batch.
|
|
583
682
|
"""
|
|
683
|
+
# TODO check this! Related to is_supervised which is also wacky
|
|
584
684
|
out_channels = target.shape[1]
|
|
585
685
|
|
|
586
686
|
# get the reconstructed image
|