careamics 0.0.3__py3-none-any.whl → 0.0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +25 -17
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +391 -0
- careamics/cli/main.py +134 -0
- careamics/config/architectures/lvae_model.py +0 -4
- careamics/config/configuration_factory.py +480 -177
- careamics/config/configuration_model.py +1 -2
- careamics/config/data_model.py +1 -15
- careamics/config/fcn_algorithm_model.py +14 -9
- careamics/config/likelihood_model.py +21 -4
- careamics/config/nm_model.py +31 -5
- careamics/config/optimizer_models.py +3 -1
- careamics/config/support/supported_optimizers.py +1 -1
- careamics/config/support/supported_transforms.py +1 -0
- careamics/config/training_model.py +35 -6
- careamics/config/transformations/__init__.py +4 -1
- careamics/config/transformations/transform_union.py +20 -0
- careamics/config/vae_algorithm_model.py +2 -36
- careamics/dataset/tiling/lvae_tiled_patching.py +90 -8
- careamics/lightning/lightning_module.py +10 -8
- careamics/lightning/train_data_module.py +2 -2
- careamics/losses/loss_factory.py +3 -3
- careamics/losses/lvae/losses.py +2 -2
- careamics/lvae_training/dataset/__init__.py +15 -0
- careamics/lvae_training/dataset/{vae_data_config.py → config.py} +25 -81
- careamics/lvae_training/dataset/lc_dataset.py +28 -20
- careamics/lvae_training/dataset/{vae_dataset.py → multich_dataset.py} +91 -51
- careamics/lvae_training/dataset/multifile_dataset.py +334 -0
- careamics/lvae_training/dataset/types.py +43 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +232 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +109 -64
- careamics/lvae_training/get_config.py +1 -1
- careamics/lvae_training/train_lvae.py +1 -1
- careamics/model_io/bioimage/bioimage_utils.py +4 -2
- careamics/model_io/bmz_io.py +6 -5
- careamics/models/lvae/likelihoods.py +18 -9
- careamics/models/lvae/lvae.py +12 -16
- careamics/models/lvae/noise_models.py +1 -1
- careamics/transforms/compose.py +90 -15
- careamics/transforms/n2v_manipulate.py +6 -2
- careamics/transforms/normalize.py +14 -3
- careamics/transforms/xy_flip.py +16 -6
- careamics/transforms/xy_random_rotate90.py +16 -7
- careamics/utils/metrics.py +204 -24
- careamics/utils/serializers.py +60 -0
- {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/METADATA +4 -3
- {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/RECORD +54 -43
- careamics-0.0.4.1.dist-info/entry_points.txt +2 -0
- careamics/lvae_training/dataset/data_utils.py +0 -701
- careamics/lvae_training/dataset/lc_dataset_config.py +0 -13
- {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/WHEEL +0 -0
- {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -22,7 +22,7 @@ def get_unzip_path(zip_path: Union[Path, str]) -> Path:
|
|
|
22
22
|
return zip_path.parent / (str(zip_path.name) + ".unzip")
|
|
23
23
|
|
|
24
24
|
|
|
25
|
-
def create_env_text(pytorch_version: str) -> str:
|
|
25
|
+
def create_env_text(pytorch_version: str, torchvision_version: str) -> str:
|
|
26
26
|
"""Create environment yaml content for the bioimage model.
|
|
27
27
|
|
|
28
28
|
This installs an environment with the specified pytorch version and the latest
|
|
@@ -32,6 +32,8 @@ def create_env_text(pytorch_version: str) -> str:
|
|
|
32
32
|
----------
|
|
33
33
|
pytorch_version : str
|
|
34
34
|
Pytorch version.
|
|
35
|
+
torchvision_version : str
|
|
36
|
+
Torchvision version.
|
|
35
37
|
|
|
36
38
|
Returns
|
|
37
39
|
-------
|
|
@@ -43,7 +45,7 @@ def create_env_text(pytorch_version: str) -> str:
|
|
|
43
45
|
f"dependencies:\n"
|
|
44
46
|
f" - python=3.10\n"
|
|
45
47
|
f" - pytorch={pytorch_version}\n"
|
|
46
|
-
f" - torchvision={
|
|
48
|
+
f" - torchvision={torchvision_version}\n"
|
|
47
49
|
f" - pip\n"
|
|
48
50
|
f" - pip:\n"
|
|
49
51
|
f" - git+https://github.com/CAREamics/careamics.git\n"
|
careamics/model_io/bmz_io.py
CHANGED
|
@@ -8,7 +8,9 @@ import numpy as np
|
|
|
8
8
|
import pkg_resources
|
|
9
9
|
from bioimageio.core import load_description, test_model
|
|
10
10
|
from bioimageio.spec import ValidationSummary, save_bioimageio_package
|
|
11
|
-
from torch import __version__
|
|
11
|
+
from torch import __version__ as PYTORCH_VERSION
|
|
12
|
+
from torch import load, save
|
|
13
|
+
from torchvision import __version__ as TORCHVISION_VERSION
|
|
12
14
|
|
|
13
15
|
from careamics.config import Configuration, load_configuration, save_configuration
|
|
14
16
|
from careamics.config.support import SupportedArchitecture
|
|
@@ -141,7 +143,6 @@ def export_to_bmz(
|
|
|
141
143
|
path_to_archive.parent.mkdir(parents=True, exist_ok=True)
|
|
142
144
|
|
|
143
145
|
# versions
|
|
144
|
-
pytorch_version = __version__
|
|
145
146
|
careamics_version = pkg_resources.get_distribution("careamics").version
|
|
146
147
|
|
|
147
148
|
# save files in temporary folder
|
|
@@ -151,7 +152,7 @@ def export_to_bmz(
|
|
|
151
152
|
# create environment file
|
|
152
153
|
# TODO move in bioimage module
|
|
153
154
|
env_path = temp_path / "environment.yml"
|
|
154
|
-
env_path.write_text(create_env_text(
|
|
155
|
+
env_path.write_text(create_env_text(PYTORCH_VERSION, TORCHVISION_VERSION))
|
|
155
156
|
|
|
156
157
|
# export input and ouputs
|
|
157
158
|
inputs = temp_path / "inputs.npy"
|
|
@@ -174,7 +175,7 @@ def export_to_bmz(
|
|
|
174
175
|
inputs=inputs,
|
|
175
176
|
outputs=outputs,
|
|
176
177
|
weights_path=weight_path,
|
|
177
|
-
torch_version=
|
|
178
|
+
torch_version=PYTORCH_VERSION,
|
|
178
179
|
careamics_version=careamics_version,
|
|
179
180
|
config_path=config_path,
|
|
180
181
|
env_path=env_path,
|
|
@@ -183,7 +184,7 @@ def export_to_bmz(
|
|
|
183
184
|
)
|
|
184
185
|
|
|
185
186
|
# test model description
|
|
186
|
-
summary: ValidationSummary = test_model(model_description
|
|
187
|
+
summary: ValidationSummary = test_model(model_description)
|
|
187
188
|
if summary.status == "failed":
|
|
188
189
|
raise ValueError(f"Model description test failed: {summary}")
|
|
189
190
|
|
|
@@ -7,6 +7,7 @@ from __future__ import annotations
|
|
|
7
7
|
import math
|
|
8
8
|
from typing import Literal, Union, TYPE_CHECKING, Any, Optional
|
|
9
9
|
|
|
10
|
+
import numpy as np
|
|
10
11
|
import torch
|
|
11
12
|
from torch import nn
|
|
12
13
|
|
|
@@ -287,30 +288,37 @@ class NoiseModelLikelihood(LikelihoodModule):
|
|
|
287
288
|
|
|
288
289
|
def __init__(
|
|
289
290
|
self,
|
|
290
|
-
data_mean: torch.Tensor,
|
|
291
|
-
data_std: torch.Tensor,
|
|
292
|
-
noiseModel: NoiseModel,
|
|
291
|
+
data_mean: Union[np.ndarray, torch.Tensor],
|
|
292
|
+
data_std: Union[np.ndarray, torch.Tensor],
|
|
293
|
+
noiseModel: NoiseModel,
|
|
293
294
|
):
|
|
294
295
|
"""Constructor.
|
|
295
296
|
|
|
296
297
|
Parameters
|
|
297
298
|
----------
|
|
298
|
-
data_mean: torch.Tensor
|
|
299
|
+
data_mean: Union[np.ndarray, torch.Tensor]
|
|
299
300
|
The mean of the data, used to unnormalize data for noise model evaluation.
|
|
300
|
-
data_std: torch.Tensor
|
|
301
|
+
data_std: Union[np.ndarray, torch.Tensor]
|
|
301
302
|
The standard deviation of the data, used to unnormalize data for noise
|
|
302
303
|
model evaluation.
|
|
303
304
|
noiseModel: NoiseModel
|
|
304
305
|
The noise model instance used to compute the likelihood.
|
|
305
306
|
"""
|
|
306
307
|
super().__init__()
|
|
307
|
-
self.data_mean = data_mean
|
|
308
|
-
self.data_std = data_std
|
|
308
|
+
self.data_mean = torch.Tensor(data_mean)
|
|
309
|
+
self.data_std = torch.Tensor(data_std)
|
|
309
310
|
self.noiseModel = noiseModel
|
|
310
311
|
|
|
311
|
-
def
|
|
312
|
+
def _set_params_to_same_device_as(
|
|
312
313
|
self, correct_device_tensor: torch.Tensor
|
|
313
|
-
) -> None:
|
|
314
|
+
) -> None:
|
|
315
|
+
"""Set the parameters to the same device as the input tensor.
|
|
316
|
+
|
|
317
|
+
Parameters
|
|
318
|
+
----------
|
|
319
|
+
correct_device_tensor: torch.Tensor
|
|
320
|
+
The tensor whose device is used to set the parameters.
|
|
321
|
+
"""
|
|
314
322
|
if self.data_mean.device != correct_device_tensor.device:
|
|
315
323
|
self.data_mean = self.data_mean.to(correct_device_tensor.device)
|
|
316
324
|
self.data_std = self.data_std.to(correct_device_tensor.device)
|
|
@@ -355,6 +363,7 @@ class NoiseModelLikelihood(LikelihoodModule):
|
|
|
355
363
|
torch.Tensor
|
|
356
364
|
The log-likelihood tensor. Shape is (B, C, [Z], Y, X).
|
|
357
365
|
"""
|
|
366
|
+
self._set_params_to_same_device_as(x)
|
|
358
367
|
predicted_s_denormalized = params["mean"] * self.data_std + self.data_mean
|
|
359
368
|
x_denormalized = x * self.data_std + self.data_mean
|
|
360
369
|
likelihoods = self.noiseModel.likelihood(
|
careamics/models/lvae/lvae.py
CHANGED
|
@@ -38,7 +38,6 @@ class LadderVAE(nn.Module):
|
|
|
38
38
|
decoder_dropout: float,
|
|
39
39
|
nonlinearity: str,
|
|
40
40
|
predict_logvar: bool,
|
|
41
|
-
enable_noise_model: bool,
|
|
42
41
|
analytical_kl: bool,
|
|
43
42
|
):
|
|
44
43
|
"""
|
|
@@ -62,15 +61,12 @@ class LadderVAE(nn.Module):
|
|
|
62
61
|
self.decoder_dropout = decoder_dropout
|
|
63
62
|
self.nonlin = nonlinearity
|
|
64
63
|
self.predict_logvar = predict_logvar
|
|
65
|
-
self.enable_noise_model = enable_noise_model
|
|
66
|
-
|
|
67
64
|
self.analytical_kl = analytical_kl
|
|
68
65
|
# -------------------------------------------------------
|
|
69
66
|
|
|
70
67
|
# -------------------------------------------------------
|
|
71
68
|
# Model attributes -> Hardcoded
|
|
72
69
|
self.model_type = ModelType.LadderVae # TODO remove !
|
|
73
|
-
self.model_type = ModelType.LadderVae # TODO remove !
|
|
74
70
|
self.encoder_blocks_per_layer = 1
|
|
75
71
|
self.decoder_blocks_per_layer = 1
|
|
76
72
|
self.bottomup_batchnorm = True
|
|
@@ -94,13 +90,6 @@ class LadderVAE(nn.Module):
|
|
|
94
90
|
self._stochastic_use_naive_exponential = False
|
|
95
91
|
self._enable_topdown_normalize_factor = True
|
|
96
92
|
|
|
97
|
-
# Noise model attributes -> Hardcoded
|
|
98
|
-
self.noise_model_type = "gmm"
|
|
99
|
-
self.denoise_channel = (
|
|
100
|
-
"input" # 4 values for denoise_channel {'Ch1', 'Ch2', 'input','all'}
|
|
101
|
-
)
|
|
102
|
-
self.noise_model_learnable = False
|
|
103
|
-
|
|
104
93
|
# Attributes that handle LC -> Hardcoded
|
|
105
94
|
self.enable_multiscale = (
|
|
106
95
|
self._multiscale_count is not None and self._multiscale_count > 1
|
|
@@ -806,11 +795,18 @@ class LadderVAE(nn.Module):
|
|
|
806
795
|
|
|
807
796
|
# return samples
|
|
808
797
|
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
798
|
+
def reset_for_different_output_size(self, output_size: int) -> None:
|
|
799
|
+
"""Reset shape of output and latent tensors for different output size.
|
|
800
|
+
|
|
801
|
+
Used during evaluation to reset expected shapes of tensors when
|
|
802
|
+
input/output shape changes.
|
|
803
|
+
For instance, it is needed when the model was trained on, say, 64x64 sized
|
|
804
|
+
patches, but prediction is done on 128x128 patches.
|
|
805
|
+
"""
|
|
806
|
+
for i in range(self.n_layers):
|
|
807
|
+
sz = output_size // 2 ** (1 + i)
|
|
808
|
+
self.bottom_up_layers[i].output_expected_shape = (sz, sz)
|
|
809
|
+
self.top_down_layers[i].latent_shape = (output_size, output_size)
|
|
814
810
|
|
|
815
811
|
def pad_input(self, x):
|
|
816
812
|
"""
|
|
@@ -76,7 +76,7 @@ def train_gm_noise_model(
|
|
|
76
76
|
# TODO any training params ? Different channels ?
|
|
77
77
|
noise_model = GaussianMixtureNoiseModel(model_config)
|
|
78
78
|
# TODO revisit config unpacking
|
|
79
|
-
noise_model.train_noise_model(
|
|
79
|
+
noise_model.train_noise_model(model_config.signal, model_config.observation)
|
|
80
80
|
return noise_model
|
|
81
81
|
|
|
82
82
|
|
careamics/transforms/compose.py
CHANGED
|
@@ -1,13 +1,14 @@
|
|
|
1
1
|
"""A class chaining transforms together."""
|
|
2
2
|
|
|
3
|
-
from typing import Dict, List, Optional, Tuple, cast
|
|
3
|
+
from typing import Dict, List, Optional, Tuple, Union, cast
|
|
4
4
|
|
|
5
|
-
|
|
5
|
+
from numpy.typing import NDArray
|
|
6
6
|
|
|
7
|
-
from careamics.config.
|
|
7
|
+
from careamics.config.transformations import TransformModel
|
|
8
8
|
|
|
9
9
|
from .n2v_manipulate import N2VManipulate
|
|
10
10
|
from .normalize import Normalize
|
|
11
|
+
from .transform import Transform
|
|
11
12
|
from .xy_flip import XYFlip
|
|
12
13
|
from .xy_random_rotate90 import XYRandomRotate90
|
|
13
14
|
|
|
@@ -36,7 +37,7 @@ class Compose:
|
|
|
36
37
|
|
|
37
38
|
Parameters
|
|
38
39
|
----------
|
|
39
|
-
transform_list : List[
|
|
40
|
+
transform_list : List[TransformModel]
|
|
40
41
|
A list of dictionaries where each dictionary contains the name of a
|
|
41
42
|
transform and its parameters.
|
|
42
43
|
|
|
@@ -46,26 +47,27 @@ class Compose:
|
|
|
46
47
|
A callable that applies the transforms to the input data.
|
|
47
48
|
"""
|
|
48
49
|
|
|
49
|
-
def __init__(self, transform_list: List[
|
|
50
|
+
def __init__(self, transform_list: List[TransformModel]) -> None:
|
|
50
51
|
"""Instantiate a Compose object.
|
|
51
52
|
|
|
52
53
|
Parameters
|
|
53
54
|
----------
|
|
54
|
-
transform_list : List[
|
|
55
|
+
transform_list : List[TransformModel]
|
|
55
56
|
A list of dictionaries where each dictionary contains the name of a
|
|
56
57
|
transform and its parameters.
|
|
57
58
|
"""
|
|
58
59
|
# retrieve all available transforms
|
|
59
|
-
|
|
60
|
+
# TODO: correctly type hint get_all_transforms function output
|
|
61
|
+
all_transforms: dict[str, type[Transform]] = get_all_transforms()
|
|
60
62
|
|
|
61
63
|
# instantiate all transforms
|
|
62
|
-
self.transforms = [
|
|
64
|
+
self.transforms: list[Transform] = [
|
|
63
65
|
all_transforms[t.name](**t.model_dump()) for t in transform_list
|
|
64
66
|
]
|
|
65
67
|
|
|
66
68
|
def _chain_transforms(
|
|
67
|
-
self, patch:
|
|
68
|
-
) -> Tuple[
|
|
69
|
+
self, patch: NDArray, target: Optional[NDArray]
|
|
70
|
+
) -> Tuple[Optional[NDArray], ...]:
|
|
69
71
|
"""Chain transforms on the input data.
|
|
70
72
|
|
|
71
73
|
Parameters
|
|
@@ -80,16 +82,56 @@ class Compose:
|
|
|
80
82
|
Tuple[np.ndarray, Optional[np.ndarray]]
|
|
81
83
|
The output of the transformations.
|
|
82
84
|
"""
|
|
83
|
-
params
|
|
85
|
+
params: Union[
|
|
86
|
+
tuple[NDArray, Optional[NDArray]],
|
|
87
|
+
tuple[NDArray, NDArray, NDArray], # N2VManiupulate output
|
|
88
|
+
] = (patch, target)
|
|
84
89
|
|
|
85
90
|
for t in self.transforms:
|
|
86
|
-
|
|
91
|
+
# N2VManipulate returns tuple of 3 arrays
|
|
92
|
+
# - Other transoforms return tuple of (patch, target, additional_arrays)
|
|
93
|
+
if isinstance(t, N2VManipulate):
|
|
94
|
+
patch, *_ = params
|
|
95
|
+
params = t(patch=patch)
|
|
96
|
+
else:
|
|
97
|
+
*params, _ = t(*params) # ignore additional_arrays dict
|
|
87
98
|
|
|
88
99
|
return params
|
|
89
100
|
|
|
101
|
+
def _chain_transforms_additional_arrays(
|
|
102
|
+
self,
|
|
103
|
+
patch: NDArray,
|
|
104
|
+
target: Optional[NDArray],
|
|
105
|
+
**additional_arrays: NDArray,
|
|
106
|
+
) -> Tuple[NDArray, Optional[NDArray], dict[str, NDArray]]:
|
|
107
|
+
"""Chain transforms on the input data, with additional arrays.
|
|
108
|
+
|
|
109
|
+
Parameters
|
|
110
|
+
----------
|
|
111
|
+
patch : np.ndarray
|
|
112
|
+
Input data.
|
|
113
|
+
target : Optional[np.ndarray]
|
|
114
|
+
Target data, by default None.
|
|
115
|
+
**additional_arrays : NDArray
|
|
116
|
+
Additional arrays that will be transformed identically to `patch` and
|
|
117
|
+
`target`.
|
|
118
|
+
|
|
119
|
+
Returns
|
|
120
|
+
-------
|
|
121
|
+
Tuple[np.ndarray, Optional[np.ndarray]]
|
|
122
|
+
The output of the transformations.
|
|
123
|
+
"""
|
|
124
|
+
params = {"patch": patch, "target": target, **additional_arrays}
|
|
125
|
+
|
|
126
|
+
for t in self.transforms:
|
|
127
|
+
patch, target, additional_arrays = t(**params)
|
|
128
|
+
params = {"patch": patch, "target": target, **additional_arrays}
|
|
129
|
+
|
|
130
|
+
return patch, target, additional_arrays
|
|
131
|
+
|
|
90
132
|
def __call__(
|
|
91
|
-
self, patch:
|
|
92
|
-
) -> Tuple[
|
|
133
|
+
self, patch: NDArray, target: Optional[NDArray] = None
|
|
134
|
+
) -> Tuple[NDArray, ...]:
|
|
93
135
|
"""Apply the transforms to the input data.
|
|
94
136
|
|
|
95
137
|
Parameters
|
|
@@ -104,4 +146,37 @@ class Compose:
|
|
|
104
146
|
Tuple[np.ndarray, ...]
|
|
105
147
|
The output of the transformations.
|
|
106
148
|
"""
|
|
107
|
-
|
|
149
|
+
# TODO: solve casting Compose.__call__ ouput
|
|
150
|
+
return cast(Tuple[NDArray, ...], self._chain_transforms(patch, target))
|
|
151
|
+
|
|
152
|
+
def transform_with_additional_arrays(
|
|
153
|
+
self,
|
|
154
|
+
patch: NDArray,
|
|
155
|
+
target: Optional[NDArray] = None,
|
|
156
|
+
**additional_arrays: NDArray,
|
|
157
|
+
) -> tuple[NDArray, Optional[NDArray], dict[str, NDArray]]:
|
|
158
|
+
"""Apply the transforms to the input data, including additional arrays.
|
|
159
|
+
|
|
160
|
+
Parameters
|
|
161
|
+
----------
|
|
162
|
+
patch : np.ndarray
|
|
163
|
+
The input data.
|
|
164
|
+
target : Optional[np.ndarray], optional
|
|
165
|
+
Target data, by default None.
|
|
166
|
+
**additional_arrays : NDArray
|
|
167
|
+
Additional arrays that will be transformed identically to `patch` and
|
|
168
|
+
`target`.
|
|
169
|
+
|
|
170
|
+
Returns
|
|
171
|
+
-------
|
|
172
|
+
NDArray
|
|
173
|
+
The transformed patch.
|
|
174
|
+
NDArray | None
|
|
175
|
+
The transformed target.
|
|
176
|
+
dict of {str, NDArray}
|
|
177
|
+
Transformed additional arrays. Keys correspond to the keyword argument
|
|
178
|
+
names.
|
|
179
|
+
"""
|
|
180
|
+
return self._chain_transforms_additional_arrays(
|
|
181
|
+
patch, target, **additional_arrays
|
|
182
|
+
)
|
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
from typing import Any, Literal, Optional, Tuple
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
|
+
from numpy.typing import NDArray
|
|
6
7
|
|
|
7
8
|
from careamics.config.support import SupportedPixelManipulation, SupportedStructAxis
|
|
8
9
|
from careamics.transforms.transform import Transform
|
|
@@ -98,8 +99,8 @@ class N2VManipulate(Transform):
|
|
|
98
99
|
self.rng = np.random.default_rng(seed=seed)
|
|
99
100
|
|
|
100
101
|
def __call__(
|
|
101
|
-
self, patch:
|
|
102
|
-
) -> Tuple[
|
|
102
|
+
self, patch: NDArray, *args: Any, **kwargs: Any
|
|
103
|
+
) -> Tuple[NDArray, NDArray, NDArray]:
|
|
103
104
|
"""Apply the transform to the image.
|
|
104
105
|
|
|
105
106
|
Parameters
|
|
@@ -142,5 +143,8 @@ class N2VManipulate(Transform):
|
|
|
142
143
|
else:
|
|
143
144
|
raise ValueError(f"Unknown masking strategy ({self.strategy}).")
|
|
144
145
|
|
|
146
|
+
# TODO: Output does not match other transforms, how to resolve?
|
|
147
|
+
# - Don't include in Compose and apply after if algorithm is N2V?
|
|
148
|
+
# - or just don't return patch? but then mask is in the target position
|
|
145
149
|
# TODO why return patch?
|
|
146
150
|
return masked, patch, mask
|
|
@@ -90,8 +90,11 @@ class Normalize(Transform):
|
|
|
90
90
|
self.eps = 1e-6
|
|
91
91
|
|
|
92
92
|
def __call__(
|
|
93
|
-
self,
|
|
94
|
-
|
|
93
|
+
self,
|
|
94
|
+
patch: np.ndarray,
|
|
95
|
+
target: Optional[NDArray] = None,
|
|
96
|
+
**additional_arrays: NDArray,
|
|
97
|
+
) -> tuple[NDArray, Optional[NDArray], dict[str, NDArray]]:
|
|
95
98
|
"""Apply the transform to the source patch and the target (optional).
|
|
96
99
|
|
|
97
100
|
Parameters
|
|
@@ -100,6 +103,9 @@ class Normalize(Transform):
|
|
|
100
103
|
Patch, 2D or 3D, shape C(Z)YX.
|
|
101
104
|
target : NDArray, optional
|
|
102
105
|
Target for the patch, by default None.
|
|
106
|
+
**additional_arrays : NDArray
|
|
107
|
+
Additional arrays that will be transformed identically to `patch` and
|
|
108
|
+
`target`.
|
|
103
109
|
|
|
104
110
|
Returns
|
|
105
111
|
-------
|
|
@@ -111,6 +117,11 @@ class Normalize(Transform):
|
|
|
111
117
|
f"Number of means (got a list of size {len(self.image_means)}) and "
|
|
112
118
|
f"number of channels (got shape {patch.shape} for C(Z)YX) do not match."
|
|
113
119
|
)
|
|
120
|
+
if len(additional_arrays) != 0:
|
|
121
|
+
raise NotImplementedError(
|
|
122
|
+
"Transforming additional arrays is currently not supported for "
|
|
123
|
+
"`Normalize`."
|
|
124
|
+
)
|
|
114
125
|
|
|
115
126
|
# reshape mean and std and apply the normalization to the patch
|
|
116
127
|
means = _reshape_stats(self.image_means, patch.ndim)
|
|
@@ -129,7 +140,7 @@ class Normalize(Transform):
|
|
|
129
140
|
else:
|
|
130
141
|
norm_target = None
|
|
131
142
|
|
|
132
|
-
return norm_patch, norm_target
|
|
143
|
+
return norm_patch, norm_target, additional_arrays
|
|
133
144
|
|
|
134
145
|
def _apply(self, patch: NDArray, mean: NDArray, std: NDArray) -> NDArray:
|
|
135
146
|
"""
|
careamics/transforms/xy_flip.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
"""XY flip transform."""
|
|
2
2
|
|
|
3
|
-
from typing import Optional
|
|
3
|
+
from typing import Optional
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
|
+
from numpy.typing import NDArray
|
|
6
7
|
|
|
7
8
|
from careamics.transforms.transform import Transform
|
|
8
9
|
|
|
@@ -78,8 +79,11 @@ class XYFlip(Transform):
|
|
|
78
79
|
self.rng = np.random.default_rng(seed=seed)
|
|
79
80
|
|
|
80
81
|
def __call__(
|
|
81
|
-
self,
|
|
82
|
-
|
|
82
|
+
self,
|
|
83
|
+
patch: NDArray,
|
|
84
|
+
target: Optional[NDArray] = None,
|
|
85
|
+
**additional_arrays: NDArray,
|
|
86
|
+
) -> tuple[NDArray, Optional[NDArray], dict[str, NDArray]]:
|
|
83
87
|
"""Apply the transform to the source patch and the target (optional).
|
|
84
88
|
|
|
85
89
|
Parameters
|
|
@@ -88,6 +92,9 @@ class XYFlip(Transform):
|
|
|
88
92
|
Patch, 2D or 3D, shape C(Z)YX.
|
|
89
93
|
target : Optional[np.ndarray], optional
|
|
90
94
|
Target for the patch, by default None.
|
|
95
|
+
**additional_arrays : NDArray
|
|
96
|
+
Additional arrays that will be transformed identically to `patch` and
|
|
97
|
+
`target`.
|
|
91
98
|
|
|
92
99
|
Returns
|
|
93
100
|
-------
|
|
@@ -95,17 +102,20 @@ class XYFlip(Transform):
|
|
|
95
102
|
Transformed patch and target.
|
|
96
103
|
"""
|
|
97
104
|
if self.rng.random() > self.p:
|
|
98
|
-
return patch, target
|
|
105
|
+
return patch, target, additional_arrays
|
|
99
106
|
|
|
100
107
|
# choose an axis to flip
|
|
101
108
|
axis = self.rng.choice(self.axis_indices)
|
|
102
109
|
|
|
103
110
|
patch_transformed = self._apply(patch, axis)
|
|
104
111
|
target_transformed = self._apply(target, axis) if target is not None else None
|
|
112
|
+
additional_transformed = {
|
|
113
|
+
key: self._apply(array, axis) for key, array in additional_arrays.items()
|
|
114
|
+
}
|
|
105
115
|
|
|
106
|
-
return patch_transformed, target_transformed
|
|
116
|
+
return patch_transformed, target_transformed, additional_transformed
|
|
107
117
|
|
|
108
|
-
def _apply(self, patch:
|
|
118
|
+
def _apply(self, patch: NDArray, axis: int) -> NDArray:
|
|
109
119
|
"""Apply the transform to the image.
|
|
110
120
|
|
|
111
121
|
Parameters
|
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
from typing import Optional, Tuple
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
|
+
from numpy.typing import NDArray
|
|
6
7
|
|
|
7
8
|
from careamics.transforms.transform import Transform
|
|
8
9
|
|
|
@@ -49,8 +50,11 @@ class XYRandomRotate90(Transform):
|
|
|
49
50
|
self.rng = np.random.default_rng(seed=seed)
|
|
50
51
|
|
|
51
52
|
def __call__(
|
|
52
|
-
self,
|
|
53
|
-
|
|
53
|
+
self,
|
|
54
|
+
patch: NDArray,
|
|
55
|
+
target: Optional[NDArray] = None,
|
|
56
|
+
**additional_arrays: NDArray,
|
|
57
|
+
) -> tuple[NDArray, Optional[NDArray], dict[str, NDArray]]:
|
|
54
58
|
"""Apply the transform to the source patch and the target (optional).
|
|
55
59
|
|
|
56
60
|
Parameters
|
|
@@ -59,6 +63,9 @@ class XYRandomRotate90(Transform):
|
|
|
59
63
|
Patch, 2D or 3D, shape C(Z)YX.
|
|
60
64
|
target : Optional[np.ndarray], optional
|
|
61
65
|
Target for the patch, by default None.
|
|
66
|
+
**additional_arrays : NDArray
|
|
67
|
+
Additional arrays that will be transformed identically to `patch` and
|
|
68
|
+
`target`.
|
|
62
69
|
|
|
63
70
|
Returns
|
|
64
71
|
-------
|
|
@@ -66,7 +73,7 @@ class XYRandomRotate90(Transform):
|
|
|
66
73
|
Transformed patch and target.
|
|
67
74
|
"""
|
|
68
75
|
if self.rng.random() > self.p:
|
|
69
|
-
return patch, target
|
|
76
|
+
return patch, target, additional_arrays
|
|
70
77
|
|
|
71
78
|
# number of rotations
|
|
72
79
|
n_rot = self.rng.integers(1, 4)
|
|
@@ -76,12 +83,14 @@ class XYRandomRotate90(Transform):
|
|
|
76
83
|
target_transformed = (
|
|
77
84
|
self._apply(target, n_rot, axes) if target is not None else None
|
|
78
85
|
)
|
|
86
|
+
additional_transformed = {
|
|
87
|
+
key: self._apply(array, n_rot, axes)
|
|
88
|
+
for key, array in additional_arrays.items()
|
|
89
|
+
}
|
|
79
90
|
|
|
80
|
-
return patch_transformed, target_transformed
|
|
91
|
+
return patch_transformed, target_transformed, additional_transformed
|
|
81
92
|
|
|
82
|
-
def _apply(
|
|
83
|
-
self, patch: np.ndarray, n_rot: int, axes: Tuple[int, int]
|
|
84
|
-
) -> np.ndarray:
|
|
93
|
+
def _apply(self, patch: NDArray, n_rot: int, axes: Tuple[int, int]) -> NDArray:
|
|
85
94
|
"""Apply the transform to the image.
|
|
86
95
|
|
|
87
96
|
Parameters
|