careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +163 -266
- careamics/config/algorithm_model.py +0 -15
- careamics/config/architectures/custom_model.py +3 -3
- careamics/config/configuration_example.py +0 -3
- careamics/config/configuration_factory.py +23 -25
- careamics/config/configuration_model.py +11 -11
- careamics/config/data_model.py +80 -50
- careamics/config/inference_model.py +29 -17
- careamics/config/optimizer_models.py +7 -7
- careamics/config/support/supported_transforms.py +0 -1
- careamics/config/tile_information.py +26 -58
- careamics/config/transformations/normalize_model.py +32 -4
- careamics/config/validators/validator_utils.py +1 -1
- careamics/dataset/__init__.py +12 -1
- careamics/dataset/dataset_utils/__init__.py +8 -1
- careamics/dataset/dataset_utils/file_utils.py +1 -1
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/read_tiff.py +0 -9
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +66 -171
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +92 -249
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
- careamics/dataset/patching/patching.py +54 -25
- careamics/dataset/patching/random_patching.py +9 -4
- careamics/dataset/patching/validate_patch_dimension.py +5 -3
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
- careamics/lightning_datamodule.py +1 -6
- careamics/lightning_module.py +11 -7
- careamics/lightning_prediction_datamodule.py +52 -72
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/bioimage/model_description.py +40 -32
- careamics/model_io/bmz_io.py +1 -1
- careamics/model_io/model_io_utils.py +5 -2
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/prediction_utils/__init__.py +12 -0
- careamics/prediction_utils/create_pred_datamodule.py +185 -0
- careamics/prediction_utils/prediction_outputs.py +165 -0
- careamics/prediction_utils/stitch_prediction.py +100 -0
- careamics/transforms/n2v_manipulate.py +3 -1
- careamics/transforms/normalize.py +139 -68
- careamics/transforms/pixel_manipulation.py +33 -9
- careamics/transforms/tta.py +43 -29
- careamics/utils/ram.py +2 -2
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +7 -6
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/RECORD +65 -42
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
- careamics/lightning_prediction_loop.py +0 -118
- careamics/prediction/__init__.py +0 -7
- careamics/prediction/stitch_prediction.py +0 -70
- careamics/utils/running_stats.py +0 -43
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
from typing import
|
|
5
|
+
from typing import Literal
|
|
6
6
|
|
|
7
7
|
from pydantic import (
|
|
8
8
|
BaseModel,
|
|
@@ -32,7 +32,7 @@ class OptimizerModel(BaseModel):
|
|
|
32
32
|
|
|
33
33
|
Attributes
|
|
34
34
|
----------
|
|
35
|
-
name :
|
|
35
|
+
name : {"Adam", "SGD"}
|
|
36
36
|
Name of the optimizer.
|
|
37
37
|
parameters : dict
|
|
38
38
|
Parameters of the optimizer (see torch documentation).
|
|
@@ -56,7 +56,7 @@ class OptimizerModel(BaseModel):
|
|
|
56
56
|
|
|
57
57
|
@field_validator("parameters")
|
|
58
58
|
@classmethod
|
|
59
|
-
def filter_parameters(cls, user_params: dict, values: ValidationInfo) ->
|
|
59
|
+
def filter_parameters(cls, user_params: dict, values: ValidationInfo) -> dict:
|
|
60
60
|
"""
|
|
61
61
|
Validate optimizer parameters.
|
|
62
62
|
|
|
@@ -71,7 +71,7 @@ class OptimizerModel(BaseModel):
|
|
|
71
71
|
|
|
72
72
|
Returns
|
|
73
73
|
-------
|
|
74
|
-
|
|
74
|
+
dict
|
|
75
75
|
Filtered optimizer parameters.
|
|
76
76
|
|
|
77
77
|
Raises
|
|
@@ -127,7 +127,7 @@ class LrSchedulerModel(BaseModel):
|
|
|
127
127
|
|
|
128
128
|
Attributes
|
|
129
129
|
----------
|
|
130
|
-
name :
|
|
130
|
+
name : {"ReduceLROnPlateau", "StepLR"}
|
|
131
131
|
Name of the learning rate scheduler.
|
|
132
132
|
parameters : dict
|
|
133
133
|
Parameters of the learning rate scheduler (see torch documentation).
|
|
@@ -146,7 +146,7 @@ class LrSchedulerModel(BaseModel):
|
|
|
146
146
|
|
|
147
147
|
@field_validator("parameters")
|
|
148
148
|
@classmethod
|
|
149
|
-
def filter_parameters(cls, user_params: dict, values: ValidationInfo) ->
|
|
149
|
+
def filter_parameters(cls, user_params: dict, values: ValidationInfo) -> dict:
|
|
150
150
|
"""Filter parameters based on the learning rate scheduler's signature.
|
|
151
151
|
|
|
152
152
|
Parameters
|
|
@@ -158,7 +158,7 @@ class LrSchedulerModel(BaseModel):
|
|
|
158
158
|
|
|
159
159
|
Returns
|
|
160
160
|
-------
|
|
161
|
-
|
|
161
|
+
dict
|
|
162
162
|
Filtered scheduler parameters.
|
|
163
163
|
|
|
164
164
|
Raises
|
|
@@ -2,9 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
from
|
|
6
|
-
|
|
7
|
-
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
|
|
5
|
+
from pydantic import BaseModel, ConfigDict, field_validator
|
|
8
6
|
|
|
9
7
|
|
|
10
8
|
class TileInformation(BaseModel):
|
|
@@ -13,30 +11,33 @@ class TileInformation(BaseModel):
|
|
|
13
11
|
|
|
14
12
|
This model is used to represent the information required to stitch back a tile into
|
|
15
13
|
a larger image. It is used throughout the prediction pipeline of CAREamics.
|
|
14
|
+
|
|
15
|
+
Array shape should be (C)(Z)YX, where C and Z are optional dimensions, and must not
|
|
16
|
+
contain singleton dimensions.
|
|
16
17
|
"""
|
|
17
18
|
|
|
18
19
|
model_config = ConfigDict(validate_default=True)
|
|
19
20
|
|
|
20
|
-
array_shape:
|
|
21
|
-
tiled: bool = False
|
|
21
|
+
array_shape: tuple[int, ...]
|
|
22
22
|
last_tile: bool = False
|
|
23
|
-
overlap_crop_coords:
|
|
24
|
-
stitch_coords:
|
|
23
|
+
overlap_crop_coords: tuple[tuple[int, ...], ...]
|
|
24
|
+
stitch_coords: tuple[tuple[int, ...], ...]
|
|
25
|
+
sample_id: int
|
|
25
26
|
|
|
26
27
|
@field_validator("array_shape")
|
|
27
28
|
@classmethod
|
|
28
|
-
def no_singleton_dimensions(cls, v:
|
|
29
|
+
def no_singleton_dimensions(cls, v: tuple[int, ...]):
|
|
29
30
|
"""
|
|
30
31
|
Check that the array shape does not have any singleton dimensions.
|
|
31
32
|
|
|
32
33
|
Parameters
|
|
33
34
|
----------
|
|
34
|
-
v :
|
|
35
|
+
v : tuple of int
|
|
35
36
|
Array shape to check.
|
|
36
37
|
|
|
37
38
|
Returns
|
|
38
39
|
-------
|
|
39
|
-
|
|
40
|
+
tuple of int
|
|
40
41
|
The array shape if it does not contain singleton dimensions.
|
|
41
42
|
|
|
42
43
|
Raises
|
|
@@ -48,59 +49,26 @@ class TileInformation(BaseModel):
|
|
|
48
49
|
raise ValueError("Array shape must not contain singleton dimensions.")
|
|
49
50
|
return v
|
|
50
51
|
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
def only_if_tiled(cls, v: bool, values: ValidationInfo):
|
|
54
|
-
"""
|
|
55
|
-
Check that the last tile flag is only set if tiling is enabled.
|
|
52
|
+
def __eq__(self, other_tile: object):
|
|
53
|
+
"""Check if two tile information objects are equal.
|
|
56
54
|
|
|
57
55
|
Parameters
|
|
58
56
|
----------
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
values : ValidationInfo
|
|
62
|
-
Validation information.
|
|
57
|
+
other_tile : object
|
|
58
|
+
Tile information object to compare with.
|
|
63
59
|
|
|
64
60
|
Returns
|
|
65
61
|
-------
|
|
66
62
|
bool
|
|
67
|
-
|
|
63
|
+
Whether the two tile information objects are equal.
|
|
68
64
|
"""
|
|
69
|
-
if not
|
|
70
|
-
return
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
Check that the coordinates are not `None` if tiling is enabled.
|
|
80
|
-
|
|
81
|
-
The method also return `None` if tiling is not enabled.
|
|
82
|
-
|
|
83
|
-
Parameters
|
|
84
|
-
----------
|
|
85
|
-
v : Optional[Tuple[int, ...]]
|
|
86
|
-
Coordinates to check.
|
|
87
|
-
values : ValidationInfo
|
|
88
|
-
Validation information.
|
|
89
|
-
|
|
90
|
-
Returns
|
|
91
|
-
-------
|
|
92
|
-
Optional[Tuple[int, ...]]
|
|
93
|
-
The coordinates if tiling is enabled, otherwise `None`.
|
|
94
|
-
|
|
95
|
-
Raises
|
|
96
|
-
------
|
|
97
|
-
ValueError
|
|
98
|
-
If the coordinates are `None` and tiling is enabled.
|
|
99
|
-
"""
|
|
100
|
-
if values.data["tiled"]:
|
|
101
|
-
if v is None:
|
|
102
|
-
raise ValueError("Value must be specified if tiling is enabled.")
|
|
103
|
-
|
|
104
|
-
return v
|
|
105
|
-
else:
|
|
106
|
-
return None
|
|
65
|
+
if not isinstance(other_tile, TileInformation):
|
|
66
|
+
return NotImplemented
|
|
67
|
+
|
|
68
|
+
return (
|
|
69
|
+
self.array_shape == other_tile.array_shape
|
|
70
|
+
and self.last_tile == other_tile.last_tile
|
|
71
|
+
and self.overlap_crop_coords == other_tile.overlap_crop_coords
|
|
72
|
+
and self.stitch_coords == other_tile.stitch_coords
|
|
73
|
+
and self.sample_id == other_tile.sample_id
|
|
74
|
+
)
|
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
"""Pydantic model for the Normalize transform."""
|
|
2
2
|
|
|
3
|
-
from typing import Literal
|
|
3
|
+
from typing import Literal, Optional
|
|
4
4
|
|
|
5
|
-
from pydantic import ConfigDict, Field
|
|
5
|
+
from pydantic import ConfigDict, Field, model_validator
|
|
6
|
+
from typing_extensions import Self
|
|
6
7
|
|
|
7
8
|
from .transform_model import TransformModel
|
|
8
9
|
|
|
@@ -28,5 +29,32 @@ class NormalizeModel(TransformModel):
|
|
|
28
29
|
)
|
|
29
30
|
|
|
30
31
|
name: Literal["Normalize"] = "Normalize"
|
|
31
|
-
|
|
32
|
-
|
|
32
|
+
image_means: list = Field(..., min_length=0, max_length=32)
|
|
33
|
+
image_stds: list = Field(..., min_length=0, max_length=32)
|
|
34
|
+
target_means: Optional[list] = Field(default=None, min_length=0, max_length=32)
|
|
35
|
+
target_stds: Optional[list] = Field(default=None, min_length=0, max_length=32)
|
|
36
|
+
|
|
37
|
+
@model_validator(mode="after")
|
|
38
|
+
def validate_means_stds(self: Self) -> Self:
|
|
39
|
+
"""Validate that the means and stds have the same length.
|
|
40
|
+
|
|
41
|
+
Returns
|
|
42
|
+
-------
|
|
43
|
+
Self
|
|
44
|
+
The instance of the model.
|
|
45
|
+
"""
|
|
46
|
+
if len(self.image_means) != len(self.image_stds):
|
|
47
|
+
raise ValueError("The number of image means and stds must be the same.")
|
|
48
|
+
|
|
49
|
+
if (self.target_means is None) != (self.target_stds is None):
|
|
50
|
+
raise ValueError(
|
|
51
|
+
"Both target means and stds must be provided together, or bot None."
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
if self.target_means is not None and self.target_stds is not None:
|
|
55
|
+
if len(self.target_means) != len(self.target_stds):
|
|
56
|
+
raise ValueError(
|
|
57
|
+
"The number of target means and stds must be the same."
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
return self
|
|
@@ -72,7 +72,7 @@ def value_ge_than_8_power_of_2(
|
|
|
72
72
|
If the value is not a power of 2.
|
|
73
73
|
"""
|
|
74
74
|
if value < 8:
|
|
75
|
-
raise ValueError(f"Value must be
|
|
75
|
+
raise ValueError(f"Value must be greater than 8 (got {value}).")
|
|
76
76
|
|
|
77
77
|
if (value & (value - 1)) != 0:
|
|
78
78
|
raise ValueError(f"Value must be a power of 2 (got {value}).")
|
careamics/dataset/__init__.py
CHANGED
|
@@ -1,6 +1,17 @@
|
|
|
1
1
|
"""Dataset module."""
|
|
2
2
|
|
|
3
|
-
__all__ = [
|
|
3
|
+
__all__ = [
|
|
4
|
+
"InMemoryDataset",
|
|
5
|
+
"InMemoryPredDataset",
|
|
6
|
+
"InMemoryTiledPredDataset",
|
|
7
|
+
"PathIterableDataset",
|
|
8
|
+
"IterableTiledPredDataset",
|
|
9
|
+
"IterablePredDataset",
|
|
10
|
+
]
|
|
4
11
|
|
|
5
12
|
from .in_memory_dataset import InMemoryDataset
|
|
13
|
+
from .in_memory_pred_dataset import InMemoryPredDataset
|
|
14
|
+
from .in_memory_tiled_pred_dataset import InMemoryTiledPredDataset
|
|
6
15
|
from .iterable_dataset import PathIterableDataset
|
|
16
|
+
from .iterable_pred_dataset import IterablePredDataset
|
|
17
|
+
from .iterable_tiled_pred_dataset import IterableTiledPredDataset
|
|
@@ -2,17 +2,24 @@
|
|
|
2
2
|
|
|
3
3
|
__all__ = [
|
|
4
4
|
"reshape_array",
|
|
5
|
+
"compute_normalization_stats",
|
|
5
6
|
"get_files_size",
|
|
6
7
|
"list_files",
|
|
7
8
|
"validate_source_target_files",
|
|
8
9
|
"read_tiff",
|
|
9
10
|
"get_read_func",
|
|
10
11
|
"read_zarr",
|
|
12
|
+
"iterate_over_files",
|
|
13
|
+
"WelfordStatistics",
|
|
11
14
|
]
|
|
12
15
|
|
|
13
16
|
|
|
14
|
-
from .dataset_utils import
|
|
17
|
+
from .dataset_utils import (
|
|
18
|
+
reshape_array,
|
|
19
|
+
)
|
|
15
20
|
from .file_utils import get_files_size, list_files, validate_source_target_files
|
|
21
|
+
from .iterate_over_files import iterate_over_files
|
|
16
22
|
from .read_tiff import read_tiff
|
|
17
23
|
from .read_utils import get_read_func
|
|
18
24
|
from .read_zarr import read_zarr
|
|
25
|
+
from .running_stats import WelfordStatistics, compute_normalization_stats
|
|
@@ -33,7 +33,7 @@ def list_files(
|
|
|
33
33
|
data_type: Union[str, SupportedData],
|
|
34
34
|
extension_filter: str = "",
|
|
35
35
|
) -> List[Path]:
|
|
36
|
-
"""
|
|
36
|
+
"""List recursively files in `data_path` and return a sorted list.
|
|
37
37
|
|
|
38
38
|
If `data_path` is a file, its name is validated against the `data_type` using
|
|
39
39
|
`fnmatch`, and the method returns `data_path` itself.
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
"""Function to iterate over files."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Callable, Generator, Optional, Union
|
|
7
|
+
|
|
8
|
+
from numpy.typing import NDArray
|
|
9
|
+
from torch.utils.data import get_worker_info
|
|
10
|
+
|
|
11
|
+
from careamics.config import DataConfig, InferenceConfig
|
|
12
|
+
from careamics.utils.logging import get_logger
|
|
13
|
+
|
|
14
|
+
from .dataset_utils import reshape_array
|
|
15
|
+
from .read_tiff import read_tiff
|
|
16
|
+
|
|
17
|
+
logger = get_logger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def iterate_over_files(
|
|
21
|
+
data_config: Union[DataConfig, InferenceConfig],
|
|
22
|
+
data_files: list[Path],
|
|
23
|
+
target_files: Optional[list[Path]] = None,
|
|
24
|
+
read_source_func: Callable = read_tiff,
|
|
25
|
+
) -> Generator[tuple[NDArray, Optional[NDArray]], None, None]:
|
|
26
|
+
"""Iterate over data source and yield whole reshaped images.
|
|
27
|
+
|
|
28
|
+
Parameters
|
|
29
|
+
----------
|
|
30
|
+
data_config : CAREamics DataConfig or InferenceConfig
|
|
31
|
+
Configuration.
|
|
32
|
+
data_files : list of pathlib.Path
|
|
33
|
+
List of data files.
|
|
34
|
+
target_files : list of pathlib.Path, optional
|
|
35
|
+
List of target files, by default None.
|
|
36
|
+
read_source_func : Callable, optional
|
|
37
|
+
Function to read the source, by default read_tiff.
|
|
38
|
+
|
|
39
|
+
Yields
|
|
40
|
+
------
|
|
41
|
+
NDArray
|
|
42
|
+
Image.
|
|
43
|
+
"""
|
|
44
|
+
# When num_workers > 0, each worker process will have a different copy of the
|
|
45
|
+
# dataset object
|
|
46
|
+
# Configuring each copy independently to avoid having duplicate data returned
|
|
47
|
+
# from the workers
|
|
48
|
+
worker_info = get_worker_info()
|
|
49
|
+
worker_id = worker_info.id if worker_info is not None else 0
|
|
50
|
+
num_workers = worker_info.num_workers if worker_info is not None else 1
|
|
51
|
+
|
|
52
|
+
# iterate over the files
|
|
53
|
+
for i, filename in enumerate(data_files):
|
|
54
|
+
# retrieve file corresponding to the worker id
|
|
55
|
+
if i % num_workers == worker_id:
|
|
56
|
+
try:
|
|
57
|
+
# read data
|
|
58
|
+
sample = read_source_func(filename, data_config.axes)
|
|
59
|
+
|
|
60
|
+
# reshape array
|
|
61
|
+
reshaped_sample = reshape_array(sample, data_config.axes)
|
|
62
|
+
|
|
63
|
+
# read target, if available
|
|
64
|
+
if target_files is not None:
|
|
65
|
+
if filename.name != target_files[i].name:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
f"File {filename} does not match target file "
|
|
68
|
+
f"{target_files[i]}. Have you passed sorted "
|
|
69
|
+
f"arrays?"
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
# read target
|
|
73
|
+
target = read_source_func(target_files[i], data_config.axes)
|
|
74
|
+
|
|
75
|
+
# reshape target
|
|
76
|
+
reshaped_target = reshape_array(target, data_config.axes)
|
|
77
|
+
|
|
78
|
+
yield reshaped_sample, reshaped_target
|
|
79
|
+
else:
|
|
80
|
+
yield reshaped_sample, None
|
|
81
|
+
|
|
82
|
+
except Exception as e:
|
|
83
|
+
logger.error(f"Error reading file {filename}: {e}")
|
|
@@ -53,13 +53,4 @@ def read_tiff(file_path: Path, *args: list, **kwargs: dict) -> np.ndarray:
|
|
|
53
53
|
else:
|
|
54
54
|
raise ValueError(f"File {file_path} is not a valid tiff.")
|
|
55
55
|
|
|
56
|
-
# check dimensions
|
|
57
|
-
# TODO or should this really be done here? probably in the LightningDataModule
|
|
58
|
-
# TODO this should also be centralized somewhere else (validate_dimensions)
|
|
59
|
-
if len(array.shape) < 2 or len(array.shape) > 6:
|
|
60
|
-
raise ValueError(
|
|
61
|
-
f"Incorrect data dimensions. Must be 2, 3 or 4 (got {array.shape} for"
|
|
62
|
-
f"file {file_path})."
|
|
63
|
-
)
|
|
64
|
-
|
|
65
56
|
return array
|
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
"""Computing data statistics."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from numpy.typing import NDArray
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def compute_normalization_stats(image: NDArray) -> tuple[NDArray, NDArray]:
|
|
8
|
+
"""
|
|
9
|
+
Compute mean and standard deviation of an array.
|
|
10
|
+
|
|
11
|
+
Expected input shape is (S, C, (Z), Y, X). The mean and standard deviation are
|
|
12
|
+
computed per channel.
|
|
13
|
+
|
|
14
|
+
Parameters
|
|
15
|
+
----------
|
|
16
|
+
image : NDArray
|
|
17
|
+
Input array.
|
|
18
|
+
|
|
19
|
+
Returns
|
|
20
|
+
-------
|
|
21
|
+
tuple of (list of floats, list of floats)
|
|
22
|
+
Lists of mean and standard deviation values per channel.
|
|
23
|
+
"""
|
|
24
|
+
# Define the list of axes excluding the channel axis
|
|
25
|
+
axes = tuple(np.delete(np.arange(image.ndim), 1))
|
|
26
|
+
return np.mean(image, axis=axes), np.std(image, axis=axes)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def update_iterative_stats(
|
|
30
|
+
count: NDArray, mean: NDArray, m2: NDArray, new_values: NDArray
|
|
31
|
+
) -> tuple[NDArray, NDArray, NDArray]:
|
|
32
|
+
"""Update the mean and variance of an array iteratively.
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
count : NDArray
|
|
37
|
+
Number of elements in the array.
|
|
38
|
+
mean : NDArray
|
|
39
|
+
Mean of the array.
|
|
40
|
+
m2 : NDArray
|
|
41
|
+
Variance of the array.
|
|
42
|
+
new_values : NDArray
|
|
43
|
+
New values to add to the mean and variance.
|
|
44
|
+
|
|
45
|
+
Returns
|
|
46
|
+
-------
|
|
47
|
+
tuple[NDArray, NDArray, NDArray]
|
|
48
|
+
Updated count, mean, and variance.
|
|
49
|
+
"""
|
|
50
|
+
count += np.array([np.prod(channel.shape) for channel in new_values])
|
|
51
|
+
# newvalues - oldMean
|
|
52
|
+
delta = [
|
|
53
|
+
np.subtract(v.flatten(), [m] * len(v.flatten()))
|
|
54
|
+
for v, m in zip(new_values, mean)
|
|
55
|
+
]
|
|
56
|
+
|
|
57
|
+
mean += np.array([np.sum(d / c) for d, c in zip(delta, count)])
|
|
58
|
+
# newvalues - newMeant
|
|
59
|
+
delta2 = [
|
|
60
|
+
np.subtract(v.flatten(), [m] * len(v.flatten()))
|
|
61
|
+
for v, m in zip(new_values, mean)
|
|
62
|
+
]
|
|
63
|
+
|
|
64
|
+
m2 += np.array([np.sum(d * d2) for d, d2 in zip(delta, delta2)])
|
|
65
|
+
|
|
66
|
+
return (count, mean, m2)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def finalize_iterative_stats(
|
|
70
|
+
count: NDArray, mean: NDArray, m2: NDArray
|
|
71
|
+
) -> tuple[NDArray, NDArray]:
|
|
72
|
+
"""Finalize the mean and variance computation.
|
|
73
|
+
|
|
74
|
+
Parameters
|
|
75
|
+
----------
|
|
76
|
+
count : NDArray
|
|
77
|
+
Number of elements in the array.
|
|
78
|
+
mean : NDArray
|
|
79
|
+
Mean of the array.
|
|
80
|
+
m2 : NDArray
|
|
81
|
+
Variance of the array.
|
|
82
|
+
|
|
83
|
+
Returns
|
|
84
|
+
-------
|
|
85
|
+
tuple[NDArray, NDArray]
|
|
86
|
+
Final mean and standard deviation.
|
|
87
|
+
"""
|
|
88
|
+
std = np.array([np.sqrt(m / c) for m, c in zip(m2, count)])
|
|
89
|
+
if any(c < 2 for c in count):
|
|
90
|
+
return np.full(mean.shape, np.nan), np.full(std.shape, np.nan)
|
|
91
|
+
else:
|
|
92
|
+
return mean, std
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class WelfordStatistics:
|
|
96
|
+
"""Compute Welford statistics iteratively.
|
|
97
|
+
|
|
98
|
+
The Welford algorithm is used to compute the mean and variance of an array
|
|
99
|
+
iteratively. Based on the implementation from:
|
|
100
|
+
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
def update(self, array: NDArray, sample_idx: int) -> None:
|
|
104
|
+
"""Update the Welford statistics.
|
|
105
|
+
|
|
106
|
+
Parameters
|
|
107
|
+
----------
|
|
108
|
+
array : NDArray
|
|
109
|
+
Input array.
|
|
110
|
+
sample_idx : int
|
|
111
|
+
Current sample number.
|
|
112
|
+
"""
|
|
113
|
+
self.sample_idx = sample_idx
|
|
114
|
+
sample_channels = np.array(np.split(array, array.shape[1], axis=1))
|
|
115
|
+
|
|
116
|
+
# Initialize the statistics
|
|
117
|
+
if self.sample_idx == 0:
|
|
118
|
+
# Compute the mean and standard deviation
|
|
119
|
+
self.mean, _ = compute_normalization_stats(array)
|
|
120
|
+
# Initialize the count and m2 with zero-valued arrays of shape (C,)
|
|
121
|
+
self.count, self.mean, self.m2 = update_iterative_stats(
|
|
122
|
+
count=np.zeros(array.shape[1]),
|
|
123
|
+
mean=self.mean,
|
|
124
|
+
m2=np.zeros(array.shape[1]),
|
|
125
|
+
new_values=sample_channels,
|
|
126
|
+
)
|
|
127
|
+
else:
|
|
128
|
+
# Update the statistics
|
|
129
|
+
self.count, self.mean, self.m2 = update_iterative_stats(
|
|
130
|
+
count=self.count, mean=self.mean, m2=self.m2, new_values=sample_channels
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
self.sample_idx += 1
|
|
134
|
+
|
|
135
|
+
def finalize(self) -> tuple[NDArray, NDArray]:
|
|
136
|
+
"""Finalize the Welford statistics.
|
|
137
|
+
|
|
138
|
+
Returns
|
|
139
|
+
-------
|
|
140
|
+
tuple or numpy arrays
|
|
141
|
+
Final mean and standard deviation.
|
|
142
|
+
"""
|
|
143
|
+
return finalize_iterative_stats(self.count, self.mean, self.m2)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
# from multiprocessing import Value
|
|
147
|
+
# from typing import tuple
|
|
148
|
+
|
|
149
|
+
# import numpy as np
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
# class RunningStats:
|
|
153
|
+
# """Calculates running mean and std."""
|
|
154
|
+
|
|
155
|
+
# def __init__(self) -> None:
|
|
156
|
+
# self.reset()
|
|
157
|
+
|
|
158
|
+
# def reset(self) -> None:
|
|
159
|
+
# """Reset the running stats."""
|
|
160
|
+
# self.avg_mean = Value("d", 0)
|
|
161
|
+
# self.avg_std = Value("d", 0)
|
|
162
|
+
# self.m2 = Value("d", 0)
|
|
163
|
+
# self.count = Value("i", 0)
|
|
164
|
+
|
|
165
|
+
# def init(self, mean: float, std: float) -> None:
|
|
166
|
+
# """Initialize running stats."""
|
|
167
|
+
# with self.avg_mean.get_lock():
|
|
168
|
+
# self.avg_mean.value += mean
|
|
169
|
+
# with self.avg_std.get_lock():
|
|
170
|
+
# self.avg_std.value = std
|
|
171
|
+
|
|
172
|
+
# def compute_std(self) -> tuple[float, float]:
|
|
173
|
+
# """Compute std."""
|
|
174
|
+
# if self.count.value >= 2:
|
|
175
|
+
# self.avg_std.value = np.sqrt(self.m2.value / self.count.value)
|
|
176
|
+
|
|
177
|
+
# def update(self, value: float) -> None:
|
|
178
|
+
# """Update running stats."""
|
|
179
|
+
# with self.count.get_lock():
|
|
180
|
+
# self.count.value += 1
|
|
181
|
+
# delta = value - self.avg_mean.value
|
|
182
|
+
# with self.avg_mean.get_lock():
|
|
183
|
+
# self.avg_mean.value += delta / self.count.value
|
|
184
|
+
# delta2 = value - self.avg_mean.value
|
|
185
|
+
# with self.m2.get_lock():
|
|
186
|
+
# self.m2.value += delta * delta2
|