careamics 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +39 -28
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +391 -0
- careamics/cli/main.py +134 -0
- careamics/config/__init__.py +7 -3
- careamics/config/architectures/__init__.py +2 -2
- careamics/config/architectures/architecture_model.py +1 -1
- careamics/config/architectures/custom_model.py +11 -8
- careamics/config/architectures/lvae_model.py +170 -0
- careamics/config/configuration_factory.py +481 -170
- careamics/config/configuration_model.py +6 -3
- careamics/config/data_model.py +31 -20
- careamics/config/{algorithm_model.py → fcn_algorithm_model.py} +35 -45
- careamics/config/likelihood_model.py +60 -0
- careamics/config/nm_model.py +127 -0
- careamics/config/optimizer_models.py +3 -1
- careamics/config/support/supported_activations.py +1 -0
- careamics/config/support/supported_algorithms.py +17 -4
- careamics/config/support/supported_architectures.py +8 -11
- careamics/config/support/supported_losses.py +3 -1
- careamics/config/support/supported_optimizers.py +1 -1
- careamics/config/support/supported_transforms.py +1 -0
- careamics/config/training_model.py +35 -6
- careamics/config/transformations/__init__.py +4 -1
- careamics/config/transformations/n2v_manipulate_model.py +1 -1
- careamics/config/transformations/transform_union.py +20 -0
- careamics/config/vae_algorithm_model.py +137 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +364 -0
- careamics/file_io/read/tiff.py +1 -1
- careamics/lightning/__init__.py +3 -2
- careamics/lightning/callbacks/hyperparameters_callback.py +1 -1
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +1 -1
- careamics/lightning/lightning_module.py +367 -9
- careamics/lightning/predict_data_module.py +2 -2
- careamics/lightning/train_data_module.py +4 -4
- careamics/losses/__init__.py +11 -1
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/{losses.py → fcn/losses.py} +1 -1
- careamics/losses/loss_factory.py +112 -6
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +445 -0
- careamics/lvae_training/dataset/__init__.py +15 -0
- careamics/lvae_training/dataset/config.py +123 -0
- careamics/lvae_training/dataset/lc_dataset.py +267 -0
- careamics/lvae_training/{data_modules.py → dataset/multich_dataset.py} +375 -501
- careamics/lvae_training/dataset/multifile_dataset.py +334 -0
- careamics/lvae_training/dataset/types.py +43 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +232 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +109 -64
- careamics/lvae_training/get_config.py +1 -1
- careamics/lvae_training/train_lvae.py +6 -3
- careamics/model_io/bioimage/bioimage_utils.py +1 -1
- careamics/model_io/bioimage/model_description.py +2 -2
- careamics/model_io/bmz_io.py +20 -7
- careamics/model_io/model_io_utils.py +16 -4
- careamics/models/__init__.py +1 -3
- careamics/models/activation.py +2 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +21 -21
- careamics/models/lvae/likelihoods.py +190 -129
- careamics/models/lvae/lvae.py +60 -148
- careamics/models/lvae/noise_models.py +318 -186
- careamics/models/lvae/utils.py +2 -2
- careamics/models/model_factory.py +22 -7
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/stitch_prediction.py +16 -2
- careamics/transforms/compose.py +90 -15
- careamics/transforms/n2v_manipulate.py +6 -2
- careamics/transforms/normalize.py +14 -3
- careamics/transforms/pixel_manipulation.py +1 -1
- careamics/transforms/xy_flip.py +16 -6
- careamics/transforms/xy_random_rotate90.py +16 -7
- careamics/utils/metrics.py +277 -24
- careamics/utils/serializers.py +60 -0
- {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/METADATA +5 -4
- {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/RECORD +85 -60
- careamics-0.0.4.dist-info/entry_points.txt +2 -0
- careamics/config/architectures/vae_model.py +0 -42
- careamics/lvae_training/data_utils.py +0 -618
- {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/WHEEL +0 -0
- {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/licenses/LICENSE +0 -0
careamics/transforms/compose.py
CHANGED
|
@@ -1,13 +1,14 @@
|
|
|
1
1
|
"""A class chaining transforms together."""
|
|
2
2
|
|
|
3
|
-
from typing import Dict, List, Optional, Tuple, cast
|
|
3
|
+
from typing import Dict, List, Optional, Tuple, Union, cast
|
|
4
4
|
|
|
5
|
-
|
|
5
|
+
from numpy.typing import NDArray
|
|
6
6
|
|
|
7
|
-
from careamics.config.
|
|
7
|
+
from careamics.config.transformations import TransformModel
|
|
8
8
|
|
|
9
9
|
from .n2v_manipulate import N2VManipulate
|
|
10
10
|
from .normalize import Normalize
|
|
11
|
+
from .transform import Transform
|
|
11
12
|
from .xy_flip import XYFlip
|
|
12
13
|
from .xy_random_rotate90 import XYRandomRotate90
|
|
13
14
|
|
|
@@ -36,7 +37,7 @@ class Compose:
|
|
|
36
37
|
|
|
37
38
|
Parameters
|
|
38
39
|
----------
|
|
39
|
-
transform_list : List[
|
|
40
|
+
transform_list : List[TransformModel]
|
|
40
41
|
A list of dictionaries where each dictionary contains the name of a
|
|
41
42
|
transform and its parameters.
|
|
42
43
|
|
|
@@ -46,26 +47,27 @@ class Compose:
|
|
|
46
47
|
A callable that applies the transforms to the input data.
|
|
47
48
|
"""
|
|
48
49
|
|
|
49
|
-
def __init__(self, transform_list: List[
|
|
50
|
+
def __init__(self, transform_list: List[TransformModel]) -> None:
|
|
50
51
|
"""Instantiate a Compose object.
|
|
51
52
|
|
|
52
53
|
Parameters
|
|
53
54
|
----------
|
|
54
|
-
transform_list : List[
|
|
55
|
+
transform_list : List[TransformModel]
|
|
55
56
|
A list of dictionaries where each dictionary contains the name of a
|
|
56
57
|
transform and its parameters.
|
|
57
58
|
"""
|
|
58
59
|
# retrieve all available transforms
|
|
59
|
-
|
|
60
|
+
# TODO: correctly type hint get_all_transforms function output
|
|
61
|
+
all_transforms: dict[str, type[Transform]] = get_all_transforms()
|
|
60
62
|
|
|
61
63
|
# instantiate all transforms
|
|
62
|
-
self.transforms = [
|
|
64
|
+
self.transforms: list[Transform] = [
|
|
63
65
|
all_transforms[t.name](**t.model_dump()) for t in transform_list
|
|
64
66
|
]
|
|
65
67
|
|
|
66
68
|
def _chain_transforms(
|
|
67
|
-
self, patch:
|
|
68
|
-
) -> Tuple[
|
|
69
|
+
self, patch: NDArray, target: Optional[NDArray]
|
|
70
|
+
) -> Tuple[Optional[NDArray], ...]:
|
|
69
71
|
"""Chain transforms on the input data.
|
|
70
72
|
|
|
71
73
|
Parameters
|
|
@@ -80,16 +82,56 @@ class Compose:
|
|
|
80
82
|
Tuple[np.ndarray, Optional[np.ndarray]]
|
|
81
83
|
The output of the transformations.
|
|
82
84
|
"""
|
|
83
|
-
params
|
|
85
|
+
params: Union[
|
|
86
|
+
tuple[NDArray, Optional[NDArray]],
|
|
87
|
+
tuple[NDArray, NDArray, NDArray], # N2VManiupulate output
|
|
88
|
+
] = (patch, target)
|
|
84
89
|
|
|
85
90
|
for t in self.transforms:
|
|
86
|
-
|
|
91
|
+
# N2VManipulate returns tuple of 3 arrays
|
|
92
|
+
# - Other transoforms return tuple of (patch, target, additional_arrays)
|
|
93
|
+
if isinstance(t, N2VManipulate):
|
|
94
|
+
patch, *_ = params
|
|
95
|
+
params = t(patch=patch)
|
|
96
|
+
else:
|
|
97
|
+
*params, _ = t(*params) # ignore additional_arrays dict
|
|
87
98
|
|
|
88
99
|
return params
|
|
89
100
|
|
|
101
|
+
def _chain_transforms_additional_arrays(
|
|
102
|
+
self,
|
|
103
|
+
patch: NDArray,
|
|
104
|
+
target: Optional[NDArray],
|
|
105
|
+
**additional_arrays: NDArray,
|
|
106
|
+
) -> Tuple[NDArray, Optional[NDArray], dict[str, NDArray]]:
|
|
107
|
+
"""Chain transforms on the input data, with additional arrays.
|
|
108
|
+
|
|
109
|
+
Parameters
|
|
110
|
+
----------
|
|
111
|
+
patch : np.ndarray
|
|
112
|
+
Input data.
|
|
113
|
+
target : Optional[np.ndarray]
|
|
114
|
+
Target data, by default None.
|
|
115
|
+
**additional_arrays : NDArray
|
|
116
|
+
Additional arrays that will be transformed identically to `patch` and
|
|
117
|
+
`target`.
|
|
118
|
+
|
|
119
|
+
Returns
|
|
120
|
+
-------
|
|
121
|
+
Tuple[np.ndarray, Optional[np.ndarray]]
|
|
122
|
+
The output of the transformations.
|
|
123
|
+
"""
|
|
124
|
+
params = {"patch": patch, "target": target, **additional_arrays}
|
|
125
|
+
|
|
126
|
+
for t in self.transforms:
|
|
127
|
+
patch, target, additional_arrays = t(**params)
|
|
128
|
+
params = {"patch": patch, "target": target, **additional_arrays}
|
|
129
|
+
|
|
130
|
+
return patch, target, additional_arrays
|
|
131
|
+
|
|
90
132
|
def __call__(
|
|
91
|
-
self, patch:
|
|
92
|
-
) -> Tuple[
|
|
133
|
+
self, patch: NDArray, target: Optional[NDArray] = None
|
|
134
|
+
) -> Tuple[NDArray, ...]:
|
|
93
135
|
"""Apply the transforms to the input data.
|
|
94
136
|
|
|
95
137
|
Parameters
|
|
@@ -104,4 +146,37 @@ class Compose:
|
|
|
104
146
|
Tuple[np.ndarray, ...]
|
|
105
147
|
The output of the transformations.
|
|
106
148
|
"""
|
|
107
|
-
|
|
149
|
+
# TODO: solve casting Compose.__call__ ouput
|
|
150
|
+
return cast(Tuple[NDArray, ...], self._chain_transforms(patch, target))
|
|
151
|
+
|
|
152
|
+
def transform_with_additional_arrays(
|
|
153
|
+
self,
|
|
154
|
+
patch: NDArray,
|
|
155
|
+
target: Optional[NDArray] = None,
|
|
156
|
+
**additional_arrays: NDArray,
|
|
157
|
+
) -> tuple[NDArray, Optional[NDArray], dict[str, NDArray]]:
|
|
158
|
+
"""Apply the transforms to the input data, including additional arrays.
|
|
159
|
+
|
|
160
|
+
Parameters
|
|
161
|
+
----------
|
|
162
|
+
patch : np.ndarray
|
|
163
|
+
The input data.
|
|
164
|
+
target : Optional[np.ndarray], optional
|
|
165
|
+
Target data, by default None.
|
|
166
|
+
**additional_arrays : NDArray
|
|
167
|
+
Additional arrays that will be transformed identically to `patch` and
|
|
168
|
+
`target`.
|
|
169
|
+
|
|
170
|
+
Returns
|
|
171
|
+
-------
|
|
172
|
+
NDArray
|
|
173
|
+
The transformed patch.
|
|
174
|
+
NDArray | None
|
|
175
|
+
The transformed target.
|
|
176
|
+
dict of {str, NDArray}
|
|
177
|
+
Transformed additional arrays. Keys correspond to the keyword argument
|
|
178
|
+
names.
|
|
179
|
+
"""
|
|
180
|
+
return self._chain_transforms_additional_arrays(
|
|
181
|
+
patch, target, **additional_arrays
|
|
182
|
+
)
|
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
from typing import Any, Literal, Optional, Tuple
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
|
+
from numpy.typing import NDArray
|
|
6
7
|
|
|
7
8
|
from careamics.config.support import SupportedPixelManipulation, SupportedStructAxis
|
|
8
9
|
from careamics.transforms.transform import Transform
|
|
@@ -98,8 +99,8 @@ class N2VManipulate(Transform):
|
|
|
98
99
|
self.rng = np.random.default_rng(seed=seed)
|
|
99
100
|
|
|
100
101
|
def __call__(
|
|
101
|
-
self, patch:
|
|
102
|
-
) -> Tuple[
|
|
102
|
+
self, patch: NDArray, *args: Any, **kwargs: Any
|
|
103
|
+
) -> Tuple[NDArray, NDArray, NDArray]:
|
|
103
104
|
"""Apply the transform to the image.
|
|
104
105
|
|
|
105
106
|
Parameters
|
|
@@ -142,5 +143,8 @@ class N2VManipulate(Transform):
|
|
|
142
143
|
else:
|
|
143
144
|
raise ValueError(f"Unknown masking strategy ({self.strategy}).")
|
|
144
145
|
|
|
146
|
+
# TODO: Output does not match other transforms, how to resolve?
|
|
147
|
+
# - Don't include in Compose and apply after if algorithm is N2V?
|
|
148
|
+
# - or just don't return patch? but then mask is in the target position
|
|
145
149
|
# TODO why return patch?
|
|
146
150
|
return masked, patch, mask
|
|
@@ -90,8 +90,11 @@ class Normalize(Transform):
|
|
|
90
90
|
self.eps = 1e-6
|
|
91
91
|
|
|
92
92
|
def __call__(
|
|
93
|
-
self,
|
|
94
|
-
|
|
93
|
+
self,
|
|
94
|
+
patch: np.ndarray,
|
|
95
|
+
target: Optional[NDArray] = None,
|
|
96
|
+
**additional_arrays: NDArray,
|
|
97
|
+
) -> tuple[NDArray, Optional[NDArray], dict[str, NDArray]]:
|
|
95
98
|
"""Apply the transform to the source patch and the target (optional).
|
|
96
99
|
|
|
97
100
|
Parameters
|
|
@@ -100,6 +103,9 @@ class Normalize(Transform):
|
|
|
100
103
|
Patch, 2D or 3D, shape C(Z)YX.
|
|
101
104
|
target : NDArray, optional
|
|
102
105
|
Target for the patch, by default None.
|
|
106
|
+
**additional_arrays : NDArray
|
|
107
|
+
Additional arrays that will be transformed identically to `patch` and
|
|
108
|
+
`target`.
|
|
103
109
|
|
|
104
110
|
Returns
|
|
105
111
|
-------
|
|
@@ -111,6 +117,11 @@ class Normalize(Transform):
|
|
|
111
117
|
f"Number of means (got a list of size {len(self.image_means)}) and "
|
|
112
118
|
f"number of channels (got shape {patch.shape} for C(Z)YX) do not match."
|
|
113
119
|
)
|
|
120
|
+
if len(additional_arrays) != 0:
|
|
121
|
+
raise NotImplementedError(
|
|
122
|
+
"Transforming additional arrays is currently not supported for "
|
|
123
|
+
"`Normalize`."
|
|
124
|
+
)
|
|
114
125
|
|
|
115
126
|
# reshape mean and std and apply the normalization to the patch
|
|
116
127
|
means = _reshape_stats(self.image_means, patch.ndim)
|
|
@@ -129,7 +140,7 @@ class Normalize(Transform):
|
|
|
129
140
|
else:
|
|
130
141
|
norm_target = None
|
|
131
142
|
|
|
132
|
-
return norm_patch, norm_target
|
|
143
|
+
return norm_patch, norm_target, additional_arrays
|
|
133
144
|
|
|
134
145
|
def _apply(self, patch: NDArray, mean: NDArray, std: NDArray) -> NDArray:
|
|
135
146
|
"""
|
|
@@ -161,7 +161,7 @@ def _get_stratified_coords(
|
|
|
161
161
|
coordinate_grid = np.array(coordinate_grid_list).reshape(len(shape), -1).T
|
|
162
162
|
|
|
163
163
|
grid_random_increment = rng.integers(
|
|
164
|
-
_odd_jitter_func(float(max(steps)), rng)
|
|
164
|
+
_odd_jitter_func(float(max(steps)), rng) # type: ignore
|
|
165
165
|
* np.ones_like(coordinate_grid).astype(np.int32)
|
|
166
166
|
- 1,
|
|
167
167
|
size=coordinate_grid.shape,
|
careamics/transforms/xy_flip.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
"""XY flip transform."""
|
|
2
2
|
|
|
3
|
-
from typing import Optional
|
|
3
|
+
from typing import Optional
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
|
+
from numpy.typing import NDArray
|
|
6
7
|
|
|
7
8
|
from careamics.transforms.transform import Transform
|
|
8
9
|
|
|
@@ -78,8 +79,11 @@ class XYFlip(Transform):
|
|
|
78
79
|
self.rng = np.random.default_rng(seed=seed)
|
|
79
80
|
|
|
80
81
|
def __call__(
|
|
81
|
-
self,
|
|
82
|
-
|
|
82
|
+
self,
|
|
83
|
+
patch: NDArray,
|
|
84
|
+
target: Optional[NDArray] = None,
|
|
85
|
+
**additional_arrays: NDArray,
|
|
86
|
+
) -> tuple[NDArray, Optional[NDArray], dict[str, NDArray]]:
|
|
83
87
|
"""Apply the transform to the source patch and the target (optional).
|
|
84
88
|
|
|
85
89
|
Parameters
|
|
@@ -88,6 +92,9 @@ class XYFlip(Transform):
|
|
|
88
92
|
Patch, 2D or 3D, shape C(Z)YX.
|
|
89
93
|
target : Optional[np.ndarray], optional
|
|
90
94
|
Target for the patch, by default None.
|
|
95
|
+
**additional_arrays : NDArray
|
|
96
|
+
Additional arrays that will be transformed identically to `patch` and
|
|
97
|
+
`target`.
|
|
91
98
|
|
|
92
99
|
Returns
|
|
93
100
|
-------
|
|
@@ -95,17 +102,20 @@ class XYFlip(Transform):
|
|
|
95
102
|
Transformed patch and target.
|
|
96
103
|
"""
|
|
97
104
|
if self.rng.random() > self.p:
|
|
98
|
-
return patch, target
|
|
105
|
+
return patch, target, additional_arrays
|
|
99
106
|
|
|
100
107
|
# choose an axis to flip
|
|
101
108
|
axis = self.rng.choice(self.axis_indices)
|
|
102
109
|
|
|
103
110
|
patch_transformed = self._apply(patch, axis)
|
|
104
111
|
target_transformed = self._apply(target, axis) if target is not None else None
|
|
112
|
+
additional_transformed = {
|
|
113
|
+
key: self._apply(array, axis) for key, array in additional_arrays.items()
|
|
114
|
+
}
|
|
105
115
|
|
|
106
|
-
return patch_transformed, target_transformed
|
|
116
|
+
return patch_transformed, target_transformed, additional_transformed
|
|
107
117
|
|
|
108
|
-
def _apply(self, patch:
|
|
118
|
+
def _apply(self, patch: NDArray, axis: int) -> NDArray:
|
|
109
119
|
"""Apply the transform to the image.
|
|
110
120
|
|
|
111
121
|
Parameters
|
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
from typing import Optional, Tuple
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
|
+
from numpy.typing import NDArray
|
|
6
7
|
|
|
7
8
|
from careamics.transforms.transform import Transform
|
|
8
9
|
|
|
@@ -49,8 +50,11 @@ class XYRandomRotate90(Transform):
|
|
|
49
50
|
self.rng = np.random.default_rng(seed=seed)
|
|
50
51
|
|
|
51
52
|
def __call__(
|
|
52
|
-
self,
|
|
53
|
-
|
|
53
|
+
self,
|
|
54
|
+
patch: NDArray,
|
|
55
|
+
target: Optional[NDArray] = None,
|
|
56
|
+
**additional_arrays: NDArray,
|
|
57
|
+
) -> tuple[NDArray, Optional[NDArray], dict[str, NDArray]]:
|
|
54
58
|
"""Apply the transform to the source patch and the target (optional).
|
|
55
59
|
|
|
56
60
|
Parameters
|
|
@@ -59,6 +63,9 @@ class XYRandomRotate90(Transform):
|
|
|
59
63
|
Patch, 2D or 3D, shape C(Z)YX.
|
|
60
64
|
target : Optional[np.ndarray], optional
|
|
61
65
|
Target for the patch, by default None.
|
|
66
|
+
**additional_arrays : NDArray
|
|
67
|
+
Additional arrays that will be transformed identically to `patch` and
|
|
68
|
+
`target`.
|
|
62
69
|
|
|
63
70
|
Returns
|
|
64
71
|
-------
|
|
@@ -66,7 +73,7 @@ class XYRandomRotate90(Transform):
|
|
|
66
73
|
Transformed patch and target.
|
|
67
74
|
"""
|
|
68
75
|
if self.rng.random() > self.p:
|
|
69
|
-
return patch, target
|
|
76
|
+
return patch, target, additional_arrays
|
|
70
77
|
|
|
71
78
|
# number of rotations
|
|
72
79
|
n_rot = self.rng.integers(1, 4)
|
|
@@ -76,12 +83,14 @@ class XYRandomRotate90(Transform):
|
|
|
76
83
|
target_transformed = (
|
|
77
84
|
self._apply(target, n_rot, axes) if target is not None else None
|
|
78
85
|
)
|
|
86
|
+
additional_transformed = {
|
|
87
|
+
key: self._apply(array, n_rot, axes)
|
|
88
|
+
for key, array in additional_arrays.items()
|
|
89
|
+
}
|
|
79
90
|
|
|
80
|
-
return patch_transformed, target_transformed
|
|
91
|
+
return patch_transformed, target_transformed, additional_transformed
|
|
81
92
|
|
|
82
|
-
def _apply(
|
|
83
|
-
self, patch: np.ndarray, n_rot: int, axes: Tuple[int, int]
|
|
84
|
-
) -> np.ndarray:
|
|
93
|
+
def _apply(self, patch: NDArray, n_rot: int, axes: Tuple[int, int]) -> NDArray:
|
|
85
94
|
"""Apply the transform to the image.
|
|
86
95
|
|
|
87
96
|
Parameters
|