careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +6 -1
- careamics/careamist.py +729 -0
- careamics/config/__init__.py +39 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +37 -0
- careamics/config/architectures/custom_model.py +162 -0
- careamics/config/architectures/lvae_model.py +174 -0
- careamics/config/architectures/register_model.py +103 -0
- careamics/config/architectures/unet_model.py +118 -0
- careamics/config/callback_model.py +123 -0
- careamics/config/configuration_factory.py +583 -0
- careamics/config/configuration_model.py +604 -0
- careamics/config/data_model.py +527 -0
- careamics/config/fcn_algorithm_model.py +147 -0
- careamics/config/inference_model.py +239 -0
- careamics/config/likelihood_model.py +43 -0
- careamics/config/nm_model.py +101 -0
- careamics/config/optimizer_models.py +187 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +132 -0
- careamics/config/references/references.py +39 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +33 -0
- careamics/config/support/supported_architectures.py +17 -0
- careamics/config/support/supported_data.py +109 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +29 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +11 -0
- careamics/config/tile_information.py +65 -0
- careamics/config/training_model.py +72 -0
- careamics/config/transformations/__init__.py +15 -0
- careamics/config/transformations/n2v_manipulate_model.py +64 -0
- careamics/config/transformations/normalize_model.py +60 -0
- careamics/config/transformations/transform_model.py +45 -0
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +35 -0
- careamics/config/vae_algorithm_model.py +171 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +101 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +101 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +310 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +295 -0
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +299 -0
- careamics/dataset/patching/random_patching.py +201 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
- careamics/dataset/tiling/tiled_patching.py +164 -0
- careamics/dataset/zarr_dataset.py +151 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +12 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/read/zarr.py +60 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +18 -0
- careamics/lightning/callbacks/__init__.py +11 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/lightning_module.py +632 -0
- careamics/lightning/predict_data_module.py +333 -0
- careamics/lightning/train_data_module.py +680 -0
- careamics/losses/__init__.py +15 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +98 -0
- careamics/losses/loss_factory.py +155 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +445 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/dataset/__init__.py +0 -0
- careamics/lvae_training/dataset/data_utils.py +701 -0
- careamics/lvae_training/dataset/lc_dataset.py +259 -0
- careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
- careamics/lvae_training/dataset/vae_data_config.py +179 -0
- careamics/lvae_training/dataset/vae_dataset.py +1054 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +121 -0
- careamics/model_io/bioimage/bioimage_utils.py +52 -0
- careamics/model_io/bioimage/model_description.py +327 -0
- careamics/model_io/bmz_io.py +246 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +39 -0
- careamics/models/layers.py +493 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +364 -0
- careamics/models/lvae/lvae.py +901 -0
- careamics/models/lvae/noise_models.py +541 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +67 -0
- careamics/models/unet.py +443 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +135 -0
- careamics/prediction_utils/stitch_prediction.py +112 -0
- careamics/transforms/__init__.py +20 -0
- careamics/transforms/compose.py +107 -0
- careamics/transforms/n2v_manipulate.py +146 -0
- careamics/transforms/normalize.py +243 -0
- careamics/transforms/pixel_manipulation.py +407 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +101 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +66 -0
- careamics/utils/logging.py +322 -0
- careamics/utils/metrics.py +188 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/torch_utils.py +127 -0
- careamics-0.0.3.dist-info/METADATA +78 -0
- careamics-0.0.3.dist-info/RECORD +154 -0
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
- careamics-0.0.1.dist-info/METADATA +0 -46
- careamics-0.0.1.dist-info/RECORD +0 -6
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
"""Dataset utilities."""
|
|
2
|
+
|
|
3
|
+
from typing import List, Tuple
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from careamics.utils.logging import get_logger
|
|
8
|
+
|
|
9
|
+
logger = get_logger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _get_shape_order(
|
|
13
|
+
shape_in: Tuple[int, ...], axes_in: str, ref_axes: str = "STCZYX"
|
|
14
|
+
) -> Tuple[Tuple[int, ...], str, List[int]]:
|
|
15
|
+
"""
|
|
16
|
+
Compute a new shape for the array based on the reference axes.
|
|
17
|
+
|
|
18
|
+
Parameters
|
|
19
|
+
----------
|
|
20
|
+
shape_in : Tuple[int, ...]
|
|
21
|
+
Input shape.
|
|
22
|
+
axes_in : str
|
|
23
|
+
Input axes.
|
|
24
|
+
ref_axes : str
|
|
25
|
+
Reference axes.
|
|
26
|
+
|
|
27
|
+
Returns
|
|
28
|
+
-------
|
|
29
|
+
Tuple[Tuple[int, ...], str, List[int]]
|
|
30
|
+
New shape, new axes, indices of axes in the new axes order.
|
|
31
|
+
"""
|
|
32
|
+
indices = [axes_in.find(k) for k in ref_axes]
|
|
33
|
+
|
|
34
|
+
# remove all non-existing axes (index == -1)
|
|
35
|
+
new_indices = list(filter(lambda k: k != -1, indices))
|
|
36
|
+
|
|
37
|
+
# find axes order and get new shape
|
|
38
|
+
new_axes = [axes_in[ind] for ind in new_indices]
|
|
39
|
+
new_shape = tuple([shape_in[ind] for ind in new_indices])
|
|
40
|
+
|
|
41
|
+
return new_shape, "".join(new_axes), new_indices
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def reshape_array(x: np.ndarray, axes: str) -> np.ndarray:
|
|
45
|
+
"""Reshape the data to (S, C, (Z), Y, X) by moving axes.
|
|
46
|
+
|
|
47
|
+
If the data has both S and T axes, the two axes will be merged. A singleton
|
|
48
|
+
dimension is added if there are no C axis.
|
|
49
|
+
|
|
50
|
+
Parameters
|
|
51
|
+
----------
|
|
52
|
+
x : np.ndarray
|
|
53
|
+
Input array.
|
|
54
|
+
axes : str
|
|
55
|
+
Description of axes in format `STCZYX`.
|
|
56
|
+
|
|
57
|
+
Returns
|
|
58
|
+
-------
|
|
59
|
+
np.ndarray
|
|
60
|
+
Reshaped array with shape (S, C, (Z), Y, X).
|
|
61
|
+
"""
|
|
62
|
+
_x = x
|
|
63
|
+
_axes = axes
|
|
64
|
+
|
|
65
|
+
# sanity checks
|
|
66
|
+
if len(_axes) != len(_x.shape):
|
|
67
|
+
raise ValueError(
|
|
68
|
+
f"Incompatible data shape ({_x.shape}) and axes ({_axes}). Are the axes "
|
|
69
|
+
f"correct?"
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
# get new x shape
|
|
73
|
+
new_x_shape, new_axes, indices = _get_shape_order(_x.shape, _axes)
|
|
74
|
+
|
|
75
|
+
# if S is not in the list of axes, then add a singleton S
|
|
76
|
+
if "S" not in new_axes:
|
|
77
|
+
new_axes = "S" + new_axes
|
|
78
|
+
_x = _x[np.newaxis, ...]
|
|
79
|
+
new_x_shape = (1,) + new_x_shape
|
|
80
|
+
|
|
81
|
+
# need to change the array of indices
|
|
82
|
+
indices = [0] + [1 + i for i in indices]
|
|
83
|
+
|
|
84
|
+
# reshape by moving axes
|
|
85
|
+
destination = list(range(len(indices)))
|
|
86
|
+
_x = np.moveaxis(_x, indices, destination)
|
|
87
|
+
|
|
88
|
+
# remove T if necessary
|
|
89
|
+
if "T" in new_axes:
|
|
90
|
+
new_x_shape = (-1,) + new_x_shape[2:] # remove T and S
|
|
91
|
+
new_axes = new_axes.replace("T", "")
|
|
92
|
+
|
|
93
|
+
# reshape S and T together
|
|
94
|
+
_x = _x.reshape(new_x_shape)
|
|
95
|
+
|
|
96
|
+
# add channel
|
|
97
|
+
if "C" not in new_axes:
|
|
98
|
+
# Add channel axis after S
|
|
99
|
+
_x = np.expand_dims(_x, new_axes.index("S") + 1)
|
|
100
|
+
|
|
101
|
+
return _x
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
"""File utilities."""
|
|
2
|
+
|
|
3
|
+
from fnmatch import fnmatch
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import List, Union
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from careamics.config.support import SupportedData
|
|
10
|
+
from careamics.utils.logging import get_logger
|
|
11
|
+
|
|
12
|
+
logger = get_logger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_files_size(files: List[Path]) -> float:
|
|
16
|
+
"""Get files size in MB.
|
|
17
|
+
|
|
18
|
+
Parameters
|
|
19
|
+
----------
|
|
20
|
+
files : List[Path]
|
|
21
|
+
List of files.
|
|
22
|
+
|
|
23
|
+
Returns
|
|
24
|
+
-------
|
|
25
|
+
float
|
|
26
|
+
Total size of the files in MB.
|
|
27
|
+
"""
|
|
28
|
+
return np.sum([f.stat().st_size / 1024**2 for f in files])
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def list_files(
|
|
32
|
+
data_path: Union[str, Path],
|
|
33
|
+
data_type: Union[str, SupportedData],
|
|
34
|
+
extension_filter: str = "",
|
|
35
|
+
) -> List[Path]:
|
|
36
|
+
"""List recursively files in `data_path` and return a sorted list.
|
|
37
|
+
|
|
38
|
+
If `data_path` is a file, its name is validated against the `data_type` using
|
|
39
|
+
`fnmatch`, and the method returns `data_path` itself.
|
|
40
|
+
|
|
41
|
+
By default, if `data_type` is equal to `custom`, all files will be listed. To
|
|
42
|
+
further filter the files, use `extension_filter`.
|
|
43
|
+
|
|
44
|
+
`extension_filter` must be compatible with `fnmatch` and `Path.rglob`, e.g. "*.npy"
|
|
45
|
+
or "*.czi".
|
|
46
|
+
|
|
47
|
+
Parameters
|
|
48
|
+
----------
|
|
49
|
+
data_path : Union[str, Path]
|
|
50
|
+
Path to the folder containing the data.
|
|
51
|
+
data_type : Union[str, SupportedData]
|
|
52
|
+
One of the supported data type (e.g. tif, custom).
|
|
53
|
+
extension_filter : str, optional
|
|
54
|
+
Extension filter, by default "".
|
|
55
|
+
|
|
56
|
+
Returns
|
|
57
|
+
-------
|
|
58
|
+
List[Path]
|
|
59
|
+
List of pathlib.Path objects.
|
|
60
|
+
|
|
61
|
+
Raises
|
|
62
|
+
------
|
|
63
|
+
FileNotFoundError
|
|
64
|
+
If the data path does not exist.
|
|
65
|
+
ValueError
|
|
66
|
+
If the data path is empty or no files with the extension were found.
|
|
67
|
+
ValueError
|
|
68
|
+
If the file does not match the requested extension.
|
|
69
|
+
"""
|
|
70
|
+
# convert to Path
|
|
71
|
+
data_path = Path(data_path)
|
|
72
|
+
|
|
73
|
+
# raise error if does not exists
|
|
74
|
+
if not data_path.exists():
|
|
75
|
+
raise FileNotFoundError(f"Data path {data_path} does not exist.")
|
|
76
|
+
|
|
77
|
+
# get extension compatible with fnmatch and rglob search
|
|
78
|
+
extension = SupportedData.get_extension_pattern(data_type)
|
|
79
|
+
|
|
80
|
+
if data_type == SupportedData.CUSTOM and extension_filter != "":
|
|
81
|
+
extension = extension_filter
|
|
82
|
+
|
|
83
|
+
# search recurively
|
|
84
|
+
if data_path.is_dir():
|
|
85
|
+
# search recursively the path for files with the extension
|
|
86
|
+
files = sorted(data_path.rglob(extension))
|
|
87
|
+
else:
|
|
88
|
+
# raise error if it has the wrong extension
|
|
89
|
+
if not fnmatch(str(data_path.absolute()), extension):
|
|
90
|
+
raise ValueError(
|
|
91
|
+
f"File {data_path} does not match the requested extension "
|
|
92
|
+
f'"{extension}".'
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
# save in list
|
|
96
|
+
files = [data_path]
|
|
97
|
+
|
|
98
|
+
# raise error if no files were found
|
|
99
|
+
if len(files) == 0:
|
|
100
|
+
raise ValueError(
|
|
101
|
+
f'Data path {data_path} is empty or files with extension "{extension}" '
|
|
102
|
+
f"were not found."
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
return files
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def validate_source_target_files(src_files: List[Path], tar_files: List[Path]) -> None:
|
|
109
|
+
"""
|
|
110
|
+
Validate source and target path lists.
|
|
111
|
+
|
|
112
|
+
The two lists should have the same number of files, and the filenames should match.
|
|
113
|
+
|
|
114
|
+
Parameters
|
|
115
|
+
----------
|
|
116
|
+
src_files : List[Path]
|
|
117
|
+
List of source files.
|
|
118
|
+
tar_files : List[Path]
|
|
119
|
+
List of target files.
|
|
120
|
+
|
|
121
|
+
Raises
|
|
122
|
+
------
|
|
123
|
+
ValueError
|
|
124
|
+
If the number of files in source and target folders is not the same.
|
|
125
|
+
ValueError
|
|
126
|
+
If some filenames in Train and target folders are not the same.
|
|
127
|
+
"""
|
|
128
|
+
# check equal length
|
|
129
|
+
if len(src_files) != len(tar_files):
|
|
130
|
+
raise ValueError(
|
|
131
|
+
f"The number of source files ({len(src_files)}) is not equal to the number "
|
|
132
|
+
f"of target files ({len(tar_files)})."
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
# check identical names
|
|
136
|
+
src_names = {f.name for f in src_files}
|
|
137
|
+
tar_names = {f.name for f in tar_files}
|
|
138
|
+
difference = src_names.symmetric_difference(tar_names)
|
|
139
|
+
|
|
140
|
+
if len(difference) > 0:
|
|
141
|
+
raise ValueError(f"Source and target files have different names: {difference}.")
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
"""Function to iterate over files."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Callable, Generator, Optional, Union
|
|
7
|
+
|
|
8
|
+
from numpy.typing import NDArray
|
|
9
|
+
from torch.utils.data import get_worker_info
|
|
10
|
+
|
|
11
|
+
from careamics.config import DataConfig, InferenceConfig
|
|
12
|
+
from careamics.file_io.read import read_tiff
|
|
13
|
+
from careamics.utils.logging import get_logger
|
|
14
|
+
|
|
15
|
+
from .dataset_utils import reshape_array
|
|
16
|
+
|
|
17
|
+
logger = get_logger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def iterate_over_files(
|
|
21
|
+
data_config: Union[DataConfig, InferenceConfig],
|
|
22
|
+
data_files: list[Path],
|
|
23
|
+
target_files: Optional[list[Path]] = None,
|
|
24
|
+
read_source_func: Callable = read_tiff,
|
|
25
|
+
) -> Generator[tuple[NDArray, Optional[NDArray]], None, None]:
|
|
26
|
+
"""Iterate over data source and yield whole reshaped images.
|
|
27
|
+
|
|
28
|
+
Parameters
|
|
29
|
+
----------
|
|
30
|
+
data_config : CAREamics DataConfig or InferenceConfig
|
|
31
|
+
Configuration.
|
|
32
|
+
data_files : list of pathlib.Path
|
|
33
|
+
List of data files.
|
|
34
|
+
target_files : list of pathlib.Path, optional
|
|
35
|
+
List of target files, by default None.
|
|
36
|
+
read_source_func : Callable, optional
|
|
37
|
+
Function to read the source, by default read_tiff.
|
|
38
|
+
|
|
39
|
+
Yields
|
|
40
|
+
------
|
|
41
|
+
NDArray
|
|
42
|
+
Image.
|
|
43
|
+
"""
|
|
44
|
+
# When num_workers > 0, each worker process will have a different copy of the
|
|
45
|
+
# dataset object
|
|
46
|
+
# Configuring each copy independently to avoid having duplicate data returned
|
|
47
|
+
# from the workers
|
|
48
|
+
worker_info = get_worker_info()
|
|
49
|
+
worker_id = worker_info.id if worker_info is not None else 0
|
|
50
|
+
num_workers = worker_info.num_workers if worker_info is not None else 1
|
|
51
|
+
|
|
52
|
+
# iterate over the files
|
|
53
|
+
for i, filename in enumerate(data_files):
|
|
54
|
+
# retrieve file corresponding to the worker id
|
|
55
|
+
if i % num_workers == worker_id:
|
|
56
|
+
try:
|
|
57
|
+
# read data
|
|
58
|
+
sample = read_source_func(filename, data_config.axes)
|
|
59
|
+
|
|
60
|
+
# reshape array
|
|
61
|
+
reshaped_sample = reshape_array(sample, data_config.axes)
|
|
62
|
+
|
|
63
|
+
# read target, if available
|
|
64
|
+
if target_files is not None:
|
|
65
|
+
if filename.name != target_files[i].name:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
f"File {filename} does not match target file "
|
|
68
|
+
f"{target_files[i]}. Have you passed sorted "
|
|
69
|
+
f"arrays?"
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
# read target
|
|
73
|
+
target = read_source_func(target_files[i], data_config.axes)
|
|
74
|
+
|
|
75
|
+
# reshape target
|
|
76
|
+
reshaped_target = reshape_array(target, data_config.axes)
|
|
77
|
+
|
|
78
|
+
yield reshaped_sample, reshaped_target
|
|
79
|
+
else:
|
|
80
|
+
yield reshaped_sample, None
|
|
81
|
+
|
|
82
|
+
except Exception as e:
|
|
83
|
+
logger.error(f"Error reading file {filename}: {e}")
|
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
"""Computing data statistics."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from numpy.typing import NDArray
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def compute_normalization_stats(image: NDArray) -> tuple[NDArray, NDArray]:
|
|
8
|
+
"""
|
|
9
|
+
Compute mean and standard deviation of an array.
|
|
10
|
+
|
|
11
|
+
Expected input shape is (S, C, (Z), Y, X). The mean and standard deviation are
|
|
12
|
+
computed per channel.
|
|
13
|
+
|
|
14
|
+
Parameters
|
|
15
|
+
----------
|
|
16
|
+
image : NDArray
|
|
17
|
+
Input array.
|
|
18
|
+
|
|
19
|
+
Returns
|
|
20
|
+
-------
|
|
21
|
+
tuple of (list of floats, list of floats)
|
|
22
|
+
Lists of mean and standard deviation values per channel.
|
|
23
|
+
"""
|
|
24
|
+
# Define the list of axes excluding the channel axis
|
|
25
|
+
axes = tuple(np.delete(np.arange(image.ndim), 1))
|
|
26
|
+
return np.mean(image, axis=axes), np.std(image, axis=axes)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def update_iterative_stats(
|
|
30
|
+
count: NDArray, mean: NDArray, m2: NDArray, new_values: NDArray
|
|
31
|
+
) -> tuple[NDArray, NDArray, NDArray]:
|
|
32
|
+
"""Update the mean and variance of an array iteratively.
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
count : NDArray
|
|
37
|
+
Number of elements in the array.
|
|
38
|
+
mean : NDArray
|
|
39
|
+
Mean of the array.
|
|
40
|
+
m2 : NDArray
|
|
41
|
+
Variance of the array.
|
|
42
|
+
new_values : NDArray
|
|
43
|
+
New values to add to the mean and variance.
|
|
44
|
+
|
|
45
|
+
Returns
|
|
46
|
+
-------
|
|
47
|
+
tuple[NDArray, NDArray, NDArray]
|
|
48
|
+
Updated count, mean, and variance.
|
|
49
|
+
"""
|
|
50
|
+
count += np.array([np.prod(channel.shape) for channel in new_values])
|
|
51
|
+
# newvalues - oldMean
|
|
52
|
+
delta = [
|
|
53
|
+
np.subtract(v.flatten(), [m] * len(v.flatten()))
|
|
54
|
+
for v, m in zip(new_values, mean)
|
|
55
|
+
]
|
|
56
|
+
|
|
57
|
+
mean += np.array([np.sum(d / c) for d, c in zip(delta, count)])
|
|
58
|
+
# newvalues - newMeant
|
|
59
|
+
delta2 = [
|
|
60
|
+
np.subtract(v.flatten(), [m] * len(v.flatten()))
|
|
61
|
+
for v, m in zip(new_values, mean)
|
|
62
|
+
]
|
|
63
|
+
|
|
64
|
+
m2 += np.array([np.sum(d * d2) for d, d2 in zip(delta, delta2)])
|
|
65
|
+
|
|
66
|
+
return (count, mean, m2)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def finalize_iterative_stats(
|
|
70
|
+
count: NDArray, mean: NDArray, m2: NDArray
|
|
71
|
+
) -> tuple[NDArray, NDArray]:
|
|
72
|
+
"""Finalize the mean and variance computation.
|
|
73
|
+
|
|
74
|
+
Parameters
|
|
75
|
+
----------
|
|
76
|
+
count : NDArray
|
|
77
|
+
Number of elements in the array.
|
|
78
|
+
mean : NDArray
|
|
79
|
+
Mean of the array.
|
|
80
|
+
m2 : NDArray
|
|
81
|
+
Variance of the array.
|
|
82
|
+
|
|
83
|
+
Returns
|
|
84
|
+
-------
|
|
85
|
+
tuple[NDArray, NDArray]
|
|
86
|
+
Final mean and standard deviation.
|
|
87
|
+
"""
|
|
88
|
+
std = np.array([np.sqrt(m / c) for m, c in zip(m2, count)])
|
|
89
|
+
if any(c < 2 for c in count):
|
|
90
|
+
return np.full(mean.shape, np.nan), np.full(std.shape, np.nan)
|
|
91
|
+
else:
|
|
92
|
+
return mean, std
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class WelfordStatistics:
|
|
96
|
+
"""Compute Welford statistics iteratively.
|
|
97
|
+
|
|
98
|
+
The Welford algorithm is used to compute the mean and variance of an array
|
|
99
|
+
iteratively. Based on the implementation from:
|
|
100
|
+
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
def update(self, array: NDArray, sample_idx: int) -> None:
|
|
104
|
+
"""Update the Welford statistics.
|
|
105
|
+
|
|
106
|
+
Parameters
|
|
107
|
+
----------
|
|
108
|
+
array : NDArray
|
|
109
|
+
Input array.
|
|
110
|
+
sample_idx : int
|
|
111
|
+
Current sample number.
|
|
112
|
+
"""
|
|
113
|
+
self.sample_idx = sample_idx
|
|
114
|
+
sample_channels = np.array(np.split(array, array.shape[1], axis=1))
|
|
115
|
+
|
|
116
|
+
# Initialize the statistics
|
|
117
|
+
if self.sample_idx == 0:
|
|
118
|
+
# Compute the mean and standard deviation
|
|
119
|
+
self.mean, _ = compute_normalization_stats(array)
|
|
120
|
+
# Initialize the count and m2 with zero-valued arrays of shape (C,)
|
|
121
|
+
self.count, self.mean, self.m2 = update_iterative_stats(
|
|
122
|
+
count=np.zeros(array.shape[1]),
|
|
123
|
+
mean=self.mean,
|
|
124
|
+
m2=np.zeros(array.shape[1]),
|
|
125
|
+
new_values=sample_channels,
|
|
126
|
+
)
|
|
127
|
+
else:
|
|
128
|
+
# Update the statistics
|
|
129
|
+
self.count, self.mean, self.m2 = update_iterative_stats(
|
|
130
|
+
count=self.count, mean=self.mean, m2=self.m2, new_values=sample_channels
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
self.sample_idx += 1
|
|
134
|
+
|
|
135
|
+
def finalize(self) -> tuple[NDArray, NDArray]:
|
|
136
|
+
"""Finalize the Welford statistics.
|
|
137
|
+
|
|
138
|
+
Returns
|
|
139
|
+
-------
|
|
140
|
+
tuple or numpy arrays
|
|
141
|
+
Final mean and standard deviation.
|
|
142
|
+
"""
|
|
143
|
+
return finalize_iterative_stats(self.count, self.mean, self.m2)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
# from multiprocessing import Value
|
|
147
|
+
# from typing import tuple
|
|
148
|
+
|
|
149
|
+
# import numpy as np
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
# class RunningStats:
|
|
153
|
+
# """Calculates running mean and std."""
|
|
154
|
+
|
|
155
|
+
# def __init__(self) -> None:
|
|
156
|
+
# self.reset()
|
|
157
|
+
|
|
158
|
+
# def reset(self) -> None:
|
|
159
|
+
# """Reset the running stats."""
|
|
160
|
+
# self.avg_mean = Value("d", 0)
|
|
161
|
+
# self.avg_std = Value("d", 0)
|
|
162
|
+
# self.m2 = Value("d", 0)
|
|
163
|
+
# self.count = Value("i", 0)
|
|
164
|
+
|
|
165
|
+
# def init(self, mean: float, std: float) -> None:
|
|
166
|
+
# """Initialize running stats."""
|
|
167
|
+
# with self.avg_mean.get_lock():
|
|
168
|
+
# self.avg_mean.value += mean
|
|
169
|
+
# with self.avg_std.get_lock():
|
|
170
|
+
# self.avg_std.value = std
|
|
171
|
+
|
|
172
|
+
# def compute_std(self) -> tuple[float, float]:
|
|
173
|
+
# """Compute std."""
|
|
174
|
+
# if self.count.value >= 2:
|
|
175
|
+
# self.avg_std.value = np.sqrt(self.m2.value / self.count.value)
|
|
176
|
+
|
|
177
|
+
# def update(self, value: float) -> None:
|
|
178
|
+
# """Update running stats."""
|
|
179
|
+
# with self.count.get_lock():
|
|
180
|
+
# self.count.value += 1
|
|
181
|
+
# delta = value - self.avg_mean.value
|
|
182
|
+
# with self.avg_mean.get_lock():
|
|
183
|
+
# self.avg_mean.value += delta / self.count.value
|
|
184
|
+
# delta2 = value - self.avg_mean.value
|
|
185
|
+
# with self.m2.get_lock():
|
|
186
|
+
# self.m2.value += delta * delta2
|