careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +1 -14
- careamics/careamist.py +212 -294
- careamics/config/__init__.py +0 -3
- careamics/config/algorithm_model.py +8 -15
- careamics/config/architectures/architecture_model.py +1 -0
- careamics/config/architectures/custom_model.py +5 -3
- careamics/config/architectures/unet_model.py +19 -0
- careamics/config/architectures/vae_model.py +1 -0
- careamics/config/callback_model.py +76 -34
- careamics/config/configuration_factory.py +18 -98
- careamics/config/configuration_model.py +23 -18
- careamics/config/data_model.py +103 -54
- careamics/config/inference_model.py +41 -19
- careamics/config/optimizer_models.py +13 -7
- careamics/config/support/supported_data.py +29 -4
- careamics/config/support/supported_transforms.py +0 -1
- careamics/config/tile_information.py +36 -58
- careamics/config/training_model.py +5 -1
- careamics/config/transformations/normalize_model.py +32 -4
- careamics/config/validators/validator_utils.py +1 -1
- careamics/dataset/__init__.py +12 -1
- careamics/dataset/dataset_utils/__init__.py +8 -7
- careamics/dataset/dataset_utils/file_utils.py +2 -2
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +84 -173
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +97 -250
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/patching.py +97 -52
- careamics/dataset/patching/random_patching.py +9 -4
- careamics/dataset/patching/validate_patch_dimension.py +5 -3
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
- careamics/file_io/__init__.py +7 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/{dataset/dataset_utils/read_tiff.py → file_io/read/tiff.py} +3 -10
- careamics/file_io/write/__init__.py +9 -0
- careamics/file_io/write/get_func.py +59 -0
- careamics/file_io/write/tiff.py +39 -0
- careamics/lightning/__init__.py +17 -0
- careamics/{lightning_module.py → lightning/lightning_module.py} +69 -92
- careamics/{lightning_prediction_datamodule.py → lightning/predict_data_module.py} +120 -178
- careamics/{lightning_datamodule.py → lightning/train_data_module.py} +135 -220
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/bioimage/model_description.py +40 -32
- careamics/model_io/bmz_io.py +2 -2
- careamics/model_io/model_io_utils.py +6 -3
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/prediction_outputs.py +137 -0
- careamics/prediction_utils/stitch_prediction.py +103 -0
- careamics/transforms/n2v_manipulate.py +3 -1
- careamics/transforms/normalize.py +139 -68
- careamics/transforms/pixel_manipulation.py +33 -9
- careamics/transforms/tta.py +43 -29
- careamics/utils/__init__.py +2 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/ram.py +2 -2
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/METADATA +7 -6
- careamics-0.1.0rc8.dist-info/RECORD +135 -0
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/WHEEL +1 -1
- careamics/config/configuration_example.py +0 -89
- careamics/dataset/dataset_utils/read_utils.py +0 -27
- careamics/lightning_prediction_loop.py +0 -118
- careamics/prediction/__init__.py +0 -7
- careamics/prediction/stitch_prediction.py +0 -70
- careamics/utils/running_stats.py +0 -43
- careamics-0.1.0rc6.dist-info/RECORD +0 -107
- /careamics/{dataset/dataset_utils/read_zarr.py → file_io/read/zarr.py} +0 -0
- /careamics/{callbacks → lightning/callbacks}/__init__.py +0 -0
- /careamics/{callbacks → lightning/callbacks}/hyperparameters_callback.py +0 -0
- /careamics/{callbacks → lightning/callbacks}/progress_bar_callback.py +0 -0
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/licenses/LICENSE +0 -0
|
@@ -2,9 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
from
|
|
6
|
-
|
|
7
|
-
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
|
|
5
|
+
from pydantic import BaseModel, ConfigDict, field_validator
|
|
8
6
|
|
|
9
7
|
|
|
10
8
|
class TileInformation(BaseModel):
|
|
@@ -13,30 +11,43 @@ class TileInformation(BaseModel):
|
|
|
13
11
|
|
|
14
12
|
This model is used to represent the information required to stitch back a tile into
|
|
15
13
|
a larger image. It is used throughout the prediction pipeline of CAREamics.
|
|
14
|
+
|
|
15
|
+
Array shape should be (C)(Z)YX, where C and Z are optional dimensions, and must not
|
|
16
|
+
contain singleton dimensions.
|
|
16
17
|
"""
|
|
17
18
|
|
|
18
19
|
model_config = ConfigDict(validate_default=True)
|
|
19
20
|
|
|
20
|
-
array_shape:
|
|
21
|
-
|
|
21
|
+
array_shape: tuple[int, ...]
|
|
22
|
+
"""Shape of the original (untiled) array."""
|
|
23
|
+
|
|
22
24
|
last_tile: bool = False
|
|
23
|
-
|
|
24
|
-
|
|
25
|
+
"""Whether this tile is the last one of the array."""
|
|
26
|
+
|
|
27
|
+
overlap_crop_coords: tuple[tuple[int, ...], ...]
|
|
28
|
+
"""Inner coordinates of the tile where to crop the prediction in order to stitch
|
|
29
|
+
it back into the original image."""
|
|
30
|
+
|
|
31
|
+
stitch_coords: tuple[tuple[int, ...], ...]
|
|
32
|
+
"""Coordinates in the original image where to stitch the cropped tile back."""
|
|
33
|
+
|
|
34
|
+
sample_id: int
|
|
35
|
+
"""Sample ID of the tile."""
|
|
25
36
|
|
|
26
37
|
@field_validator("array_shape")
|
|
27
38
|
@classmethod
|
|
28
|
-
def no_singleton_dimensions(cls, v:
|
|
39
|
+
def no_singleton_dimensions(cls, v: tuple[int, ...]):
|
|
29
40
|
"""
|
|
30
41
|
Check that the array shape does not have any singleton dimensions.
|
|
31
42
|
|
|
32
43
|
Parameters
|
|
33
44
|
----------
|
|
34
|
-
v :
|
|
45
|
+
v : tuple of int
|
|
35
46
|
Array shape to check.
|
|
36
47
|
|
|
37
48
|
Returns
|
|
38
49
|
-------
|
|
39
|
-
|
|
50
|
+
tuple of int
|
|
40
51
|
The array shape if it does not contain singleton dimensions.
|
|
41
52
|
|
|
42
53
|
Raises
|
|
@@ -48,59 +59,26 @@ class TileInformation(BaseModel):
|
|
|
48
59
|
raise ValueError("Array shape must not contain singleton dimensions.")
|
|
49
60
|
return v
|
|
50
61
|
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
def only_if_tiled(cls, v: bool, values: ValidationInfo):
|
|
54
|
-
"""
|
|
55
|
-
Check that the last tile flag is only set if tiling is enabled.
|
|
62
|
+
def __eq__(self, other_tile: object):
|
|
63
|
+
"""Check if two tile information objects are equal.
|
|
56
64
|
|
|
57
65
|
Parameters
|
|
58
66
|
----------
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
values : ValidationInfo
|
|
62
|
-
Validation information.
|
|
67
|
+
other_tile : object
|
|
68
|
+
Tile information object to compare with.
|
|
63
69
|
|
|
64
70
|
Returns
|
|
65
71
|
-------
|
|
66
72
|
bool
|
|
67
|
-
|
|
68
|
-
"""
|
|
69
|
-
if not values.data["tiled"]:
|
|
70
|
-
return False
|
|
71
|
-
return v
|
|
72
|
-
|
|
73
|
-
@field_validator("overlap_crop_coords", "stitch_coords")
|
|
74
|
-
@classmethod
|
|
75
|
-
def mandatory_if_tiled(
|
|
76
|
-
cls, v: Optional[Tuple[int, ...]], values: ValidationInfo
|
|
77
|
-
) -> Optional[Tuple[int, ...]]:
|
|
73
|
+
Whether the two tile information objects are equal.
|
|
78
74
|
"""
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
Returns
|
|
91
|
-
-------
|
|
92
|
-
Optional[Tuple[int, ...]]
|
|
93
|
-
The coordinates if tiling is enabled, otherwise `None`.
|
|
94
|
-
|
|
95
|
-
Raises
|
|
96
|
-
------
|
|
97
|
-
ValueError
|
|
98
|
-
If the coordinates are `None` and tiling is enabled.
|
|
99
|
-
"""
|
|
100
|
-
if values.data["tiled"]:
|
|
101
|
-
if v is None:
|
|
102
|
-
raise ValueError("Value must be specified if tiling is enabled.")
|
|
103
|
-
|
|
104
|
-
return v
|
|
105
|
-
else:
|
|
106
|
-
return None
|
|
75
|
+
if not isinstance(other_tile, TileInformation):
|
|
76
|
+
return NotImplemented
|
|
77
|
+
|
|
78
|
+
return (
|
|
79
|
+
self.array_shape == other_tile.array_shape
|
|
80
|
+
and self.last_tile == other_tile.last_tile
|
|
81
|
+
and self.overlap_crop_coords == other_tile.overlap_crop_coords
|
|
82
|
+
and self.stitch_coords == other_tile.stitch_coords
|
|
83
|
+
and self.sample_id == other_tile.sample_id
|
|
84
|
+
)
|
|
@@ -35,15 +35,19 @@ class TrainingConfig(BaseModel):
|
|
|
35
35
|
)
|
|
36
36
|
|
|
37
37
|
num_epochs: int = Field(default=20, ge=1)
|
|
38
|
+
"""Number of epochs, greater than 0."""
|
|
38
39
|
|
|
39
40
|
logger: Optional[Literal["wandb", "tensorboard"]] = None
|
|
41
|
+
"""Logger to use during training. If None, no logger will be used. Available
|
|
42
|
+
loggers are defined in SupportedLogger."""
|
|
40
43
|
|
|
41
44
|
checkpoint_callback: CheckpointModel = CheckpointModel()
|
|
45
|
+
"""Checkpoint callback configuration."""
|
|
42
46
|
|
|
43
47
|
early_stopping_callback: Optional[EarlyStoppingModel] = Field(
|
|
44
48
|
default=None, validate_default=True
|
|
45
49
|
)
|
|
46
|
-
|
|
50
|
+
"""Early stopping callback configuration."""
|
|
47
51
|
|
|
48
52
|
def __str__(self) -> str:
|
|
49
53
|
"""Pretty string reprensenting the configuration.
|
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
"""Pydantic model for the Normalize transform."""
|
|
2
2
|
|
|
3
|
-
from typing import Literal
|
|
3
|
+
from typing import Literal, Optional
|
|
4
4
|
|
|
5
|
-
from pydantic import ConfigDict, Field
|
|
5
|
+
from pydantic import ConfigDict, Field, model_validator
|
|
6
|
+
from typing_extensions import Self
|
|
6
7
|
|
|
7
8
|
from .transform_model import TransformModel
|
|
8
9
|
|
|
@@ -28,5 +29,32 @@ class NormalizeModel(TransformModel):
|
|
|
28
29
|
)
|
|
29
30
|
|
|
30
31
|
name: Literal["Normalize"] = "Normalize"
|
|
31
|
-
|
|
32
|
-
|
|
32
|
+
image_means: list = Field(..., min_length=0, max_length=32)
|
|
33
|
+
image_stds: list = Field(..., min_length=0, max_length=32)
|
|
34
|
+
target_means: Optional[list] = Field(default=None, min_length=0, max_length=32)
|
|
35
|
+
target_stds: Optional[list] = Field(default=None, min_length=0, max_length=32)
|
|
36
|
+
|
|
37
|
+
@model_validator(mode="after")
|
|
38
|
+
def validate_means_stds(self: Self) -> Self:
|
|
39
|
+
"""Validate that the means and stds have the same length.
|
|
40
|
+
|
|
41
|
+
Returns
|
|
42
|
+
-------
|
|
43
|
+
Self
|
|
44
|
+
The instance of the model.
|
|
45
|
+
"""
|
|
46
|
+
if len(self.image_means) != len(self.image_stds):
|
|
47
|
+
raise ValueError("The number of image means and stds must be the same.")
|
|
48
|
+
|
|
49
|
+
if (self.target_means is None) != (self.target_stds is None):
|
|
50
|
+
raise ValueError(
|
|
51
|
+
"Both target means and stds must be provided together, or bot None."
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
if self.target_means is not None and self.target_stds is not None:
|
|
55
|
+
if len(self.target_means) != len(self.target_stds):
|
|
56
|
+
raise ValueError(
|
|
57
|
+
"The number of target means and stds must be the same."
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
return self
|
|
@@ -72,7 +72,7 @@ def value_ge_than_8_power_of_2(
|
|
|
72
72
|
If the value is not a power of 2.
|
|
73
73
|
"""
|
|
74
74
|
if value < 8:
|
|
75
|
-
raise ValueError(f"Value must be
|
|
75
|
+
raise ValueError(f"Value must be greater than 8 (got {value}).")
|
|
76
76
|
|
|
77
77
|
if (value & (value - 1)) != 0:
|
|
78
78
|
raise ValueError(f"Value must be a power of 2 (got {value}).")
|
careamics/dataset/__init__.py
CHANGED
|
@@ -1,6 +1,17 @@
|
|
|
1
1
|
"""Dataset module."""
|
|
2
2
|
|
|
3
|
-
__all__ = [
|
|
3
|
+
__all__ = [
|
|
4
|
+
"InMemoryDataset",
|
|
5
|
+
"InMemoryPredDataset",
|
|
6
|
+
"InMemoryTiledPredDataset",
|
|
7
|
+
"PathIterableDataset",
|
|
8
|
+
"IterableTiledPredDataset",
|
|
9
|
+
"IterablePredDataset",
|
|
10
|
+
]
|
|
4
11
|
|
|
5
12
|
from .in_memory_dataset import InMemoryDataset
|
|
13
|
+
from .in_memory_pred_dataset import InMemoryPredDataset
|
|
14
|
+
from .in_memory_tiled_pred_dataset import InMemoryTiledPredDataset
|
|
6
15
|
from .iterable_dataset import PathIterableDataset
|
|
16
|
+
from .iterable_pred_dataset import IterablePredDataset
|
|
17
|
+
from .iterable_tiled_pred_dataset import IterableTiledPredDataset
|
|
@@ -2,17 +2,18 @@
|
|
|
2
2
|
|
|
3
3
|
__all__ = [
|
|
4
4
|
"reshape_array",
|
|
5
|
+
"compute_normalization_stats",
|
|
5
6
|
"get_files_size",
|
|
6
7
|
"list_files",
|
|
7
8
|
"validate_source_target_files",
|
|
8
|
-
"
|
|
9
|
-
"
|
|
10
|
-
"read_zarr",
|
|
9
|
+
"iterate_over_files",
|
|
10
|
+
"WelfordStatistics",
|
|
11
11
|
]
|
|
12
12
|
|
|
13
13
|
|
|
14
|
-
from .dataset_utils import
|
|
14
|
+
from .dataset_utils import (
|
|
15
|
+
reshape_array,
|
|
16
|
+
)
|
|
15
17
|
from .file_utils import get_files_size, list_files, validate_source_target_files
|
|
16
|
-
from .
|
|
17
|
-
from .
|
|
18
|
-
from .read_zarr import read_zarr
|
|
18
|
+
from .iterate_over_files import iterate_over_files
|
|
19
|
+
from .running_stats import WelfordStatistics, compute_normalization_stats
|
|
@@ -33,7 +33,7 @@ def list_files(
|
|
|
33
33
|
data_type: Union[str, SupportedData],
|
|
34
34
|
extension_filter: str = "",
|
|
35
35
|
) -> List[Path]:
|
|
36
|
-
"""
|
|
36
|
+
"""List recursively files in `data_path` and return a sorted list.
|
|
37
37
|
|
|
38
38
|
If `data_path` is a file, its name is validated against the `data_type` using
|
|
39
39
|
`fnmatch`, and the method returns `data_path` itself.
|
|
@@ -75,7 +75,7 @@ def list_files(
|
|
|
75
75
|
raise FileNotFoundError(f"Data path {data_path} does not exist.")
|
|
76
76
|
|
|
77
77
|
# get extension compatible with fnmatch and rglob search
|
|
78
|
-
extension = SupportedData.
|
|
78
|
+
extension = SupportedData.get_extension_pattern(data_type)
|
|
79
79
|
|
|
80
80
|
if data_type == SupportedData.CUSTOM and extension_filter != "":
|
|
81
81
|
extension = extension_filter
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
"""Function to iterate over files."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Callable, Generator, Optional, Union
|
|
7
|
+
|
|
8
|
+
from numpy.typing import NDArray
|
|
9
|
+
from torch.utils.data import get_worker_info
|
|
10
|
+
|
|
11
|
+
from careamics.config import DataConfig, InferenceConfig
|
|
12
|
+
from careamics.file_io.read import read_tiff
|
|
13
|
+
from careamics.utils.logging import get_logger
|
|
14
|
+
|
|
15
|
+
from .dataset_utils import reshape_array
|
|
16
|
+
|
|
17
|
+
logger = get_logger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def iterate_over_files(
|
|
21
|
+
data_config: Union[DataConfig, InferenceConfig],
|
|
22
|
+
data_files: list[Path],
|
|
23
|
+
target_files: Optional[list[Path]] = None,
|
|
24
|
+
read_source_func: Callable = read_tiff,
|
|
25
|
+
) -> Generator[tuple[NDArray, Optional[NDArray]], None, None]:
|
|
26
|
+
"""Iterate over data source and yield whole reshaped images.
|
|
27
|
+
|
|
28
|
+
Parameters
|
|
29
|
+
----------
|
|
30
|
+
data_config : CAREamics DataConfig or InferenceConfig
|
|
31
|
+
Configuration.
|
|
32
|
+
data_files : list of pathlib.Path
|
|
33
|
+
List of data files.
|
|
34
|
+
target_files : list of pathlib.Path, optional
|
|
35
|
+
List of target files, by default None.
|
|
36
|
+
read_source_func : Callable, optional
|
|
37
|
+
Function to read the source, by default read_tiff.
|
|
38
|
+
|
|
39
|
+
Yields
|
|
40
|
+
------
|
|
41
|
+
NDArray
|
|
42
|
+
Image.
|
|
43
|
+
"""
|
|
44
|
+
# When num_workers > 0, each worker process will have a different copy of the
|
|
45
|
+
# dataset object
|
|
46
|
+
# Configuring each copy independently to avoid having duplicate data returned
|
|
47
|
+
# from the workers
|
|
48
|
+
worker_info = get_worker_info()
|
|
49
|
+
worker_id = worker_info.id if worker_info is not None else 0
|
|
50
|
+
num_workers = worker_info.num_workers if worker_info is not None else 1
|
|
51
|
+
|
|
52
|
+
# iterate over the files
|
|
53
|
+
for i, filename in enumerate(data_files):
|
|
54
|
+
# retrieve file corresponding to the worker id
|
|
55
|
+
if i % num_workers == worker_id:
|
|
56
|
+
try:
|
|
57
|
+
# read data
|
|
58
|
+
sample = read_source_func(filename, data_config.axes)
|
|
59
|
+
|
|
60
|
+
# reshape array
|
|
61
|
+
reshaped_sample = reshape_array(sample, data_config.axes)
|
|
62
|
+
|
|
63
|
+
# read target, if available
|
|
64
|
+
if target_files is not None:
|
|
65
|
+
if filename.name != target_files[i].name:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
f"File {filename} does not match target file "
|
|
68
|
+
f"{target_files[i]}. Have you passed sorted "
|
|
69
|
+
f"arrays?"
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
# read target
|
|
73
|
+
target = read_source_func(target_files[i], data_config.axes)
|
|
74
|
+
|
|
75
|
+
# reshape target
|
|
76
|
+
reshaped_target = reshape_array(target, data_config.axes)
|
|
77
|
+
|
|
78
|
+
yield reshaped_sample, reshaped_target
|
|
79
|
+
else:
|
|
80
|
+
yield reshaped_sample, None
|
|
81
|
+
|
|
82
|
+
except Exception as e:
|
|
83
|
+
logger.error(f"Error reading file {filename}: {e}")
|
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
"""Computing data statistics."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from numpy.typing import NDArray
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def compute_normalization_stats(image: NDArray) -> tuple[NDArray, NDArray]:
|
|
8
|
+
"""
|
|
9
|
+
Compute mean and standard deviation of an array.
|
|
10
|
+
|
|
11
|
+
Expected input shape is (S, C, (Z), Y, X). The mean and standard deviation are
|
|
12
|
+
computed per channel.
|
|
13
|
+
|
|
14
|
+
Parameters
|
|
15
|
+
----------
|
|
16
|
+
image : NDArray
|
|
17
|
+
Input array.
|
|
18
|
+
|
|
19
|
+
Returns
|
|
20
|
+
-------
|
|
21
|
+
tuple of (list of floats, list of floats)
|
|
22
|
+
Lists of mean and standard deviation values per channel.
|
|
23
|
+
"""
|
|
24
|
+
# Define the list of axes excluding the channel axis
|
|
25
|
+
axes = tuple(np.delete(np.arange(image.ndim), 1))
|
|
26
|
+
return np.mean(image, axis=axes), np.std(image, axis=axes)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def update_iterative_stats(
|
|
30
|
+
count: NDArray, mean: NDArray, m2: NDArray, new_values: NDArray
|
|
31
|
+
) -> tuple[NDArray, NDArray, NDArray]:
|
|
32
|
+
"""Update the mean and variance of an array iteratively.
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
count : NDArray
|
|
37
|
+
Number of elements in the array.
|
|
38
|
+
mean : NDArray
|
|
39
|
+
Mean of the array.
|
|
40
|
+
m2 : NDArray
|
|
41
|
+
Variance of the array.
|
|
42
|
+
new_values : NDArray
|
|
43
|
+
New values to add to the mean and variance.
|
|
44
|
+
|
|
45
|
+
Returns
|
|
46
|
+
-------
|
|
47
|
+
tuple[NDArray, NDArray, NDArray]
|
|
48
|
+
Updated count, mean, and variance.
|
|
49
|
+
"""
|
|
50
|
+
count += np.array([np.prod(channel.shape) for channel in new_values])
|
|
51
|
+
# newvalues - oldMean
|
|
52
|
+
delta = [
|
|
53
|
+
np.subtract(v.flatten(), [m] * len(v.flatten()))
|
|
54
|
+
for v, m in zip(new_values, mean)
|
|
55
|
+
]
|
|
56
|
+
|
|
57
|
+
mean += np.array([np.sum(d / c) for d, c in zip(delta, count)])
|
|
58
|
+
# newvalues - newMeant
|
|
59
|
+
delta2 = [
|
|
60
|
+
np.subtract(v.flatten(), [m] * len(v.flatten()))
|
|
61
|
+
for v, m in zip(new_values, mean)
|
|
62
|
+
]
|
|
63
|
+
|
|
64
|
+
m2 += np.array([np.sum(d * d2) for d, d2 in zip(delta, delta2)])
|
|
65
|
+
|
|
66
|
+
return (count, mean, m2)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def finalize_iterative_stats(
|
|
70
|
+
count: NDArray, mean: NDArray, m2: NDArray
|
|
71
|
+
) -> tuple[NDArray, NDArray]:
|
|
72
|
+
"""Finalize the mean and variance computation.
|
|
73
|
+
|
|
74
|
+
Parameters
|
|
75
|
+
----------
|
|
76
|
+
count : NDArray
|
|
77
|
+
Number of elements in the array.
|
|
78
|
+
mean : NDArray
|
|
79
|
+
Mean of the array.
|
|
80
|
+
m2 : NDArray
|
|
81
|
+
Variance of the array.
|
|
82
|
+
|
|
83
|
+
Returns
|
|
84
|
+
-------
|
|
85
|
+
tuple[NDArray, NDArray]
|
|
86
|
+
Final mean and standard deviation.
|
|
87
|
+
"""
|
|
88
|
+
std = np.array([np.sqrt(m / c) for m, c in zip(m2, count)])
|
|
89
|
+
if any(c < 2 for c in count):
|
|
90
|
+
return np.full(mean.shape, np.nan), np.full(std.shape, np.nan)
|
|
91
|
+
else:
|
|
92
|
+
return mean, std
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class WelfordStatistics:
|
|
96
|
+
"""Compute Welford statistics iteratively.
|
|
97
|
+
|
|
98
|
+
The Welford algorithm is used to compute the mean and variance of an array
|
|
99
|
+
iteratively. Based on the implementation from:
|
|
100
|
+
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
def update(self, array: NDArray, sample_idx: int) -> None:
|
|
104
|
+
"""Update the Welford statistics.
|
|
105
|
+
|
|
106
|
+
Parameters
|
|
107
|
+
----------
|
|
108
|
+
array : NDArray
|
|
109
|
+
Input array.
|
|
110
|
+
sample_idx : int
|
|
111
|
+
Current sample number.
|
|
112
|
+
"""
|
|
113
|
+
self.sample_idx = sample_idx
|
|
114
|
+
sample_channels = np.array(np.split(array, array.shape[1], axis=1))
|
|
115
|
+
|
|
116
|
+
# Initialize the statistics
|
|
117
|
+
if self.sample_idx == 0:
|
|
118
|
+
# Compute the mean and standard deviation
|
|
119
|
+
self.mean, _ = compute_normalization_stats(array)
|
|
120
|
+
# Initialize the count and m2 with zero-valued arrays of shape (C,)
|
|
121
|
+
self.count, self.mean, self.m2 = update_iterative_stats(
|
|
122
|
+
count=np.zeros(array.shape[1]),
|
|
123
|
+
mean=self.mean,
|
|
124
|
+
m2=np.zeros(array.shape[1]),
|
|
125
|
+
new_values=sample_channels,
|
|
126
|
+
)
|
|
127
|
+
else:
|
|
128
|
+
# Update the statistics
|
|
129
|
+
self.count, self.mean, self.m2 = update_iterative_stats(
|
|
130
|
+
count=self.count, mean=self.mean, m2=self.m2, new_values=sample_channels
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
self.sample_idx += 1
|
|
134
|
+
|
|
135
|
+
def finalize(self) -> tuple[NDArray, NDArray]:
|
|
136
|
+
"""Finalize the Welford statistics.
|
|
137
|
+
|
|
138
|
+
Returns
|
|
139
|
+
-------
|
|
140
|
+
tuple or numpy arrays
|
|
141
|
+
Final mean and standard deviation.
|
|
142
|
+
"""
|
|
143
|
+
return finalize_iterative_stats(self.count, self.mean, self.m2)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
# from multiprocessing import Value
|
|
147
|
+
# from typing import tuple
|
|
148
|
+
|
|
149
|
+
# import numpy as np
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
# class RunningStats:
|
|
153
|
+
# """Calculates running mean and std."""
|
|
154
|
+
|
|
155
|
+
# def __init__(self) -> None:
|
|
156
|
+
# self.reset()
|
|
157
|
+
|
|
158
|
+
# def reset(self) -> None:
|
|
159
|
+
# """Reset the running stats."""
|
|
160
|
+
# self.avg_mean = Value("d", 0)
|
|
161
|
+
# self.avg_std = Value("d", 0)
|
|
162
|
+
# self.m2 = Value("d", 0)
|
|
163
|
+
# self.count = Value("i", 0)
|
|
164
|
+
|
|
165
|
+
# def init(self, mean: float, std: float) -> None:
|
|
166
|
+
# """Initialize running stats."""
|
|
167
|
+
# with self.avg_mean.get_lock():
|
|
168
|
+
# self.avg_mean.value += mean
|
|
169
|
+
# with self.avg_std.get_lock():
|
|
170
|
+
# self.avg_std.value = std
|
|
171
|
+
|
|
172
|
+
# def compute_std(self) -> tuple[float, float]:
|
|
173
|
+
# """Compute std."""
|
|
174
|
+
# if self.count.value >= 2:
|
|
175
|
+
# self.avg_std.value = np.sqrt(self.m2.value / self.count.value)
|
|
176
|
+
|
|
177
|
+
# def update(self, value: float) -> None:
|
|
178
|
+
# """Update running stats."""
|
|
179
|
+
# with self.count.get_lock():
|
|
180
|
+
# self.count.value += 1
|
|
181
|
+
# delta = value - self.avg_mean.value
|
|
182
|
+
# with self.avg_mean.get_lock():
|
|
183
|
+
# self.avg_mean.value += delta / self.count.value
|
|
184
|
+
# delta2 = value - self.avg_mean.value
|
|
185
|
+
# with self.m2.get_lock():
|
|
186
|
+
# self.m2.value += delta * delta2
|