careamics 0.0.4.2__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +17 -2
- careamics/careamist.py +239 -28
- careamics/cli/conf.py +19 -31
- careamics/cli/main.py +112 -12
- careamics/cli/utils.py +29 -0
- careamics/config/__init__.py +48 -24
- careamics/config/algorithms/__init__.py +15 -0
- careamics/config/algorithms/care_algorithm_model.py +50 -0
- careamics/config/algorithms/n2n_algorithm_model.py +42 -0
- careamics/config/algorithms/n2v_algorithm_model.py +35 -0
- careamics/config/algorithms/unet_algorithm_model.py +88 -0
- careamics/config/{vae_algorithm_model.py → algorithms/vae_algorithm_model.py} +26 -23
- careamics/config/architectures/__init__.py +1 -11
- careamics/config/architectures/architecture_model.py +3 -3
- careamics/config/architectures/lvae_model.py +109 -21
- careamics/config/architectures/unet_model.py +1 -0
- careamics/config/care_configuration.py +100 -0
- careamics/config/configuration.py +354 -0
- careamics/config/{configuration_factory.py → configuration_factories.py} +152 -81
- careamics/config/configuration_io.py +85 -0
- careamics/config/data/__init__.py +10 -0
- careamics/config/{data_model.py → data/data_model.py} +58 -198
- careamics/config/data/n2v_data_model.py +193 -0
- careamics/config/likelihood_model.py +8 -8
- careamics/config/loss_model.py +56 -0
- careamics/config/n2n_configuration.py +101 -0
- careamics/config/n2v_configuration.py +266 -0
- careamics/config/nm_model.py +24 -25
- careamics/config/support/__init__.py +7 -7
- careamics/config/support/supported_algorithms.py +0 -3
- careamics/config/support/supported_architectures.py +0 -4
- careamics/config/transformations/__init__.py +10 -4
- careamics/config/transformations/transform_model.py +3 -3
- careamics/config/transformations/transform_unions.py +42 -0
- careamics/config/validators/validator_utils.py +3 -3
- careamics/dataset/__init__.py +2 -2
- careamics/dataset/dataset_utils/__init__.py +3 -3
- careamics/dataset/dataset_utils/dataset_utils.py +4 -6
- careamics/dataset/dataset_utils/file_utils.py +9 -9
- careamics/dataset/dataset_utils/iterate_over_files.py +4 -3
- careamics/dataset/dataset_utils/running_stats.py +22 -23
- careamics/dataset/in_memory_dataset.py +11 -12
- careamics/dataset/iterable_dataset.py +4 -4
- careamics/dataset/iterable_pred_dataset.py +2 -1
- careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
- careamics/dataset/patching/random_patching.py +11 -10
- careamics/dataset/patching/sequential_patching.py +26 -26
- careamics/dataset/patching/validate_patch_dimension.py +3 -3
- careamics/dataset/tiling/__init__.py +2 -2
- careamics/dataset/tiling/collate_tiles.py +3 -3
- careamics/dataset/tiling/lvae_tiled_patching.py +2 -1
- careamics/dataset/tiling/tiled_patching.py +11 -10
- careamics/file_io/__init__.py +5 -5
- careamics/file_io/read/__init__.py +1 -1
- careamics/file_io/read/get_func.py +2 -2
- careamics/file_io/write/__init__.py +2 -2
- careamics/lightning/__init__.py +5 -5
- careamics/lightning/callbacks/__init__.py +1 -1
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +3 -3
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +2 -1
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +2 -1
- careamics/lightning/callbacks/progress_bar_callback.py +2 -2
- careamics/lightning/lightning_module.py +69 -34
- careamics/lightning/train_data_module.py +41 -27
- careamics/losses/__init__.py +3 -3
- careamics/losses/loss_factory.py +1 -85
- careamics/losses/lvae/losses.py +223 -164
- careamics/lvae_training/calibration.py +184 -0
- careamics/lvae_training/dataset/config.py +2 -2
- careamics/lvae_training/dataset/multich_dataset.py +11 -19
- careamics/lvae_training/dataset/multifile_dataset.py +3 -2
- careamics/lvae_training/dataset/types.py +15 -26
- careamics/lvae_training/dataset/utils/index_manager.py +4 -4
- careamics/lvae_training/eval_utils.py +125 -213
- careamics/model_io/__init__.py +1 -1
- careamics/model_io/bioimage/__init__.py +1 -1
- careamics/model_io/bioimage/_readme_factory.py +26 -34
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +56 -34
- careamics/model_io/bmz_io.py +42 -42
- careamics/model_io/model_io_utils.py +9 -9
- careamics/models/layers.py +22 -20
- careamics/models/lvae/layers.py +348 -975
- careamics/models/lvae/likelihoods.py +10 -8
- careamics/models/lvae/lvae.py +214 -275
- careamics/models/lvae/noise_models.py +179 -112
- careamics/models/lvae/stochastic.py +393 -0
- careamics/models/lvae/utils.py +82 -73
- careamics/models/model_factory.py +2 -15
- careamics/models/unet.py +8 -8
- careamics/prediction_utils/__init__.py +1 -1
- careamics/prediction_utils/prediction_outputs.py +15 -15
- careamics/prediction_utils/stitch_prediction.py +6 -6
- careamics/transforms/__init__.py +5 -5
- careamics/transforms/compose.py +13 -13
- careamics/transforms/n2v_manipulate.py +3 -3
- careamics/transforms/pixel_manipulation.py +9 -9
- careamics/transforms/xy_random_rotate90.py +4 -4
- careamics/utils/__init__.py +5 -5
- careamics/utils/context.py +2 -1
- careamics/utils/lightning_utils.py +57 -0
- careamics/utils/logging.py +11 -10
- careamics/utils/serializers.py +2 -0
- careamics/utils/torch_utils.py +8 -8
- {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/METADATA +16 -13
- careamics-0.0.6.dist-info/RECORD +176 -0
- {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/WHEEL +1 -1
- careamics/config/architectures/custom_model.py +0 -162
- careamics/config/architectures/register_model.py +0 -103
- careamics/config/configuration_model.py +0 -603
- careamics/config/fcn_algorithm_model.py +0 -152
- careamics/config/references/__init__.py +0 -45
- careamics/config/references/algorithm_descriptions.py +0 -132
- careamics/config/references/references.py +0 -39
- careamics/config/transformations/transform_union.py +0 -20
- careamics-0.0.4.2.dist-info/RECORD +0 -165
- {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/licenses/LICENSE +0 -0
|
@@ -6,13 +6,13 @@ It includes functions to:
|
|
|
6
6
|
- create plots to visualize the results.
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
|
-
import math
|
|
10
9
|
import os
|
|
11
|
-
from typing import
|
|
10
|
+
from typing import List, Literal, Union
|
|
12
11
|
|
|
13
12
|
import matplotlib
|
|
14
13
|
import matplotlib.pyplot as plt
|
|
15
14
|
import numpy as np
|
|
15
|
+
from scipy import stats
|
|
16
16
|
import torch
|
|
17
17
|
from torch import nn
|
|
18
18
|
from torch.utils.data import Dataset
|
|
@@ -21,14 +21,21 @@ from torch.utils.data import DataLoader
|
|
|
21
21
|
from tqdm import tqdm
|
|
22
22
|
|
|
23
23
|
from careamics.lightning import VAEModule
|
|
24
|
-
|
|
25
|
-
get_reconstruction_loss,
|
|
26
|
-
reconstruction_loss_musplit_denoisplit,
|
|
27
|
-
)
|
|
24
|
+
|
|
28
25
|
from careamics.models.lvae.utils import ModelType
|
|
29
26
|
from careamics.utils.metrics import scale_invariant_psnr, RunningPSNR
|
|
30
27
|
|
|
31
28
|
|
|
29
|
+
class TilingMode:
|
|
30
|
+
"""
|
|
31
|
+
Enum for the tiling mode.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
TrimBoundary = 0
|
|
35
|
+
PadBoundary = 1
|
|
36
|
+
ShiftBoundary = 2
|
|
37
|
+
|
|
38
|
+
|
|
32
39
|
# ------------------------------------------------------------------------------------------------
|
|
33
40
|
# Function of plotting: TODO -> moved them to another file, plot_utils.py
|
|
34
41
|
def clean_ax(ax):
|
|
@@ -596,51 +603,49 @@ def get_dset_predictions(
|
|
|
596
603
|
logvar_arr.append(logvar.cpu().numpy())
|
|
597
604
|
|
|
598
605
|
# compute reconstruction loss
|
|
599
|
-
if loss_type == "musplit":
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
elif loss_type == "denoisplit":
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
elif loss_type == "denoisplit_musplit":
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
# store rec loss values for first pred
|
|
619
|
-
if mmse_idx == 0:
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
606
|
+
# if loss_type == "musplit":
|
|
607
|
+
# rec_loss = get_reconstruction_loss(
|
|
608
|
+
# reconstruction=rec, target=tar, likelihood_obj=gauss_likelihood
|
|
609
|
+
# )
|
|
610
|
+
# elif loss_type == "denoisplit":
|
|
611
|
+
# rec_loss = get_reconstruction_loss(
|
|
612
|
+
# reconstruction=rec, target=tar, likelihood_obj=nm_likelihood
|
|
613
|
+
# )
|
|
614
|
+
# elif loss_type == "denoisplit_musplit":
|
|
615
|
+
# rec_loss = reconstruction_loss_musplit_denoisplit(
|
|
616
|
+
# predictions=rec,
|
|
617
|
+
# targets=tar,
|
|
618
|
+
# gaussian_likelihood=gauss_likelihood,
|
|
619
|
+
# nm_likelihood=nm_likelihood,
|
|
620
|
+
# nm_weight=model.loss_parameters.denoisplit_weight,
|
|
621
|
+
# gaussian_weight=model.loss_parameters.musplit_weight,
|
|
622
|
+
# )
|
|
623
|
+
# rec_loss = {"loss": rec_loss} # hacky, but ok for now
|
|
624
|
+
|
|
625
|
+
# # store rec loss values for first pred
|
|
626
|
+
# if mmse_idx == 0:
|
|
627
|
+
# try:
|
|
628
|
+
# losses.append(rec_loss["loss"].cpu().numpy())
|
|
629
|
+
# except:
|
|
630
|
+
# losses.append(rec_loss["loss"])
|
|
624
631
|
|
|
625
632
|
# update running PSNR
|
|
626
|
-
for i in range(num_channels):
|
|
627
|
-
|
|
633
|
+
# for i in range(num_channels):
|
|
634
|
+
# patch_psnr_channels[i].update(rec_img[:, i], tar[:, i])
|
|
628
635
|
|
|
629
636
|
# aggregate results
|
|
630
637
|
samples = torch.cat(rec_img_list, dim=0)
|
|
631
638
|
mmse_imgs = torch.mean(samples, dim=0) # avg over MMSE dim
|
|
632
|
-
mmse_std = torch.std(samples, dim=0)
|
|
639
|
+
# mmse_std = torch.std(samples, dim=0)
|
|
633
640
|
predictions.append(mmse_imgs.cpu().numpy())
|
|
634
|
-
predictions_std.append(mmse_std.cpu().numpy())
|
|
635
|
-
|
|
636
|
-
psnr = [x.get() for x in patch_psnr_channels]
|
|
637
|
-
return (
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
psnr,
|
|
643
|
-
)
|
|
641
|
+
# predictions_std.append(mmse_std.cpu().numpy())
|
|
642
|
+
|
|
643
|
+
# psnr = [x.get() for x in patch_psnr_channels]
|
|
644
|
+
return np.concatenate(predictions, axis=0)
|
|
645
|
+
# np.concatenate(predictions_std, axis=0),
|
|
646
|
+
# np.concatenate(logvar_arr),
|
|
647
|
+
# np.array(losses),
|
|
648
|
+
# psnr, # TODO revisit !
|
|
644
649
|
|
|
645
650
|
|
|
646
651
|
# ------------------------------------------------------------------------------------------
|
|
@@ -773,178 +778,85 @@ def stitch_predictions(predictions, dset, smoothening_pixelcount=0):
|
|
|
773
778
|
return output
|
|
774
779
|
|
|
775
780
|
|
|
776
|
-
#
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
# ------------------------------------------------------------------------------------------
|
|
780
|
-
### Classes and Functions used for Calibration
|
|
781
|
-
class Calibration:
|
|
782
|
-
|
|
783
|
-
def __init__(
|
|
784
|
-
self, num_bins: int = 15, mode: Literal["pixelwise", "patchwise"] = "pixelwise"
|
|
785
|
-
):
|
|
786
|
-
self._bins = num_bins
|
|
787
|
-
self._bin_boundaries = None
|
|
788
|
-
self._mode = mode
|
|
789
|
-
assert mode in ["pixelwise", "patchwise"]
|
|
790
|
-
self._boundary_mode = "uniform"
|
|
791
|
-
assert self._boundary_mode in ["quantile", "uniform"]
|
|
792
|
-
# self._bin_boundaries = {}
|
|
793
|
-
|
|
794
|
-
def logvar_to_std(self, logvar: np.ndarray) -> np.ndarray:
|
|
795
|
-
return np.exp(logvar / 2)
|
|
796
|
-
|
|
797
|
-
def compute_bin_boundaries(self, predict_logvar: np.ndarray) -> np.ndarray:
|
|
798
|
-
"""
|
|
799
|
-
Compute the bin boundaries for `num_bins` bins and the given logvar values.
|
|
800
|
-
"""
|
|
801
|
-
if self._boundary_mode == "quantile":
|
|
802
|
-
boundaries = np.quantile(
|
|
803
|
-
self.logvar_to_std(predict_logvar), np.linspace(0, 1, self._bins + 1)
|
|
804
|
-
)
|
|
805
|
-
return boundaries
|
|
806
|
-
else:
|
|
807
|
-
min_logvar = np.min(predict_logvar)
|
|
808
|
-
max_logvar = np.max(predict_logvar)
|
|
809
|
-
min_std = self.logvar_to_std(min_logvar)
|
|
810
|
-
max_std = self.logvar_to_std(max_logvar)
|
|
811
|
-
return np.linspace(min_std, max_std, self._bins + 1)
|
|
812
|
-
|
|
813
|
-
def compute_stats(
|
|
814
|
-
self, pred: np.ndarray, pred_logvar: np.ndarray, target: np.ndarray
|
|
815
|
-
) -> Dict[int, Dict[str, Union[np.ndarray, List]]]:
|
|
816
|
-
"""
|
|
817
|
-
It computes the bin-wise RMSE and RMV for each channel of the predicted image.
|
|
818
|
-
|
|
819
|
-
Recall that:
|
|
820
|
-
- RMSE = np.sqrt((pred - target)**2 / num_pixels)
|
|
821
|
-
- RMV = np.sqrt(np.mean(pred_std**2))
|
|
822
|
-
|
|
823
|
-
ALGORITHM
|
|
824
|
-
- For each channel:
|
|
825
|
-
- Given the bin boundaries, assign pixels of `std_ch` array to a specific bin index.
|
|
826
|
-
- For each bin index:
|
|
827
|
-
- Compute the RMSE, RMV, and number of pixels for that bin.
|
|
828
|
-
|
|
829
|
-
NOTE: each channel of the predicted image/logvar has its own stats.
|
|
830
|
-
|
|
831
|
-
Args:
|
|
832
|
-
pred: np.ndarray, shape (n, h, w, c)
|
|
833
|
-
pred_logvar: np.ndarray, shape (n, h, w, c)
|
|
834
|
-
target: np.ndarray, shape (n, h, w, c)
|
|
835
|
-
"""
|
|
836
|
-
self._bin_boundaries = {}
|
|
837
|
-
stats = {}
|
|
838
|
-
for ch_idx in range(pred.shape[-1]):
|
|
839
|
-
stats[ch_idx] = {
|
|
840
|
-
"bin_count": [],
|
|
841
|
-
"rmv": [],
|
|
842
|
-
"rmse": [],
|
|
843
|
-
"bin_boundaries": None,
|
|
844
|
-
"bin_matrix": [],
|
|
845
|
-
}
|
|
846
|
-
pred_ch = pred[..., ch_idx]
|
|
847
|
-
logvar_ch = pred_logvar[..., ch_idx]
|
|
848
|
-
std_ch = self.logvar_to_std(logvar_ch)
|
|
849
|
-
target_ch = target[..., ch_idx]
|
|
850
|
-
if self._mode == "pixelwise":
|
|
851
|
-
boundaries = self.compute_bin_boundaries(logvar_ch)
|
|
852
|
-
stats[ch_idx]["bin_boundaries"] = boundaries
|
|
853
|
-
bin_matrix = np.digitize(std_ch.reshape(-1), boundaries)
|
|
854
|
-
bin_matrix = bin_matrix.reshape(std_ch.shape)
|
|
855
|
-
stats[ch_idx]["bin_matrix"] = bin_matrix
|
|
856
|
-
error = (pred_ch - target_ch) ** 2
|
|
857
|
-
for bin_idx in range(self._bins):
|
|
858
|
-
bin_mask = bin_matrix == bin_idx
|
|
859
|
-
bin_error = error[bin_mask]
|
|
860
|
-
bin_size = np.sum(bin_mask)
|
|
861
|
-
bin_error = (
|
|
862
|
-
np.sqrt(np.sum(bin_error) / bin_size) if bin_size > 0 else None
|
|
863
|
-
) # RMSE
|
|
864
|
-
bin_var = np.sqrt(np.mean(std_ch[bin_mask] ** 2)) # RMV
|
|
865
|
-
stats[ch_idx]["rmse"].append(bin_error)
|
|
866
|
-
stats[ch_idx]["rmv"].append(bin_var)
|
|
867
|
-
stats[ch_idx]["bin_count"].append(bin_size)
|
|
868
|
-
else:
|
|
869
|
-
raise NotImplementedError("Patchwise mode is not implemented yet.")
|
|
870
|
-
return stats
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
def nll(x, mean, logvar):
|
|
781
|
+
# from disentangle.analysis.stitch_prediction import *
|
|
782
|
+
def stitch_predictions_new(predictions, dset):
|
|
874
783
|
"""
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
:param x: tensor of points, with shape (batch, channels, dim1, dim2)
|
|
879
|
-
:param mean: tensor with mean of distribution, shape
|
|
880
|
-
(batch, channels, dim1, dim2)
|
|
881
|
-
:param logvar: tensor with log-variance of distribution, shape has to be
|
|
882
|
-
either scalar or broadcastable
|
|
883
|
-
"""
|
|
884
|
-
var = torch.exp(logvar)
|
|
885
|
-
log_prob = -0.5 * (
|
|
886
|
-
((x - mean) ** 2) / var + logvar + torch.tensor(2 * math.pi).log()
|
|
887
|
-
)
|
|
888
|
-
nll = -log_prob
|
|
889
|
-
return nll
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
def get_calibrated_factor_for_stdev(
|
|
893
|
-
pred: Union[np.ndarray, torch.Tensor],
|
|
894
|
-
pred_logvar: Union[np.ndarray, torch.Tensor],
|
|
895
|
-
target: Union[np.ndarray, torch.Tensor],
|
|
896
|
-
batch_size: int = 32,
|
|
897
|
-
epochs: int = 500,
|
|
898
|
-
lr: float = 0.01,
|
|
899
|
-
):
|
|
900
|
-
"""
|
|
901
|
-
Here, we calibrate the uncertainty by multiplying the predicted std (mmse estimate or predicted logvar) with a scalar.
|
|
902
|
-
We return the calibrated scalar. This needs to be multiplied with the std.
|
|
903
|
-
|
|
904
|
-
NOTE: Why is the input logvar and not std? because the model typically predicts logvar and not std.
|
|
784
|
+
Args:
|
|
785
|
+
smoothening_pixelcount: number of pixels which can be interpolated
|
|
905
786
|
"""
|
|
906
|
-
#
|
|
907
|
-
|
|
908
|
-
|
|
787
|
+
# Commented out since it is not used as of now
|
|
788
|
+
# if isinstance(dset, MultiFileDset):
|
|
789
|
+
# cum_count = 0
|
|
790
|
+
# output = []
|
|
791
|
+
# for dset in dset.dsets:
|
|
792
|
+
# cnt = dset.idx_manager.total_grid_count()
|
|
793
|
+
# output.append(
|
|
794
|
+
# stitch_predictions(predictions[cum_count:cum_count + cnt], dset))
|
|
795
|
+
# cum_count += cnt
|
|
796
|
+
# return output
|
|
909
797
|
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
optimizer.zero_grad()
|
|
913
|
-
# Select a random batch of predictions
|
|
914
|
-
mask = np.random.randint(0, pred.shape[0], batch_size)
|
|
915
|
-
pred_batch = torch.Tensor(pred[mask]).cuda()
|
|
916
|
-
pred_logvar_batch = torch.Tensor(pred_logvar[mask]).cuda()
|
|
917
|
-
target_batch = torch.Tensor(target[mask]).cuda()
|
|
798
|
+
# else:
|
|
799
|
+
mng = dset.idx_manager
|
|
918
800
|
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
loss.backward()
|
|
923
|
-
optimizer.step()
|
|
924
|
-
bar.set_description(f"nll: {loss.item()} scalar: {scalar.item()}")
|
|
925
|
-
|
|
926
|
-
return np.sqrt(scalar.item())
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
def plot_calibration(ax, calibration_stats):
|
|
930
|
-
first_idx = get_first_index(calibration_stats[0]["bin_count"], 0.001)
|
|
931
|
-
last_idx = get_last_index(calibration_stats[0]["bin_count"], 0.999)
|
|
932
|
-
ax.plot(
|
|
933
|
-
calibration_stats[0]["rmv"][first_idx:-last_idx],
|
|
934
|
-
calibration_stats[0]["rmse"][first_idx:-last_idx],
|
|
935
|
-
"o",
|
|
936
|
-
label=r"$\hat{C}_0$: Ch1",
|
|
937
|
-
)
|
|
801
|
+
# if there are more channels, use all of them.
|
|
802
|
+
shape = list(dset.get_data_shape())
|
|
803
|
+
shape[-1] = max(shape[-1], predictions.shape[1])
|
|
938
804
|
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
805
|
+
output = np.zeros(shape, dtype=predictions.dtype)
|
|
806
|
+
# frame_shape = dset.get_data_shape()[:-1]
|
|
807
|
+
for dset_idx in range(predictions.shape[0]):
|
|
808
|
+
# loc = get_location_from_idx(dset, dset_idx, predictions.shape[-2], predictions.shape[-1])
|
|
809
|
+
# grid start, grid end
|
|
810
|
+
gs = np.array(mng.get_location_from_dataset_idx(dset_idx), dtype=int)
|
|
811
|
+
ge = gs + mng.grid_shape
|
|
812
|
+
|
|
813
|
+
# patch start, patch end
|
|
814
|
+
ps = gs - mng.patch_offset()
|
|
815
|
+
pe = ps + mng.patch_shape
|
|
816
|
+
# print('PS')
|
|
817
|
+
# print(ps)
|
|
818
|
+
# print(pe)
|
|
819
|
+
|
|
820
|
+
# valid grid start, valid grid end
|
|
821
|
+
vgs = np.array([max(0, x) for x in gs], dtype=int)
|
|
822
|
+
vge = np.array([min(x, y) for x, y in zip(ge, mng.data_shape)], dtype=int)
|
|
823
|
+
# assert np.all(vgs == gs)
|
|
824
|
+
# assert np.all(vge == ge) # TODO comented out this shit cuz I have no interest to dig why it's failing at this point !
|
|
825
|
+
# print('VGS')
|
|
826
|
+
# print(gs)
|
|
827
|
+
# print(ge)
|
|
828
|
+
|
|
829
|
+
if mng.tiling_mode == TilingMode.ShiftBoundary:
|
|
830
|
+
for dim in range(len(vgs)):
|
|
831
|
+
if ps[dim] == 0:
|
|
832
|
+
vgs[dim] = 0
|
|
833
|
+
if pe[dim] == mng.data_shape[dim]:
|
|
834
|
+
vge[dim] = mng.data_shape[dim]
|
|
835
|
+
|
|
836
|
+
# relative start, relative end. This will be used on pred_tiled
|
|
837
|
+
rs = vgs - ps
|
|
838
|
+
re = rs + (vge - vgs)
|
|
839
|
+
# print('RS')
|
|
840
|
+
# print(rs)
|
|
841
|
+
# print(re)
|
|
842
|
+
|
|
843
|
+
# print(output.shape)
|
|
844
|
+
# print(predictions.shape)
|
|
845
|
+
for ch_idx in range(predictions.shape[1]):
|
|
846
|
+
if len(output.shape) == 4:
|
|
847
|
+
# channel dimension is the last one.
|
|
848
|
+
output[vgs[0] : vge[0], vgs[1] : vge[1], vgs[2] : vge[2], ch_idx] = (
|
|
849
|
+
predictions[dset_idx][ch_idx, rs[1] : re[1], rs[2] : re[2]]
|
|
850
|
+
)
|
|
851
|
+
elif len(output.shape) == 5:
|
|
852
|
+
# channel dimension is the last one.
|
|
853
|
+
assert vge[0] - vgs[0] == 1, "Only one frame is supported"
|
|
854
|
+
output[
|
|
855
|
+
vgs[0], vgs[1] : vge[1], vgs[2] : vge[2], vgs[3] : vge[3], ch_idx
|
|
856
|
+
] = predictions[dset_idx][
|
|
857
|
+
ch_idx, rs[1] : re[1], rs[2] : re[2], rs[3] : re[3]
|
|
858
|
+
]
|
|
859
|
+
else:
|
|
860
|
+
raise ValueError(f"Unsupported shape {output.shape}")
|
|
947
861
|
|
|
948
|
-
|
|
949
|
-
ax.set_ylabel("RMSE")
|
|
950
|
-
ax.legend()
|
|
862
|
+
return output
|
careamics/model_io/__init__.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
"""Functions used to create a README.md file for BMZ export."""
|
|
2
2
|
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import Optional
|
|
5
4
|
|
|
6
5
|
import yaml
|
|
7
6
|
|
|
@@ -28,7 +27,7 @@ def _yaml_block(yaml_str: str) -> str:
|
|
|
28
27
|
def readme_factory(
|
|
29
28
|
config: Configuration,
|
|
30
29
|
careamics_version: str,
|
|
31
|
-
data_description:
|
|
30
|
+
data_description: str,
|
|
32
31
|
) -> Path:
|
|
33
32
|
"""Create a README file for the model.
|
|
34
33
|
|
|
@@ -41,18 +40,14 @@ def readme_factory(
|
|
|
41
40
|
CAREamics configuration.
|
|
42
41
|
careamics_version : str
|
|
43
42
|
CAREamics version.
|
|
44
|
-
data_description :
|
|
45
|
-
Description of the data
|
|
43
|
+
data_description : str
|
|
44
|
+
Description of the data.
|
|
46
45
|
|
|
47
46
|
Returns
|
|
48
47
|
-------
|
|
49
48
|
Path
|
|
50
49
|
Path to the README file.
|
|
51
50
|
"""
|
|
52
|
-
algorithm = config.algorithm_config
|
|
53
|
-
training = config.training_config
|
|
54
|
-
data = config.data_config
|
|
55
|
-
|
|
56
51
|
# create file
|
|
57
52
|
# TODO use tempfile as in the bmz_io module
|
|
58
53
|
with cwd(get_careamics_home()):
|
|
@@ -60,47 +55,44 @@ def readme_factory(
|
|
|
60
55
|
readme.touch()
|
|
61
56
|
|
|
62
57
|
# algorithm pretty name
|
|
63
|
-
algorithm_flavour = config.
|
|
58
|
+
algorithm_flavour = config.get_algorithm_friendly_name()
|
|
64
59
|
algorithm_pretty_name = algorithm_flavour + " - CAREamics"
|
|
65
60
|
|
|
66
61
|
description = [f"# {algorithm_pretty_name}\n\n"]
|
|
67
62
|
|
|
63
|
+
# data description
|
|
64
|
+
description.append("## Data description\n\n")
|
|
65
|
+
description.append(data_description)
|
|
66
|
+
description.append("\n\n")
|
|
67
|
+
|
|
68
68
|
# algorithm description
|
|
69
|
-
description.append("Algorithm description:\n\n")
|
|
69
|
+
description.append("## Algorithm description:\n\n")
|
|
70
70
|
description.append(config.get_algorithm_description())
|
|
71
71
|
description.append("\n\n")
|
|
72
72
|
|
|
73
|
-
#
|
|
73
|
+
# configuration description
|
|
74
|
+
description.append("## Configuration\n\n")
|
|
75
|
+
|
|
74
76
|
description.append(
|
|
75
77
|
f"{algorithm_flavour} was trained using CAREamics (version "
|
|
76
|
-
f"{careamics_version})
|
|
77
|
-
f"parameters:\n\n"
|
|
78
|
-
)
|
|
79
|
-
description.append(
|
|
80
|
-
_yaml_block(yaml.dump(algorithm.model_dump(exclude_none=True)))
|
|
78
|
+
f"{careamics_version}) using the following configuration:\n\n"
|
|
81
79
|
)
|
|
82
|
-
description.append("\n\n")
|
|
83
|
-
|
|
84
|
-
# data description
|
|
85
|
-
description.append("## Data description\n\n")
|
|
86
|
-
if data_description is not None:
|
|
87
|
-
description.append(data_description)
|
|
88
|
-
description.append("\n\n")
|
|
89
|
-
|
|
90
|
-
description.append("The data was processed using the following parameters:\n\n")
|
|
91
80
|
|
|
92
|
-
description.append(_yaml_block(yaml.dump(
|
|
81
|
+
description.append(_yaml_block(yaml.dump(config.model_dump(exclude_none=True))))
|
|
93
82
|
description.append("\n\n")
|
|
94
83
|
|
|
95
|
-
#
|
|
96
|
-
description.append("
|
|
97
|
-
|
|
98
|
-
description.append("The model was trained using the following parameters:\n\n")
|
|
84
|
+
# validation
|
|
85
|
+
description.append("# Validation\n\n")
|
|
99
86
|
|
|
100
87
|
description.append(
|
|
101
|
-
|
|
88
|
+
"In order to validate the model, we encourage users to acquire a "
|
|
89
|
+
"test dataset with ground-truth data. Comparing the ground-truth data "
|
|
90
|
+
"with the prediction allows unbiased evaluation of the model performances. "
|
|
91
|
+
"This can be done for instance by using metrics such as PSNR, SSIM, or"
|
|
92
|
+
"MicroSSIM. In the absence of ground-truth, inspecting the residual image "
|
|
93
|
+
"(difference between input and predicted image) can be helpful to identify "
|
|
94
|
+
"whether real signal is removed from the input image.\n\n"
|
|
102
95
|
)
|
|
103
|
-
description.append("\n\n")
|
|
104
96
|
|
|
105
97
|
# references
|
|
106
98
|
reference = config.get_algorithm_references()
|
|
@@ -111,9 +103,9 @@ def readme_factory(
|
|
|
111
103
|
|
|
112
104
|
# links
|
|
113
105
|
description.append(
|
|
114
|
-
"
|
|
106
|
+
"# Links\n\n"
|
|
115
107
|
"- [CAREamics repository](https://github.com/CAREamics/careamics)\n"
|
|
116
|
-
"- [CAREamics documentation](https://careamics.github.io/
|
|
108
|
+
"- [CAREamics documentation](https://careamics.github.io/)\n"
|
|
117
109
|
)
|
|
118
110
|
|
|
119
111
|
readme.write_text("".join(description))
|