careamics 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +39 -28
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +391 -0
- careamics/cli/main.py +134 -0
- careamics/config/__init__.py +7 -3
- careamics/config/architectures/__init__.py +2 -2
- careamics/config/architectures/architecture_model.py +1 -1
- careamics/config/architectures/custom_model.py +11 -8
- careamics/config/architectures/lvae_model.py +170 -0
- careamics/config/configuration_factory.py +481 -170
- careamics/config/configuration_model.py +6 -3
- careamics/config/data_model.py +31 -20
- careamics/config/{algorithm_model.py → fcn_algorithm_model.py} +35 -45
- careamics/config/likelihood_model.py +60 -0
- careamics/config/nm_model.py +127 -0
- careamics/config/optimizer_models.py +3 -1
- careamics/config/support/supported_activations.py +1 -0
- careamics/config/support/supported_algorithms.py +17 -4
- careamics/config/support/supported_architectures.py +8 -11
- careamics/config/support/supported_losses.py +3 -1
- careamics/config/support/supported_optimizers.py +1 -1
- careamics/config/support/supported_transforms.py +1 -0
- careamics/config/training_model.py +35 -6
- careamics/config/transformations/__init__.py +4 -1
- careamics/config/transformations/n2v_manipulate_model.py +1 -1
- careamics/config/transformations/transform_union.py +20 -0
- careamics/config/vae_algorithm_model.py +137 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +364 -0
- careamics/file_io/read/tiff.py +1 -1
- careamics/lightning/__init__.py +3 -2
- careamics/lightning/callbacks/hyperparameters_callback.py +1 -1
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +1 -1
- careamics/lightning/lightning_module.py +367 -9
- careamics/lightning/predict_data_module.py +2 -2
- careamics/lightning/train_data_module.py +4 -4
- careamics/losses/__init__.py +11 -1
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/{losses.py → fcn/losses.py} +1 -1
- careamics/losses/loss_factory.py +112 -6
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +445 -0
- careamics/lvae_training/dataset/__init__.py +15 -0
- careamics/lvae_training/dataset/config.py +123 -0
- careamics/lvae_training/dataset/lc_dataset.py +267 -0
- careamics/lvae_training/{data_modules.py → dataset/multich_dataset.py} +375 -501
- careamics/lvae_training/dataset/multifile_dataset.py +334 -0
- careamics/lvae_training/dataset/types.py +43 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +232 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +109 -64
- careamics/lvae_training/get_config.py +1 -1
- careamics/lvae_training/train_lvae.py +6 -3
- careamics/model_io/bioimage/bioimage_utils.py +1 -1
- careamics/model_io/bioimage/model_description.py +2 -2
- careamics/model_io/bmz_io.py +20 -7
- careamics/model_io/model_io_utils.py +16 -4
- careamics/models/__init__.py +1 -3
- careamics/models/activation.py +2 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +21 -21
- careamics/models/lvae/likelihoods.py +190 -129
- careamics/models/lvae/lvae.py +60 -148
- careamics/models/lvae/noise_models.py +318 -186
- careamics/models/lvae/utils.py +2 -2
- careamics/models/model_factory.py +22 -7
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/stitch_prediction.py +16 -2
- careamics/transforms/compose.py +90 -15
- careamics/transforms/n2v_manipulate.py +6 -2
- careamics/transforms/normalize.py +14 -3
- careamics/transforms/pixel_manipulation.py +1 -1
- careamics/transforms/xy_flip.py +16 -6
- careamics/transforms/xy_random_rotate90.py +16 -7
- careamics/utils/metrics.py +277 -24
- careamics/utils/serializers.py +60 -0
- {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/METADATA +5 -4
- {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/RECORD +85 -60
- careamics-0.0.4.dist-info/entry_points.txt +2 -0
- careamics/config/architectures/vae_model.py +0 -42
- careamics/lvae_training/data_utils.py +0 -618
- {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/WHEEL +0 -0
- {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/licenses/LICENSE +0 -0
careamics/config/__init__.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
"""Configuration module."""
|
|
2
2
|
|
|
3
3
|
__all__ = [
|
|
4
|
-
"
|
|
4
|
+
"FCNAlgorithmConfig",
|
|
5
|
+
"VAEAlgorithmConfig",
|
|
5
6
|
"DataConfig",
|
|
6
7
|
"Configuration",
|
|
7
8
|
"CheckpointModel",
|
|
@@ -15,9 +16,9 @@ __all__ = [
|
|
|
15
16
|
"register_model",
|
|
16
17
|
"CustomModel",
|
|
17
18
|
"clear_custom_models",
|
|
19
|
+
"GaussianMixtureNMConfig",
|
|
20
|
+
"MultiChannelNMConfig",
|
|
18
21
|
]
|
|
19
|
-
|
|
20
|
-
from .algorithm_model import AlgorithmConfig
|
|
21
22
|
from .architectures import CustomModel, clear_custom_models, register_model
|
|
22
23
|
from .callback_model import CheckpointModel
|
|
23
24
|
from .configuration_factory import (
|
|
@@ -31,5 +32,8 @@ from .configuration_model import (
|
|
|
31
32
|
save_configuration,
|
|
32
33
|
)
|
|
33
34
|
from .data_model import DataConfig
|
|
35
|
+
from .fcn_algorithm_model import FCNAlgorithmConfig
|
|
34
36
|
from .inference_model import InferenceConfig
|
|
37
|
+
from .nm_model import GaussianMixtureNMConfig, MultiChannelNMConfig
|
|
35
38
|
from .training_model import TrainingConfig
|
|
39
|
+
from .vae_algorithm_model import VAEAlgorithmConfig
|
|
@@ -4,7 +4,7 @@ __all__ = [
|
|
|
4
4
|
"ArchitectureModel",
|
|
5
5
|
"CustomModel",
|
|
6
6
|
"UNetModel",
|
|
7
|
-
"
|
|
7
|
+
"LVAEModel",
|
|
8
8
|
"clear_custom_models",
|
|
9
9
|
"get_custom_model",
|
|
10
10
|
"register_model",
|
|
@@ -12,6 +12,6 @@ __all__ = [
|
|
|
12
12
|
|
|
13
13
|
from .architecture_model import ArchitectureModel
|
|
14
14
|
from .custom_model import CustomModel
|
|
15
|
+
from .lvae_model import LVAEModel
|
|
15
16
|
from .register_model import clear_custom_models, get_custom_model, register_model
|
|
16
17
|
from .unet_model import UNetModel
|
|
17
|
-
from .vae_model import VAEModel
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
import inspect
|
|
5
6
|
from pprint import pformat
|
|
6
7
|
from typing import Any, Literal
|
|
7
8
|
|
|
@@ -23,12 +24,13 @@ class CustomModel(ArchitectureModel):
|
|
|
23
24
|
|
|
24
25
|
Attributes
|
|
25
26
|
----------
|
|
26
|
-
architecture : Literal["
|
|
27
|
-
Discriminator for the custom model, must be set to "
|
|
27
|
+
architecture : Literal["custom"]
|
|
28
|
+
Discriminator for the custom model, must be set to "custom".
|
|
28
29
|
name : str
|
|
29
30
|
Name of the custom model.
|
|
30
31
|
parameters : CustomParametersModel
|
|
31
|
-
|
|
32
|
+
All parameters, required for the initialization of the torch module have to be
|
|
33
|
+
passed here.
|
|
32
34
|
|
|
33
35
|
Raises
|
|
34
36
|
------
|
|
@@ -57,7 +59,7 @@ class CustomModel(ArchitectureModel):
|
|
|
57
59
|
...
|
|
58
60
|
>>> # Create a configuration
|
|
59
61
|
>>> config_dict = {
|
|
60
|
-
... "architecture": "
|
|
62
|
+
... "architecture": "custom",
|
|
61
63
|
... "name": "my_linear",
|
|
62
64
|
... "in_features": 10,
|
|
63
65
|
... "out_features": 5,
|
|
@@ -71,10 +73,9 @@ class CustomModel(ArchitectureModel):
|
|
|
71
73
|
)
|
|
72
74
|
|
|
73
75
|
# discriminator used for choosing the pydantic model in Model
|
|
74
|
-
architecture: Literal["
|
|
76
|
+
architecture: Literal["custom"]
|
|
75
77
|
"""Name of the architecture."""
|
|
76
78
|
|
|
77
|
-
# name of the custom model
|
|
78
79
|
name: str
|
|
79
80
|
"""Name of the custom model."""
|
|
80
81
|
|
|
@@ -120,10 +121,12 @@ class CustomModel(ArchitectureModel):
|
|
|
120
121
|
get_custom_model(self.name)(**self.model_dump())
|
|
121
122
|
except Exception as e:
|
|
122
123
|
raise ValueError(
|
|
123
|
-
f"
|
|
124
|
+
f"while passing parameters to the model {e}. Verify that all "
|
|
124
125
|
f"mandatory parameters are provided, and that either the {e} accepts "
|
|
125
126
|
f"*args and **kwargs in its __init__() method, or that no additional"
|
|
126
|
-
f"parameter is provided."
|
|
127
|
+
f"parameter is provided. Trace: "
|
|
128
|
+
f"filename: {inspect.trace()[-1].filename}, function: "
|
|
129
|
+
f"{inspect.trace()[-1].function}, line: {inspect.trace()[-1].lineno}"
|
|
127
130
|
) from None
|
|
128
131
|
|
|
129
132
|
return self
|
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
"""LVAE Pydantic model."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from pydantic import ConfigDict, Field, field_validator, model_validator
|
|
6
|
+
from typing_extensions import Self
|
|
7
|
+
|
|
8
|
+
from .architecture_model import ArchitectureModel
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
# TODO: it is quite confusing to call this LVAEModel, as it is basically a config
|
|
12
|
+
class LVAEModel(ArchitectureModel):
|
|
13
|
+
"""LVAE model."""
|
|
14
|
+
|
|
15
|
+
model_config = ConfigDict(validate_assignment=True, validate_default=True)
|
|
16
|
+
|
|
17
|
+
architecture: Literal["LVAE"]
|
|
18
|
+
input_shape: int = Field(default=64, ge=8, le=1024)
|
|
19
|
+
multiscale_count: int = Field(default=5) # TODO clarify
|
|
20
|
+
# 0 - off, len(z_dims) + 1 # TODO can/should be le to z_dims len + 1
|
|
21
|
+
z_dims: list = Field(default=[128, 128, 128, 128])
|
|
22
|
+
output_channels: int = Field(default=1, ge=1)
|
|
23
|
+
encoder_n_filters: int = Field(default=64, ge=8, le=1024)
|
|
24
|
+
decoder_n_filters: int = Field(default=64, ge=8, le=1024)
|
|
25
|
+
encoder_dropout: float = Field(default=0.1, ge=0.0, le=0.9)
|
|
26
|
+
decoder_dropout: float = Field(default=0.1, ge=0.0, le=0.9)
|
|
27
|
+
nonlinearity: Literal[
|
|
28
|
+
"None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU", "ELU"
|
|
29
|
+
] = Field(
|
|
30
|
+
default="ELU",
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
predict_logvar: Literal[None, "pixelwise"] = None
|
|
34
|
+
|
|
35
|
+
analytical_kl: bool = Field(
|
|
36
|
+
default=False,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
@field_validator("encoder_n_filters")
|
|
40
|
+
@classmethod
|
|
41
|
+
def validate_encoder_even(cls, encoder_n_filters: int) -> int:
|
|
42
|
+
"""
|
|
43
|
+
Validate that num_channels_init is even.
|
|
44
|
+
|
|
45
|
+
Parameters
|
|
46
|
+
----------
|
|
47
|
+
encoder_n_filters : int
|
|
48
|
+
Number of channels.
|
|
49
|
+
|
|
50
|
+
Returns
|
|
51
|
+
-------
|
|
52
|
+
int
|
|
53
|
+
Validated number of channels.
|
|
54
|
+
|
|
55
|
+
Raises
|
|
56
|
+
------
|
|
57
|
+
ValueError
|
|
58
|
+
If the number of channels is odd.
|
|
59
|
+
"""
|
|
60
|
+
# if odd
|
|
61
|
+
if encoder_n_filters % 2 != 0:
|
|
62
|
+
raise ValueError(
|
|
63
|
+
f"Number of channels for the bottom layer must be even"
|
|
64
|
+
f" (got {encoder_n_filters})."
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
return encoder_n_filters
|
|
68
|
+
|
|
69
|
+
@field_validator("decoder_n_filters")
|
|
70
|
+
@classmethod
|
|
71
|
+
def validate_decoder_even(cls, decoder_n_filters: int) -> int:
|
|
72
|
+
"""
|
|
73
|
+
Validate that num_channels_init is even.
|
|
74
|
+
|
|
75
|
+
Parameters
|
|
76
|
+
----------
|
|
77
|
+
decoder_n_filters : int
|
|
78
|
+
Number of channels.
|
|
79
|
+
|
|
80
|
+
Returns
|
|
81
|
+
-------
|
|
82
|
+
int
|
|
83
|
+
Validated number of channels.
|
|
84
|
+
|
|
85
|
+
Raises
|
|
86
|
+
------
|
|
87
|
+
ValueError
|
|
88
|
+
If the number of channels is odd.
|
|
89
|
+
"""
|
|
90
|
+
# if odd
|
|
91
|
+
if decoder_n_filters % 2 != 0:
|
|
92
|
+
raise ValueError(
|
|
93
|
+
f"Number of channels for the bottom layer must be even"
|
|
94
|
+
f" (got {decoder_n_filters})."
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
return decoder_n_filters
|
|
98
|
+
|
|
99
|
+
@field_validator("z_dims")
|
|
100
|
+
def validate_z_dims(cls, z_dims: tuple) -> tuple:
|
|
101
|
+
"""
|
|
102
|
+
Validate the z_dims.
|
|
103
|
+
|
|
104
|
+
Parameters
|
|
105
|
+
----------
|
|
106
|
+
z_dims : tuple
|
|
107
|
+
Tuple of z dimensions.
|
|
108
|
+
|
|
109
|
+
Returns
|
|
110
|
+
-------
|
|
111
|
+
tuple
|
|
112
|
+
Validated z dimensions.
|
|
113
|
+
|
|
114
|
+
Raises
|
|
115
|
+
------
|
|
116
|
+
ValueError
|
|
117
|
+
If the number of z dimensions is not 4.
|
|
118
|
+
"""
|
|
119
|
+
if len(z_dims) < 2:
|
|
120
|
+
raise ValueError(
|
|
121
|
+
f"Number of z dimensions must be at least 2 (got {len(z_dims)})."
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
return z_dims
|
|
125
|
+
|
|
126
|
+
@model_validator(mode="after")
|
|
127
|
+
def validate_multiscale_count(cls, self: Self) -> Self:
|
|
128
|
+
"""
|
|
129
|
+
Validate the multiscale count.
|
|
130
|
+
|
|
131
|
+
Parameters
|
|
132
|
+
----------
|
|
133
|
+
self : Self
|
|
134
|
+
The model.
|
|
135
|
+
|
|
136
|
+
Returns
|
|
137
|
+
-------
|
|
138
|
+
Self
|
|
139
|
+
The validated model.
|
|
140
|
+
"""
|
|
141
|
+
# if self.multiscale_count != 0:
|
|
142
|
+
# if self.multiscale_count != len(self.z_dims) - 1:
|
|
143
|
+
# raise ValueError(
|
|
144
|
+
# f"Multiscale count must be 0 or equal to the number of Z "
|
|
145
|
+
# f"dims - 1 (got {self.multiscale_count} and {len(self.z_dims)})."
|
|
146
|
+
# )
|
|
147
|
+
|
|
148
|
+
return self
|
|
149
|
+
|
|
150
|
+
def set_3D(self, is_3D: bool) -> None:
|
|
151
|
+
"""
|
|
152
|
+
Set 3D model by setting the `conv_dims` parameters.
|
|
153
|
+
|
|
154
|
+
Parameters
|
|
155
|
+
----------
|
|
156
|
+
is_3D : bool
|
|
157
|
+
Whether the algorithm is 3D or not.
|
|
158
|
+
"""
|
|
159
|
+
raise NotImplementedError("VAE is not implemented yet.")
|
|
160
|
+
|
|
161
|
+
def is_3D(self) -> bool:
|
|
162
|
+
"""
|
|
163
|
+
Return whether the model is 3D or not.
|
|
164
|
+
|
|
165
|
+
Returns
|
|
166
|
+
-------
|
|
167
|
+
bool
|
|
168
|
+
Whether the model is 3D or not.
|
|
169
|
+
"""
|
|
170
|
+
raise NotImplementedError("VAE is not implemented yet.")
|