careamics 0.0.3__py3-none-any.whl → 0.0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +25 -17
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +391 -0
- careamics/cli/main.py +134 -0
- careamics/config/architectures/lvae_model.py +0 -4
- careamics/config/configuration_factory.py +480 -177
- careamics/config/configuration_model.py +1 -2
- careamics/config/data_model.py +1 -15
- careamics/config/fcn_algorithm_model.py +14 -9
- careamics/config/likelihood_model.py +21 -4
- careamics/config/nm_model.py +31 -5
- careamics/config/optimizer_models.py +3 -1
- careamics/config/support/supported_optimizers.py +1 -1
- careamics/config/support/supported_transforms.py +1 -0
- careamics/config/training_model.py +35 -6
- careamics/config/transformations/__init__.py +4 -1
- careamics/config/transformations/transform_union.py +20 -0
- careamics/config/vae_algorithm_model.py +2 -36
- careamics/dataset/tiling/lvae_tiled_patching.py +90 -8
- careamics/lightning/lightning_module.py +10 -8
- careamics/lightning/train_data_module.py +2 -2
- careamics/losses/loss_factory.py +3 -3
- careamics/losses/lvae/losses.py +2 -2
- careamics/lvae_training/dataset/__init__.py +15 -0
- careamics/lvae_training/dataset/{vae_data_config.py → config.py} +25 -81
- careamics/lvae_training/dataset/lc_dataset.py +28 -20
- careamics/lvae_training/dataset/{vae_dataset.py → multich_dataset.py} +91 -51
- careamics/lvae_training/dataset/multifile_dataset.py +334 -0
- careamics/lvae_training/dataset/types.py +43 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +232 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +109 -64
- careamics/lvae_training/get_config.py +1 -1
- careamics/lvae_training/train_lvae.py +1 -1
- careamics/model_io/bioimage/bioimage_utils.py +4 -2
- careamics/model_io/bmz_io.py +6 -5
- careamics/models/lvae/likelihoods.py +18 -9
- careamics/models/lvae/lvae.py +12 -16
- careamics/models/lvae/noise_models.py +1 -1
- careamics/transforms/compose.py +90 -15
- careamics/transforms/n2v_manipulate.py +6 -2
- careamics/transforms/normalize.py +14 -3
- careamics/transforms/xy_flip.py +16 -6
- careamics/transforms/xy_random_rotate90.py +16 -7
- careamics/utils/metrics.py +204 -24
- careamics/utils/serializers.py +60 -0
- {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/METADATA +4 -3
- {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/RECORD +54 -43
- careamics-0.0.4.1.dist-info/entry_points.txt +2 -0
- careamics/lvae_training/dataset/data_utils.py +0 -701
- careamics/lvae_training/dataset/lc_dataset_config.py +0 -13
- {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/WHEEL +0 -0
- {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -124,7 +124,6 @@ class Configuration(BaseModel):
|
|
|
124
124
|
>>> config_dict = {
|
|
125
125
|
... "experiment_name": "N2V_experiment",
|
|
126
126
|
... "algorithm_config": {
|
|
127
|
-
... "algorithm_type": "fcn",
|
|
128
127
|
... "algorithm": "n2v",
|
|
129
128
|
... "loss": "n2v",
|
|
130
129
|
... "model": {
|
|
@@ -158,7 +157,7 @@ class Configuration(BaseModel):
|
|
|
158
157
|
|
|
159
158
|
# Sub-configurations
|
|
160
159
|
algorithm_config: Union[FCNAlgorithmConfig, VAEAlgorithmConfig] = Field(
|
|
161
|
-
discriminator="
|
|
160
|
+
discriminator="algorithm"
|
|
162
161
|
)
|
|
163
162
|
"""Algorithm configuration, holding all parameters required to configure the
|
|
164
163
|
model."""
|
careamics/config/data_model.py
CHANGED
|
@@ -10,7 +10,6 @@ from numpy.typing import NDArray
|
|
|
10
10
|
from pydantic import (
|
|
11
11
|
BaseModel,
|
|
12
12
|
ConfigDict,
|
|
13
|
-
Discriminator,
|
|
14
13
|
Field,
|
|
15
14
|
PlainSerializer,
|
|
16
15
|
field_validator,
|
|
@@ -19,9 +18,7 @@ from pydantic import (
|
|
|
19
18
|
from typing_extensions import Annotated, Self
|
|
20
19
|
|
|
21
20
|
from .support import SupportedTransform
|
|
22
|
-
from .transformations
|
|
23
|
-
from .transformations.xy_flip_model import XYFlipModel
|
|
24
|
-
from .transformations.xy_random_rotate90_model import XYRandomRotate90Model
|
|
21
|
+
from .transformations import TRANSFORMS_UNION, N2VManipulateModel
|
|
25
22
|
from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2
|
|
26
23
|
|
|
27
24
|
|
|
@@ -48,17 +45,6 @@ Float = Annotated[float, PlainSerializer(np_float_to_scientific_str, return_type
|
|
|
48
45
|
"""Annotated float type, used to serialize floats to strings."""
|
|
49
46
|
|
|
50
47
|
|
|
51
|
-
TRANSFORMS_UNION = Annotated[
|
|
52
|
-
Union[
|
|
53
|
-
XYFlipModel,
|
|
54
|
-
XYRandomRotate90Model,
|
|
55
|
-
N2VManipulateModel,
|
|
56
|
-
],
|
|
57
|
-
Discriminator("name"), # used to tell the different transform models apart
|
|
58
|
-
]
|
|
59
|
-
"""Available transforms in CAREamics."""
|
|
60
|
-
|
|
61
|
-
|
|
62
48
|
class DataConfig(BaseModel):
|
|
63
49
|
"""
|
|
64
50
|
Data configuration.
|
|
@@ -24,11 +24,11 @@ class FCNAlgorithmConfig(BaseModel):
|
|
|
24
24
|
|
|
25
25
|
Attributes
|
|
26
26
|
----------
|
|
27
|
-
algorithm :
|
|
27
|
+
algorithm : {"n2v", "care", "n2n", "custom"}
|
|
28
28
|
Algorithm to use.
|
|
29
|
-
loss :
|
|
29
|
+
loss : {"n2v", "mae", "mse"}
|
|
30
30
|
Loss function to use.
|
|
31
|
-
model :
|
|
31
|
+
model : UNetModel or CustomModel
|
|
32
32
|
Model architecture to use.
|
|
33
33
|
optimizer : OptimizerModel, optional
|
|
34
34
|
Optimizer to use.
|
|
@@ -48,7 +48,6 @@ class FCNAlgorithmConfig(BaseModel):
|
|
|
48
48
|
>>> from careamics.config import FCNAlgorithmConfig
|
|
49
49
|
>>> config_dict = {
|
|
50
50
|
... "algorithm": "n2v",
|
|
51
|
-
... "algorithm_type": "fcn",
|
|
52
51
|
... "loss": "n2v",
|
|
53
52
|
... "model": {
|
|
54
53
|
... "architecture": "UNet",
|
|
@@ -65,11 +64,6 @@ class FCNAlgorithmConfig(BaseModel):
|
|
|
65
64
|
)
|
|
66
65
|
|
|
67
66
|
# Mandatory fields
|
|
68
|
-
# defined in SupportedAlgorithm
|
|
69
|
-
algorithm_type: Literal["fcn"]
|
|
70
|
-
"""Algorithm type must be `fcn` (fully convolutional network) to differentiate this
|
|
71
|
-
configuration from LVAE."""
|
|
72
|
-
|
|
73
67
|
algorithm: Literal["n2v", "care", "n2n", "custom"]
|
|
74
68
|
"""Name of the algorithm, as defined in SupportedAlgorithm. Use `custom` for custom
|
|
75
69
|
model architecture."""
|
|
@@ -145,3 +139,14 @@ class FCNAlgorithmConfig(BaseModel):
|
|
|
145
139
|
Pretty string.
|
|
146
140
|
"""
|
|
147
141
|
return pformat(self.model_dump())
|
|
142
|
+
|
|
143
|
+
@classmethod
|
|
144
|
+
def get_compatible_algorithms(cls) -> list[str]:
|
|
145
|
+
"""Get the list of compatible algorithms.
|
|
146
|
+
|
|
147
|
+
Returns
|
|
148
|
+
-------
|
|
149
|
+
list of str
|
|
150
|
+
List of compatible algorithms.
|
|
151
|
+
"""
|
|
152
|
+
return ["n2v", "care", "n2n"]
|
|
@@ -2,16 +2,30 @@
|
|
|
2
2
|
|
|
3
3
|
from typing import Literal, Optional, Union
|
|
4
4
|
|
|
5
|
+
import numpy as np
|
|
5
6
|
import torch
|
|
6
|
-
from pydantic import BaseModel, ConfigDict
|
|
7
|
+
from pydantic import BaseModel, ConfigDict, Field, PlainSerializer, PlainValidator
|
|
8
|
+
from typing_extensions import Annotated
|
|
7
9
|
|
|
8
10
|
from careamics.models.lvae.noise_models import (
|
|
9
11
|
GaussianMixtureNoiseModel,
|
|
10
12
|
MultiChannelNoiseModel,
|
|
11
13
|
)
|
|
14
|
+
from careamics.utils.serializers import _array_to_json, _to_torch
|
|
12
15
|
|
|
13
16
|
NoiseModel = Union[GaussianMixtureNoiseModel, MultiChannelNoiseModel]
|
|
14
17
|
|
|
18
|
+
# TODO: this is a temporary solution to serialize and deserialize tensor fields
|
|
19
|
+
# in pydantic models. Specifically, the aim is to enable saving and loading configs
|
|
20
|
+
# with such tensors to/from JSON files during, resp., training and evaluation.
|
|
21
|
+
Tensor = Annotated[
|
|
22
|
+
Union[np.ndarray, torch.Tensor],
|
|
23
|
+
PlainSerializer(_array_to_json, return_type=str),
|
|
24
|
+
PlainValidator(_to_torch),
|
|
25
|
+
]
|
|
26
|
+
"""Annotated tensor type, used to serialize arrays or tensors to JSON strings
|
|
27
|
+
and deserialize them back to tensors."""
|
|
28
|
+
|
|
15
29
|
|
|
16
30
|
class GaussianLikelihoodConfig(BaseModel):
|
|
17
31
|
"""Gaussian likelihood configuration."""
|
|
@@ -31,13 +45,16 @@ class NMLikelihoodConfig(BaseModel):
|
|
|
31
45
|
|
|
32
46
|
model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)
|
|
33
47
|
|
|
34
|
-
|
|
48
|
+
# TODO remove and use as parameters to the likelihood functions?
|
|
49
|
+
data_mean: Tensor = torch.zeros(1)
|
|
35
50
|
"""The mean of the data, used to unnormalize data for noise model evaluation.
|
|
36
51
|
Shape is (target_ch,) (or (1, target_ch, [1], 1, 1))."""
|
|
37
52
|
|
|
38
|
-
|
|
53
|
+
# TODO remove and use as parameters to the likelihood functions?
|
|
54
|
+
data_std: Tensor = torch.ones(1)
|
|
39
55
|
"""The standard deviation of the data, used to unnormalize data for noise
|
|
40
56
|
model evaluation. Shape is (target_ch,) (or (1, target_ch, [1], 1, 1))."""
|
|
41
57
|
|
|
42
|
-
|
|
58
|
+
# TODO: serialization/deserialization for this
|
|
59
|
+
noise_model: Optional[NoiseModel] = Field(default=None, exclude=True)
|
|
43
60
|
"""The noise model instance used to compute the likelihood."""
|
careamics/config/nm_model.py
CHANGED
|
@@ -4,8 +4,30 @@ from pathlib import Path
|
|
|
4
4
|
from typing import Literal, Optional, Union
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
|
-
|
|
8
|
-
from
|
|
7
|
+
import torch
|
|
8
|
+
from pydantic import (
|
|
9
|
+
BaseModel,
|
|
10
|
+
ConfigDict,
|
|
11
|
+
Field,
|
|
12
|
+
PlainSerializer,
|
|
13
|
+
PlainValidator,
|
|
14
|
+
model_validator,
|
|
15
|
+
)
|
|
16
|
+
from typing_extensions import Annotated, Self
|
|
17
|
+
|
|
18
|
+
from careamics.utils.serializers import _array_to_json, _to_numpy
|
|
19
|
+
|
|
20
|
+
# TODO: this is a temporary solution to serialize and deserialize array fields
|
|
21
|
+
# in pydantic models. Specifically, the aim is to enable saving and loading configs
|
|
22
|
+
# with such arrays to/from JSON files during, resp., training and evaluation.
|
|
23
|
+
Array = Annotated[
|
|
24
|
+
Union[np.ndarray, torch.Tensor],
|
|
25
|
+
PlainSerializer(_array_to_json, return_type=str),
|
|
26
|
+
PlainValidator(_to_numpy),
|
|
27
|
+
]
|
|
28
|
+
"""Annotated array type, used to serialize arrays or tensors to JSON strings
|
|
29
|
+
and deserialize them back to arrays."""
|
|
30
|
+
|
|
9
31
|
|
|
10
32
|
# TODO: add histogram-based noise model
|
|
11
33
|
|
|
@@ -26,13 +48,17 @@ class GaussianMixtureNMConfig(BaseModel):
|
|
|
26
48
|
"""Path to the directory where the trained noise model (*.npz) is saved in the
|
|
27
49
|
`train` method."""
|
|
28
50
|
|
|
29
|
-
|
|
51
|
+
# TODO remove and use as parameters to the NM functions?
|
|
52
|
+
signal: Optional[Union[str, Path, np.ndarray]] = Field(default=None, exclude=True)
|
|
30
53
|
"""Path to the file containing signal or respective numpy array."""
|
|
31
54
|
|
|
32
|
-
|
|
55
|
+
# TODO remove and use as parameters to the NM functions?
|
|
56
|
+
observation: Optional[Union[str, Path, np.ndarray]] = Field(
|
|
57
|
+
default=None, exclude=True
|
|
58
|
+
)
|
|
33
59
|
"""Path to the file containing observation or respective numpy array."""
|
|
34
60
|
|
|
35
|
-
weight: Optional[
|
|
61
|
+
weight: Optional[Array] = None
|
|
36
62
|
"""A [3*n_gaussian, n_coeff] sized array containing the values of the weights
|
|
37
63
|
describing the GMM noise model, with each row corresponding to one
|
|
38
64
|
parameter of each gaussian, namely [mean, standard deviation and weight].
|
|
@@ -44,7 +44,9 @@ class OptimizerModel(BaseModel):
|
|
|
44
44
|
)
|
|
45
45
|
|
|
46
46
|
# Mandatory field
|
|
47
|
-
name: Literal["Adam", "SGD"] = Field(
|
|
47
|
+
name: Literal["Adam", "SGD", "Adamax"] = Field(
|
|
48
|
+
default="Adam", validate_default=True
|
|
49
|
+
)
|
|
48
50
|
"""Name of the optimizer, supported optimizers are defined in SupportedOptimizer."""
|
|
49
51
|
|
|
50
52
|
# Optional parameters, empty dict default value to allow filtering dictionary
|
|
@@ -3,13 +3,9 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
from pprint import pformat
|
|
6
|
-
from typing import Literal, Optional
|
|
6
|
+
from typing import Literal, Optional, Union
|
|
7
7
|
|
|
8
|
-
from pydantic import
|
|
9
|
-
BaseModel,
|
|
10
|
-
ConfigDict,
|
|
11
|
-
Field,
|
|
12
|
-
)
|
|
8
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
13
9
|
|
|
14
10
|
from .callback_model import CheckpointModel, EarlyStoppingModel
|
|
15
11
|
|
|
@@ -37,6 +33,20 @@ class TrainingConfig(BaseModel):
|
|
|
37
33
|
num_epochs: int = Field(default=20, ge=1)
|
|
38
34
|
"""Number of epochs, greater than 0."""
|
|
39
35
|
|
|
36
|
+
precision: Literal["64", "32", "16-mixed", "bf16-mixed"] = Field(default="32")
|
|
37
|
+
"""Numerical precision"""
|
|
38
|
+
max_steps: int = Field(default=-1, ge=-1)
|
|
39
|
+
"""Maximum number of steps to train for. -1 means no limit."""
|
|
40
|
+
check_val_every_n_epoch: int = Field(default=1, ge=1)
|
|
41
|
+
"""Validation step frequency."""
|
|
42
|
+
enable_progress_bar: bool = Field(default=True)
|
|
43
|
+
"""Whether to enable the progress bar."""
|
|
44
|
+
accumulate_grad_batches: int = Field(default=1, ge=1)
|
|
45
|
+
"""Number of batches to accumulate gradients over before stepping the optimizer."""
|
|
46
|
+
gradient_clip_val: Optional[Union[int, float]] = None
|
|
47
|
+
"""The value to which to clip the gradient"""
|
|
48
|
+
gradient_clip_algorithm: Literal["value", "norm"] = "norm"
|
|
49
|
+
"""The algorithm to use for gradient clipping (see lightning `Trainer`)."""
|
|
40
50
|
logger: Optional[Literal["wandb", "tensorboard"]] = None
|
|
41
51
|
"""Logger to use during training. If None, no logger will be used. Available
|
|
42
52
|
loggers are defined in SupportedLogger."""
|
|
@@ -70,3 +80,22 @@ class TrainingConfig(BaseModel):
|
|
|
70
80
|
Whether the logger is defined or not.
|
|
71
81
|
"""
|
|
72
82
|
return self.logger is not None
|
|
83
|
+
|
|
84
|
+
@field_validator("max_steps")
|
|
85
|
+
@classmethod
|
|
86
|
+
def validate_max_steps(cls, max_steps: int) -> int:
|
|
87
|
+
"""Validate the max_steps parameter.
|
|
88
|
+
|
|
89
|
+
Parameters
|
|
90
|
+
----------
|
|
91
|
+
max_steps : int
|
|
92
|
+
Maximum number of steps to train for. -1 means no limit.
|
|
93
|
+
|
|
94
|
+
Returns
|
|
95
|
+
-------
|
|
96
|
+
int
|
|
97
|
+
Validated max_steps.
|
|
98
|
+
"""
|
|
99
|
+
if max_steps == 0:
|
|
100
|
+
raise ValueError("max_steps must be greater than 0. Use -1 for no limit.")
|
|
101
|
+
return max_steps
|
|
@@ -5,11 +5,14 @@ __all__ = [
|
|
|
5
5
|
"XYFlipModel",
|
|
6
6
|
"NormalizeModel",
|
|
7
7
|
"XYRandomRotate90Model",
|
|
8
|
-
"
|
|
8
|
+
"TransformModel",
|
|
9
|
+
"TRANSFORMS_UNION",
|
|
9
10
|
]
|
|
10
11
|
|
|
11
12
|
|
|
12
13
|
from .n2v_manipulate_model import N2VManipulateModel
|
|
13
14
|
from .normalize_model import NormalizeModel
|
|
15
|
+
from .transform_model import TransformModel
|
|
16
|
+
from .transform_union import TRANSFORMS_UNION
|
|
14
17
|
from .xy_flip_model import XYFlipModel
|
|
15
18
|
from .xy_random_rotate90_model import XYRandomRotate90Model
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""Type used to represent all transformations users can create."""
|
|
2
|
+
|
|
3
|
+
from typing import Union
|
|
4
|
+
|
|
5
|
+
from pydantic import Discriminator
|
|
6
|
+
from typing_extensions import Annotated
|
|
7
|
+
|
|
8
|
+
from .n2v_manipulate_model import N2VManipulateModel
|
|
9
|
+
from .xy_flip_model import XYFlipModel
|
|
10
|
+
from .xy_random_rotate90_model import XYRandomRotate90Model
|
|
11
|
+
|
|
12
|
+
TRANSFORMS_UNION = Annotated[
|
|
13
|
+
Union[
|
|
14
|
+
XYFlipModel,
|
|
15
|
+
XYRandomRotate90Model,
|
|
16
|
+
N2VManipulateModel,
|
|
17
|
+
],
|
|
18
|
+
Discriminator("name"), # used to tell the different transform models apart
|
|
19
|
+
]
|
|
20
|
+
"""Available transforms in CAREamics."""
|
|
@@ -19,40 +19,7 @@ from .optimizer_models import LrSchedulerModel, OptimizerModel
|
|
|
19
19
|
class VAEAlgorithmConfig(BaseModel):
|
|
20
20
|
"""Algorithm configuration.
|
|
21
21
|
|
|
22
|
-
|
|
23
|
-
training algorithm: which algorithm, loss function, model architecture, optimizer,
|
|
24
|
-
and learning rate scheduler to use.
|
|
25
|
-
|
|
26
|
-
Currently, we only support N2V, CARE, N2N and custom models. The `n2v` algorithm is
|
|
27
|
-
only compatible with `n2v` loss and `UNet` architecture. The `custom` algorithm
|
|
28
|
-
allows you to register your own architecture and select it using its name as
|
|
29
|
-
`name` in the custom pydantic model.
|
|
30
|
-
|
|
31
|
-
Attributes
|
|
32
|
-
----------
|
|
33
|
-
algorithm : algorithm: Literal["musplit", "denoisplit", "custom"]
|
|
34
|
-
Algorithm to use.
|
|
35
|
-
loss : Literal["musplit", "denoisplit", "denoisplit_musplit"]
|
|
36
|
-
Loss function to use.
|
|
37
|
-
model : Union[LVAEModel, CustomModel]
|
|
38
|
-
Model architecture to use.
|
|
39
|
-
noise_model: Optional[MultiChannelNmModel]
|
|
40
|
-
Noise model to use.
|
|
41
|
-
noise_model_likelihood_model: Optional[NMLikelihoodModel]
|
|
42
|
-
Noise model likelihood model to use.
|
|
43
|
-
gaussian_likelihood_model: Optional[GaussianLikelihoodModel]
|
|
44
|
-
Gaussian likelihood model to use.
|
|
45
|
-
optimizer : OptimizerModel, optional
|
|
46
|
-
Optimizer to use.
|
|
47
|
-
lr_scheduler : LrSchedulerModel, optional
|
|
48
|
-
Learning rate scheduler to use.
|
|
49
|
-
|
|
50
|
-
Raises
|
|
51
|
-
------
|
|
52
|
-
ValueError
|
|
53
|
-
Algorithm parameter type validation errors.
|
|
54
|
-
ValueError
|
|
55
|
-
If the algorithm, loss and model are not compatible.
|
|
22
|
+
# TODO
|
|
56
23
|
|
|
57
24
|
Examples
|
|
58
25
|
--------
|
|
@@ -70,8 +37,7 @@ class VAEAlgorithmConfig(BaseModel):
|
|
|
70
37
|
# defined in SupportedAlgorithm
|
|
71
38
|
# TODO: Use supported Enum classes for typing?
|
|
72
39
|
# - values can still be passed as strings and they will be cast to Enum
|
|
73
|
-
|
|
74
|
-
algorithm: Literal["musplit", "denoisplit", "custom"]
|
|
40
|
+
algorithm: Literal["musplit", "denoisplit"]
|
|
75
41
|
loss: Literal["musplit", "denoisplit", "denoisplit_musplit"]
|
|
76
42
|
model: Union[LVAEModel, CustomModel] = Field(discriminator="architecture")
|
|
77
43
|
|
|
@@ -8,6 +8,7 @@ import numpy as np
|
|
|
8
8
|
from numpy.typing import NDArray
|
|
9
9
|
|
|
10
10
|
from careamics.config.tile_information import TileInformation
|
|
11
|
+
from careamics.lvae_training.dataset.utils.index_manager import GridIndexManager
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
def extract_tiles(
|
|
@@ -66,10 +67,12 @@ def extract_tiles(
|
|
|
66
67
|
# itertools.product is equivalent of nested loops
|
|
67
68
|
|
|
68
69
|
stitch_size = tile_size - overlaps
|
|
69
|
-
for
|
|
70
|
+
for tile_grid_indices in itertools.product(
|
|
71
|
+
*[range(n) for n in tile_grid_shape]
|
|
72
|
+
):
|
|
70
73
|
|
|
71
74
|
# calculate crop coordinates
|
|
72
|
-
crop_coords_start = np.array(
|
|
75
|
+
crop_coords_start = np.array(tile_grid_indices) * stitch_size
|
|
73
76
|
crop_slices: tuple[Union[builtins.ellipsis, slice], ...] = (
|
|
74
77
|
...,
|
|
75
78
|
*[
|
|
@@ -80,7 +83,7 @@ def extract_tiles(
|
|
|
80
83
|
tile = sample[crop_slices]
|
|
81
84
|
|
|
82
85
|
tile_info = compute_tile_info(
|
|
83
|
-
np.array(
|
|
86
|
+
np.array(tile_grid_indices),
|
|
84
87
|
np.array(data_shape),
|
|
85
88
|
np.array(tile_size),
|
|
86
89
|
np.array(overlaps),
|
|
@@ -93,19 +96,98 @@ def extract_tiles(
|
|
|
93
96
|
yield tile, tile_info
|
|
94
97
|
|
|
95
98
|
|
|
99
|
+
def compute_tile_info_legacy(
|
|
100
|
+
grid_index_manager: GridIndexManager, index: int
|
|
101
|
+
) -> TileInformation:
|
|
102
|
+
"""
|
|
103
|
+
Compute the tile information for a tile at a given dataset index.
|
|
104
|
+
|
|
105
|
+
Parameters
|
|
106
|
+
----------
|
|
107
|
+
grid_index_manager : GridIndexManager
|
|
108
|
+
The grid index manager that keeps track of tile locations.
|
|
109
|
+
index : int
|
|
110
|
+
The dataset index.
|
|
111
|
+
|
|
112
|
+
Returns
|
|
113
|
+
-------
|
|
114
|
+
TileInformation
|
|
115
|
+
Information that describes how to crop and stitch a tile to create a full image.
|
|
116
|
+
|
|
117
|
+
Raises
|
|
118
|
+
------
|
|
119
|
+
ValueError
|
|
120
|
+
If `grid_index_manager.data_shape` does not have 4 or 5 dimensions.
|
|
121
|
+
"""
|
|
122
|
+
data_shape = np.array(grid_index_manager.data_shape)
|
|
123
|
+
if len(data_shape) == 5:
|
|
124
|
+
n_spatial_dims = 3
|
|
125
|
+
elif len(data_shape) == 4:
|
|
126
|
+
n_spatial_dims = 2
|
|
127
|
+
else:
|
|
128
|
+
raise ValueError("Data shape must have 4 or 5 dimensions, equating to SC(Z)YX.")
|
|
129
|
+
|
|
130
|
+
stitch_coords_start = np.array(
|
|
131
|
+
grid_index_manager.get_location_from_dataset_idx(index)
|
|
132
|
+
)
|
|
133
|
+
stitch_coords_end = stitch_coords_start + np.array(grid_index_manager.grid_shape)
|
|
134
|
+
|
|
135
|
+
tile_coords_start = stitch_coords_start - grid_index_manager.patch_offset()
|
|
136
|
+
|
|
137
|
+
# --- replace out of bounds indices
|
|
138
|
+
out_of_lower_bound = stitch_coords_start < 0
|
|
139
|
+
out_of_upper_bound = stitch_coords_end > data_shape
|
|
140
|
+
stitch_coords_start[out_of_lower_bound] = 0
|
|
141
|
+
stitch_coords_end[out_of_upper_bound] = data_shape[out_of_upper_bound]
|
|
142
|
+
|
|
143
|
+
# TODO: TilingMode not in current version
|
|
144
|
+
# if grid_index_manager.tiling_mode == TilingMode.ShiftBoundary:
|
|
145
|
+
# for dim in range(len(stitch_coords_start)):
|
|
146
|
+
# if tile_coords_start[dim] == 0:
|
|
147
|
+
# stitch_coords_start[dim] = 0
|
|
148
|
+
# if tile_coords_end[dim] == grid_index_manager.data_shape[dim]:
|
|
149
|
+
# tile_coords_end [dim]= grid_index_manager.data_shape[dim]
|
|
150
|
+
|
|
151
|
+
# --- calculate overlap crop coords
|
|
152
|
+
overlap_crop_coords_start = stitch_coords_start - tile_coords_start
|
|
153
|
+
overlap_crop_coords_end = overlap_crop_coords_start + (
|
|
154
|
+
stitch_coords_end - stitch_coords_start
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
last_tile = index == grid_index_manager.total_grid_count() - 1
|
|
158
|
+
|
|
159
|
+
# --- combine start and end
|
|
160
|
+
stitch_coords = tuple(
|
|
161
|
+
(start, end) for start, end in zip(stitch_coords_start, stitch_coords_end)
|
|
162
|
+
)
|
|
163
|
+
overlap_crop_coords = tuple(
|
|
164
|
+
(start, end)
|
|
165
|
+
for start, end in zip(overlap_crop_coords_start, overlap_crop_coords_end)
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
tile_info = TileInformation(
|
|
169
|
+
array_shape=data_shape[1:], # remove S dim
|
|
170
|
+
last_tile=last_tile,
|
|
171
|
+
overlap_crop_coords=overlap_crop_coords[-n_spatial_dims:],
|
|
172
|
+
stitch_coords=stitch_coords[-n_spatial_dims:],
|
|
173
|
+
sample_id=0,
|
|
174
|
+
)
|
|
175
|
+
return tile_info
|
|
176
|
+
|
|
177
|
+
|
|
96
178
|
def compute_tile_info(
|
|
97
|
-
|
|
179
|
+
tile_grid_indices: NDArray[np.int_],
|
|
98
180
|
data_shape: NDArray[np.int_],
|
|
99
181
|
tile_size: NDArray[np.int_],
|
|
100
182
|
overlaps: NDArray[np.int_],
|
|
101
183
|
sample_id: int = 0,
|
|
102
184
|
) -> TileInformation:
|
|
103
185
|
"""
|
|
104
|
-
Compute the tile information for a tile with the coordinates `
|
|
186
|
+
Compute the tile information for a tile with the coordinates `tile_grid_indices`.
|
|
105
187
|
|
|
106
188
|
Parameters
|
|
107
189
|
----------
|
|
108
|
-
|
|
190
|
+
tile_grid_indices : 1D np.array of int
|
|
109
191
|
The coordinates of the tile within the tile grid, ((Z), Y, X), i.e. for 2D
|
|
110
192
|
tiling the coordinates for the second tile in the first row of tiles would be
|
|
111
193
|
(0, 1).
|
|
@@ -127,7 +209,7 @@ def compute_tile_info(
|
|
|
127
209
|
|
|
128
210
|
# The extent of the tile which will make up part of the stitched image.
|
|
129
211
|
stitch_size = tile_size - overlaps
|
|
130
|
-
stitch_coords_start =
|
|
212
|
+
stitch_coords_start = tile_grid_indices * stitch_size
|
|
131
213
|
stitch_coords_end = stitch_coords_start + stitch_size
|
|
132
214
|
|
|
133
215
|
tile_coords_start = stitch_coords_start - overlaps // 2
|
|
@@ -155,7 +237,7 @@ def compute_tile_info(
|
|
|
155
237
|
|
|
156
238
|
# --- Check if last tile
|
|
157
239
|
tile_grid_shape = np.array(compute_tile_grid_shape(data_shape, tile_size, overlaps))
|
|
158
|
-
last_tile = (
|
|
240
|
+
last_tile = (tile_grid_indices == (tile_grid_shape - 1)).all()
|
|
159
241
|
|
|
160
242
|
tile_info = TileInformation(
|
|
161
243
|
array_shape=data_shape,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""CAREamics Lightning module."""
|
|
2
2
|
|
|
3
|
-
from typing import Any, Callable,
|
|
3
|
+
from typing import Any, Callable, Optional, Union
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
import pytorch_lightning as L
|
|
@@ -271,6 +271,12 @@ class VAEModule(L.LightningModule):
|
|
|
271
271
|
self.noise_model: NoiseModel = noise_model_factory(
|
|
272
272
|
self.algorithm_config.noise_model
|
|
273
273
|
)
|
|
274
|
+
# TODO: here we can add some code to check whether the noise model is not None
|
|
275
|
+
# and `self.algorithm_config.noise_model_likelihood_model.noise_model` is,
|
|
276
|
+
# instead, None. In that case we could assign the noise model to the latter.
|
|
277
|
+
# This is particular useful when loading an algorithm config from file.
|
|
278
|
+
# Indeed, in that case the noise model in the nm likelihood is likely
|
|
279
|
+
# not available since excluded from serializaion.
|
|
274
280
|
self.noise_model_likelihood: NoiseModelLikelihood = likelihood_factory(
|
|
275
281
|
self.algorithm_config.noise_model_likelihood_model
|
|
276
282
|
)
|
|
@@ -550,7 +556,6 @@ class VAEModule(L.LightningModule):
|
|
|
550
556
|
|
|
551
557
|
# TODO: make this LVAE compatible (?)
|
|
552
558
|
def create_careamics_module(
|
|
553
|
-
algorithm_type: Literal["fcn"],
|
|
554
559
|
algorithm: Union[SupportedAlgorithm, str],
|
|
555
560
|
loss: Union[SupportedLoss, str],
|
|
556
561
|
architecture: Union[SupportedArchitecture, str],
|
|
@@ -567,8 +572,6 @@ def create_careamics_module(
|
|
|
567
572
|
|
|
568
573
|
Parameters
|
|
569
574
|
----------
|
|
570
|
-
algorithm_type : Literal["fcn"]
|
|
571
|
-
Algorithm type to use for training.
|
|
572
575
|
algorithm : SupportedAlgorithm or str
|
|
573
576
|
Algorithm to use for training (see SupportedAlgorithm).
|
|
574
577
|
loss : SupportedLoss or str
|
|
@@ -604,7 +607,6 @@ def create_careamics_module(
|
|
|
604
607
|
if model_parameters is None:
|
|
605
608
|
model_parameters = {}
|
|
606
609
|
algorithm_configuration: dict[str, Any] = {
|
|
607
|
-
"algorithm_type": algorithm_type,
|
|
608
610
|
"algorithm": algorithm,
|
|
609
611
|
"loss": loss,
|
|
610
612
|
"optimizer": {
|
|
@@ -623,10 +625,10 @@ def create_careamics_module(
|
|
|
623
625
|
algorithm_configuration["model"] = model_configuration
|
|
624
626
|
|
|
625
627
|
# call the parent init using an AlgorithmModel instance
|
|
626
|
-
|
|
628
|
+
algorithm_str = algorithm_configuration["algorithm"]
|
|
629
|
+
if algorithm_str in FCNAlgorithmConfig.get_compatible_algorithms():
|
|
627
630
|
return FCNModule(FCNAlgorithmConfig(**algorithm_configuration))
|
|
628
631
|
else:
|
|
629
632
|
raise NotImplementedError(
|
|
630
|
-
f"Model {
|
|
631
|
-
f"implemented or unknown."
|
|
633
|
+
f"Model {algorithm_str} is not implemented or unknown."
|
|
632
634
|
)
|
|
@@ -9,8 +9,8 @@ from numpy.typing import NDArray
|
|
|
9
9
|
from torch.utils.data import DataLoader
|
|
10
10
|
|
|
11
11
|
from careamics.config import DataConfig
|
|
12
|
-
from careamics.config.data_model import TRANSFORMS_UNION
|
|
13
12
|
from careamics.config.support import SupportedData
|
|
13
|
+
from careamics.config.transformations import TransformModel
|
|
14
14
|
from careamics.dataset.dataset_utils import (
|
|
15
15
|
get_files_size,
|
|
16
16
|
list_files,
|
|
@@ -472,7 +472,7 @@ def create_train_datamodule(
|
|
|
472
472
|
axes: str,
|
|
473
473
|
batch_size: int,
|
|
474
474
|
val_data: Optional[Union[str, Path, NDArray]] = None,
|
|
475
|
-
transforms: Optional[list[
|
|
475
|
+
transforms: Optional[list[TransformModel]] = None,
|
|
476
476
|
train_target_data: Optional[Union[str, Path, NDArray]] = None,
|
|
477
477
|
val_target_data: Optional[Union[str, Path, NDArray]] = None,
|
|
478
478
|
read_source_func: Optional[Callable] = None,
|
careamics/losses/loss_factory.py
CHANGED
|
@@ -56,9 +56,9 @@ class LVAELossParameters:
|
|
|
56
56
|
reconstruction_weight: float = 1.0
|
|
57
57
|
"""Weight for the reconstruction loss in the total net loss
|
|
58
58
|
(i.e., `net_loss = reconstruction_weight * rec_loss + kl_weight * kl_loss`)."""
|
|
59
|
-
musplit_weight: float = 0.
|
|
60
|
-
"""Weight for the muSplit loss (used in the muSplit-
|
|
61
|
-
denoisplit_weight: float =
|
|
59
|
+
musplit_weight: float = 0.1
|
|
60
|
+
"""Weight for the muSplit loss (used in the muSplit-denoiSplit loss)."""
|
|
61
|
+
denoisplit_weight: float = 0.9
|
|
62
62
|
"""Weight for the denoiSplit loss (used in the muSplit-deonoiSplit loss)."""
|
|
63
63
|
kl_type: Literal["kl", "kl_restricted", "kl_spatial", "kl_channelwise"] = "kl"
|
|
64
64
|
"""Type of KL divergence used as KL loss."""
|
careamics/losses/lvae/losses.py
CHANGED
|
@@ -137,8 +137,8 @@ def reconstruction_loss_musplit_denoisplit(
|
|
|
137
137
|
recons_loss : torch.Tensor
|
|
138
138
|
The reconstruction loss. Shape is (1, ).
|
|
139
139
|
"""
|
|
140
|
-
# TODO:
|
|
141
|
-
#
|
|
140
|
+
# TODO: refactor this function to make it closer to `get_reconstruction_loss`
|
|
141
|
+
# (or viceversa)
|
|
142
142
|
if predictions.shape[1] == 2 * targets.shape[1]:
|
|
143
143
|
# predictions contain both mean and log-variance
|
|
144
144
|
out_mean, _ = predictions.chunk(2, dim=1)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from .multich_dataset import MultiChDloader
|
|
2
|
+
from .lc_dataset import LCMultiChDloader
|
|
3
|
+
from .multifile_dataset import MultiFileDset
|
|
4
|
+
from .config import DatasetConfig
|
|
5
|
+
from .types import DataType, DataSplitType, TilingMode
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"DatasetConfig",
|
|
9
|
+
"MultiChDloader",
|
|
10
|
+
"LCMultiChDloader",
|
|
11
|
+
"MultiFileDset",
|
|
12
|
+
"DataType",
|
|
13
|
+
"DataSplitType",
|
|
14
|
+
"TilingMode",
|
|
15
|
+
]
|