careamics 0.0.15__py3-none-any.whl → 0.0.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +11 -14
- careamics/cli/conf.py +18 -3
- careamics/config/__init__.py +8 -0
- careamics/config/algorithms/__init__.py +4 -0
- careamics/config/algorithms/hdn_algorithm_model.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_model.py +103 -0
- careamics/config/algorithms/n2v_algorithm_model.py +1 -2
- careamics/config/algorithms/vae_algorithm_model.py +51 -16
- careamics/config/architectures/lvae_model.py +12 -8
- careamics/config/callback_model.py +7 -3
- careamics/config/configuration.py +15 -63
- careamics/config/configuration_factories.py +853 -29
- careamics/config/data/data_model.py +50 -11
- careamics/config/data/ng_data_model.py +168 -4
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_model.py +16 -0
- careamics/config/data/patch_filter/mask_filter_model.py +17 -0
- careamics/config/data/patch_filter/max_filter_model.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_model.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_model.py +15 -0
- careamics/config/inference_model.py +1 -2
- careamics/config/likelihood_model.py +2 -2
- careamics/config/loss_model.py +6 -2
- careamics/config/nm_model.py +26 -1
- careamics/config/optimizer_models.py +1 -2
- careamics/config/support/supported_algorithms.py +5 -3
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_losses.py +5 -2
- careamics/config/training_model.py +6 -36
- careamics/config/transformations/normalize_model.py +1 -2
- careamics/dataset_ng/dataset.py +57 -5
- careamics/dataset_ng/factory.py +101 -18
- careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +4 -4
- careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -2
- careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +33 -7
- careamics/dataset_ng/patch_extractor/image_stack_loader.py +2 -2
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +94 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +95 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/file_io/read/__init__.py +0 -1
- careamics/lightning/__init__.py +16 -2
- careamics/lightning/callbacks/__init__.py +2 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/dataset_ng/data_module.py +79 -2
- careamics/lightning/lightning_module.py +162 -61
- careamics/lightning/microsplit_data_module.py +636 -0
- careamics/lightning/predict_data_module.py +8 -1
- careamics/lightning/train_data_module.py +19 -8
- careamics/losses/__init__.py +7 -1
- careamics/losses/loss_factory.py +9 -1
- careamics/losses/lvae/losses.py +85 -0
- careamics/lvae_training/dataset/__init__.py +8 -8
- careamics/lvae_training/dataset/config.py +56 -44
- careamics/lvae_training/dataset/lc_dataset.py +18 -12
- careamics/lvae_training/dataset/ms_dataset_ref.py +5 -5
- careamics/lvae_training/dataset/multich_dataset.py +24 -18
- careamics/lvae_training/dataset/multifile_dataset.py +6 -6
- careamics/lvae_training/eval_utils.py +46 -24
- careamics/model_io/bmz_io.py +9 -5
- careamics/models/lvae/likelihoods.py +31 -14
- careamics/models/lvae/lvae.py +2 -2
- careamics/models/lvae/noise_models.py +20 -14
- careamics/prediction_utils/__init__.py +8 -2
- careamics/prediction_utils/prediction_outputs.py +49 -3
- careamics/prediction_utils/stitch_prediction.py +83 -1
- careamics/transforms/xy_random_rotate90.py +1 -1
- careamics/utils/version.py +4 -4
- {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/METADATA +19 -22
- {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/RECORD +77 -60
- careamics/dataset/zarr_dataset.py +0 -151
- careamics/file_io/read/zarr.py +0 -60
- {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/WHEEL +0 -0
- {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/licenses/LICENSE +0 -0
|
@@ -5,7 +5,7 @@ from typing import Any, Literal, Union
|
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
7
|
import pytorch_lightning as L
|
|
8
|
-
|
|
8
|
+
import torch
|
|
9
9
|
|
|
10
10
|
from careamics.config import (
|
|
11
11
|
N2VAlgorithm,
|
|
@@ -71,7 +71,9 @@ class FCNModule(L.LightningModule):
|
|
|
71
71
|
Learning rate scheduler name.
|
|
72
72
|
"""
|
|
73
73
|
|
|
74
|
-
def __init__(
|
|
74
|
+
def __init__(
|
|
75
|
+
self, algorithm_config: Union[UNetBasedAlgorithm, VAEBasedAlgorithm, dict]
|
|
76
|
+
) -> None:
|
|
75
77
|
"""Lightning module for CAREamics.
|
|
76
78
|
|
|
77
79
|
This class encapsulates the a PyTorch model along with the training, validation,
|
|
@@ -98,7 +100,7 @@ class FCNModule(L.LightningModule):
|
|
|
98
100
|
self.n2v_preprocess = None
|
|
99
101
|
|
|
100
102
|
self.algorithm = algorithm_config.algorithm
|
|
101
|
-
self.model: nn.Module = model_factory(algorithm_config.model)
|
|
103
|
+
self.model: torch.nn.Module = model_factory(algorithm_config.model)
|
|
102
104
|
self.loss_func = loss_factory(algorithm_config.loss)
|
|
103
105
|
|
|
104
106
|
# save optimizer and lr_scheduler names and parameters
|
|
@@ -122,12 +124,12 @@ class FCNModule(L.LightningModule):
|
|
|
122
124
|
"""
|
|
123
125
|
return self.model(x)
|
|
124
126
|
|
|
125
|
-
def training_step(self, batch: Tensor, batch_idx: Any) -> Any:
|
|
127
|
+
def training_step(self, batch: torch.Tensor, batch_idx: Any) -> Any:
|
|
126
128
|
"""Training step.
|
|
127
129
|
|
|
128
130
|
Parameters
|
|
129
131
|
----------
|
|
130
|
-
batch : torch.Tensor
|
|
132
|
+
batch : torch.torch.Tensor
|
|
131
133
|
Input batch.
|
|
132
134
|
batch_idx : Any
|
|
133
135
|
Batch index.
|
|
@@ -154,12 +156,12 @@ class FCNModule(L.LightningModule):
|
|
|
154
156
|
self.log("learning_rate", current_lr, on_step=False, on_epoch=True, logger=True)
|
|
155
157
|
return loss
|
|
156
158
|
|
|
157
|
-
def validation_step(self, batch: Tensor, batch_idx: Any) -> None:
|
|
159
|
+
def validation_step(self, batch: torch.Tensor, batch_idx: Any) -> None:
|
|
158
160
|
"""Validation step.
|
|
159
161
|
|
|
160
162
|
Parameters
|
|
161
163
|
----------
|
|
162
|
-
batch : torch.Tensor
|
|
164
|
+
batch : torch.torch.Tensor
|
|
163
165
|
Input batch.
|
|
164
166
|
batch_idx : Any
|
|
165
167
|
Batch index.
|
|
@@ -184,12 +186,12 @@ class FCNModule(L.LightningModule):
|
|
|
184
186
|
logger=True,
|
|
185
187
|
)
|
|
186
188
|
|
|
187
|
-
def predict_step(self, batch: Tensor, batch_idx: Any) -> Any:
|
|
189
|
+
def predict_step(self, batch: torch.Tensor, batch_idx: Any) -> Any:
|
|
188
190
|
"""Prediction step.
|
|
189
191
|
|
|
190
192
|
Parameters
|
|
191
193
|
----------
|
|
192
|
-
batch : torch.Tensor
|
|
194
|
+
batch : torch.torch.torch.Tensor
|
|
193
195
|
Input batch.
|
|
194
196
|
batch_idx : Any
|
|
195
197
|
Batch index.
|
|
@@ -330,17 +332,21 @@ class VAEModule(L.LightningModule):
|
|
|
330
332
|
# self.save_hyperparameters(self.algorithm_config.model_dump())
|
|
331
333
|
|
|
332
334
|
# create model
|
|
333
|
-
self.model: nn.Module = model_factory(self.algorithm_config.model)
|
|
335
|
+
self.model: torch.nn.Module = model_factory(self.algorithm_config.model)
|
|
334
336
|
|
|
337
|
+
# supervised_mode
|
|
338
|
+
self.supervised_mode = self.algorithm_config.is_supervised
|
|
335
339
|
# create loss function
|
|
336
340
|
self.noise_model: NoiseModel | None = noise_model_factory(
|
|
337
341
|
self.algorithm_config.noise_model
|
|
338
342
|
)
|
|
339
343
|
|
|
340
|
-
self.noise_model_likelihood: NoiseModelLikelihood | None =
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
+
self.noise_model_likelihood: NoiseModelLikelihood | None = None
|
|
345
|
+
if self.algorithm_config.noise_model_likelihood is not None:
|
|
346
|
+
self.noise_model_likelihood = likelihood_factory(
|
|
347
|
+
config=self.algorithm_config.noise_model_likelihood,
|
|
348
|
+
noise_model=self.noise_model,
|
|
349
|
+
)
|
|
344
350
|
|
|
345
351
|
self.gaussian_likelihood: GaussianLikelihood | None = likelihood_factory(
|
|
346
352
|
self.algorithm_config.gaussian_likelihood
|
|
@@ -360,30 +366,43 @@ class VAEModule(L.LightningModule):
|
|
|
360
366
|
RunningPSNR() for _ in range(self.algorithm_config.model.output_channels)
|
|
361
367
|
]
|
|
362
368
|
|
|
363
|
-
def forward(self, x: Tensor) -> tuple[Tensor, dict[str, Any]]:
|
|
369
|
+
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, dict[str, Any]]:
|
|
364
370
|
"""Forward pass.
|
|
365
371
|
|
|
366
372
|
Parameters
|
|
367
373
|
----------
|
|
368
|
-
x : Tensor
|
|
374
|
+
x : torch.Tensor
|
|
369
375
|
Input tensor of shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the
|
|
370
376
|
number of lateral inputs.
|
|
371
377
|
|
|
372
378
|
Returns
|
|
373
379
|
-------
|
|
374
|
-
tuple[Tensor, dict[str, Any]]
|
|
380
|
+
tuple[torch.Tensor, dict[str, Any]]
|
|
375
381
|
A tuple with the output tensor and additional data from the top-down pass.
|
|
376
382
|
"""
|
|
377
383
|
return self.model(x) # TODO Different model can have more than one output
|
|
378
384
|
|
|
385
|
+
def set_data_stats(self, data_mean, data_std):
|
|
386
|
+
"""Set data mean and std for the noise model likelihood.
|
|
387
|
+
|
|
388
|
+
Parameters
|
|
389
|
+
----------
|
|
390
|
+
data_mean : float
|
|
391
|
+
Mean of the data.
|
|
392
|
+
data_std : float
|
|
393
|
+
Standard deviation of the data.
|
|
394
|
+
"""
|
|
395
|
+
if self.noise_model_likelihood is not None:
|
|
396
|
+
self.noise_model_likelihood.set_data_stats(data_mean, data_std)
|
|
397
|
+
|
|
379
398
|
def training_step(
|
|
380
|
-
self, batch: tuple[Tensor, Tensor], batch_idx: Any
|
|
381
|
-
) -> dict[str, Tensor] | None:
|
|
399
|
+
self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: Any
|
|
400
|
+
) -> dict[str, torch.Tensor] | None:
|
|
382
401
|
"""Training step.
|
|
383
402
|
|
|
384
403
|
Parameters
|
|
385
404
|
----------
|
|
386
|
-
batch : tuple[Tensor, Tensor]
|
|
405
|
+
batch : tuple[torch.Tensor, torch.Tensor]
|
|
387
406
|
Input batch. It is a tuple with the input tensor and the target tensor.
|
|
388
407
|
The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the
|
|
389
408
|
number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X),
|
|
@@ -397,15 +416,30 @@ class VAEModule(L.LightningModule):
|
|
|
397
416
|
Any
|
|
398
417
|
Loss value.
|
|
399
418
|
"""
|
|
400
|
-
x, target = batch
|
|
419
|
+
x, *target = batch
|
|
401
420
|
|
|
402
421
|
# Forward pass
|
|
403
422
|
out = self.model(x)
|
|
423
|
+
if not self.supervised_mode:
|
|
424
|
+
target = x
|
|
425
|
+
else:
|
|
426
|
+
target = target[
|
|
427
|
+
0
|
|
428
|
+
] # hacky way to unpack. #TODO maybe should be fixed on the dataset level
|
|
404
429
|
|
|
405
430
|
# Update loss parameters
|
|
406
431
|
self.loss_parameters.kl_params.current_epoch = self.current_epoch
|
|
407
432
|
|
|
408
433
|
# Compute loss
|
|
434
|
+
if self.noise_model_likelihood is not None:
|
|
435
|
+
if (
|
|
436
|
+
self.noise_model_likelihood.data_mean is None
|
|
437
|
+
or self.noise_model_likelihood.data_std is None
|
|
438
|
+
):
|
|
439
|
+
raise RuntimeError(
|
|
440
|
+
"NoiseModelLikelihood: data_mean and data_std must be set before"
|
|
441
|
+
"training."
|
|
442
|
+
)
|
|
409
443
|
loss = self.loss_func(
|
|
410
444
|
model_outputs=out,
|
|
411
445
|
targets=target,
|
|
@@ -417,15 +451,26 @@ class VAEModule(L.LightningModule):
|
|
|
417
451
|
# Logging
|
|
418
452
|
# TODO: implement a separate logging method?
|
|
419
453
|
self.log_dict(loss, on_step=True, on_epoch=True)
|
|
420
|
-
|
|
454
|
+
|
|
455
|
+
try:
|
|
456
|
+
optimizer = self.optimizers()
|
|
457
|
+
current_lr = optimizer.param_groups[0]["lr"]
|
|
458
|
+
self.log(
|
|
459
|
+
"learning_rate", current_lr, on_step=False, on_epoch=True, logger=True
|
|
460
|
+
)
|
|
461
|
+
except RuntimeError:
|
|
462
|
+
# This happens when the module is not attached to a trainer, e.g., in tests
|
|
463
|
+
pass
|
|
421
464
|
return loss
|
|
422
465
|
|
|
423
|
-
def validation_step(
|
|
466
|
+
def validation_step(
|
|
467
|
+
self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: Any
|
|
468
|
+
) -> None:
|
|
424
469
|
"""Validation step.
|
|
425
470
|
|
|
426
471
|
Parameters
|
|
427
472
|
----------
|
|
428
|
-
batch : tuple[Tensor, Tensor]
|
|
473
|
+
batch : tuple[torch.Tensor, torch.Tensor]
|
|
429
474
|
Input batch. It is a tuple with the input tensor and the target tensor.
|
|
430
475
|
The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the
|
|
431
476
|
number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X),
|
|
@@ -434,11 +479,16 @@ class VAEModule(L.LightningModule):
|
|
|
434
479
|
batch_idx : Any
|
|
435
480
|
Batch index.
|
|
436
481
|
"""
|
|
437
|
-
x, target = batch
|
|
482
|
+
x, *target = batch
|
|
438
483
|
|
|
439
484
|
# Forward pass
|
|
440
485
|
out = self.model(x)
|
|
441
|
-
|
|
486
|
+
if not self.supervised_mode:
|
|
487
|
+
target = x
|
|
488
|
+
else:
|
|
489
|
+
target = target[
|
|
490
|
+
0
|
|
491
|
+
] # hacky way to unpack. #TODO maybe should be fixed on the datasel level
|
|
442
492
|
# Compute loss
|
|
443
493
|
loss = self.loss_func(
|
|
444
494
|
model_outputs=out,
|
|
@@ -464,12 +514,12 @@ class VAEModule(L.LightningModule):
|
|
|
464
514
|
else:
|
|
465
515
|
self.log("val_psnr", 0.0, on_epoch=True, prog_bar=True)
|
|
466
516
|
|
|
467
|
-
def predict_step(self, batch: Tensor, batch_idx: Any) -> Any:
|
|
517
|
+
def predict_step(self, batch: torch.Tensor, batch_idx: Any) -> Any:
|
|
468
518
|
"""Prediction step.
|
|
469
519
|
|
|
470
520
|
Parameters
|
|
471
521
|
----------
|
|
472
|
-
batch : Tensor
|
|
522
|
+
batch : torch.Tensor
|
|
473
523
|
Input batch.
|
|
474
524
|
batch_idx : Any
|
|
475
525
|
Batch index.
|
|
@@ -479,36 +529,86 @@ class VAEModule(L.LightningModule):
|
|
|
479
529
|
Any
|
|
480
530
|
Model output.
|
|
481
531
|
"""
|
|
482
|
-
if self.
|
|
532
|
+
if self.algorithm_config.algorithm == "microsplit":
|
|
483
533
|
x, *aux = batch
|
|
484
|
-
|
|
485
|
-
x
|
|
486
|
-
aux = []
|
|
534
|
+
# Reset model for inference with spatial dimensions only (H, W)
|
|
535
|
+
self.model.reset_for_inference(x.shape[-2:])
|
|
487
536
|
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
augmented_batch = tta.forward(x) # list of augmented tensors
|
|
493
|
-
augmented_output = []
|
|
494
|
-
for augmented in augmented_batch:
|
|
495
|
-
augmented_pred = self.model(augmented)
|
|
496
|
-
augmented_output.append(augmented_pred)
|
|
497
|
-
output = tta.backward(augmented_output)
|
|
498
|
-
else:
|
|
499
|
-
output = self.model(x)
|
|
537
|
+
rec_img_list = []
|
|
538
|
+
for _ in range(self.algorithm_config.mmse_count):
|
|
539
|
+
# get model output
|
|
540
|
+
rec, _ = self.model(x)
|
|
500
541
|
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
542
|
+
# get reconstructed img
|
|
543
|
+
if self.model.predict_logvar is None:
|
|
544
|
+
rec_img = rec
|
|
545
|
+
_logvar = torch.tensor([-1])
|
|
546
|
+
else:
|
|
547
|
+
rec_img, _logvar = torch.chunk(rec, chunks=2, dim=1)
|
|
548
|
+
rec_img_list.append(rec_img.cpu().unsqueeze(0)) # add MMSE dim
|
|
549
|
+
|
|
550
|
+
# aggregate results
|
|
551
|
+
samples = torch.cat(rec_img_list, dim=0)
|
|
552
|
+
mmse_imgs = torch.mean(samples, dim=0) # avg over MMSE dim
|
|
553
|
+
std_imgs = torch.std(samples, dim=0) # std over MMSE dim
|
|
554
|
+
|
|
555
|
+
tile_prediction = mmse_imgs.cpu().numpy()
|
|
556
|
+
tile_std = std_imgs.cpu().numpy()
|
|
557
|
+
|
|
558
|
+
return tile_prediction, tile_std
|
|
507
559
|
|
|
508
|
-
if len(aux) > 0: # aux can be tiling information
|
|
509
|
-
return denormalized_output, *aux
|
|
510
560
|
else:
|
|
511
|
-
|
|
561
|
+
# Regular prediction logic
|
|
562
|
+
if self._trainer.datamodule.tiled:
|
|
563
|
+
# TODO tile_size should match model input size
|
|
564
|
+
x, *aux = batch
|
|
565
|
+
x = (
|
|
566
|
+
x[0] if isinstance(x, list | tuple) else x
|
|
567
|
+
) # TODO ugly, so far i don't know why x might be a list
|
|
568
|
+
self.model.reset_for_inference(x.shape) # TODO should it be here ?
|
|
569
|
+
else:
|
|
570
|
+
x = batch[0] if isinstance(batch, list | tuple) else batch
|
|
571
|
+
aux = []
|
|
572
|
+
self.model.reset_for_inference(x.shape)
|
|
573
|
+
|
|
574
|
+
mmse_list = []
|
|
575
|
+
for _ in range(self.algorithm_config.mmse_count):
|
|
576
|
+
# apply test-time augmentation if available
|
|
577
|
+
if self._trainer.datamodule.prediction_config.tta_transforms:
|
|
578
|
+
tta = ImageRestorationTTA()
|
|
579
|
+
augmented_batch = tta.forward(x) # list of augmented tensors
|
|
580
|
+
augmented_output = []
|
|
581
|
+
for augmented in augmented_batch:
|
|
582
|
+
augmented_pred = self.model(augmented)
|
|
583
|
+
augmented_output.append(augmented_pred)
|
|
584
|
+
output = tta.backward(augmented_output)
|
|
585
|
+
else:
|
|
586
|
+
output = self.model(x)
|
|
587
|
+
|
|
588
|
+
# taking the 1st element of the output, 2nd is std if
|
|
589
|
+
# predict_logvar=="pixelwise"
|
|
590
|
+
output = (
|
|
591
|
+
output[0]
|
|
592
|
+
if self.model.predict_logvar is None
|
|
593
|
+
else output[0][:, 0:1, ...]
|
|
594
|
+
)
|
|
595
|
+
mmse_list.append(output)
|
|
596
|
+
|
|
597
|
+
mmse = torch.stack(mmse_list).mean(0)
|
|
598
|
+
std = torch.stack(mmse_list).std(0) # TODO why?
|
|
599
|
+
# TODO better way to unpack if pred logvar
|
|
600
|
+
# Denormalize the output
|
|
601
|
+
denorm = Denormalize(
|
|
602
|
+
image_means=self._trainer.datamodule.predict_dataset.image_means,
|
|
603
|
+
image_stds=self._trainer.datamodule.predict_dataset.image_stds,
|
|
604
|
+
)
|
|
605
|
+
|
|
606
|
+
denormalized_output = denorm(patch=mmse.cpu().numpy())
|
|
607
|
+
|
|
608
|
+
if len(aux) > 0: # aux can be tiling information
|
|
609
|
+
return denormalized_output, std, *aux
|
|
610
|
+
else:
|
|
611
|
+
return denormalized_output, std
|
|
512
612
|
|
|
513
613
|
def configure_optimizers(self) -> Any:
|
|
514
614
|
"""Configure optimizers and learning rate schedulers.
|
|
@@ -537,19 +637,19 @@ class VAEModule(L.LightningModule):
|
|
|
537
637
|
# should we refactor LadderVAE so that it already outputs
|
|
538
638
|
# tuple(`mean`, `logvar`, `td_data`)?
|
|
539
639
|
def get_reconstructed_tensor(
|
|
540
|
-
self, model_outputs: tuple[Tensor, dict[str, Any]]
|
|
541
|
-
) -> Tensor:
|
|
640
|
+
self, model_outputs: tuple[torch.Tensor, dict[str, Any]]
|
|
641
|
+
) -> torch.Tensor:
|
|
542
642
|
"""Get the reconstructed tensor from the LVAE model outputs.
|
|
543
643
|
|
|
544
644
|
Parameters
|
|
545
645
|
----------
|
|
546
|
-
model_outputs : tuple[Tensor, dict[str, Any]]
|
|
646
|
+
model_outputs : tuple[torch.Tensor, dict[str, Any]]
|
|
547
647
|
Model outputs. It is a tuple with a tensor representing the predicted mean
|
|
548
648
|
and (optionally) logvar, and the top-down data dictionary.
|
|
549
649
|
|
|
550
650
|
Returns
|
|
551
651
|
-------
|
|
552
|
-
Tensor
|
|
652
|
+
torch.Tensor
|
|
553
653
|
Reconstructed tensor, i.e., the predicted mean.
|
|
554
654
|
"""
|
|
555
655
|
predictions, _ = model_outputs
|
|
@@ -560,18 +660,18 @@ class VAEModule(L.LightningModule):
|
|
|
560
660
|
|
|
561
661
|
def compute_val_psnr(
|
|
562
662
|
self,
|
|
563
|
-
model_output: tuple[Tensor, dict[str, Any]],
|
|
564
|
-
target: Tensor,
|
|
663
|
+
model_output: tuple[torch.Tensor, dict[str, Any]],
|
|
664
|
+
target: torch.Tensor,
|
|
565
665
|
psnr_func: Callable = scale_invariant_psnr,
|
|
566
666
|
) -> list[float]:
|
|
567
667
|
"""Compute the PSNR for the current validation batch.
|
|
568
668
|
|
|
569
669
|
Parameters
|
|
570
670
|
----------
|
|
571
|
-
model_output : tuple[Tensor, dict[str, Any]]
|
|
671
|
+
model_output : tuple[torch.Tensor, dict[str, Any]]
|
|
572
672
|
Model output, a tuple with the predicted mean and (optionally) logvar,
|
|
573
673
|
and the top-down data dictionary.
|
|
574
|
-
target : Tensor
|
|
674
|
+
target : torch.Tensor
|
|
575
675
|
Target tensor.
|
|
576
676
|
psnr_func : Callable, optional
|
|
577
677
|
PSNR function to use, by default `scale_invariant_psnr`.
|
|
@@ -581,6 +681,7 @@ class VAEModule(L.LightningModule):
|
|
|
581
681
|
list[float]
|
|
582
682
|
PSNR for each channel in the current batch.
|
|
583
683
|
"""
|
|
684
|
+
# TODO check this! Related to is_supervised which is also wacky
|
|
584
685
|
out_channels = target.shape[1]
|
|
585
686
|
|
|
586
687
|
# get the reconstructed image
|