careamics 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +6 -1
- careamics/careamist.py +726 -0
- careamics/config/__init__.py +35 -0
- careamics/config/algorithm_model.py +162 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +37 -0
- careamics/config/architectures/custom_model.py +159 -0
- careamics/config/architectures/register_model.py +103 -0
- careamics/config/architectures/unet_model.py +118 -0
- careamics/config/architectures/vae_model.py +42 -0
- careamics/config/callback_model.py +123 -0
- careamics/config/configuration_factory.py +575 -0
- careamics/config/configuration_model.py +600 -0
- careamics/config/data_model.py +502 -0
- careamics/config/inference_model.py +239 -0
- careamics/config/optimizer_models.py +187 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +132 -0
- careamics/config/references/references.py +39 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +26 -0
- careamics/config/support/supported_algorithms.py +20 -0
- careamics/config/support/supported_architectures.py +20 -0
- careamics/config/support/supported_data.py +109 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +27 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +11 -0
- careamics/config/tile_information.py +65 -0
- careamics/config/training_model.py +72 -0
- careamics/config/transformations/__init__.py +15 -0
- careamics/config/transformations/n2v_manipulate_model.py +64 -0
- careamics/config/transformations/normalize_model.py +60 -0
- careamics/config/transformations/transform_model.py +45 -0
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +35 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +101 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +101 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +310 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +295 -0
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +299 -0
- careamics/dataset/patching/random_patching.py +201 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/tiled_patching.py +164 -0
- careamics/dataset/zarr_dataset.py +151 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +12 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/read/zarr.py +60 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +17 -0
- careamics/lightning/callbacks/__init__.py +11 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/lightning_module.py +276 -0
- careamics/lightning/predict_data_module.py +333 -0
- careamics/lightning/train_data_module.py +680 -0
- careamics/losses/__init__.py +5 -0
- careamics/losses/loss_factory.py +49 -0
- careamics/losses/losses.py +98 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +121 -0
- careamics/model_io/bioimage/bioimage_utils.py +52 -0
- careamics/model_io/bioimage/model_description.py +327 -0
- careamics/model_io/bmz_io.py +233 -0
- careamics/model_io/model_io_utils.py +83 -0
- careamics/models/__init__.py +7 -0
- careamics/models/activation.py +37 -0
- careamics/models/layers.py +493 -0
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +52 -0
- careamics/models/unet.py +443 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/prediction_outputs.py +135 -0
- careamics/prediction_utils/stitch_prediction.py +98 -0
- careamics/transforms/__init__.py +20 -0
- careamics/transforms/compose.py +107 -0
- careamics/transforms/n2v_manipulate.py +146 -0
- careamics/transforms/normalize.py +243 -0
- careamics/transforms/pixel_manipulation.py +407 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +101 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +66 -0
- careamics/utils/logging.py +322 -0
- careamics/utils/metrics.py +115 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/torch_utils.py +127 -0
- careamics-0.0.2.dist-info/METADATA +78 -0
- careamics-0.0.2.dist-info/RECORD +140 -0
- {careamics-0.0.1.dist-info → careamics-0.0.2.dist-info}/WHEEL +1 -1
- {careamics-0.0.1.dist-info → careamics-0.0.2.dist-info}/licenses/LICENSE +1 -1
- careamics-0.0.1.dist-info/METADATA +0 -46
- careamics-0.0.1.dist-info/RECORD +0 -6
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
"""Computing data statistics."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from numpy.typing import NDArray
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def compute_normalization_stats(image: NDArray) -> tuple[NDArray, NDArray]:
|
|
8
|
+
"""
|
|
9
|
+
Compute mean and standard deviation of an array.
|
|
10
|
+
|
|
11
|
+
Expected input shape is (S, C, (Z), Y, X). The mean and standard deviation are
|
|
12
|
+
computed per channel.
|
|
13
|
+
|
|
14
|
+
Parameters
|
|
15
|
+
----------
|
|
16
|
+
image : NDArray
|
|
17
|
+
Input array.
|
|
18
|
+
|
|
19
|
+
Returns
|
|
20
|
+
-------
|
|
21
|
+
tuple of (list of floats, list of floats)
|
|
22
|
+
Lists of mean and standard deviation values per channel.
|
|
23
|
+
"""
|
|
24
|
+
# Define the list of axes excluding the channel axis
|
|
25
|
+
axes = tuple(np.delete(np.arange(image.ndim), 1))
|
|
26
|
+
return np.mean(image, axis=axes), np.std(image, axis=axes)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def update_iterative_stats(
|
|
30
|
+
count: NDArray, mean: NDArray, m2: NDArray, new_values: NDArray
|
|
31
|
+
) -> tuple[NDArray, NDArray, NDArray]:
|
|
32
|
+
"""Update the mean and variance of an array iteratively.
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
count : NDArray
|
|
37
|
+
Number of elements in the array.
|
|
38
|
+
mean : NDArray
|
|
39
|
+
Mean of the array.
|
|
40
|
+
m2 : NDArray
|
|
41
|
+
Variance of the array.
|
|
42
|
+
new_values : NDArray
|
|
43
|
+
New values to add to the mean and variance.
|
|
44
|
+
|
|
45
|
+
Returns
|
|
46
|
+
-------
|
|
47
|
+
tuple[NDArray, NDArray, NDArray]
|
|
48
|
+
Updated count, mean, and variance.
|
|
49
|
+
"""
|
|
50
|
+
count += np.array([np.prod(channel.shape) for channel in new_values])
|
|
51
|
+
# newvalues - oldMean
|
|
52
|
+
delta = [
|
|
53
|
+
np.subtract(v.flatten(), [m] * len(v.flatten()))
|
|
54
|
+
for v, m in zip(new_values, mean)
|
|
55
|
+
]
|
|
56
|
+
|
|
57
|
+
mean += np.array([np.sum(d / c) for d, c in zip(delta, count)])
|
|
58
|
+
# newvalues - newMeant
|
|
59
|
+
delta2 = [
|
|
60
|
+
np.subtract(v.flatten(), [m] * len(v.flatten()))
|
|
61
|
+
for v, m in zip(new_values, mean)
|
|
62
|
+
]
|
|
63
|
+
|
|
64
|
+
m2 += np.array([np.sum(d * d2) for d, d2 in zip(delta, delta2)])
|
|
65
|
+
|
|
66
|
+
return (count, mean, m2)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def finalize_iterative_stats(
|
|
70
|
+
count: NDArray, mean: NDArray, m2: NDArray
|
|
71
|
+
) -> tuple[NDArray, NDArray]:
|
|
72
|
+
"""Finalize the mean and variance computation.
|
|
73
|
+
|
|
74
|
+
Parameters
|
|
75
|
+
----------
|
|
76
|
+
count : NDArray
|
|
77
|
+
Number of elements in the array.
|
|
78
|
+
mean : NDArray
|
|
79
|
+
Mean of the array.
|
|
80
|
+
m2 : NDArray
|
|
81
|
+
Variance of the array.
|
|
82
|
+
|
|
83
|
+
Returns
|
|
84
|
+
-------
|
|
85
|
+
tuple[NDArray, NDArray]
|
|
86
|
+
Final mean and standard deviation.
|
|
87
|
+
"""
|
|
88
|
+
std = np.array([np.sqrt(m / c) for m, c in zip(m2, count)])
|
|
89
|
+
if any(c < 2 for c in count):
|
|
90
|
+
return np.full(mean.shape, np.nan), np.full(std.shape, np.nan)
|
|
91
|
+
else:
|
|
92
|
+
return mean, std
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class WelfordStatistics:
|
|
96
|
+
"""Compute Welford statistics iteratively.
|
|
97
|
+
|
|
98
|
+
The Welford algorithm is used to compute the mean and variance of an array
|
|
99
|
+
iteratively. Based on the implementation from:
|
|
100
|
+
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
def update(self, array: NDArray, sample_idx: int) -> None:
|
|
104
|
+
"""Update the Welford statistics.
|
|
105
|
+
|
|
106
|
+
Parameters
|
|
107
|
+
----------
|
|
108
|
+
array : NDArray
|
|
109
|
+
Input array.
|
|
110
|
+
sample_idx : int
|
|
111
|
+
Current sample number.
|
|
112
|
+
"""
|
|
113
|
+
self.sample_idx = sample_idx
|
|
114
|
+
sample_channels = np.array(np.split(array, array.shape[1], axis=1))
|
|
115
|
+
|
|
116
|
+
# Initialize the statistics
|
|
117
|
+
if self.sample_idx == 0:
|
|
118
|
+
# Compute the mean and standard deviation
|
|
119
|
+
self.mean, _ = compute_normalization_stats(array)
|
|
120
|
+
# Initialize the count and m2 with zero-valued arrays of shape (C,)
|
|
121
|
+
self.count, self.mean, self.m2 = update_iterative_stats(
|
|
122
|
+
count=np.zeros(array.shape[1]),
|
|
123
|
+
mean=self.mean,
|
|
124
|
+
m2=np.zeros(array.shape[1]),
|
|
125
|
+
new_values=sample_channels,
|
|
126
|
+
)
|
|
127
|
+
else:
|
|
128
|
+
# Update the statistics
|
|
129
|
+
self.count, self.mean, self.m2 = update_iterative_stats(
|
|
130
|
+
count=self.count, mean=self.mean, m2=self.m2, new_values=sample_channels
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
self.sample_idx += 1
|
|
134
|
+
|
|
135
|
+
def finalize(self) -> tuple[NDArray, NDArray]:
|
|
136
|
+
"""Finalize the Welford statistics.
|
|
137
|
+
|
|
138
|
+
Returns
|
|
139
|
+
-------
|
|
140
|
+
tuple or numpy arrays
|
|
141
|
+
Final mean and standard deviation.
|
|
142
|
+
"""
|
|
143
|
+
return finalize_iterative_stats(self.count, self.mean, self.m2)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
# from multiprocessing import Value
|
|
147
|
+
# from typing import tuple
|
|
148
|
+
|
|
149
|
+
# import numpy as np
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
# class RunningStats:
|
|
153
|
+
# """Calculates running mean and std."""
|
|
154
|
+
|
|
155
|
+
# def __init__(self) -> None:
|
|
156
|
+
# self.reset()
|
|
157
|
+
|
|
158
|
+
# def reset(self) -> None:
|
|
159
|
+
# """Reset the running stats."""
|
|
160
|
+
# self.avg_mean = Value("d", 0)
|
|
161
|
+
# self.avg_std = Value("d", 0)
|
|
162
|
+
# self.m2 = Value("d", 0)
|
|
163
|
+
# self.count = Value("i", 0)
|
|
164
|
+
|
|
165
|
+
# def init(self, mean: float, std: float) -> None:
|
|
166
|
+
# """Initialize running stats."""
|
|
167
|
+
# with self.avg_mean.get_lock():
|
|
168
|
+
# self.avg_mean.value += mean
|
|
169
|
+
# with self.avg_std.get_lock():
|
|
170
|
+
# self.avg_std.value = std
|
|
171
|
+
|
|
172
|
+
# def compute_std(self) -> tuple[float, float]:
|
|
173
|
+
# """Compute std."""
|
|
174
|
+
# if self.count.value >= 2:
|
|
175
|
+
# self.avg_std.value = np.sqrt(self.m2.value / self.count.value)
|
|
176
|
+
|
|
177
|
+
# def update(self, value: float) -> None:
|
|
178
|
+
# """Update running stats."""
|
|
179
|
+
# with self.count.get_lock():
|
|
180
|
+
# self.count.value += 1
|
|
181
|
+
# delta = value - self.avg_mean.value
|
|
182
|
+
# with self.avg_mean.get_lock():
|
|
183
|
+
# self.avg_mean.value += delta / self.count.value
|
|
184
|
+
# delta2 = value - self.avg_mean.value
|
|
185
|
+
# with self.m2.get_lock():
|
|
186
|
+
# self.m2.value += delta * delta2
|
|
@@ -0,0 +1,310 @@
|
|
|
1
|
+
"""In-memory dataset module."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import copy
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Callable, Optional, Union
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
from torch.utils.data import Dataset
|
|
11
|
+
|
|
12
|
+
from careamics.file_io.read import read_tiff
|
|
13
|
+
from careamics.transforms import Compose
|
|
14
|
+
|
|
15
|
+
from ..config import DataConfig
|
|
16
|
+
from ..config.transformations import NormalizeModel
|
|
17
|
+
from ..utils.logging import get_logger
|
|
18
|
+
from .patching.patching import (
|
|
19
|
+
PatchedOutput,
|
|
20
|
+
Stats,
|
|
21
|
+
prepare_patches_supervised,
|
|
22
|
+
prepare_patches_supervised_array,
|
|
23
|
+
prepare_patches_unsupervised,
|
|
24
|
+
prepare_patches_unsupervised_array,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
logger = get_logger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class InMemoryDataset(Dataset):
|
|
31
|
+
"""Dataset storing data in memory and allowing generating patches from it.
|
|
32
|
+
|
|
33
|
+
Parameters
|
|
34
|
+
----------
|
|
35
|
+
data_config : CAREamics DataConfig
|
|
36
|
+
(see careamics.config.data_model.DataConfig)
|
|
37
|
+
Data configuration.
|
|
38
|
+
inputs : numpy.ndarray or list[pathlib.Path]
|
|
39
|
+
Input data.
|
|
40
|
+
input_target : numpy.ndarray or list[pathlib.Path], optional
|
|
41
|
+
Target data, by default None.
|
|
42
|
+
read_source_func : Callable, optional
|
|
43
|
+
Read source function for custom types, by default read_tiff.
|
|
44
|
+
**kwargs : Any
|
|
45
|
+
Additional keyword arguments, unused.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
data_config: DataConfig,
|
|
51
|
+
inputs: Union[np.ndarray, list[Path]],
|
|
52
|
+
input_target: Optional[Union[np.ndarray, list[Path]]] = None,
|
|
53
|
+
read_source_func: Callable = read_tiff,
|
|
54
|
+
**kwargs: Any,
|
|
55
|
+
) -> None:
|
|
56
|
+
"""
|
|
57
|
+
Constructor.
|
|
58
|
+
|
|
59
|
+
Parameters
|
|
60
|
+
----------
|
|
61
|
+
data_config : DataConfig
|
|
62
|
+
Data configuration.
|
|
63
|
+
inputs : numpy.ndarray or list[pathlib.Path]
|
|
64
|
+
Input data.
|
|
65
|
+
input_target : numpy.ndarray or list[pathlib.Path], optional
|
|
66
|
+
Target data, by default None.
|
|
67
|
+
read_source_func : Callable, optional
|
|
68
|
+
Read source function for custom types, by default read_tiff.
|
|
69
|
+
**kwargs : Any
|
|
70
|
+
Additional keyword arguments, unused.
|
|
71
|
+
"""
|
|
72
|
+
self.data_config = data_config
|
|
73
|
+
self.inputs = inputs
|
|
74
|
+
self.input_targets = input_target
|
|
75
|
+
self.axes = self.data_config.axes
|
|
76
|
+
self.patch_size = self.data_config.patch_size
|
|
77
|
+
|
|
78
|
+
# read function
|
|
79
|
+
self.read_source_func = read_source_func
|
|
80
|
+
|
|
81
|
+
# generate patches
|
|
82
|
+
supervised = self.input_targets is not None
|
|
83
|
+
patches_data = self._prepare_patches(supervised)
|
|
84
|
+
|
|
85
|
+
# unpack the dataclass
|
|
86
|
+
self.data = patches_data.patches
|
|
87
|
+
self.data_targets = patches_data.targets
|
|
88
|
+
|
|
89
|
+
# set image statistics
|
|
90
|
+
if self.data_config.image_means is None:
|
|
91
|
+
self.image_stats = patches_data.image_stats
|
|
92
|
+
logger.info(
|
|
93
|
+
f"Computed dataset mean: {self.image_stats.means}, "
|
|
94
|
+
f"std: {self.image_stats.stds}"
|
|
95
|
+
)
|
|
96
|
+
else:
|
|
97
|
+
self.image_stats = Stats(
|
|
98
|
+
self.data_config.image_means, self.data_config.image_stds
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
# set target statistics
|
|
102
|
+
if self.data_config.target_means is None:
|
|
103
|
+
self.target_stats = patches_data.target_stats
|
|
104
|
+
else:
|
|
105
|
+
self.target_stats = Stats(
|
|
106
|
+
self.data_config.target_means, self.data_config.target_stds
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# update mean and std in configuration
|
|
110
|
+
# the object is mutable and should then be recorded in the CAREamist obj
|
|
111
|
+
self.data_config.set_means_and_stds(
|
|
112
|
+
image_means=self.image_stats.means,
|
|
113
|
+
image_stds=self.image_stats.stds,
|
|
114
|
+
target_means=self.target_stats.means,
|
|
115
|
+
target_stds=self.target_stats.stds,
|
|
116
|
+
)
|
|
117
|
+
# get transforms
|
|
118
|
+
self.patch_transform = Compose(
|
|
119
|
+
transform_list=[
|
|
120
|
+
NormalizeModel(
|
|
121
|
+
image_means=self.image_stats.means,
|
|
122
|
+
image_stds=self.image_stats.stds,
|
|
123
|
+
target_means=self.target_stats.means,
|
|
124
|
+
target_stds=self.target_stats.stds,
|
|
125
|
+
)
|
|
126
|
+
]
|
|
127
|
+
+ self.data_config.transforms,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
def _prepare_patches(self, supervised: bool) -> PatchedOutput:
|
|
131
|
+
"""
|
|
132
|
+
Iterate over data source and create an array of patches.
|
|
133
|
+
|
|
134
|
+
Parameters
|
|
135
|
+
----------
|
|
136
|
+
supervised : bool
|
|
137
|
+
Whether the dataset is supervised or not.
|
|
138
|
+
|
|
139
|
+
Returns
|
|
140
|
+
-------
|
|
141
|
+
numpy.ndarray
|
|
142
|
+
Array of patches.
|
|
143
|
+
"""
|
|
144
|
+
if supervised:
|
|
145
|
+
if isinstance(self.inputs, np.ndarray) and isinstance(
|
|
146
|
+
self.input_targets, np.ndarray
|
|
147
|
+
):
|
|
148
|
+
return prepare_patches_supervised_array(
|
|
149
|
+
self.inputs,
|
|
150
|
+
self.axes,
|
|
151
|
+
self.input_targets,
|
|
152
|
+
self.patch_size,
|
|
153
|
+
)
|
|
154
|
+
elif isinstance(self.inputs, list) and isinstance(self.input_targets, list):
|
|
155
|
+
return prepare_patches_supervised(
|
|
156
|
+
self.inputs,
|
|
157
|
+
self.input_targets,
|
|
158
|
+
self.axes,
|
|
159
|
+
self.patch_size,
|
|
160
|
+
self.read_source_func,
|
|
161
|
+
)
|
|
162
|
+
else:
|
|
163
|
+
raise ValueError(
|
|
164
|
+
f"Data and target must be of the same type, either both numpy "
|
|
165
|
+
f"arrays or both lists of paths, got {type(self.inputs)} (data) "
|
|
166
|
+
f"and {type(self.input_targets)} (target)."
|
|
167
|
+
)
|
|
168
|
+
else:
|
|
169
|
+
if isinstance(self.inputs, np.ndarray):
|
|
170
|
+
return prepare_patches_unsupervised_array(
|
|
171
|
+
self.inputs,
|
|
172
|
+
self.axes,
|
|
173
|
+
self.patch_size,
|
|
174
|
+
)
|
|
175
|
+
else:
|
|
176
|
+
return prepare_patches_unsupervised(
|
|
177
|
+
self.inputs,
|
|
178
|
+
self.axes,
|
|
179
|
+
self.patch_size,
|
|
180
|
+
self.read_source_func,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
def __len__(self) -> int:
|
|
184
|
+
"""
|
|
185
|
+
Return the length of the dataset.
|
|
186
|
+
|
|
187
|
+
Returns
|
|
188
|
+
-------
|
|
189
|
+
int
|
|
190
|
+
Length of the dataset.
|
|
191
|
+
"""
|
|
192
|
+
return self.data.shape[0]
|
|
193
|
+
|
|
194
|
+
def __getitem__(self, index: int) -> tuple[np.ndarray, ...]:
|
|
195
|
+
"""
|
|
196
|
+
Return the patch corresponding to the provided index.
|
|
197
|
+
|
|
198
|
+
Parameters
|
|
199
|
+
----------
|
|
200
|
+
index : int
|
|
201
|
+
Index of the patch to return.
|
|
202
|
+
|
|
203
|
+
Returns
|
|
204
|
+
-------
|
|
205
|
+
tuple of numpy.ndarray
|
|
206
|
+
Patch.
|
|
207
|
+
|
|
208
|
+
Raises
|
|
209
|
+
------
|
|
210
|
+
ValueError
|
|
211
|
+
If dataset mean and std are not set.
|
|
212
|
+
"""
|
|
213
|
+
patch = self.data[index]
|
|
214
|
+
|
|
215
|
+
# if there is a target
|
|
216
|
+
if self.data_targets is not None:
|
|
217
|
+
# get target
|
|
218
|
+
target = self.data_targets[index]
|
|
219
|
+
|
|
220
|
+
return self.patch_transform(patch=patch, target=target)
|
|
221
|
+
|
|
222
|
+
elif self.data_config.has_n2v_manipulate(): # TODO not compatible with HDN
|
|
223
|
+
return self.patch_transform(patch=patch)
|
|
224
|
+
else:
|
|
225
|
+
raise ValueError(
|
|
226
|
+
"Something went wrong! No target provided (not supervised training) "
|
|
227
|
+
"and no N2V manipulation (no N2V training)."
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
def get_data_statistics(self) -> tuple[list[float], list[float]]:
|
|
231
|
+
"""Return training data statistics.
|
|
232
|
+
|
|
233
|
+
This does not return the target data statistics, only those of the input.
|
|
234
|
+
|
|
235
|
+
Returns
|
|
236
|
+
-------
|
|
237
|
+
tuple of list of floats
|
|
238
|
+
Means and standard deviations across channels of the training data.
|
|
239
|
+
"""
|
|
240
|
+
return self.image_stats.get_statistics()
|
|
241
|
+
|
|
242
|
+
def split_dataset(
|
|
243
|
+
self,
|
|
244
|
+
percentage: float = 0.1,
|
|
245
|
+
minimum_patches: int = 1,
|
|
246
|
+
) -> InMemoryDataset:
|
|
247
|
+
"""Split a new dataset away from the current one.
|
|
248
|
+
|
|
249
|
+
This method is used to extract random validation patches from the dataset.
|
|
250
|
+
|
|
251
|
+
Parameters
|
|
252
|
+
----------
|
|
253
|
+
percentage : float, optional
|
|
254
|
+
Percentage of patches to extract, by default 0.1.
|
|
255
|
+
minimum_patches : int, optional
|
|
256
|
+
Minimum number of patches to extract, by default 5.
|
|
257
|
+
|
|
258
|
+
Returns
|
|
259
|
+
-------
|
|
260
|
+
CAREamics InMemoryDataset
|
|
261
|
+
New dataset with the extracted patches.
|
|
262
|
+
|
|
263
|
+
Raises
|
|
264
|
+
------
|
|
265
|
+
ValueError
|
|
266
|
+
If `percentage` is not between 0 and 1.
|
|
267
|
+
ValueError
|
|
268
|
+
If `minimum_number` is not between 1 and the number of patches.
|
|
269
|
+
"""
|
|
270
|
+
if percentage < 0 or percentage > 1:
|
|
271
|
+
raise ValueError(f"Percentage must be between 0 and 1, got {percentage}.")
|
|
272
|
+
|
|
273
|
+
if minimum_patches < 1 or minimum_patches > len(self):
|
|
274
|
+
raise ValueError(
|
|
275
|
+
f"Minimum number of patches must be between 1 and "
|
|
276
|
+
f"{len(self)} (number of patches), got "
|
|
277
|
+
f"{minimum_patches}. Adjust the patch size or the minimum number of "
|
|
278
|
+
f"patches."
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
total_patches = len(self)
|
|
282
|
+
|
|
283
|
+
# number of patches to extract (either percentage rounded or minimum number)
|
|
284
|
+
n_patches = max(round(total_patches * percentage), minimum_patches)
|
|
285
|
+
|
|
286
|
+
# get random indices
|
|
287
|
+
indices = np.random.choice(total_patches, n_patches, replace=False)
|
|
288
|
+
|
|
289
|
+
# extract patches
|
|
290
|
+
val_patches = self.data[indices]
|
|
291
|
+
|
|
292
|
+
# remove patches from self.patch
|
|
293
|
+
self.data = np.delete(self.data, indices, axis=0)
|
|
294
|
+
|
|
295
|
+
# same for targets
|
|
296
|
+
if self.data_targets is not None:
|
|
297
|
+
val_targets = self.data_targets[indices]
|
|
298
|
+
self.data_targets = np.delete(self.data_targets, indices, axis=0)
|
|
299
|
+
|
|
300
|
+
# clone the dataset
|
|
301
|
+
dataset = copy.deepcopy(self)
|
|
302
|
+
|
|
303
|
+
# reassign patches
|
|
304
|
+
dataset.data = val_patches
|
|
305
|
+
|
|
306
|
+
# reassign targets
|
|
307
|
+
if self.data_targets is not None:
|
|
308
|
+
dataset.data_targets = val_targets
|
|
309
|
+
|
|
310
|
+
return dataset
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
"""In-memory prediction dataset."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from numpy.typing import NDArray
|
|
6
|
+
from torch.utils.data import Dataset
|
|
7
|
+
|
|
8
|
+
from careamics.transforms import Compose
|
|
9
|
+
|
|
10
|
+
from ..config import InferenceConfig
|
|
11
|
+
from ..config.transformations import NormalizeModel
|
|
12
|
+
from .dataset_utils import reshape_array
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class InMemoryPredDataset(Dataset):
|
|
16
|
+
"""Simple prediction dataset returning images along the sample axis.
|
|
17
|
+
|
|
18
|
+
Parameters
|
|
19
|
+
----------
|
|
20
|
+
prediction_config : InferenceConfig
|
|
21
|
+
Prediction configuration.
|
|
22
|
+
inputs : NDArray
|
|
23
|
+
Input data.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
prediction_config: InferenceConfig,
|
|
29
|
+
inputs: NDArray,
|
|
30
|
+
) -> None:
|
|
31
|
+
"""Constructor.
|
|
32
|
+
|
|
33
|
+
Parameters
|
|
34
|
+
----------
|
|
35
|
+
prediction_config : InferenceConfig
|
|
36
|
+
Prediction configuration.
|
|
37
|
+
inputs : NDArray
|
|
38
|
+
Input data.
|
|
39
|
+
|
|
40
|
+
Raises
|
|
41
|
+
------
|
|
42
|
+
ValueError
|
|
43
|
+
If data_path is not a directory.
|
|
44
|
+
"""
|
|
45
|
+
self.pred_config = prediction_config
|
|
46
|
+
self.input_array = inputs
|
|
47
|
+
self.axes = self.pred_config.axes
|
|
48
|
+
self.image_means = self.pred_config.image_means
|
|
49
|
+
self.image_stds = self.pred_config.image_stds
|
|
50
|
+
|
|
51
|
+
# Reshape data
|
|
52
|
+
self.data = reshape_array(self.input_array, self.axes)
|
|
53
|
+
|
|
54
|
+
# get transforms
|
|
55
|
+
self.patch_transform = Compose(
|
|
56
|
+
transform_list=[
|
|
57
|
+
NormalizeModel(image_means=self.image_means, image_stds=self.image_stds)
|
|
58
|
+
],
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
def __len__(self) -> int:
|
|
62
|
+
"""
|
|
63
|
+
Return the length of the dataset.
|
|
64
|
+
|
|
65
|
+
Returns
|
|
66
|
+
-------
|
|
67
|
+
int
|
|
68
|
+
Length of the dataset.
|
|
69
|
+
"""
|
|
70
|
+
return len(self.data)
|
|
71
|
+
|
|
72
|
+
def __getitem__(self, index: int) -> NDArray:
|
|
73
|
+
"""
|
|
74
|
+
Return the patch corresponding to the provided index.
|
|
75
|
+
|
|
76
|
+
Parameters
|
|
77
|
+
----------
|
|
78
|
+
index : int
|
|
79
|
+
Index of the patch to return.
|
|
80
|
+
|
|
81
|
+
Returns
|
|
82
|
+
-------
|
|
83
|
+
NDArray
|
|
84
|
+
Transformed patch.
|
|
85
|
+
"""
|
|
86
|
+
transformed_patch, _ = self.patch_transform(patch=self.data[index])
|
|
87
|
+
|
|
88
|
+
return transformed_patch
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
"""In-memory tiled prediction dataset."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from numpy.typing import NDArray
|
|
6
|
+
from torch.utils.data import Dataset
|
|
7
|
+
|
|
8
|
+
from careamics.transforms import Compose
|
|
9
|
+
|
|
10
|
+
from ..config import InferenceConfig
|
|
11
|
+
from ..config.tile_information import TileInformation
|
|
12
|
+
from ..config.transformations import NormalizeModel
|
|
13
|
+
from .dataset_utils import reshape_array
|
|
14
|
+
from .tiling import extract_tiles
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class InMemoryTiledPredDataset(Dataset):
|
|
18
|
+
"""Prediction dataset storing data in memory and returning tiles of each image.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
prediction_config : InferenceConfig
|
|
23
|
+
Prediction configuration.
|
|
24
|
+
inputs : NDArray
|
|
25
|
+
Input data.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
prediction_config: InferenceConfig,
|
|
31
|
+
inputs: NDArray,
|
|
32
|
+
) -> None:
|
|
33
|
+
"""Constructor.
|
|
34
|
+
|
|
35
|
+
Parameters
|
|
36
|
+
----------
|
|
37
|
+
prediction_config : InferenceConfig
|
|
38
|
+
Prediction configuration.
|
|
39
|
+
inputs : NDArray
|
|
40
|
+
Input data.
|
|
41
|
+
|
|
42
|
+
Raises
|
|
43
|
+
------
|
|
44
|
+
ValueError
|
|
45
|
+
If data_path is not a directory.
|
|
46
|
+
"""
|
|
47
|
+
if (
|
|
48
|
+
prediction_config.tile_size is None
|
|
49
|
+
or prediction_config.tile_overlap is None
|
|
50
|
+
):
|
|
51
|
+
raise ValueError(
|
|
52
|
+
"Tile size and overlap must be provided to use the tiled prediction "
|
|
53
|
+
"dataset."
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
self.pred_config = prediction_config
|
|
57
|
+
self.input_array = inputs
|
|
58
|
+
self.axes = self.pred_config.axes
|
|
59
|
+
self.tile_size = prediction_config.tile_size
|
|
60
|
+
self.tile_overlap = prediction_config.tile_overlap
|
|
61
|
+
self.image_means = self.pred_config.image_means
|
|
62
|
+
self.image_stds = self.pred_config.image_stds
|
|
63
|
+
|
|
64
|
+
# Generate patches
|
|
65
|
+
self.data = self._prepare_tiles()
|
|
66
|
+
|
|
67
|
+
# get transforms
|
|
68
|
+
self.patch_transform = Compose(
|
|
69
|
+
transform_list=[
|
|
70
|
+
NormalizeModel(image_means=self.image_means, image_stds=self.image_stds)
|
|
71
|
+
],
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
def _prepare_tiles(self) -> list[tuple[NDArray, TileInformation]]:
|
|
75
|
+
"""
|
|
76
|
+
Iterate over data source and create an array of patches.
|
|
77
|
+
|
|
78
|
+
Returns
|
|
79
|
+
-------
|
|
80
|
+
list of tuples of NDArray and TileInformation
|
|
81
|
+
List of tiles and tile information.
|
|
82
|
+
"""
|
|
83
|
+
# reshape array
|
|
84
|
+
reshaped_sample = reshape_array(self.input_array, self.axes)
|
|
85
|
+
|
|
86
|
+
# generate patches, which returns a generator
|
|
87
|
+
patch_generator = extract_tiles(
|
|
88
|
+
arr=reshaped_sample,
|
|
89
|
+
tile_size=self.tile_size,
|
|
90
|
+
overlaps=self.tile_overlap,
|
|
91
|
+
)
|
|
92
|
+
patches_list = list(patch_generator)
|
|
93
|
+
|
|
94
|
+
if len(patches_list) == 0:
|
|
95
|
+
raise ValueError("No tiles generated, ")
|
|
96
|
+
|
|
97
|
+
return patches_list
|
|
98
|
+
|
|
99
|
+
def __len__(self) -> int:
|
|
100
|
+
"""
|
|
101
|
+
Return the length of the dataset.
|
|
102
|
+
|
|
103
|
+
Returns
|
|
104
|
+
-------
|
|
105
|
+
int
|
|
106
|
+
Length of the dataset.
|
|
107
|
+
"""
|
|
108
|
+
return len(self.data)
|
|
109
|
+
|
|
110
|
+
def __getitem__(self, index: int) -> tuple[NDArray, TileInformation]:
|
|
111
|
+
"""
|
|
112
|
+
Return the patch corresponding to the provided index.
|
|
113
|
+
|
|
114
|
+
Parameters
|
|
115
|
+
----------
|
|
116
|
+
index : int
|
|
117
|
+
Index of the patch to return.
|
|
118
|
+
|
|
119
|
+
Returns
|
|
120
|
+
-------
|
|
121
|
+
tuple of NDArray and TileInformation
|
|
122
|
+
Transformed patch.
|
|
123
|
+
"""
|
|
124
|
+
tile_array, tile_info = self.data[index]
|
|
125
|
+
|
|
126
|
+
# Apply transforms
|
|
127
|
+
transformed_tile, _ = self.patch_transform(patch=tile_array)
|
|
128
|
+
|
|
129
|
+
return transformed_tile, tile_info
|