careamics 0.1.0rc5__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/callbacks/hyperparameters_callback.py +10 -3
- careamics/callbacks/progress_bar_callback.py +37 -4
- careamics/careamist.py +164 -231
- careamics/config/algorithm_model.py +5 -18
- careamics/config/architectures/architecture_model.py +7 -0
- careamics/config/architectures/custom_model.py +11 -4
- careamics/config/architectures/register_model.py +3 -1
- careamics/config/architectures/unet_model.py +2 -0
- careamics/config/architectures/vae_model.py +2 -0
- careamics/config/callback_model.py +3 -15
- careamics/config/configuration_example.py +4 -5
- careamics/config/configuration_factory.py +27 -41
- careamics/config/configuration_model.py +11 -11
- careamics/config/data_model.py +89 -63
- careamics/config/inference_model.py +28 -81
- careamics/config/optimizer_models.py +11 -11
- careamics/config/support/__init__.py +0 -2
- careamics/config/support/supported_activations.py +2 -0
- careamics/config/support/supported_algorithms.py +3 -1
- careamics/config/support/supported_architectures.py +2 -0
- careamics/config/support/supported_data.py +2 -0
- careamics/config/support/supported_loggers.py +2 -0
- careamics/config/support/supported_losses.py +2 -0
- careamics/config/support/supported_optimizers.py +2 -0
- careamics/config/support/supported_pixel_manipulations.py +3 -3
- careamics/config/support/supported_struct_axis.py +2 -0
- careamics/config/support/supported_transforms.py +4 -16
- careamics/config/tile_information.py +28 -58
- careamics/config/transformations/__init__.py +3 -2
- careamics/config/transformations/normalize_model.py +32 -4
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +11 -3
- careamics/config/validators/validator_utils.py +1 -1
- careamics/conftest.py +12 -0
- careamics/dataset/__init__.py +12 -1
- careamics/dataset/dataset_utils/__init__.py +8 -1
- careamics/dataset/dataset_utils/dataset_utils.py +4 -4
- careamics/dataset/dataset_utils/file_utils.py +4 -3
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/read_tiff.py +6 -11
- careamics/dataset/dataset_utils/read_utils.py +2 -0
- careamics/dataset/dataset_utils/read_zarr.py +11 -7
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +88 -154
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +121 -191
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
- careamics/dataset/patching/patching.py +109 -39
- careamics/dataset/patching/random_patching.py +17 -6
- careamics/dataset/patching/sequential_patching.py +14 -8
- careamics/dataset/patching/validate_patch_dimension.py +7 -3
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/{patching → tiling}/tiled_patching.py +7 -5
- careamics/dataset/zarr_dataset.py +2 -0
- careamics/lightning_datamodule.py +46 -25
- careamics/lightning_module.py +19 -9
- careamics/lightning_prediction_datamodule.py +54 -84
- careamics/losses/__init__.py +2 -3
- careamics/losses/loss_factory.py +1 -1
- careamics/losses/losses.py +11 -7
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/bioimage/model_description.py +40 -32
- careamics/model_io/bmz_io.py +3 -3
- careamics/model_io/model_io_utils.py +5 -2
- careamics/models/activation.py +2 -0
- careamics/models/layers.py +121 -25
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +1 -1
- careamics/models/unet.py +35 -14
- careamics/prediction_utils/__init__.py +12 -0
- careamics/prediction_utils/create_pred_datamodule.py +185 -0
- careamics/prediction_utils/prediction_outputs.py +165 -0
- careamics/prediction_utils/stitch_prediction.py +100 -0
- careamics/transforms/__init__.py +2 -2
- careamics/transforms/compose.py +33 -7
- careamics/transforms/n2v_manipulate.py +52 -14
- careamics/transforms/normalize.py +171 -48
- careamics/transforms/pixel_manipulation.py +35 -11
- careamics/transforms/struct_mask_parameters.py +3 -1
- careamics/transforms/transform.py +10 -19
- careamics/transforms/tta.py +43 -29
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +38 -5
- careamics/utils/base_enum.py +28 -0
- careamics/utils/path_utils.py +2 -0
- careamics/utils/ram.py +4 -2
- careamics/utils/receptive_field.py +93 -87
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +8 -6
- careamics-0.1.0rc7.dist-info/RECORD +130 -0
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
- careamics/config/noise_models.py +0 -162
- careamics/config/support/supported_extraction_strategies.py +0 -25
- careamics/config/transformations/nd_flip_model.py +0 -27
- careamics/lightning_prediction_loop.py +0 -116
- careamics/losses/noise_model_factory.py +0 -40
- careamics/losses/noise_models.py +0 -524
- careamics/prediction/__init__.py +0 -7
- careamics/prediction/stitch_prediction.py +0 -74
- careamics/transforms/nd_flip.py +0 -67
- careamics/utils/running_stats.py +0 -43
- careamics-0.1.0rc5.dist-info/RECORD +0 -111
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
"""Computing data statistics."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from numpy.typing import NDArray
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def compute_normalization_stats(image: NDArray) -> tuple[NDArray, NDArray]:
|
|
8
|
+
"""
|
|
9
|
+
Compute mean and standard deviation of an array.
|
|
10
|
+
|
|
11
|
+
Expected input shape is (S, C, (Z), Y, X). The mean and standard deviation are
|
|
12
|
+
computed per channel.
|
|
13
|
+
|
|
14
|
+
Parameters
|
|
15
|
+
----------
|
|
16
|
+
image : NDArray
|
|
17
|
+
Input array.
|
|
18
|
+
|
|
19
|
+
Returns
|
|
20
|
+
-------
|
|
21
|
+
tuple of (list of floats, list of floats)
|
|
22
|
+
Lists of mean and standard deviation values per channel.
|
|
23
|
+
"""
|
|
24
|
+
# Define the list of axes excluding the channel axis
|
|
25
|
+
axes = tuple(np.delete(np.arange(image.ndim), 1))
|
|
26
|
+
return np.mean(image, axis=axes), np.std(image, axis=axes)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def update_iterative_stats(
|
|
30
|
+
count: NDArray, mean: NDArray, m2: NDArray, new_values: NDArray
|
|
31
|
+
) -> tuple[NDArray, NDArray, NDArray]:
|
|
32
|
+
"""Update the mean and variance of an array iteratively.
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
count : NDArray
|
|
37
|
+
Number of elements in the array.
|
|
38
|
+
mean : NDArray
|
|
39
|
+
Mean of the array.
|
|
40
|
+
m2 : NDArray
|
|
41
|
+
Variance of the array.
|
|
42
|
+
new_values : NDArray
|
|
43
|
+
New values to add to the mean and variance.
|
|
44
|
+
|
|
45
|
+
Returns
|
|
46
|
+
-------
|
|
47
|
+
tuple[NDArray, NDArray, NDArray]
|
|
48
|
+
Updated count, mean, and variance.
|
|
49
|
+
"""
|
|
50
|
+
count += np.array([np.prod(channel.shape) for channel in new_values])
|
|
51
|
+
# newvalues - oldMean
|
|
52
|
+
delta = [
|
|
53
|
+
np.subtract(v.flatten(), [m] * len(v.flatten()))
|
|
54
|
+
for v, m in zip(new_values, mean)
|
|
55
|
+
]
|
|
56
|
+
|
|
57
|
+
mean += np.array([np.sum(d / c) for d, c in zip(delta, count)])
|
|
58
|
+
# newvalues - newMeant
|
|
59
|
+
delta2 = [
|
|
60
|
+
np.subtract(v.flatten(), [m] * len(v.flatten()))
|
|
61
|
+
for v, m in zip(new_values, mean)
|
|
62
|
+
]
|
|
63
|
+
|
|
64
|
+
m2 += np.array([np.sum(d * d2) for d, d2 in zip(delta, delta2)])
|
|
65
|
+
|
|
66
|
+
return (count, mean, m2)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def finalize_iterative_stats(
|
|
70
|
+
count: NDArray, mean: NDArray, m2: NDArray
|
|
71
|
+
) -> tuple[NDArray, NDArray]:
|
|
72
|
+
"""Finalize the mean and variance computation.
|
|
73
|
+
|
|
74
|
+
Parameters
|
|
75
|
+
----------
|
|
76
|
+
count : NDArray
|
|
77
|
+
Number of elements in the array.
|
|
78
|
+
mean : NDArray
|
|
79
|
+
Mean of the array.
|
|
80
|
+
m2 : NDArray
|
|
81
|
+
Variance of the array.
|
|
82
|
+
|
|
83
|
+
Returns
|
|
84
|
+
-------
|
|
85
|
+
tuple[NDArray, NDArray]
|
|
86
|
+
Final mean and standard deviation.
|
|
87
|
+
"""
|
|
88
|
+
std = np.array([np.sqrt(m / c) for m, c in zip(m2, count)])
|
|
89
|
+
if any(c < 2 for c in count):
|
|
90
|
+
return np.full(mean.shape, np.nan), np.full(std.shape, np.nan)
|
|
91
|
+
else:
|
|
92
|
+
return mean, std
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class WelfordStatistics:
|
|
96
|
+
"""Compute Welford statistics iteratively.
|
|
97
|
+
|
|
98
|
+
The Welford algorithm is used to compute the mean and variance of an array
|
|
99
|
+
iteratively. Based on the implementation from:
|
|
100
|
+
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
def update(self, array: NDArray, sample_idx: int) -> None:
|
|
104
|
+
"""Update the Welford statistics.
|
|
105
|
+
|
|
106
|
+
Parameters
|
|
107
|
+
----------
|
|
108
|
+
array : NDArray
|
|
109
|
+
Input array.
|
|
110
|
+
sample_idx : int
|
|
111
|
+
Current sample number.
|
|
112
|
+
"""
|
|
113
|
+
self.sample_idx = sample_idx
|
|
114
|
+
sample_channels = np.array(np.split(array, array.shape[1], axis=1))
|
|
115
|
+
|
|
116
|
+
# Initialize the statistics
|
|
117
|
+
if self.sample_idx == 0:
|
|
118
|
+
# Compute the mean and standard deviation
|
|
119
|
+
self.mean, _ = compute_normalization_stats(array)
|
|
120
|
+
# Initialize the count and m2 with zero-valued arrays of shape (C,)
|
|
121
|
+
self.count, self.mean, self.m2 = update_iterative_stats(
|
|
122
|
+
count=np.zeros(array.shape[1]),
|
|
123
|
+
mean=self.mean,
|
|
124
|
+
m2=np.zeros(array.shape[1]),
|
|
125
|
+
new_values=sample_channels,
|
|
126
|
+
)
|
|
127
|
+
else:
|
|
128
|
+
# Update the statistics
|
|
129
|
+
self.count, self.mean, self.m2 = update_iterative_stats(
|
|
130
|
+
count=self.count, mean=self.mean, m2=self.m2, new_values=sample_channels
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
self.sample_idx += 1
|
|
134
|
+
|
|
135
|
+
def finalize(self) -> tuple[NDArray, NDArray]:
|
|
136
|
+
"""Finalize the Welford statistics.
|
|
137
|
+
|
|
138
|
+
Returns
|
|
139
|
+
-------
|
|
140
|
+
tuple or numpy arrays
|
|
141
|
+
Final mean and standard deviation.
|
|
142
|
+
"""
|
|
143
|
+
return finalize_iterative_stats(self.count, self.mean, self.m2)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
# from multiprocessing import Value
|
|
147
|
+
# from typing import tuple
|
|
148
|
+
|
|
149
|
+
# import numpy as np
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
# class RunningStats:
|
|
153
|
+
# """Calculates running mean and std."""
|
|
154
|
+
|
|
155
|
+
# def __init__(self) -> None:
|
|
156
|
+
# self.reset()
|
|
157
|
+
|
|
158
|
+
# def reset(self) -> None:
|
|
159
|
+
# """Reset the running stats."""
|
|
160
|
+
# self.avg_mean = Value("d", 0)
|
|
161
|
+
# self.avg_std = Value("d", 0)
|
|
162
|
+
# self.m2 = Value("d", 0)
|
|
163
|
+
# self.count = Value("i", 0)
|
|
164
|
+
|
|
165
|
+
# def init(self, mean: float, std: float) -> None:
|
|
166
|
+
# """Initialize running stats."""
|
|
167
|
+
# with self.avg_mean.get_lock():
|
|
168
|
+
# self.avg_mean.value += mean
|
|
169
|
+
# with self.avg_std.get_lock():
|
|
170
|
+
# self.avg_std.value = std
|
|
171
|
+
|
|
172
|
+
# def compute_std(self) -> tuple[float, float]:
|
|
173
|
+
# """Compute std."""
|
|
174
|
+
# if self.count.value >= 2:
|
|
175
|
+
# self.avg_std.value = np.sqrt(self.m2.value / self.count.value)
|
|
176
|
+
|
|
177
|
+
# def update(self, value: float) -> None:
|
|
178
|
+
# """Update running stats."""
|
|
179
|
+
# with self.count.get_lock():
|
|
180
|
+
# self.count.value += 1
|
|
181
|
+
# delta = value - self.avg_mean.value
|
|
182
|
+
# with self.avg_mean.get_lock():
|
|
183
|
+
# self.avg_mean.value += delta / self.count.value
|
|
184
|
+
# delta2 = value - self.avg_mean.value
|
|
185
|
+
# with self.m2.get_lock():
|
|
186
|
+
# self.m2.value += delta * delta2
|
|
@@ -4,47 +4,73 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
import copy
|
|
6
6
|
from pathlib import Path
|
|
7
|
-
from typing import Any, Callable,
|
|
7
|
+
from typing import Any, Callable, Optional, Union
|
|
8
8
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
from torch.utils.data import Dataset
|
|
11
11
|
|
|
12
12
|
from careamics.transforms import Compose
|
|
13
13
|
|
|
14
|
-
from ..config import DataConfig
|
|
15
|
-
from ..config.
|
|
14
|
+
from ..config import DataConfig
|
|
15
|
+
from ..config.transformations import NormalizeModel
|
|
16
16
|
from ..utils.logging import get_logger
|
|
17
|
-
from .dataset_utils import read_tiff
|
|
17
|
+
from .dataset_utils import read_tiff
|
|
18
18
|
from .patching.patching import (
|
|
19
|
+
PatchedOutput,
|
|
19
20
|
prepare_patches_supervised,
|
|
20
21
|
prepare_patches_supervised_array,
|
|
21
22
|
prepare_patches_unsupervised,
|
|
22
23
|
prepare_patches_unsupervised_array,
|
|
23
24
|
)
|
|
24
|
-
from .patching.tiled_patching import extract_tiles
|
|
25
25
|
|
|
26
26
|
logger = get_logger(__name__)
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
class InMemoryDataset(Dataset):
|
|
30
|
-
"""Dataset storing data in memory and allowing generating patches from it.
|
|
30
|
+
"""Dataset storing data in memory and allowing generating patches from it.
|
|
31
|
+
|
|
32
|
+
Parameters
|
|
33
|
+
----------
|
|
34
|
+
data_config : CAREamics DataConfig
|
|
35
|
+
(see careamics.config.data_model.DataConfig)
|
|
36
|
+
Data configuration.
|
|
37
|
+
inputs : numpy.ndarray or list[pathlib.Path]
|
|
38
|
+
Input data.
|
|
39
|
+
input_target : numpy.ndarray or list[pathlib.Path], optional
|
|
40
|
+
Target data, by default None.
|
|
41
|
+
read_source_func : Callable, optional
|
|
42
|
+
Read source function for custom types, by default read_tiff.
|
|
43
|
+
**kwargs : Any
|
|
44
|
+
Additional keyword arguments, unused.
|
|
45
|
+
"""
|
|
31
46
|
|
|
32
47
|
def __init__(
|
|
33
48
|
self,
|
|
34
49
|
data_config: DataConfig,
|
|
35
|
-
inputs: Union[np.ndarray,
|
|
36
|
-
|
|
50
|
+
inputs: Union[np.ndarray, list[Path]],
|
|
51
|
+
input_target: Optional[Union[np.ndarray, list[Path]]] = None,
|
|
37
52
|
read_source_func: Callable = read_tiff,
|
|
38
53
|
**kwargs: Any,
|
|
39
54
|
) -> None:
|
|
40
55
|
"""
|
|
41
56
|
Constructor.
|
|
42
57
|
|
|
43
|
-
|
|
58
|
+
Parameters
|
|
59
|
+
----------
|
|
60
|
+
data_config : DataConfig
|
|
61
|
+
Data configuration.
|
|
62
|
+
inputs : numpy.ndarray or list[pathlib.Path]
|
|
63
|
+
Input data.
|
|
64
|
+
input_target : numpy.ndarray or list[pathlib.Path], optional
|
|
65
|
+
Target data, by default None.
|
|
66
|
+
read_source_func : Callable, optional
|
|
67
|
+
Read source function for custom types, by default read_tiff.
|
|
68
|
+
**kwargs : Any
|
|
69
|
+
Additional keyword arguments, unused.
|
|
44
70
|
"""
|
|
45
71
|
self.data_config = data_config
|
|
46
72
|
self.inputs = inputs
|
|
47
|
-
self.
|
|
73
|
+
self.input_targets = input_target
|
|
48
74
|
self.axes = self.data_config.axes
|
|
49
75
|
self.patch_size = self.data_config.patch_size
|
|
50
76
|
|
|
@@ -52,30 +78,52 @@ class InMemoryDataset(Dataset):
|
|
|
52
78
|
self.read_source_func = read_source_func
|
|
53
79
|
|
|
54
80
|
# Generate patches
|
|
55
|
-
supervised = self.
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
#
|
|
59
|
-
self.data
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
81
|
+
supervised = self.input_targets is not None
|
|
82
|
+
patches_data = self._prepare_patches(supervised)
|
|
83
|
+
|
|
84
|
+
# Unpack the dataclass
|
|
85
|
+
self.data = patches_data.patches
|
|
86
|
+
self.data_targets = patches_data.targets
|
|
87
|
+
|
|
88
|
+
if self.data_config.image_means is None:
|
|
89
|
+
self.image_means = patches_data.image_stats.means
|
|
90
|
+
self.image_stds = patches_data.image_stats.stds
|
|
91
|
+
logger.info(
|
|
92
|
+
f"Computed dataset mean: {self.image_means}, std: {self.image_stds}"
|
|
93
|
+
)
|
|
68
94
|
else:
|
|
69
|
-
self.
|
|
95
|
+
self.image_means = self.data_config.image_means
|
|
96
|
+
self.image_stds = self.data_config.image_stds
|
|
70
97
|
|
|
98
|
+
if self.data_config.target_means is None:
|
|
99
|
+
self.target_means = patches_data.target_stats.means
|
|
100
|
+
self.target_stds = patches_data.target_stats.stds
|
|
101
|
+
else:
|
|
102
|
+
self.target_means = self.data_config.target_means
|
|
103
|
+
self.target_stds = self.data_config.target_stds
|
|
104
|
+
|
|
105
|
+
# update mean and std in configuration
|
|
106
|
+
# the object is mutable and should then be recorded in the CAREamist obj
|
|
107
|
+
self.data_config.set_mean_and_std(
|
|
108
|
+
image_means=self.image_means,
|
|
109
|
+
image_stds=self.image_stds,
|
|
110
|
+
target_means=self.target_means,
|
|
111
|
+
target_stds=self.target_stds,
|
|
112
|
+
)
|
|
71
113
|
# get transforms
|
|
72
114
|
self.patch_transform = Compose(
|
|
73
|
-
transform_list=
|
|
115
|
+
transform_list=[
|
|
116
|
+
NormalizeModel(
|
|
117
|
+
image_means=self.image_means,
|
|
118
|
+
image_stds=self.image_stds,
|
|
119
|
+
target_means=self.target_means,
|
|
120
|
+
target_stds=self.target_stds,
|
|
121
|
+
)
|
|
122
|
+
]
|
|
123
|
+
+ self.data_config.transforms,
|
|
74
124
|
)
|
|
75
125
|
|
|
76
|
-
def _prepare_patches(
|
|
77
|
-
self, supervised: bool
|
|
78
|
-
) -> Tuple[np.ndarray, Optional[np.ndarray], float, float]:
|
|
126
|
+
def _prepare_patches(self, supervised: bool) -> PatchedOutput:
|
|
79
127
|
"""
|
|
80
128
|
Iterate over data source and create an array of patches.
|
|
81
129
|
|
|
@@ -86,23 +134,23 @@ class InMemoryDataset(Dataset):
|
|
|
86
134
|
|
|
87
135
|
Returns
|
|
88
136
|
-------
|
|
89
|
-
|
|
137
|
+
numpy.ndarray
|
|
90
138
|
Array of patches.
|
|
91
139
|
"""
|
|
92
140
|
if supervised:
|
|
93
141
|
if isinstance(self.inputs, np.ndarray) and isinstance(
|
|
94
|
-
self.
|
|
142
|
+
self.input_targets, np.ndarray
|
|
95
143
|
):
|
|
96
144
|
return prepare_patches_supervised_array(
|
|
97
145
|
self.inputs,
|
|
98
146
|
self.axes,
|
|
99
|
-
self.
|
|
147
|
+
self.input_targets,
|
|
100
148
|
self.patch_size,
|
|
101
149
|
)
|
|
102
|
-
elif isinstance(self.inputs, list) and isinstance(self.
|
|
150
|
+
elif isinstance(self.inputs, list) and isinstance(self.input_targets, list):
|
|
103
151
|
return prepare_patches_supervised(
|
|
104
152
|
self.inputs,
|
|
105
|
-
self.
|
|
153
|
+
self.input_targets,
|
|
106
154
|
self.axes,
|
|
107
155
|
self.patch_size,
|
|
108
156
|
self.read_source_func,
|
|
@@ -111,7 +159,7 @@ class InMemoryDataset(Dataset):
|
|
|
111
159
|
raise ValueError(
|
|
112
160
|
f"Data and target must be of the same type, either both numpy "
|
|
113
161
|
f"arrays or both lists of paths, got {type(self.inputs)} (data) "
|
|
114
|
-
f"and {type(self.
|
|
162
|
+
f"and {type(self.input_targets)} (target)."
|
|
115
163
|
)
|
|
116
164
|
else:
|
|
117
165
|
if isinstance(self.inputs, np.ndarray):
|
|
@@ -137,9 +185,9 @@ class InMemoryDataset(Dataset):
|
|
|
137
185
|
int
|
|
138
186
|
Length of the dataset.
|
|
139
187
|
"""
|
|
140
|
-
return
|
|
188
|
+
return self.data.shape[0]
|
|
141
189
|
|
|
142
|
-
def __getitem__(self, index: int) ->
|
|
190
|
+
def __getitem__(self, index: int) -> tuple[np.ndarray, ...]:
|
|
143
191
|
"""
|
|
144
192
|
Return the patch corresponding to the provided index.
|
|
145
193
|
|
|
@@ -150,7 +198,7 @@ class InMemoryDataset(Dataset):
|
|
|
150
198
|
|
|
151
199
|
Returns
|
|
152
200
|
-------
|
|
153
|
-
|
|
201
|
+
tuple of numpy.ndarray
|
|
154
202
|
Patch.
|
|
155
203
|
|
|
156
204
|
Raises
|
|
@@ -161,13 +209,13 @@ class InMemoryDataset(Dataset):
|
|
|
161
209
|
patch = self.data[index]
|
|
162
210
|
|
|
163
211
|
# if there is a target
|
|
164
|
-
if self.
|
|
212
|
+
if self.data_targets is not None:
|
|
165
213
|
# get target
|
|
166
214
|
target = self.data_targets[index]
|
|
167
215
|
|
|
168
216
|
return self.patch_transform(patch=patch, target=target)
|
|
169
217
|
|
|
170
|
-
elif self.data_config.has_n2v_manipulate():
|
|
218
|
+
elif self.data_config.has_n2v_manipulate(): # TODO not compatible with HDN
|
|
171
219
|
return self.patch_transform(patch=patch)
|
|
172
220
|
else:
|
|
173
221
|
raise ValueError(
|
|
@@ -193,7 +241,7 @@ class InMemoryDataset(Dataset):
|
|
|
193
241
|
|
|
194
242
|
Returns
|
|
195
243
|
-------
|
|
196
|
-
InMemoryDataset
|
|
244
|
+
CAREamics InMemoryDataset
|
|
197
245
|
New dataset with the extracted patches.
|
|
198
246
|
|
|
199
247
|
Raises
|
|
@@ -244,117 +292,3 @@ class InMemoryDataset(Dataset):
|
|
|
244
292
|
dataset.data_targets = val_targets
|
|
245
293
|
|
|
246
294
|
return dataset
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
class InMemoryPredictionDataset(Dataset):
|
|
250
|
-
"""
|
|
251
|
-
Dataset storing data in memory and allowing generating patches from it.
|
|
252
|
-
|
|
253
|
-
# TODO
|
|
254
|
-
"""
|
|
255
|
-
|
|
256
|
-
def __init__(
|
|
257
|
-
self,
|
|
258
|
-
prediction_config: InferenceConfig,
|
|
259
|
-
inputs: np.ndarray,
|
|
260
|
-
data_target: Optional[np.ndarray] = None,
|
|
261
|
-
read_source_func: Optional[Callable] = read_tiff,
|
|
262
|
-
) -> None:
|
|
263
|
-
"""Constructor.
|
|
264
|
-
|
|
265
|
-
Parameters
|
|
266
|
-
----------
|
|
267
|
-
array : np.ndarray
|
|
268
|
-
Array containing the data.
|
|
269
|
-
axes : str
|
|
270
|
-
Description of axes in format STCZYX.
|
|
271
|
-
|
|
272
|
-
Raises
|
|
273
|
-
------
|
|
274
|
-
ValueError
|
|
275
|
-
If data_path is not a directory.
|
|
276
|
-
"""
|
|
277
|
-
self.pred_config = prediction_config
|
|
278
|
-
self.input_array = inputs
|
|
279
|
-
self.axes = self.pred_config.axes
|
|
280
|
-
self.tile_size = self.pred_config.tile_size
|
|
281
|
-
self.tile_overlap = self.pred_config.tile_overlap
|
|
282
|
-
self.mean = self.pred_config.mean
|
|
283
|
-
self.std = self.pred_config.std
|
|
284
|
-
self.data_target = data_target
|
|
285
|
-
|
|
286
|
-
# tiling only if both tile size and overlap are provided
|
|
287
|
-
self.tiling = self.tile_size is not None and self.tile_overlap is not None
|
|
288
|
-
|
|
289
|
-
# read function
|
|
290
|
-
self.read_source_func = read_source_func
|
|
291
|
-
|
|
292
|
-
# Generate patches
|
|
293
|
-
self.data = self._prepare_tiles()
|
|
294
|
-
self.mean, self.std = self.pred_config.mean, self.pred_config.std
|
|
295
|
-
|
|
296
|
-
# get transforms
|
|
297
|
-
self.patch_transform = Compose(
|
|
298
|
-
transform_list=self.pred_config.transforms,
|
|
299
|
-
)
|
|
300
|
-
|
|
301
|
-
def _prepare_tiles(self) -> List[Tuple[np.ndarray, TileInformation]]:
|
|
302
|
-
"""
|
|
303
|
-
Iterate over data source and create an array of patches.
|
|
304
|
-
|
|
305
|
-
Returns
|
|
306
|
-
-------
|
|
307
|
-
List[XArrayTile]
|
|
308
|
-
List of tiles.
|
|
309
|
-
"""
|
|
310
|
-
# reshape array
|
|
311
|
-
reshaped_sample = reshape_array(self.input_array, self.axes)
|
|
312
|
-
|
|
313
|
-
if self.tiling:
|
|
314
|
-
# generate patches, which returns a generator
|
|
315
|
-
patch_generator = extract_tiles(
|
|
316
|
-
arr=reshaped_sample,
|
|
317
|
-
tile_size=self.tile_size,
|
|
318
|
-
overlaps=self.tile_overlap,
|
|
319
|
-
)
|
|
320
|
-
patches_list = list(patch_generator)
|
|
321
|
-
|
|
322
|
-
if len(patches_list) == 0:
|
|
323
|
-
raise ValueError("No tiles generated, ")
|
|
324
|
-
|
|
325
|
-
return patches_list
|
|
326
|
-
else:
|
|
327
|
-
array_shape = reshaped_sample.squeeze().shape
|
|
328
|
-
return [(reshaped_sample, TileInformation(array_shape=array_shape))]
|
|
329
|
-
|
|
330
|
-
def __len__(self) -> int:
|
|
331
|
-
"""
|
|
332
|
-
Return the length of the dataset.
|
|
333
|
-
|
|
334
|
-
Returns
|
|
335
|
-
-------
|
|
336
|
-
int
|
|
337
|
-
Length of the dataset.
|
|
338
|
-
"""
|
|
339
|
-
return len(self.data)
|
|
340
|
-
|
|
341
|
-
def __getitem__(self, index: int) -> Tuple[np.ndarray, TileInformation]:
|
|
342
|
-
"""
|
|
343
|
-
Return the patch corresponding to the provided index.
|
|
344
|
-
|
|
345
|
-
Parameters
|
|
346
|
-
----------
|
|
347
|
-
index : int
|
|
348
|
-
Index of the patch to return.
|
|
349
|
-
|
|
350
|
-
Returns
|
|
351
|
-
-------
|
|
352
|
-
Tuple[np.ndarray, TileInformation]
|
|
353
|
-
Transformed patch.
|
|
354
|
-
"""
|
|
355
|
-
tile_array, tile_info = self.data[index]
|
|
356
|
-
|
|
357
|
-
# Apply transforms
|
|
358
|
-
transformed_tile, _ = self.patch_transform(patch=tile_array)
|
|
359
|
-
|
|
360
|
-
return transformed_tile, tile_info
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
"""In-memory prediction dataset."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from numpy.typing import NDArray
|
|
6
|
+
from torch.utils.data import Dataset
|
|
7
|
+
|
|
8
|
+
from careamics.transforms import Compose
|
|
9
|
+
|
|
10
|
+
from ..config import InferenceConfig
|
|
11
|
+
from ..config.transformations import NormalizeModel
|
|
12
|
+
from .dataset_utils import reshape_array
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class InMemoryPredDataset(Dataset):
|
|
16
|
+
"""Simple prediction dataset returning images along the sample axis.
|
|
17
|
+
|
|
18
|
+
Parameters
|
|
19
|
+
----------
|
|
20
|
+
prediction_config : InferenceConfig
|
|
21
|
+
Prediction configuration.
|
|
22
|
+
inputs : NDArray
|
|
23
|
+
Input data.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
prediction_config: InferenceConfig,
|
|
29
|
+
inputs: NDArray,
|
|
30
|
+
) -> None:
|
|
31
|
+
"""Constructor.
|
|
32
|
+
|
|
33
|
+
Parameters
|
|
34
|
+
----------
|
|
35
|
+
prediction_config : InferenceConfig
|
|
36
|
+
Prediction configuration.
|
|
37
|
+
inputs : NDArray
|
|
38
|
+
Input data.
|
|
39
|
+
|
|
40
|
+
Raises
|
|
41
|
+
------
|
|
42
|
+
ValueError
|
|
43
|
+
If data_path is not a directory.
|
|
44
|
+
"""
|
|
45
|
+
self.pred_config = prediction_config
|
|
46
|
+
self.input_array = inputs
|
|
47
|
+
self.axes = self.pred_config.axes
|
|
48
|
+
self.image_means = self.pred_config.image_means
|
|
49
|
+
self.image_stds = self.pred_config.image_stds
|
|
50
|
+
|
|
51
|
+
# Reshape data
|
|
52
|
+
self.data = reshape_array(self.input_array, self.axes)
|
|
53
|
+
|
|
54
|
+
# get transforms
|
|
55
|
+
self.patch_transform = Compose(
|
|
56
|
+
transform_list=[
|
|
57
|
+
NormalizeModel(image_means=self.image_means, image_stds=self.image_stds)
|
|
58
|
+
],
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
def __len__(self) -> int:
|
|
62
|
+
"""
|
|
63
|
+
Return the length of the dataset.
|
|
64
|
+
|
|
65
|
+
Returns
|
|
66
|
+
-------
|
|
67
|
+
int
|
|
68
|
+
Length of the dataset.
|
|
69
|
+
"""
|
|
70
|
+
return len(self.data)
|
|
71
|
+
|
|
72
|
+
def __getitem__(self, index: int) -> NDArray:
|
|
73
|
+
"""
|
|
74
|
+
Return the patch corresponding to the provided index.
|
|
75
|
+
|
|
76
|
+
Parameters
|
|
77
|
+
----------
|
|
78
|
+
index : int
|
|
79
|
+
Index of the patch to return.
|
|
80
|
+
|
|
81
|
+
Returns
|
|
82
|
+
-------
|
|
83
|
+
NDArray
|
|
84
|
+
Transformed patch.
|
|
85
|
+
"""
|
|
86
|
+
transformed_patch, _ = self.patch_transform(patch=self.data[index])
|
|
87
|
+
|
|
88
|
+
return transformed_patch
|