careamics 0.0.4.2__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (43) hide show
  1. careamics/careamist.py +235 -25
  2. careamics/cli/conf.py +19 -30
  3. careamics/cli/main.py +111 -10
  4. careamics/cli/utils.py +29 -0
  5. careamics/config/__init__.py +2 -0
  6. careamics/config/architectures/lvae_model.py +104 -21
  7. careamics/config/configuration_factory.py +49 -45
  8. careamics/config/configuration_model.py +2 -2
  9. careamics/config/likelihood_model.py +7 -6
  10. careamics/config/loss_model.py +56 -0
  11. careamics/config/nm_model.py +24 -24
  12. careamics/config/vae_algorithm_model.py +14 -13
  13. careamics/dataset/dataset_utils/running_stats.py +22 -23
  14. careamics/lightning/lightning_module.py +58 -27
  15. careamics/lightning/train_data_module.py +15 -1
  16. careamics/losses/loss_factory.py +1 -85
  17. careamics/losses/lvae/losses.py +223 -164
  18. careamics/lvae_training/calibration.py +184 -0
  19. careamics/lvae_training/dataset/config.py +2 -2
  20. careamics/lvae_training/dataset/multich_dataset.py +11 -19
  21. careamics/lvae_training/dataset/multifile_dataset.py +3 -2
  22. careamics/lvae_training/dataset/types.py +15 -26
  23. careamics/lvae_training/dataset/utils/index_manager.py +4 -4
  24. careamics/lvae_training/eval_utils.py +125 -213
  25. careamics/model_io/bioimage/_readme_factory.py +25 -33
  26. careamics/model_io/bioimage/cover_factory.py +171 -0
  27. careamics/model_io/bioimage/model_description.py +39 -17
  28. careamics/model_io/bmz_io.py +36 -25
  29. careamics/models/layers.py +6 -4
  30. careamics/models/lvae/layers.py +348 -975
  31. careamics/models/lvae/likelihoods.py +10 -8
  32. careamics/models/lvae/lvae.py +214 -272
  33. careamics/models/lvae/noise_models.py +179 -112
  34. careamics/models/lvae/stochastic.py +393 -0
  35. careamics/models/lvae/utils.py +82 -73
  36. careamics/utils/lightning_utils.py +57 -0
  37. careamics/utils/serializers.py +2 -0
  38. careamics/utils/torch_utils.py +1 -1
  39. {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/METADATA +12 -9
  40. {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/RECORD +43 -37
  41. {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/WHEEL +1 -1
  42. {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/entry_points.txt +0 -0
  43. {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/licenses/LICENSE +0 -0
@@ -6,13 +6,13 @@ It includes functions to:
6
6
  - create plots to visualize the results.
7
7
  """
8
8
 
9
- import math
10
9
  import os
11
- from typing import Dict, List, Literal, Union
10
+ from typing import List, Literal, Union
12
11
 
13
12
  import matplotlib
14
13
  import matplotlib.pyplot as plt
15
14
  import numpy as np
15
+ from scipy import stats
16
16
  import torch
17
17
  from torch import nn
18
18
  from torch.utils.data import Dataset
@@ -21,14 +21,21 @@ from torch.utils.data import DataLoader
21
21
  from tqdm import tqdm
22
22
 
23
23
  from careamics.lightning import VAEModule
24
- from careamics.losses.lvae.losses import (
25
- get_reconstruction_loss,
26
- reconstruction_loss_musplit_denoisplit,
27
- )
24
+
28
25
  from careamics.models.lvae.utils import ModelType
29
26
  from careamics.utils.metrics import scale_invariant_psnr, RunningPSNR
30
27
 
31
28
 
29
+ class TilingMode:
30
+ """
31
+ Enum for the tiling mode.
32
+ """
33
+
34
+ TrimBoundary = 0
35
+ PadBoundary = 1
36
+ ShiftBoundary = 2
37
+
38
+
32
39
  # ------------------------------------------------------------------------------------------------
33
40
  # Function of plotting: TODO -> moved them to another file, plot_utils.py
34
41
  def clean_ax(ax):
@@ -596,51 +603,49 @@ def get_dset_predictions(
596
603
  logvar_arr.append(logvar.cpu().numpy())
597
604
 
598
605
  # compute reconstruction loss
599
- if loss_type == "musplit":
600
- rec_loss = get_reconstruction_loss(
601
- reconstruction=rec, target=tar, likelihood_obj=gauss_likelihood
602
- )
603
- elif loss_type == "denoisplit":
604
- rec_loss = get_reconstruction_loss(
605
- reconstruction=rec, target=tar, likelihood_obj=nm_likelihood
606
- )
607
- elif loss_type == "denoisplit_musplit":
608
- rec_loss = reconstruction_loss_musplit_denoisplit(
609
- predictions=rec,
610
- targets=tar,
611
- gaussian_likelihood=gauss_likelihood,
612
- nm_likelihood=nm_likelihood,
613
- nm_weight=model.loss_parameters.denoisplit_weight,
614
- gaussian_weight=model.loss_parameters.musplit_weight,
615
- )
616
- rec_loss = {"loss": rec_loss} # hacky, but ok for now
617
-
618
- # store rec loss values for first pred
619
- if mmse_idx == 0:
620
- try:
621
- losses.append(rec_loss["loss"].cpu().numpy())
622
- except:
623
- losses.append(rec_loss["loss"])
606
+ # if loss_type == "musplit":
607
+ # rec_loss = get_reconstruction_loss(
608
+ # reconstruction=rec, target=tar, likelihood_obj=gauss_likelihood
609
+ # )
610
+ # elif loss_type == "denoisplit":
611
+ # rec_loss = get_reconstruction_loss(
612
+ # reconstruction=rec, target=tar, likelihood_obj=nm_likelihood
613
+ # )
614
+ # elif loss_type == "denoisplit_musplit":
615
+ # rec_loss = reconstruction_loss_musplit_denoisplit(
616
+ # predictions=rec,
617
+ # targets=tar,
618
+ # gaussian_likelihood=gauss_likelihood,
619
+ # nm_likelihood=nm_likelihood,
620
+ # nm_weight=model.loss_parameters.denoisplit_weight,
621
+ # gaussian_weight=model.loss_parameters.musplit_weight,
622
+ # )
623
+ # rec_loss = {"loss": rec_loss} # hacky, but ok for now
624
+
625
+ # # store rec loss values for first pred
626
+ # if mmse_idx == 0:
627
+ # try:
628
+ # losses.append(rec_loss["loss"].cpu().numpy())
629
+ # except:
630
+ # losses.append(rec_loss["loss"])
624
631
 
625
632
  # update running PSNR
626
- for i in range(num_channels):
627
- patch_psnr_channels[i].update(rec_img[:, i], tar[:, i])
633
+ # for i in range(num_channels):
634
+ # patch_psnr_channels[i].update(rec_img[:, i], tar[:, i])
628
635
 
629
636
  # aggregate results
630
637
  samples = torch.cat(rec_img_list, dim=0)
631
638
  mmse_imgs = torch.mean(samples, dim=0) # avg over MMSE dim
632
- mmse_std = torch.std(samples, dim=0)
639
+ # mmse_std = torch.std(samples, dim=0)
633
640
  predictions.append(mmse_imgs.cpu().numpy())
634
- predictions_std.append(mmse_std.cpu().numpy())
635
-
636
- psnr = [x.get() for x in patch_psnr_channels]
637
- return (
638
- np.concatenate(predictions, axis=0),
639
- np.concatenate(predictions_std, axis=0),
640
- np.concatenate(logvar_arr),
641
- np.array(losses),
642
- psnr,
643
- )
641
+ # predictions_std.append(mmse_std.cpu().numpy())
642
+
643
+ # psnr = [x.get() for x in patch_psnr_channels]
644
+ return np.concatenate(predictions, axis=0)
645
+ # np.concatenate(predictions_std, axis=0),
646
+ # np.concatenate(logvar_arr),
647
+ # np.array(losses),
648
+ # psnr, # TODO revisit !
644
649
 
645
650
 
646
651
  # ------------------------------------------------------------------------------------------
@@ -773,178 +778,85 @@ def stitch_predictions(predictions, dset, smoothening_pixelcount=0):
773
778
  return output
774
779
 
775
780
 
776
- # ------------------------------------------------------------------------------------------
777
-
778
-
779
- # ------------------------------------------------------------------------------------------
780
- ### Classes and Functions used for Calibration
781
- class Calibration:
782
-
783
- def __init__(
784
- self, num_bins: int = 15, mode: Literal["pixelwise", "patchwise"] = "pixelwise"
785
- ):
786
- self._bins = num_bins
787
- self._bin_boundaries = None
788
- self._mode = mode
789
- assert mode in ["pixelwise", "patchwise"]
790
- self._boundary_mode = "uniform"
791
- assert self._boundary_mode in ["quantile", "uniform"]
792
- # self._bin_boundaries = {}
793
-
794
- def logvar_to_std(self, logvar: np.ndarray) -> np.ndarray:
795
- return np.exp(logvar / 2)
796
-
797
- def compute_bin_boundaries(self, predict_logvar: np.ndarray) -> np.ndarray:
798
- """
799
- Compute the bin boundaries for `num_bins` bins and the given logvar values.
800
- """
801
- if self._boundary_mode == "quantile":
802
- boundaries = np.quantile(
803
- self.logvar_to_std(predict_logvar), np.linspace(0, 1, self._bins + 1)
804
- )
805
- return boundaries
806
- else:
807
- min_logvar = np.min(predict_logvar)
808
- max_logvar = np.max(predict_logvar)
809
- min_std = self.logvar_to_std(min_logvar)
810
- max_std = self.logvar_to_std(max_logvar)
811
- return np.linspace(min_std, max_std, self._bins + 1)
812
-
813
- def compute_stats(
814
- self, pred: np.ndarray, pred_logvar: np.ndarray, target: np.ndarray
815
- ) -> Dict[int, Dict[str, Union[np.ndarray, List]]]:
816
- """
817
- It computes the bin-wise RMSE and RMV for each channel of the predicted image.
818
-
819
- Recall that:
820
- - RMSE = np.sqrt((pred - target)**2 / num_pixels)
821
- - RMV = np.sqrt(np.mean(pred_std**2))
822
-
823
- ALGORITHM
824
- - For each channel:
825
- - Given the bin boundaries, assign pixels of `std_ch` array to a specific bin index.
826
- - For each bin index:
827
- - Compute the RMSE, RMV, and number of pixels for that bin.
828
-
829
- NOTE: each channel of the predicted image/logvar has its own stats.
830
-
831
- Args:
832
- pred: np.ndarray, shape (n, h, w, c)
833
- pred_logvar: np.ndarray, shape (n, h, w, c)
834
- target: np.ndarray, shape (n, h, w, c)
835
- """
836
- self._bin_boundaries = {}
837
- stats = {}
838
- for ch_idx in range(pred.shape[-1]):
839
- stats[ch_idx] = {
840
- "bin_count": [],
841
- "rmv": [],
842
- "rmse": [],
843
- "bin_boundaries": None,
844
- "bin_matrix": [],
845
- }
846
- pred_ch = pred[..., ch_idx]
847
- logvar_ch = pred_logvar[..., ch_idx]
848
- std_ch = self.logvar_to_std(logvar_ch)
849
- target_ch = target[..., ch_idx]
850
- if self._mode == "pixelwise":
851
- boundaries = self.compute_bin_boundaries(logvar_ch)
852
- stats[ch_idx]["bin_boundaries"] = boundaries
853
- bin_matrix = np.digitize(std_ch.reshape(-1), boundaries)
854
- bin_matrix = bin_matrix.reshape(std_ch.shape)
855
- stats[ch_idx]["bin_matrix"] = bin_matrix
856
- error = (pred_ch - target_ch) ** 2
857
- for bin_idx in range(self._bins):
858
- bin_mask = bin_matrix == bin_idx
859
- bin_error = error[bin_mask]
860
- bin_size = np.sum(bin_mask)
861
- bin_error = (
862
- np.sqrt(np.sum(bin_error) / bin_size) if bin_size > 0 else None
863
- ) # RMSE
864
- bin_var = np.sqrt(np.mean(std_ch[bin_mask] ** 2)) # RMV
865
- stats[ch_idx]["rmse"].append(bin_error)
866
- stats[ch_idx]["rmv"].append(bin_var)
867
- stats[ch_idx]["bin_count"].append(bin_size)
868
- else:
869
- raise NotImplementedError("Patchwise mode is not implemented yet.")
870
- return stats
871
-
872
-
873
- def nll(x, mean, logvar):
781
+ # from disentangle.analysis.stitch_prediction import *
782
+ def stitch_predictions_new(predictions, dset):
874
783
  """
875
- Log of the probability density of the values x under the Normal
876
- distribution with parameters mean and logvar.
877
-
878
- :param x: tensor of points, with shape (batch, channels, dim1, dim2)
879
- :param mean: tensor with mean of distribution, shape
880
- (batch, channels, dim1, dim2)
881
- :param logvar: tensor with log-variance of distribution, shape has to be
882
- either scalar or broadcastable
883
- """
884
- var = torch.exp(logvar)
885
- log_prob = -0.5 * (
886
- ((x - mean) ** 2) / var + logvar + torch.tensor(2 * math.pi).log()
887
- )
888
- nll = -log_prob
889
- return nll
890
-
891
-
892
- def get_calibrated_factor_for_stdev(
893
- pred: Union[np.ndarray, torch.Tensor],
894
- pred_logvar: Union[np.ndarray, torch.Tensor],
895
- target: Union[np.ndarray, torch.Tensor],
896
- batch_size: int = 32,
897
- epochs: int = 500,
898
- lr: float = 0.01,
899
- ):
900
- """
901
- Here, we calibrate the uncertainty by multiplying the predicted std (mmse estimate or predicted logvar) with a scalar.
902
- We return the calibrated scalar. This needs to be multiplied with the std.
903
-
904
- NOTE: Why is the input logvar and not std? because the model typically predicts logvar and not std.
784
+ Args:
785
+ smoothening_pixelcount: number of pixels which can be interpolated
905
786
  """
906
- # create a learnable scalar
907
- scalar = torch.nn.Parameter(torch.tensor(2.0))
908
- optimizer = torch.optim.Adam([scalar], lr=lr)
787
+ # Commented out since it is not used as of now
788
+ # if isinstance(dset, MultiFileDset):
789
+ # cum_count = 0
790
+ # output = []
791
+ # for dset in dset.dsets:
792
+ # cnt = dset.idx_manager.total_grid_count()
793
+ # output.append(
794
+ # stitch_predictions(predictions[cum_count:cum_count + cnt], dset))
795
+ # cum_count += cnt
796
+ # return output
909
797
 
910
- bar = tqdm(range(epochs))
911
- for _ in bar:
912
- optimizer.zero_grad()
913
- # Select a random batch of predictions
914
- mask = np.random.randint(0, pred.shape[0], batch_size)
915
- pred_batch = torch.Tensor(pred[mask]).cuda()
916
- pred_logvar_batch = torch.Tensor(pred_logvar[mask]).cuda()
917
- target_batch = torch.Tensor(target[mask]).cuda()
798
+ # else:
799
+ mng = dset.idx_manager
918
800
 
919
- loss = torch.mean(
920
- nll(target_batch, pred_batch, pred_logvar_batch + torch.log(scalar))
921
- )
922
- loss.backward()
923
- optimizer.step()
924
- bar.set_description(f"nll: {loss.item()} scalar: {scalar.item()}")
925
-
926
- return np.sqrt(scalar.item())
927
-
928
-
929
- def plot_calibration(ax, calibration_stats):
930
- first_idx = get_first_index(calibration_stats[0]["bin_count"], 0.001)
931
- last_idx = get_last_index(calibration_stats[0]["bin_count"], 0.999)
932
- ax.plot(
933
- calibration_stats[0]["rmv"][first_idx:-last_idx],
934
- calibration_stats[0]["rmse"][first_idx:-last_idx],
935
- "o",
936
- label=r"$\hat{C}_0$: Ch1",
937
- )
801
+ # if there are more channels, use all of them.
802
+ shape = list(dset.get_data_shape())
803
+ shape[-1] = max(shape[-1], predictions.shape[1])
938
804
 
939
- first_idx = get_first_index(calibration_stats[1]["bin_count"], 0.001)
940
- last_idx = get_last_index(calibration_stats[1]["bin_count"], 0.999)
941
- ax.plot(
942
- calibration_stats[1]["rmv"][first_idx:-last_idx],
943
- calibration_stats[1]["rmse"][first_idx:-last_idx],
944
- "o",
945
- label=r"$\hat{C}_1: : Ch2$",
946
- )
805
+ output = np.zeros(shape, dtype=predictions.dtype)
806
+ # frame_shape = dset.get_data_shape()[:-1]
807
+ for dset_idx in range(predictions.shape[0]):
808
+ # loc = get_location_from_idx(dset, dset_idx, predictions.shape[-2], predictions.shape[-1])
809
+ # grid start, grid end
810
+ gs = np.array(mng.get_location_from_dataset_idx(dset_idx), dtype=int)
811
+ ge = gs + mng.grid_shape
812
+
813
+ # patch start, patch end
814
+ ps = gs - mng.patch_offset()
815
+ pe = ps + mng.patch_shape
816
+ # print('PS')
817
+ # print(ps)
818
+ # print(pe)
819
+
820
+ # valid grid start, valid grid end
821
+ vgs = np.array([max(0, x) for x in gs], dtype=int)
822
+ vge = np.array([min(x, y) for x, y in zip(ge, mng.data_shape)], dtype=int)
823
+ # assert np.all(vgs == gs)
824
+ # assert np.all(vge == ge) # TODO comented out this shit cuz I have no interest to dig why it's failing at this point !
825
+ # print('VGS')
826
+ # print(gs)
827
+ # print(ge)
828
+
829
+ if mng.tiling_mode == TilingMode.ShiftBoundary:
830
+ for dim in range(len(vgs)):
831
+ if ps[dim] == 0:
832
+ vgs[dim] = 0
833
+ if pe[dim] == mng.data_shape[dim]:
834
+ vge[dim] = mng.data_shape[dim]
835
+
836
+ # relative start, relative end. This will be used on pred_tiled
837
+ rs = vgs - ps
838
+ re = rs + (vge - vgs)
839
+ # print('RS')
840
+ # print(rs)
841
+ # print(re)
842
+
843
+ # print(output.shape)
844
+ # print(predictions.shape)
845
+ for ch_idx in range(predictions.shape[1]):
846
+ if len(output.shape) == 4:
847
+ # channel dimension is the last one.
848
+ output[vgs[0] : vge[0], vgs[1] : vge[1], vgs[2] : vge[2], ch_idx] = (
849
+ predictions[dset_idx][ch_idx, rs[1] : re[1], rs[2] : re[2]]
850
+ )
851
+ elif len(output.shape) == 5:
852
+ # channel dimension is the last one.
853
+ assert vge[0] - vgs[0] == 1, "Only one frame is supported"
854
+ output[
855
+ vgs[0], vgs[1] : vge[1], vgs[2] : vge[2], vgs[3] : vge[3], ch_idx
856
+ ] = predictions[dset_idx][
857
+ ch_idx, rs[1] : re[1], rs[2] : re[2], rs[3] : re[3]
858
+ ]
859
+ else:
860
+ raise ValueError(f"Unsupported shape {output.shape}")
947
861
 
948
- ax.set_xlabel("RMV")
949
- ax.set_ylabel("RMSE")
950
- ax.legend()
862
+ return output
@@ -1,7 +1,6 @@
1
1
  """Functions used to create a README.md file for BMZ export."""
2
2
 
3
3
  from pathlib import Path
4
- from typing import Optional
5
4
 
6
5
  import yaml
7
6
 
@@ -28,7 +27,7 @@ def _yaml_block(yaml_str: str) -> str:
28
27
  def readme_factory(
29
28
  config: Configuration,
30
29
  careamics_version: str,
31
- data_description: Optional[str] = None,
30
+ data_description: str,
32
31
  ) -> Path:
33
32
  """Create a README file for the model.
34
33
 
@@ -41,18 +40,14 @@ def readme_factory(
41
40
  CAREamics configuration.
42
41
  careamics_version : str
43
42
  CAREamics version.
44
- data_description : Optional[str], optional
45
- Description of the data, by default None.
43
+ data_description : str
44
+ Description of the data.
46
45
 
47
46
  Returns
48
47
  -------
49
48
  Path
50
49
  Path to the README file.
51
50
  """
52
- algorithm = config.algorithm_config
53
- training = config.training_config
54
- data = config.data_config
55
-
56
51
  # create file
57
52
  # TODO use tempfile as in the bmz_io module
58
53
  with cwd(get_careamics_home()):
@@ -65,42 +60,39 @@ def readme_factory(
65
60
 
66
61
  description = [f"# {algorithm_pretty_name}\n\n"]
67
62
 
63
+ # data description
64
+ description.append("## Data description\n\n")
65
+ description.append(data_description)
66
+ description.append("\n\n")
67
+
68
68
  # algorithm description
69
- description.append("Algorithm description:\n\n")
69
+ description.append("## Algorithm description:\n\n")
70
70
  description.append(config.get_algorithm_description())
71
71
  description.append("\n\n")
72
72
 
73
- # algorithm details
73
+ # configuration description
74
+ description.append("## Configuration\n\n")
75
+
74
76
  description.append(
75
77
  f"{algorithm_flavour} was trained using CAREamics (version "
76
- f"{careamics_version}) with the following algorithm "
77
- f"parameters:\n\n"
78
- )
79
- description.append(
80
- _yaml_block(yaml.dump(algorithm.model_dump(exclude_none=True)))
78
+ f"{careamics_version}) using the following configuration:\n\n"
81
79
  )
82
- description.append("\n\n")
83
-
84
- # data description
85
- description.append("## Data description\n\n")
86
- if data_description is not None:
87
- description.append(data_description)
88
- description.append("\n\n")
89
-
90
- description.append("The data was processed using the following parameters:\n\n")
91
80
 
92
- description.append(_yaml_block(yaml.dump(data.model_dump(exclude_none=True))))
81
+ description.append(_yaml_block(yaml.dump(config.model_dump(exclude_none=True))))
93
82
  description.append("\n\n")
94
83
 
95
- # training description
96
- description.append("## Training description\n\n")
97
-
98
- description.append("The model was trained using the following parameters:\n\n")
84
+ # validation
85
+ description.append("# Validation\n\n")
99
86
 
100
87
  description.append(
101
- _yaml_block(yaml.dump(training.model_dump(exclude_none=True)))
88
+ "In order to validate the model, we encourage users to acquire a "
89
+ "test dataset with ground-truth data. Comparing the ground-truth data "
90
+ "with the prediction allows unbiased evaluation of the model performances. "
91
+ "This can be done for instance by using metrics such as PSNR, SSIM, or"
92
+ "MicroSSIM. In the absence of ground-truth, inspecting the residual image "
93
+ "(difference between input and predicted image) can be helpful to identify "
94
+ "whether real signal is removed from the input image.\n\n"
102
95
  )
103
- description.append("\n\n")
104
96
 
105
97
  # references
106
98
  reference = config.get_algorithm_references()
@@ -111,9 +103,9 @@ def readme_factory(
111
103
 
112
104
  # links
113
105
  description.append(
114
- "## Links\n\n"
106
+ "# Links\n\n"
115
107
  "- [CAREamics repository](https://github.com/CAREamics/careamics)\n"
116
- "- [CAREamics documentation](https://careamics.github.io/latest/)\n"
108
+ "- [CAREamics documentation](https://careamics.github.io/)\n"
117
109
  )
118
110
 
119
111
  readme.write_text("".join(description))
@@ -0,0 +1,171 @@
1
+ """Convenience function to create covers for the BMZ."""
2
+
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ from numpy.typing import NDArray
7
+ from PIL import Image
8
+
9
+ color_palette = np.array(
10
+ [
11
+ np.array([255, 195, 0]), # grey
12
+ np.array([189, 226, 240]),
13
+ np.array([96, 60, 76]),
14
+ np.array([193, 225, 193]),
15
+ ]
16
+ )
17
+
18
+
19
+ def _get_norm_slice(array: NDArray) -> NDArray:
20
+ """Get the normalized middle slice of a 4D or 5D array (SC(Z)YX).
21
+
22
+ Parameters
23
+ ----------
24
+ array : NDArray
25
+ Array from which to get the middle slice.
26
+
27
+ Returns
28
+ -------
29
+ NDArray
30
+ Normalized middle slice of the input array.
31
+ """
32
+ if array.ndim not in (4, 5):
33
+ raise ValueError("Array must be 4D or 5D.")
34
+
35
+ channels = array.shape[1] > 1
36
+ z_stack = array.ndim == 5
37
+
38
+ # get slice
39
+ if z_stack:
40
+ array_slice = array[0, :, array.shape[2] // 2, ...]
41
+ else:
42
+ array_slice = array[0, ...]
43
+
44
+ # channels
45
+ if channels:
46
+ array_slice = np.moveaxis(array_slice, 0, -1)
47
+ else:
48
+ array_slice = array_slice[0, ...]
49
+
50
+ # normalize
51
+ array_slice = (
52
+ 255
53
+ * (array_slice - array_slice.min())
54
+ / (array_slice.max() - array_slice.min())
55
+ )
56
+
57
+ return array_slice.astype(np.uint8)
58
+
59
+
60
+ def _four_channel_image(array: NDArray) -> Image:
61
+ """Convert 4-channel array to Image.
62
+
63
+ Parameters
64
+ ----------
65
+ array : NDArray
66
+ Normalized array to convert.
67
+
68
+ Returns
69
+ -------
70
+ Image
71
+ Converted array.
72
+ """
73
+ colors = color_palette[np.newaxis, np.newaxis, :, :]
74
+ four_c_array = np.sum(array[..., :4, np.newaxis] * colors, axis=-2).astype(np.uint8)
75
+
76
+ return Image.fromarray(four_c_array).convert("RGB")
77
+
78
+
79
+ def _convert_to_image(original_shape: tuple[int, ...], array: NDArray) -> Image:
80
+ """Convert to Image.
81
+
82
+ Parameters
83
+ ----------
84
+ original_shape : tuple
85
+ Original shape of the array.
86
+ array : NDArray
87
+ Normalized array to convert.
88
+
89
+ Returns
90
+ -------
91
+ Image
92
+ Converted array.
93
+ """
94
+ n_channels = original_shape[1]
95
+
96
+ if n_channels > 1:
97
+ if n_channels == 3:
98
+ return Image.fromarray(array).convert("RGB")
99
+ elif n_channels == 2:
100
+ # add an empty channel to the numpy array
101
+ array = np.concatenate([np.zeros_like(array[..., 0:1]), array], axis=-1)
102
+
103
+ return Image.fromarray(array).convert("RGB")
104
+ else: # more than 4
105
+ return _four_channel_image(array[..., :4])
106
+ else:
107
+ return Image.fromarray(array).convert("L").convert("RGB")
108
+
109
+
110
+ def create_cover(directory: Path, array_in: NDArray, array_out: NDArray) -> Path:
111
+ """Create a cover image from input and output arrays.
112
+
113
+ Input and output arrays are expected to be SC(Z)YX. For images with a Z
114
+ dimension, the middle slice is taken.
115
+
116
+ Parameters
117
+ ----------
118
+ directory : Path
119
+ Directory in which to save the cover.
120
+ array_in : numpy.ndarray
121
+ Array from which to create the cover image.
122
+ array_out : numpy.ndarray
123
+ Array from which to create the cover image.
124
+
125
+ Returns
126
+ -------
127
+ Path
128
+ Path to the saved cover image.
129
+ """
130
+ # extract slice and normalize arrays
131
+ slice_in = _get_norm_slice(array_in)
132
+ slice_out = _get_norm_slice(array_out)
133
+
134
+ horizontal_split = slice_in.shape[-1] == slice_out.shape[-1]
135
+ if not horizontal_split:
136
+ if slice_in.shape[-2] != slice_out.shape[-2]:
137
+ raise ValueError("Input and output arrays have different shapes.")
138
+
139
+ # convert to Image
140
+ image_in = _convert_to_image(array_in.shape, slice_in)
141
+ image_out = _convert_to_image(array_out.shape, slice_out)
142
+
143
+ # split horizontally or vertically
144
+ if horizontal_split:
145
+ width = image_in.width // 2
146
+
147
+ cover = Image.new("RGB", (image_in.width, image_in.height))
148
+ cover.paste(image_in.crop((0, 0, width, image_in.height)), (0, 0))
149
+ cover.paste(
150
+ image_out.crop(
151
+ (image_in.width - width, 0, image_in.width, image_in.height)
152
+ ),
153
+ (width, 0),
154
+ )
155
+ else:
156
+ height = image_in.height // 2
157
+
158
+ cover = Image.new("RGB", (image_in.width, image_in.height))
159
+ cover.paste(image_in.crop((0, 0, image_in.width, height)), (0, 0))
160
+ cover.paste(
161
+ image_out.crop(
162
+ (0, image_in.height - height, image_in.width, image_in.height)
163
+ ),
164
+ (0, height),
165
+ )
166
+
167
+ # save
168
+ cover_path = directory / "cover.png"
169
+ cover.save(cover_path)
170
+
171
+ return cover_path