careamics 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +14 -11
- careamics/config/__init__.py +7 -3
- careamics/config/architectures/__init__.py +2 -2
- careamics/config/architectures/architecture_model.py +1 -1
- careamics/config/architectures/custom_model.py +11 -8
- careamics/config/architectures/lvae_model.py +174 -0
- careamics/config/configuration_factory.py +11 -3
- careamics/config/configuration_model.py +7 -3
- careamics/config/data_model.py +33 -8
- careamics/config/{algorithm_model.py → fcn_algorithm_model.py} +28 -43
- careamics/config/likelihood_model.py +43 -0
- careamics/config/nm_model.py +101 -0
- careamics/config/support/supported_activations.py +1 -0
- careamics/config/support/supported_algorithms.py +17 -4
- careamics/config/support/supported_architectures.py +8 -11
- careamics/config/support/supported_losses.py +3 -1
- careamics/config/transformations/n2v_manipulate_model.py +1 -1
- careamics/config/vae_algorithm_model.py +171 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
- careamics/file_io/read/tiff.py +1 -1
- careamics/lightning/__init__.py +3 -2
- careamics/lightning/callbacks/hyperparameters_callback.py +1 -1
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +1 -1
- careamics/lightning/lightning_module.py +365 -9
- careamics/lightning/predict_data_module.py +2 -2
- careamics/lightning/train_data_module.py +2 -2
- careamics/losses/__init__.py +11 -1
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/{losses.py → fcn/losses.py} +1 -1
- careamics/losses/loss_factory.py +112 -6
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +445 -0
- careamics/lvae_training/dataset/__init__.py +0 -0
- careamics/lvae_training/{data_utils.py → dataset/data_utils.py} +277 -194
- careamics/lvae_training/dataset/lc_dataset.py +259 -0
- careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
- careamics/lvae_training/dataset/vae_data_config.py +179 -0
- careamics/lvae_training/{data_modules.py → dataset/vae_dataset.py} +306 -472
- careamics/lvae_training/get_config.py +1 -1
- careamics/lvae_training/train_lvae.py +6 -3
- careamics/model_io/bioimage/bioimage_utils.py +1 -1
- careamics/model_io/bioimage/model_description.py +2 -2
- careamics/model_io/bmz_io.py +19 -6
- careamics/model_io/model_io_utils.py +16 -4
- careamics/models/__init__.py +1 -3
- careamics/models/activation.py +2 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +21 -21
- careamics/models/lvae/likelihoods.py +180 -128
- careamics/models/lvae/lvae.py +52 -136
- careamics/models/lvae/noise_models.py +318 -186
- careamics/models/lvae/utils.py +2 -2
- careamics/models/model_factory.py +22 -7
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/stitch_prediction.py +16 -2
- careamics/transforms/pixel_manipulation.py +1 -1
- careamics/utils/metrics.py +74 -1
- {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/METADATA +2 -2
- {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/RECORD +63 -49
- careamics/config/architectures/vae_model.py +0 -42
- {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/WHEEL +0 -0
- {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""Likelihood model."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal, Optional, Union
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from pydantic import BaseModel, ConfigDict
|
|
7
|
+
|
|
8
|
+
from careamics.models.lvae.noise_models import (
|
|
9
|
+
GaussianMixtureNoiseModel,
|
|
10
|
+
MultiChannelNoiseModel,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
NoiseModel = Union[GaussianMixtureNoiseModel, MultiChannelNoiseModel]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class GaussianLikelihoodConfig(BaseModel):
|
|
17
|
+
"""Gaussian likelihood configuration."""
|
|
18
|
+
|
|
19
|
+
model_config = ConfigDict(validate_assignment=True)
|
|
20
|
+
|
|
21
|
+
predict_logvar: Optional[Literal["pixelwise"]] = None
|
|
22
|
+
"""If `pixelwise`, log-variance is computed for each pixel, else log-variance
|
|
23
|
+
is not computed."""
|
|
24
|
+
|
|
25
|
+
logvar_lowerbound: Union[float, None] = None
|
|
26
|
+
"""The lowerbound value for log-variance."""
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class NMLikelihoodConfig(BaseModel):
|
|
30
|
+
"""Noise model likelihood configuration."""
|
|
31
|
+
|
|
32
|
+
model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)
|
|
33
|
+
|
|
34
|
+
data_mean: Union[torch.Tensor] = torch.zeros(1)
|
|
35
|
+
"""The mean of the data, used to unnormalize data for noise model evaluation.
|
|
36
|
+
Shape is (target_ch,) (or (1, target_ch, [1], 1, 1))."""
|
|
37
|
+
|
|
38
|
+
data_std: Union[torch.Tensor] = torch.ones(1)
|
|
39
|
+
"""The standard deviation of the data, used to unnormalize data for noise
|
|
40
|
+
model evaluation. Shape is (target_ch,) (or (1, target_ch, [1], 1, 1))."""
|
|
41
|
+
|
|
42
|
+
noise_model: Union[NoiseModel, None] = None
|
|
43
|
+
"""The noise model instance used to compute the likelihood."""
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
"""Noise models config."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Literal, Optional, Union
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
|
8
|
+
from typing_extensions import Self
|
|
9
|
+
|
|
10
|
+
# TODO: add histogram-based noise model
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class GaussianMixtureNMConfig(BaseModel):
|
|
14
|
+
"""Gaussian mixture noise model."""
|
|
15
|
+
|
|
16
|
+
model_config = ConfigDict(
|
|
17
|
+
protected_namespaces=(),
|
|
18
|
+
validate_assignment=True,
|
|
19
|
+
arbitrary_types_allowed=True,
|
|
20
|
+
extra="allow",
|
|
21
|
+
)
|
|
22
|
+
# model type
|
|
23
|
+
model_type: Literal["GaussianMixtureNoiseModel"]
|
|
24
|
+
|
|
25
|
+
path: Optional[Union[Path, str]] = None
|
|
26
|
+
"""Path to the directory where the trained noise model (*.npz) is saved in the
|
|
27
|
+
`train` method."""
|
|
28
|
+
|
|
29
|
+
signal: Optional[Union[str, Path, np.ndarray]] = None
|
|
30
|
+
"""Path to the file containing signal or respective numpy array."""
|
|
31
|
+
|
|
32
|
+
observation: Optional[Union[str, Path, np.ndarray]] = None
|
|
33
|
+
"""Path to the file containing observation or respective numpy array."""
|
|
34
|
+
|
|
35
|
+
weight: Optional[np.ndarray] = None
|
|
36
|
+
"""A [3*n_gaussian, n_coeff] sized array containing the values of the weights
|
|
37
|
+
describing the GMM noise model, with each row corresponding to one
|
|
38
|
+
parameter of each gaussian, namely [mean, standard deviation and weight].
|
|
39
|
+
Specifically, rows are organized as follows:
|
|
40
|
+
- first n_gaussian rows correspond to the means
|
|
41
|
+
- next n_gaussian rows correspond to the weights
|
|
42
|
+
- last n_gaussian rows correspond to the standard deviations
|
|
43
|
+
If `weight=None`, the weight array is initialized using the `min_signal`
|
|
44
|
+
and `max_signal` parameters."""
|
|
45
|
+
|
|
46
|
+
n_gaussian: int = Field(default=1, ge=1)
|
|
47
|
+
"""Number of gaussians used for the GMM."""
|
|
48
|
+
|
|
49
|
+
n_coeff: int = Field(default=2, ge=2)
|
|
50
|
+
"""Number of coefficients to describe the functional relationship between gaussian
|
|
51
|
+
parameters and the signal. 2 implies a linear relationship, 3 implies a quadratic
|
|
52
|
+
relationship and so on."""
|
|
53
|
+
|
|
54
|
+
min_signal: float = Field(default=0.0, ge=0.0)
|
|
55
|
+
"""Minimum signal intensity expected in the image."""
|
|
56
|
+
|
|
57
|
+
max_signal: float = Field(default=1.0, ge=0.0)
|
|
58
|
+
"""Maximum signal intensity expected in the image."""
|
|
59
|
+
|
|
60
|
+
min_sigma: float = Field(default=200.0, ge=0.0) # TODO took from nb in pn2v
|
|
61
|
+
"""Minimum value of `standard deviation` allowed in the GMM.
|
|
62
|
+
All values of `standard deviation` below this are clamped to this value."""
|
|
63
|
+
|
|
64
|
+
tol: float = Field(default=1e-10)
|
|
65
|
+
"""Tolerance used in the computation of the noise model likelihood."""
|
|
66
|
+
|
|
67
|
+
@model_validator(mode="after")
|
|
68
|
+
def validate_path_to_pretrained_vs_training_data(self: Self) -> Self:
|
|
69
|
+
"""Validate paths provided in the config.
|
|
70
|
+
|
|
71
|
+
Returns
|
|
72
|
+
-------
|
|
73
|
+
Self
|
|
74
|
+
Returns itself.
|
|
75
|
+
"""
|
|
76
|
+
if self.path and (self.signal is not None or self.observation is not None):
|
|
77
|
+
raise ValueError(
|
|
78
|
+
"Either only 'path' to pre-trained noise model should be"
|
|
79
|
+
"provided or only signal and observation in form of paths"
|
|
80
|
+
"or numpy arrays."
|
|
81
|
+
)
|
|
82
|
+
if not self.path and (self.signal is None or self.observation is None):
|
|
83
|
+
raise ValueError(
|
|
84
|
+
"Either only 'path' to pre-trained noise model should be"
|
|
85
|
+
"provided or only signal and observation in form of paths"
|
|
86
|
+
"or numpy arrays."
|
|
87
|
+
)
|
|
88
|
+
return self
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
# The noise model is given by a set of GMMs, one for each target
|
|
92
|
+
# e.g., 2 target channels, 2 noise models
|
|
93
|
+
class MultiChannelNMConfig(BaseModel):
|
|
94
|
+
"""Noise Model config aggregating noise models for single output channels."""
|
|
95
|
+
|
|
96
|
+
# TODO: check that this model config is OK
|
|
97
|
+
model_config = ConfigDict(
|
|
98
|
+
validate_assignment=True, arbitrary_types_allowed=True, extra="allow"
|
|
99
|
+
)
|
|
100
|
+
noise_models: list[GaussianMixtureNMConfig]
|
|
101
|
+
"""List of noise models, one for each target channel."""
|
|
@@ -6,15 +6,28 @@ from careamics.utils import BaseEnum
|
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class SupportedAlgorithm(str, BaseEnum):
|
|
9
|
-
"""Algorithms available in CAREamics.
|
|
10
|
-
|
|
11
|
-
# TODO
|
|
12
|
-
"""
|
|
9
|
+
"""Algorithms available in CAREamics."""
|
|
13
10
|
|
|
14
11
|
N2V = "n2v"
|
|
12
|
+
"""Noise2Void algorithm, a self-supervised approach based on blind denoising."""
|
|
13
|
+
|
|
15
14
|
CARE = "care"
|
|
15
|
+
"""Content-aware image restoration, a supervised algorithm used for a variety
|
|
16
|
+
of tasks."""
|
|
17
|
+
|
|
16
18
|
N2N = "n2n"
|
|
19
|
+
"""Noise2Noise algorithm, a self-supervised denoising scheme based on comparing
|
|
20
|
+
noisy images of the same sample."""
|
|
21
|
+
|
|
22
|
+
MUSPLIT = "musplit"
|
|
23
|
+
"""An image splitting approach based on ladder VAE architectures."""
|
|
24
|
+
|
|
25
|
+
DENOISPLIT = "denoisplit"
|
|
26
|
+
"""An image splitting and denoising approach based on ladder VAE architectures."""
|
|
27
|
+
|
|
17
28
|
CUSTOM = "custom"
|
|
29
|
+
"""Custom algorithm, used for cases where a custom architecture is provided."""
|
|
30
|
+
|
|
18
31
|
# PN2V = "pn2v"
|
|
19
32
|
# HDN = "hdn"
|
|
20
33
|
# SEG = "segmentation"
|
|
@@ -4,17 +4,14 @@ from careamics.utils import BaseEnum
|
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
class SupportedArchitecture(str, BaseEnum):
|
|
7
|
-
"""Supported architectures.
|
|
7
|
+
"""Supported architectures."""
|
|
8
8
|
|
|
9
|
-
|
|
9
|
+
UNET = "UNet"
|
|
10
|
+
"""UNet architecture used with N2V, CARE and Noise2Noise."""
|
|
10
11
|
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
- Custom: custom model registered with `@register_model` decorator
|
|
14
|
-
"""
|
|
12
|
+
LVAE = "LVAE"
|
|
13
|
+
"""Ladder Variational Autoencoder used for muSplit and denoiSplit."""
|
|
15
14
|
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
"Custom" # TODO all the others tags are small letters, except the architect
|
|
20
|
-
)
|
|
15
|
+
CUSTOM = "custom"
|
|
16
|
+
"""Keyword used for custom architectures provided by users and only compatible
|
|
17
|
+
with `FCNAlgorithmConfig` configuration."""
|
|
@@ -22,6 +22,8 @@ class SupportedLoss(str, BaseEnum):
|
|
|
22
22
|
N2V = "n2v"
|
|
23
23
|
# PN2V = "pn2v"
|
|
24
24
|
# HDN = "hdn"
|
|
25
|
+
MUSPLIT = "musplit"
|
|
26
|
+
DENOISPLIT = "denoisplit"
|
|
27
|
+
DENOISPLIT_MUSPLIT = "denoisplit_musplit"
|
|
25
28
|
# CE = "ce"
|
|
26
29
|
# DICE = "dice"
|
|
27
|
-
# CUSTOM = "custom" # TODO create mechanism for that
|
|
@@ -33,7 +33,7 @@ class N2VManipulateModel(TransformModel):
|
|
|
33
33
|
|
|
34
34
|
name: Literal["N2VManipulate"] = "N2VManipulate"
|
|
35
35
|
roi_size: int = Field(default=11, ge=3, le=21)
|
|
36
|
-
masked_pixel_percentage: float = Field(default=0.2, ge=0.05, le=
|
|
36
|
+
masked_pixel_percentage: float = Field(default=0.2, ge=0.05, le=10.0)
|
|
37
37
|
strategy: Literal["uniform", "median"] = Field(default="uniform")
|
|
38
38
|
struct_mask_axis: Literal["horizontal", "vertical", "none"] = Field(default="none")
|
|
39
39
|
struct_mask_span: int = Field(default=5, ge=3, le=15)
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
"""Algorithm configuration."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pprint import pformat
|
|
6
|
+
from typing import Literal, Optional, Union
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
|
9
|
+
from typing_extensions import Self
|
|
10
|
+
|
|
11
|
+
from careamics.config.support import SupportedAlgorithm, SupportedLoss
|
|
12
|
+
|
|
13
|
+
from .architectures import CustomModel, LVAEModel
|
|
14
|
+
from .likelihood_model import GaussianLikelihoodConfig, NMLikelihoodConfig
|
|
15
|
+
from .nm_model import MultiChannelNMConfig
|
|
16
|
+
from .optimizer_models import LrSchedulerModel, OptimizerModel
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class VAEAlgorithmConfig(BaseModel):
|
|
20
|
+
"""Algorithm configuration.
|
|
21
|
+
|
|
22
|
+
This Pydantic model validates the parameters governing the components of the
|
|
23
|
+
training algorithm: which algorithm, loss function, model architecture, optimizer,
|
|
24
|
+
and learning rate scheduler to use.
|
|
25
|
+
|
|
26
|
+
Currently, we only support N2V, CARE, N2N and custom models. The `n2v` algorithm is
|
|
27
|
+
only compatible with `n2v` loss and `UNet` architecture. The `custom` algorithm
|
|
28
|
+
allows you to register your own architecture and select it using its name as
|
|
29
|
+
`name` in the custom pydantic model.
|
|
30
|
+
|
|
31
|
+
Attributes
|
|
32
|
+
----------
|
|
33
|
+
algorithm : algorithm: Literal["musplit", "denoisplit", "custom"]
|
|
34
|
+
Algorithm to use.
|
|
35
|
+
loss : Literal["musplit", "denoisplit", "denoisplit_musplit"]
|
|
36
|
+
Loss function to use.
|
|
37
|
+
model : Union[LVAEModel, CustomModel]
|
|
38
|
+
Model architecture to use.
|
|
39
|
+
noise_model: Optional[MultiChannelNmModel]
|
|
40
|
+
Noise model to use.
|
|
41
|
+
noise_model_likelihood_model: Optional[NMLikelihoodModel]
|
|
42
|
+
Noise model likelihood model to use.
|
|
43
|
+
gaussian_likelihood_model: Optional[GaussianLikelihoodModel]
|
|
44
|
+
Gaussian likelihood model to use.
|
|
45
|
+
optimizer : OptimizerModel, optional
|
|
46
|
+
Optimizer to use.
|
|
47
|
+
lr_scheduler : LrSchedulerModel, optional
|
|
48
|
+
Learning rate scheduler to use.
|
|
49
|
+
|
|
50
|
+
Raises
|
|
51
|
+
------
|
|
52
|
+
ValueError
|
|
53
|
+
Algorithm parameter type validation errors.
|
|
54
|
+
ValueError
|
|
55
|
+
If the algorithm, loss and model are not compatible.
|
|
56
|
+
|
|
57
|
+
Examples
|
|
58
|
+
--------
|
|
59
|
+
# TODO add once finalized
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
# Pydantic class configuration
|
|
63
|
+
model_config = ConfigDict(
|
|
64
|
+
protected_namespaces=(), # allows to use model_* as a field name
|
|
65
|
+
validate_assignment=True,
|
|
66
|
+
extra="allow",
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
# Mandatory fields
|
|
70
|
+
# defined in SupportedAlgorithm
|
|
71
|
+
# TODO: Use supported Enum classes for typing?
|
|
72
|
+
# - values can still be passed as strings and they will be cast to Enum
|
|
73
|
+
algorithm_type: Literal["vae"]
|
|
74
|
+
algorithm: Literal["musplit", "denoisplit", "custom"]
|
|
75
|
+
loss: Literal["musplit", "denoisplit", "denoisplit_musplit"]
|
|
76
|
+
model: Union[LVAEModel, CustomModel] = Field(discriminator="architecture")
|
|
77
|
+
|
|
78
|
+
# TODO: these are configs, change naming of attrs
|
|
79
|
+
noise_model: Optional[MultiChannelNMConfig] = None
|
|
80
|
+
noise_model_likelihood_model: Optional[NMLikelihoodConfig] = None
|
|
81
|
+
gaussian_likelihood_model: Optional[GaussianLikelihoodConfig] = None
|
|
82
|
+
|
|
83
|
+
# Optional fields
|
|
84
|
+
optimizer: OptimizerModel = OptimizerModel()
|
|
85
|
+
"""Optimizer to use, defined in SupportedOptimizer."""
|
|
86
|
+
|
|
87
|
+
lr_scheduler: LrSchedulerModel = LrSchedulerModel()
|
|
88
|
+
|
|
89
|
+
@model_validator(mode="after")
|
|
90
|
+
def algorithm_cross_validation(self: Self) -> Self:
|
|
91
|
+
"""Validate the algorithm model based on `algorithm`.
|
|
92
|
+
|
|
93
|
+
Returns
|
|
94
|
+
-------
|
|
95
|
+
Self
|
|
96
|
+
The validated model.
|
|
97
|
+
"""
|
|
98
|
+
# musplit
|
|
99
|
+
if self.algorithm == SupportedAlgorithm.MUSPLIT:
|
|
100
|
+
if self.loss != SupportedLoss.MUSPLIT:
|
|
101
|
+
raise ValueError(
|
|
102
|
+
f"Algorithm {self.algorithm} only supports loss `musplit`."
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
if self.algorithm == SupportedAlgorithm.DENOISPLIT:
|
|
106
|
+
if self.loss not in [
|
|
107
|
+
SupportedLoss.DENOISPLIT,
|
|
108
|
+
SupportedLoss.DENOISPLIT_MUSPLIT,
|
|
109
|
+
]:
|
|
110
|
+
raise ValueError(
|
|
111
|
+
f"Algorithm {self.algorithm} only supports loss `denoisplit` "
|
|
112
|
+
"or `denoisplit_musplit."
|
|
113
|
+
)
|
|
114
|
+
if (
|
|
115
|
+
self.loss == SupportedLoss.DENOISPLIT
|
|
116
|
+
and self.model.predict_logvar is not None
|
|
117
|
+
):
|
|
118
|
+
raise ValueError(
|
|
119
|
+
"Algorithm `denoisplit` with loss `denoisplit` only supports "
|
|
120
|
+
"`predict_logvar` as `None`."
|
|
121
|
+
)
|
|
122
|
+
if self.noise_model is None:
|
|
123
|
+
raise ValueError("Algorithm `denoisplit` requires a noise model.")
|
|
124
|
+
# TODO: what if algorithm is not musplit or denoisplit (HDN?)
|
|
125
|
+
return self
|
|
126
|
+
|
|
127
|
+
@model_validator(mode="after")
|
|
128
|
+
def output_channels_validation(self: Self) -> Self:
|
|
129
|
+
"""Validate the consistency between number of out channels and noise models.
|
|
130
|
+
|
|
131
|
+
Returns
|
|
132
|
+
-------
|
|
133
|
+
Self
|
|
134
|
+
The validated model.
|
|
135
|
+
"""
|
|
136
|
+
if self.noise_model is not None:
|
|
137
|
+
assert self.model.output_channels == len(self.noise_model.noise_models), (
|
|
138
|
+
f"Number of output channels ({self.model.output_channels}) must match "
|
|
139
|
+
f"the number of noise models ({len(self.noise_model.noise_models)})."
|
|
140
|
+
)
|
|
141
|
+
return self
|
|
142
|
+
|
|
143
|
+
@model_validator(mode="after")
|
|
144
|
+
def predict_logvar_validation(self: Self) -> Self:
|
|
145
|
+
"""Validate the consistency of `predict_logvar` throughout the model.
|
|
146
|
+
|
|
147
|
+
Returns
|
|
148
|
+
-------
|
|
149
|
+
Self
|
|
150
|
+
The validated model.
|
|
151
|
+
"""
|
|
152
|
+
if self.gaussian_likelihood_model is not None:
|
|
153
|
+
assert (
|
|
154
|
+
self.model.predict_logvar
|
|
155
|
+
== self.gaussian_likelihood_model.predict_logvar
|
|
156
|
+
), (
|
|
157
|
+
f"Model `predict_logvar` ({self.model.predict_logvar}) must match "
|
|
158
|
+
"Gaussian likelihood model `predict_logvar` "
|
|
159
|
+
f"({self.gaussian_likelihood_model.predict_logvar}).",
|
|
160
|
+
)
|
|
161
|
+
return self
|
|
162
|
+
|
|
163
|
+
def __str__(self) -> str:
|
|
164
|
+
"""Pretty string representing the configuration.
|
|
165
|
+
|
|
166
|
+
Returns
|
|
167
|
+
-------
|
|
168
|
+
str
|
|
169
|
+
Pretty string.
|
|
170
|
+
"""
|
|
171
|
+
return pformat(self.model_dump())
|
|
@@ -0,0 +1,282 @@
|
|
|
1
|
+
"""Functions to reimplement the tiling in the Disentangle repository."""
|
|
2
|
+
|
|
3
|
+
import builtins
|
|
4
|
+
import itertools
|
|
5
|
+
from typing import Any, Generator, Optional, Union
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from numpy.typing import NDArray
|
|
9
|
+
|
|
10
|
+
from careamics.config.tile_information import TileInformation
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def extract_tiles(
|
|
14
|
+
arr: NDArray,
|
|
15
|
+
tile_size: NDArray[np.int_],
|
|
16
|
+
overlaps: NDArray[np.int_],
|
|
17
|
+
padding_kwargs: Optional[dict[str, Any]] = None,
|
|
18
|
+
) -> Generator[tuple[NDArray, TileInformation], None, None]:
|
|
19
|
+
"""Generate tiles from the input array with specified overlap.
|
|
20
|
+
|
|
21
|
+
The tiles cover the whole array; which will be additionally padded, to ensure that
|
|
22
|
+
the section of the tile that contributes to the final image comes from the center
|
|
23
|
+
of the tile.
|
|
24
|
+
|
|
25
|
+
The method returns a generator that yields tuples of array and tile information,
|
|
26
|
+
the latter includes whether the tile is the last one, the coordinates of the
|
|
27
|
+
overlap crop, and the coordinates of the stitched tile.
|
|
28
|
+
|
|
29
|
+
Input array should have shape SC(Z)YX, while the returned tiles have shape C(Z)YX,
|
|
30
|
+
where C can be a singleton.
|
|
31
|
+
|
|
32
|
+
Parameters
|
|
33
|
+
----------
|
|
34
|
+
arr : np.ndarray
|
|
35
|
+
Array of shape (S, C, (Z), Y, X).
|
|
36
|
+
tile_size : 1D numpy.ndarray of tuple
|
|
37
|
+
Tile sizes in each dimension, of length 2 or 3.
|
|
38
|
+
overlaps : 1D numpy.ndarray of tuple
|
|
39
|
+
Overlap values in each dimension, of length 2 or 3.
|
|
40
|
+
padding_kwargs : dict, optional
|
|
41
|
+
The arguments of `np.pad` after the first two arguments, `array` and
|
|
42
|
+
`pad_width`. If not specified the default will be `{"mode": "reflect"}`. See
|
|
43
|
+
`numpy.pad` docs:
|
|
44
|
+
https://numpy.org/doc/stable/reference/generated/numpy.pad.html.
|
|
45
|
+
|
|
46
|
+
Yields
|
|
47
|
+
------
|
|
48
|
+
Generator[Tuple[np.ndarray, TileInformation], None, None]
|
|
49
|
+
Tile generator, yields the tile and additional information.
|
|
50
|
+
"""
|
|
51
|
+
if padding_kwargs is None:
|
|
52
|
+
padding_kwargs = {"mode": "reflect"}
|
|
53
|
+
|
|
54
|
+
# Iterate over num samples (S)
|
|
55
|
+
for sample_idx in range(arr.shape[0]):
|
|
56
|
+
sample = arr[sample_idx, ...]
|
|
57
|
+
data_shape = np.array(sample.shape)
|
|
58
|
+
|
|
59
|
+
# add padding to ensure evenly spaced & overlapping tiles.
|
|
60
|
+
spatial_padding = compute_padding(data_shape, tile_size, overlaps)
|
|
61
|
+
padding = ((0, 0), *spatial_padding)
|
|
62
|
+
sample = np.pad(sample, padding, **padding_kwargs)
|
|
63
|
+
|
|
64
|
+
# The number of tiles in each dimension, should be of length 2 or 3
|
|
65
|
+
tile_grid_shape = compute_tile_grid_shape(data_shape, tile_size, overlaps)
|
|
66
|
+
# itertools.product is equivalent of nested loops
|
|
67
|
+
|
|
68
|
+
stitch_size = tile_size - overlaps
|
|
69
|
+
for tile_grid_coords in itertools.product(*[range(n) for n in tile_grid_shape]):
|
|
70
|
+
|
|
71
|
+
# calculate crop coordinates
|
|
72
|
+
crop_coords_start = np.array(tile_grid_coords) * stitch_size
|
|
73
|
+
crop_slices: tuple[Union[builtins.ellipsis, slice], ...] = (
|
|
74
|
+
...,
|
|
75
|
+
*[
|
|
76
|
+
slice(coords, coords + extent)
|
|
77
|
+
for coords, extent in zip(crop_coords_start, tile_size)
|
|
78
|
+
],
|
|
79
|
+
)
|
|
80
|
+
tile = sample[crop_slices]
|
|
81
|
+
|
|
82
|
+
tile_info = compute_tile_info(
|
|
83
|
+
np.array(tile_grid_coords),
|
|
84
|
+
np.array(data_shape),
|
|
85
|
+
np.array(tile_size),
|
|
86
|
+
np.array(overlaps),
|
|
87
|
+
sample_idx,
|
|
88
|
+
)
|
|
89
|
+
# TODO: kinda weird this is a generator,
|
|
90
|
+
# -> doesn't really save memory ? Don't think there are any places the
|
|
91
|
+
# tiles are not exracted all at the same time.
|
|
92
|
+
# Although I guess it would make sense for a zarr tile extractor.
|
|
93
|
+
yield tile, tile_info
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def compute_tile_info(
|
|
97
|
+
tile_grid_coords: NDArray[np.int_],
|
|
98
|
+
data_shape: NDArray[np.int_],
|
|
99
|
+
tile_size: NDArray[np.int_],
|
|
100
|
+
overlaps: NDArray[np.int_],
|
|
101
|
+
sample_id: int = 0,
|
|
102
|
+
) -> TileInformation:
|
|
103
|
+
"""
|
|
104
|
+
Compute the tile information for a tile with the coordinates `tile_grid_coords`.
|
|
105
|
+
|
|
106
|
+
Parameters
|
|
107
|
+
----------
|
|
108
|
+
tile_grid_coords : 1D np.array of int
|
|
109
|
+
The coordinates of the tile within the tile grid, ((Z), Y, X), i.e. for 2D
|
|
110
|
+
tiling the coordinates for the second tile in the first row of tiles would be
|
|
111
|
+
(0, 1).
|
|
112
|
+
data_shape : 1D np.array of int
|
|
113
|
+
The shape of the data, should be (C, (Z), Y, X) where Z is optional.
|
|
114
|
+
tile_size : 1D np.array of int
|
|
115
|
+
Tile sizes in each dimension, of length 2 or 3.
|
|
116
|
+
overlaps : 1D np.array of int
|
|
117
|
+
Overlap values in each dimension, of length 2 or 3.
|
|
118
|
+
sample_id : int, default=0
|
|
119
|
+
An ID to identify which sample a tile belongs to.
|
|
120
|
+
|
|
121
|
+
Returns
|
|
122
|
+
-------
|
|
123
|
+
TileInformation
|
|
124
|
+
Information that describes how to crop and stitch a tile to create a full image.
|
|
125
|
+
"""
|
|
126
|
+
spatial_dims_shape = data_shape[-len(tile_size) :]
|
|
127
|
+
|
|
128
|
+
# The extent of the tile which will make up part of the stitched image.
|
|
129
|
+
stitch_size = tile_size - overlaps
|
|
130
|
+
stitch_coords_start = tile_grid_coords * stitch_size
|
|
131
|
+
stitch_coords_end = stitch_coords_start + stitch_size
|
|
132
|
+
|
|
133
|
+
tile_coords_start = stitch_coords_start - overlaps // 2
|
|
134
|
+
|
|
135
|
+
# --- replace out of bounds indices
|
|
136
|
+
out_of_lower_bound = stitch_coords_start < 0
|
|
137
|
+
out_of_upper_bound = stitch_coords_end > spatial_dims_shape
|
|
138
|
+
stitch_coords_start[out_of_lower_bound] = 0
|
|
139
|
+
stitch_coords_end[out_of_upper_bound] = spatial_dims_shape[out_of_upper_bound]
|
|
140
|
+
|
|
141
|
+
# --- calculate overlap crop coords
|
|
142
|
+
overlap_crop_coords_start = stitch_coords_start - tile_coords_start
|
|
143
|
+
overlap_crop_coords_end = overlap_crop_coords_start + (
|
|
144
|
+
stitch_coords_end - stitch_coords_start
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
# --- combine start and end
|
|
148
|
+
stitch_coords = tuple(
|
|
149
|
+
(start, end) for start, end in zip(stitch_coords_start, stitch_coords_end)
|
|
150
|
+
)
|
|
151
|
+
overlap_crop_coords = tuple(
|
|
152
|
+
(start, end)
|
|
153
|
+
for start, end in zip(overlap_crop_coords_start, overlap_crop_coords_end)
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
# --- Check if last tile
|
|
157
|
+
tile_grid_shape = np.array(compute_tile_grid_shape(data_shape, tile_size, overlaps))
|
|
158
|
+
last_tile = (tile_grid_coords == (tile_grid_shape - 1)).all()
|
|
159
|
+
|
|
160
|
+
tile_info = TileInformation(
|
|
161
|
+
array_shape=data_shape,
|
|
162
|
+
last_tile=last_tile,
|
|
163
|
+
overlap_crop_coords=overlap_crop_coords,
|
|
164
|
+
stitch_coords=stitch_coords,
|
|
165
|
+
sample_id=sample_id,
|
|
166
|
+
)
|
|
167
|
+
return tile_info
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def compute_padding(
|
|
171
|
+
data_shape: NDArray[np.int_],
|
|
172
|
+
tile_size: NDArray[np.int_],
|
|
173
|
+
overlaps: NDArray[np.int_],
|
|
174
|
+
) -> tuple[tuple[int, int], ...]:
|
|
175
|
+
"""
|
|
176
|
+
Calculate padding to ensure stitched data comes from the center of a tile.
|
|
177
|
+
|
|
178
|
+
Padding is added to an array with shape `data_shape` so that when tiles are
|
|
179
|
+
stitched together, the data used always comes from the center of a tile, even for
|
|
180
|
+
tiles at the boundaries of the array.
|
|
181
|
+
|
|
182
|
+
Parameters
|
|
183
|
+
----------
|
|
184
|
+
data_shape : 1D numpy.array of int
|
|
185
|
+
The shape of the data to be tiled and stitched together, (S, C, (Z), Y, X).
|
|
186
|
+
tile_size : 1D numpy.array of int
|
|
187
|
+
The tile size in each dimension, ((Z), Y, X).
|
|
188
|
+
overlaps : 1D numpy.array of int
|
|
189
|
+
The tile overlap in each dimension, ((Z), Y, X).
|
|
190
|
+
|
|
191
|
+
Returns
|
|
192
|
+
-------
|
|
193
|
+
tuple of (int, int)
|
|
194
|
+
A tuple specifying the padding to add in each dimension, each element is a two
|
|
195
|
+
element tuple specifying the padding to add before and after the data. This
|
|
196
|
+
can be used as the `pad_width` argument to `numpy.pad`.
|
|
197
|
+
"""
|
|
198
|
+
tile_grid_shape = np.array(compute_tile_grid_shape(data_shape, tile_size, overlaps))
|
|
199
|
+
covered_shape = (tile_size - overlaps) * tile_grid_shape + overlaps
|
|
200
|
+
|
|
201
|
+
pad_before = overlaps // 2
|
|
202
|
+
pad_after = covered_shape - data_shape[-len(tile_size) :] - pad_before
|
|
203
|
+
|
|
204
|
+
return tuple((before, after) for before, after in zip(pad_before, pad_after))
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def n_tiles_1d(axis_size: int, tile_size: int, overlap: int) -> int:
|
|
208
|
+
"""Calculate the number of tiles in a specific dimension.
|
|
209
|
+
|
|
210
|
+
Parameters
|
|
211
|
+
----------
|
|
212
|
+
axis_size : int
|
|
213
|
+
The length of the data for in a specific dimension.
|
|
214
|
+
tile_size : int
|
|
215
|
+
The length of the tiles in a specific dimension.
|
|
216
|
+
overlap : int
|
|
217
|
+
The tile overlap in a specific dimension.
|
|
218
|
+
|
|
219
|
+
Returns
|
|
220
|
+
-------
|
|
221
|
+
int
|
|
222
|
+
The number of tiles that fit in one dimension given the arguments.
|
|
223
|
+
"""
|
|
224
|
+
return int(np.ceil(axis_size / (tile_size - overlap)))
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def total_n_tiles(
|
|
228
|
+
data_shape: tuple[int, ...], tile_size: tuple[int, ...], overlaps: tuple[int, ...]
|
|
229
|
+
) -> int:
|
|
230
|
+
"""Calculate The total number of tiles over all dimensions.
|
|
231
|
+
|
|
232
|
+
Parameters
|
|
233
|
+
----------
|
|
234
|
+
data_shape : 1D numpy.array of int
|
|
235
|
+
The shape of the data to be tiled and stitched together, (S, C, (Z), Y, X).
|
|
236
|
+
tile_size : 1D numpy.array of int
|
|
237
|
+
The tile size in each dimension, ((Z), Y, X).
|
|
238
|
+
overlaps : 1D numpy.array of int
|
|
239
|
+
The tile overlap in each dimension, ((Z), Y, X).
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
Returns
|
|
243
|
+
-------
|
|
244
|
+
int
|
|
245
|
+
The total number of tiles over all dimensions.
|
|
246
|
+
"""
|
|
247
|
+
result = 1
|
|
248
|
+
# assume spatial dimension are the last dimensions so iterate backwards
|
|
249
|
+
for i in range(-1, -len(tile_size) - 1, -1):
|
|
250
|
+
result = result * n_tiles_1d(data_shape[i], tile_size[i], overlaps[i])
|
|
251
|
+
|
|
252
|
+
return result
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def compute_tile_grid_shape(
|
|
256
|
+
data_shape: NDArray[np.int_],
|
|
257
|
+
tile_size: NDArray[np.int_],
|
|
258
|
+
overlaps: NDArray[np.int_],
|
|
259
|
+
) -> tuple[int, ...]:
|
|
260
|
+
"""Calculate the number of tiles in each dimension.
|
|
261
|
+
|
|
262
|
+
This can be thought of as a grid of tiles.
|
|
263
|
+
|
|
264
|
+
Parameters
|
|
265
|
+
----------
|
|
266
|
+
data_shape : 1D numpy.array of int
|
|
267
|
+
The shape of the data to be tiled and stitched together, (S, C, (Z), Y, X).
|
|
268
|
+
tile_size : 1D numpy.array of int
|
|
269
|
+
The tile size in each dimension, ((Z), Y, X).
|
|
270
|
+
overlaps : 1D numpy.array of int
|
|
271
|
+
The tile overlap in each dimension, ((Z), Y, X).
|
|
272
|
+
|
|
273
|
+
Returns
|
|
274
|
+
-------
|
|
275
|
+
tuple of int
|
|
276
|
+
The number of tiles in each direction, ((Z, Y, X)).
|
|
277
|
+
"""
|
|
278
|
+
shape = [0 for _ in range(len(tile_size))]
|
|
279
|
+
# assume spatial dimension are the last dimensions so iterate backwards
|
|
280
|
+
for i in range(-1, -len(tile_size) - 1, -1):
|
|
281
|
+
shape[i] = n_tiles_1d(data_shape[i], tile_size[i], overlaps[i])
|
|
282
|
+
return tuple(shape)
|