careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (69) hide show
  1. careamics/careamist.py +163 -266
  2. careamics/config/algorithm_model.py +0 -15
  3. careamics/config/architectures/custom_model.py +3 -3
  4. careamics/config/configuration_example.py +0 -3
  5. careamics/config/configuration_factory.py +23 -25
  6. careamics/config/configuration_model.py +11 -11
  7. careamics/config/data_model.py +80 -50
  8. careamics/config/inference_model.py +29 -17
  9. careamics/config/optimizer_models.py +7 -7
  10. careamics/config/support/supported_transforms.py +0 -1
  11. careamics/config/tile_information.py +26 -58
  12. careamics/config/transformations/normalize_model.py +32 -4
  13. careamics/config/validators/validator_utils.py +1 -1
  14. careamics/dataset/__init__.py +12 -1
  15. careamics/dataset/dataset_utils/__init__.py +8 -1
  16. careamics/dataset/dataset_utils/file_utils.py +1 -1
  17. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  18. careamics/dataset/dataset_utils/read_tiff.py +0 -9
  19. careamics/dataset/dataset_utils/running_stats.py +186 -0
  20. careamics/dataset/in_memory_dataset.py +66 -171
  21. careamics/dataset/in_memory_pred_dataset.py +88 -0
  22. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  23. careamics/dataset/iterable_dataset.py +92 -249
  24. careamics/dataset/iterable_pred_dataset.py +121 -0
  25. careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
  26. careamics/dataset/patching/patching.py +54 -25
  27. careamics/dataset/patching/random_patching.py +9 -4
  28. careamics/dataset/patching/validate_patch_dimension.py +5 -3
  29. careamics/dataset/tiling/__init__.py +10 -0
  30. careamics/dataset/tiling/collate_tiles.py +33 -0
  31. careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
  32. careamics/lightning_datamodule.py +1 -6
  33. careamics/lightning_module.py +11 -7
  34. careamics/lightning_prediction_datamodule.py +52 -72
  35. careamics/lvae_training/__init__.py +0 -0
  36. careamics/lvae_training/data_modules.py +1220 -0
  37. careamics/lvae_training/data_utils.py +618 -0
  38. careamics/lvae_training/eval_utils.py +905 -0
  39. careamics/lvae_training/get_config.py +84 -0
  40. careamics/lvae_training/lightning_module.py +701 -0
  41. careamics/lvae_training/metrics.py +214 -0
  42. careamics/lvae_training/train_lvae.py +339 -0
  43. careamics/lvae_training/train_utils.py +121 -0
  44. careamics/model_io/bioimage/model_description.py +40 -32
  45. careamics/model_io/bmz_io.py +1 -1
  46. careamics/model_io/model_io_utils.py +5 -2
  47. careamics/models/lvae/__init__.py +0 -0
  48. careamics/models/lvae/layers.py +1998 -0
  49. careamics/models/lvae/likelihoods.py +312 -0
  50. careamics/models/lvae/lvae.py +985 -0
  51. careamics/models/lvae/noise_models.py +409 -0
  52. careamics/models/lvae/utils.py +395 -0
  53. careamics/prediction_utils/__init__.py +12 -0
  54. careamics/prediction_utils/create_pred_datamodule.py +185 -0
  55. careamics/prediction_utils/prediction_outputs.py +165 -0
  56. careamics/prediction_utils/stitch_prediction.py +100 -0
  57. careamics/transforms/n2v_manipulate.py +3 -1
  58. careamics/transforms/normalize.py +139 -68
  59. careamics/transforms/pixel_manipulation.py +33 -9
  60. careamics/transforms/tta.py +43 -29
  61. careamics/utils/ram.py +2 -2
  62. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +7 -6
  63. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/RECORD +65 -42
  64. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
  65. careamics/lightning_prediction_loop.py +0 -118
  66. careamics/prediction/__init__.py +0 -7
  67. careamics/prediction/stitch_prediction.py +0 -70
  68. careamics/utils/running_stats.py +0 -43
  69. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
@@ -2,7 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from typing import Dict, Literal
5
+ from typing import Literal
6
6
 
7
7
  from pydantic import (
8
8
  BaseModel,
@@ -32,7 +32,7 @@ class OptimizerModel(BaseModel):
32
32
 
33
33
  Attributes
34
34
  ----------
35
- name : TorchOptimizer
35
+ name : {"Adam", "SGD"}
36
36
  Name of the optimizer.
37
37
  parameters : dict
38
38
  Parameters of the optimizer (see torch documentation).
@@ -56,7 +56,7 @@ class OptimizerModel(BaseModel):
56
56
 
57
57
  @field_validator("parameters")
58
58
  @classmethod
59
- def filter_parameters(cls, user_params: dict, values: ValidationInfo) -> Dict:
59
+ def filter_parameters(cls, user_params: dict, values: ValidationInfo) -> dict:
60
60
  """
61
61
  Validate optimizer parameters.
62
62
 
@@ -71,7 +71,7 @@ class OptimizerModel(BaseModel):
71
71
 
72
72
  Returns
73
73
  -------
74
- Dict
74
+ dict
75
75
  Filtered optimizer parameters.
76
76
 
77
77
  Raises
@@ -127,7 +127,7 @@ class LrSchedulerModel(BaseModel):
127
127
 
128
128
  Attributes
129
129
  ----------
130
- name : TorchLRScheduler
130
+ name : {"ReduceLROnPlateau", "StepLR"}
131
131
  Name of the learning rate scheduler.
132
132
  parameters : dict
133
133
  Parameters of the learning rate scheduler (see torch documentation).
@@ -146,7 +146,7 @@ class LrSchedulerModel(BaseModel):
146
146
 
147
147
  @field_validator("parameters")
148
148
  @classmethod
149
- def filter_parameters(cls, user_params: dict, values: ValidationInfo) -> Dict:
149
+ def filter_parameters(cls, user_params: dict, values: ValidationInfo) -> dict:
150
150
  """Filter parameters based on the learning rate scheduler's signature.
151
151
 
152
152
  Parameters
@@ -158,7 +158,7 @@ class LrSchedulerModel(BaseModel):
158
158
 
159
159
  Returns
160
160
  -------
161
- Dict
161
+ dict
162
162
  Filtered scheduler parameters.
163
163
 
164
164
  Raises
@@ -8,5 +8,4 @@ class SupportedTransform(str, BaseEnum):
8
8
 
9
9
  XY_FLIP = "XYFlip"
10
10
  XY_RANDOM_ROTATE90 = "XYRandomRotate90"
11
- NORMALIZE = "Normalize"
12
11
  N2V_MANIPULATE = "N2VManipulate"
@@ -2,9 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from typing import Optional, Tuple
6
-
7
- from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
5
+ from pydantic import BaseModel, ConfigDict, field_validator
8
6
 
9
7
 
10
8
  class TileInformation(BaseModel):
@@ -13,30 +11,33 @@ class TileInformation(BaseModel):
13
11
 
14
12
  This model is used to represent the information required to stitch back a tile into
15
13
  a larger image. It is used throughout the prediction pipeline of CAREamics.
14
+
15
+ Array shape should be (C)(Z)YX, where C and Z are optional dimensions, and must not
16
+ contain singleton dimensions.
16
17
  """
17
18
 
18
19
  model_config = ConfigDict(validate_default=True)
19
20
 
20
- array_shape: Tuple[int, ...]
21
- tiled: bool = False
21
+ array_shape: tuple[int, ...]
22
22
  last_tile: bool = False
23
- overlap_crop_coords: Optional[Tuple[Tuple[int, ...], ...]] = Field(default=None)
24
- stitch_coords: Optional[Tuple[Tuple[int, ...], ...]] = Field(default=None)
23
+ overlap_crop_coords: tuple[tuple[int, ...], ...]
24
+ stitch_coords: tuple[tuple[int, ...], ...]
25
+ sample_id: int
25
26
 
26
27
  @field_validator("array_shape")
27
28
  @classmethod
28
- def no_singleton_dimensions(cls, v: Tuple[int, ...]):
29
+ def no_singleton_dimensions(cls, v: tuple[int, ...]):
29
30
  """
30
31
  Check that the array shape does not have any singleton dimensions.
31
32
 
32
33
  Parameters
33
34
  ----------
34
- v : Tuple[int, ...]
35
+ v : tuple of int
35
36
  Array shape to check.
36
37
 
37
38
  Returns
38
39
  -------
39
- Tuple[int, ...]
40
+ tuple of int
40
41
  The array shape if it does not contain singleton dimensions.
41
42
 
42
43
  Raises
@@ -48,59 +49,26 @@ class TileInformation(BaseModel):
48
49
  raise ValueError("Array shape must not contain singleton dimensions.")
49
50
  return v
50
51
 
51
- @field_validator("last_tile")
52
- @classmethod
53
- def only_if_tiled(cls, v: bool, values: ValidationInfo):
54
- """
55
- Check that the last tile flag is only set if tiling is enabled.
52
+ def __eq__(self, other_tile: object):
53
+ """Check if two tile information objects are equal.
56
54
 
57
55
  Parameters
58
56
  ----------
59
- v : bool
60
- Last tile flag.
61
- values : ValidationInfo
62
- Validation information.
57
+ other_tile : object
58
+ Tile information object to compare with.
63
59
 
64
60
  Returns
65
61
  -------
66
62
  bool
67
- The last tile flag.
63
+ Whether the two tile information objects are equal.
68
64
  """
69
- if not values.data["tiled"]:
70
- return False
71
- return v
72
-
73
- @field_validator("overlap_crop_coords", "stitch_coords")
74
- @classmethod
75
- def mandatory_if_tiled(
76
- cls, v: Optional[Tuple[int, ...]], values: ValidationInfo
77
- ) -> Optional[Tuple[int, ...]]:
78
- """
79
- Check that the coordinates are not `None` if tiling is enabled.
80
-
81
- The method also return `None` if tiling is not enabled.
82
-
83
- Parameters
84
- ----------
85
- v : Optional[Tuple[int, ...]]
86
- Coordinates to check.
87
- values : ValidationInfo
88
- Validation information.
89
-
90
- Returns
91
- -------
92
- Optional[Tuple[int, ...]]
93
- The coordinates if tiling is enabled, otherwise `None`.
94
-
95
- Raises
96
- ------
97
- ValueError
98
- If the coordinates are `None` and tiling is enabled.
99
- """
100
- if values.data["tiled"]:
101
- if v is None:
102
- raise ValueError("Value must be specified if tiling is enabled.")
103
-
104
- return v
105
- else:
106
- return None
65
+ if not isinstance(other_tile, TileInformation):
66
+ return NotImplemented
67
+
68
+ return (
69
+ self.array_shape == other_tile.array_shape
70
+ and self.last_tile == other_tile.last_tile
71
+ and self.overlap_crop_coords == other_tile.overlap_crop_coords
72
+ and self.stitch_coords == other_tile.stitch_coords
73
+ and self.sample_id == other_tile.sample_id
74
+ )
@@ -1,8 +1,9 @@
1
1
  """Pydantic model for the Normalize transform."""
2
2
 
3
- from typing import Literal
3
+ from typing import Literal, Optional
4
4
 
5
- from pydantic import ConfigDict, Field
5
+ from pydantic import ConfigDict, Field, model_validator
6
+ from typing_extensions import Self
6
7
 
7
8
  from .transform_model import TransformModel
8
9
 
@@ -28,5 +29,32 @@ class NormalizeModel(TransformModel):
28
29
  )
29
30
 
30
31
  name: Literal["Normalize"] = "Normalize"
31
- mean: float = Field(default=0.485) # albumentations defaults
32
- std: float = Field(default=0.229)
32
+ image_means: list = Field(..., min_length=0, max_length=32)
33
+ image_stds: list = Field(..., min_length=0, max_length=32)
34
+ target_means: Optional[list] = Field(default=None, min_length=0, max_length=32)
35
+ target_stds: Optional[list] = Field(default=None, min_length=0, max_length=32)
36
+
37
+ @model_validator(mode="after")
38
+ def validate_means_stds(self: Self) -> Self:
39
+ """Validate that the means and stds have the same length.
40
+
41
+ Returns
42
+ -------
43
+ Self
44
+ The instance of the model.
45
+ """
46
+ if len(self.image_means) != len(self.image_stds):
47
+ raise ValueError("The number of image means and stds must be the same.")
48
+
49
+ if (self.target_means is None) != (self.target_stds is None):
50
+ raise ValueError(
51
+ "Both target means and stds must be provided together, or bot None."
52
+ )
53
+
54
+ if self.target_means is not None and self.target_stds is not None:
55
+ if len(self.target_means) != len(self.target_stds):
56
+ raise ValueError(
57
+ "The number of target means and stds must be the same."
58
+ )
59
+
60
+ return self
@@ -72,7 +72,7 @@ def value_ge_than_8_power_of_2(
72
72
  If the value is not a power of 2.
73
73
  """
74
74
  if value < 8:
75
- raise ValueError(f"Value must be non-zero positive (got {value}).")
75
+ raise ValueError(f"Value must be greater than 8 (got {value}).")
76
76
 
77
77
  if (value & (value - 1)) != 0:
78
78
  raise ValueError(f"Value must be a power of 2 (got {value}).")
@@ -1,6 +1,17 @@
1
1
  """Dataset module."""
2
2
 
3
- __all__ = ["InMemoryDataset", "PathIterableDataset"]
3
+ __all__ = [
4
+ "InMemoryDataset",
5
+ "InMemoryPredDataset",
6
+ "InMemoryTiledPredDataset",
7
+ "PathIterableDataset",
8
+ "IterableTiledPredDataset",
9
+ "IterablePredDataset",
10
+ ]
4
11
 
5
12
  from .in_memory_dataset import InMemoryDataset
13
+ from .in_memory_pred_dataset import InMemoryPredDataset
14
+ from .in_memory_tiled_pred_dataset import InMemoryTiledPredDataset
6
15
  from .iterable_dataset import PathIterableDataset
16
+ from .iterable_pred_dataset import IterablePredDataset
17
+ from .iterable_tiled_pred_dataset import IterableTiledPredDataset
@@ -2,17 +2,24 @@
2
2
 
3
3
  __all__ = [
4
4
  "reshape_array",
5
+ "compute_normalization_stats",
5
6
  "get_files_size",
6
7
  "list_files",
7
8
  "validate_source_target_files",
8
9
  "read_tiff",
9
10
  "get_read_func",
10
11
  "read_zarr",
12
+ "iterate_over_files",
13
+ "WelfordStatistics",
11
14
  ]
12
15
 
13
16
 
14
- from .dataset_utils import reshape_array
17
+ from .dataset_utils import (
18
+ reshape_array,
19
+ )
15
20
  from .file_utils import get_files_size, list_files, validate_source_target_files
21
+ from .iterate_over_files import iterate_over_files
16
22
  from .read_tiff import read_tiff
17
23
  from .read_utils import get_read_func
18
24
  from .read_zarr import read_zarr
25
+ from .running_stats import WelfordStatistics, compute_normalization_stats
@@ -33,7 +33,7 @@ def list_files(
33
33
  data_type: Union[str, SupportedData],
34
34
  extension_filter: str = "",
35
35
  ) -> List[Path]:
36
- """Create a recursive list of files in `data_path`.
36
+ """List recursively files in `data_path` and return a sorted list.
37
37
 
38
38
  If `data_path` is a file, its name is validated against the `data_type` using
39
39
  `fnmatch`, and the method returns `data_path` itself.
@@ -0,0 +1,83 @@
1
+ """Function to iterate over files."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Callable, Generator, Optional, Union
7
+
8
+ from numpy.typing import NDArray
9
+ from torch.utils.data import get_worker_info
10
+
11
+ from careamics.config import DataConfig, InferenceConfig
12
+ from careamics.utils.logging import get_logger
13
+
14
+ from .dataset_utils import reshape_array
15
+ from .read_tiff import read_tiff
16
+
17
+ logger = get_logger(__name__)
18
+
19
+
20
+ def iterate_over_files(
21
+ data_config: Union[DataConfig, InferenceConfig],
22
+ data_files: list[Path],
23
+ target_files: Optional[list[Path]] = None,
24
+ read_source_func: Callable = read_tiff,
25
+ ) -> Generator[tuple[NDArray, Optional[NDArray]], None, None]:
26
+ """Iterate over data source and yield whole reshaped images.
27
+
28
+ Parameters
29
+ ----------
30
+ data_config : CAREamics DataConfig or InferenceConfig
31
+ Configuration.
32
+ data_files : list of pathlib.Path
33
+ List of data files.
34
+ target_files : list of pathlib.Path, optional
35
+ List of target files, by default None.
36
+ read_source_func : Callable, optional
37
+ Function to read the source, by default read_tiff.
38
+
39
+ Yields
40
+ ------
41
+ NDArray
42
+ Image.
43
+ """
44
+ # When num_workers > 0, each worker process will have a different copy of the
45
+ # dataset object
46
+ # Configuring each copy independently to avoid having duplicate data returned
47
+ # from the workers
48
+ worker_info = get_worker_info()
49
+ worker_id = worker_info.id if worker_info is not None else 0
50
+ num_workers = worker_info.num_workers if worker_info is not None else 1
51
+
52
+ # iterate over the files
53
+ for i, filename in enumerate(data_files):
54
+ # retrieve file corresponding to the worker id
55
+ if i % num_workers == worker_id:
56
+ try:
57
+ # read data
58
+ sample = read_source_func(filename, data_config.axes)
59
+
60
+ # reshape array
61
+ reshaped_sample = reshape_array(sample, data_config.axes)
62
+
63
+ # read target, if available
64
+ if target_files is not None:
65
+ if filename.name != target_files[i].name:
66
+ raise ValueError(
67
+ f"File {filename} does not match target file "
68
+ f"{target_files[i]}. Have you passed sorted "
69
+ f"arrays?"
70
+ )
71
+
72
+ # read target
73
+ target = read_source_func(target_files[i], data_config.axes)
74
+
75
+ # reshape target
76
+ reshaped_target = reshape_array(target, data_config.axes)
77
+
78
+ yield reshaped_sample, reshaped_target
79
+ else:
80
+ yield reshaped_sample, None
81
+
82
+ except Exception as e:
83
+ logger.error(f"Error reading file {filename}: {e}")
@@ -53,13 +53,4 @@ def read_tiff(file_path: Path, *args: list, **kwargs: dict) -> np.ndarray:
53
53
  else:
54
54
  raise ValueError(f"File {file_path} is not a valid tiff.")
55
55
 
56
- # check dimensions
57
- # TODO or should this really be done here? probably in the LightningDataModule
58
- # TODO this should also be centralized somewhere else (validate_dimensions)
59
- if len(array.shape) < 2 or len(array.shape) > 6:
60
- raise ValueError(
61
- f"Incorrect data dimensions. Must be 2, 3 or 4 (got {array.shape} for"
62
- f"file {file_path})."
63
- )
64
-
65
56
  return array
@@ -0,0 +1,186 @@
1
+ """Computing data statistics."""
2
+
3
+ import numpy as np
4
+ from numpy.typing import NDArray
5
+
6
+
7
+ def compute_normalization_stats(image: NDArray) -> tuple[NDArray, NDArray]:
8
+ """
9
+ Compute mean and standard deviation of an array.
10
+
11
+ Expected input shape is (S, C, (Z), Y, X). The mean and standard deviation are
12
+ computed per channel.
13
+
14
+ Parameters
15
+ ----------
16
+ image : NDArray
17
+ Input array.
18
+
19
+ Returns
20
+ -------
21
+ tuple of (list of floats, list of floats)
22
+ Lists of mean and standard deviation values per channel.
23
+ """
24
+ # Define the list of axes excluding the channel axis
25
+ axes = tuple(np.delete(np.arange(image.ndim), 1))
26
+ return np.mean(image, axis=axes), np.std(image, axis=axes)
27
+
28
+
29
+ def update_iterative_stats(
30
+ count: NDArray, mean: NDArray, m2: NDArray, new_values: NDArray
31
+ ) -> tuple[NDArray, NDArray, NDArray]:
32
+ """Update the mean and variance of an array iteratively.
33
+
34
+ Parameters
35
+ ----------
36
+ count : NDArray
37
+ Number of elements in the array.
38
+ mean : NDArray
39
+ Mean of the array.
40
+ m2 : NDArray
41
+ Variance of the array.
42
+ new_values : NDArray
43
+ New values to add to the mean and variance.
44
+
45
+ Returns
46
+ -------
47
+ tuple[NDArray, NDArray, NDArray]
48
+ Updated count, mean, and variance.
49
+ """
50
+ count += np.array([np.prod(channel.shape) for channel in new_values])
51
+ # newvalues - oldMean
52
+ delta = [
53
+ np.subtract(v.flatten(), [m] * len(v.flatten()))
54
+ for v, m in zip(new_values, mean)
55
+ ]
56
+
57
+ mean += np.array([np.sum(d / c) for d, c in zip(delta, count)])
58
+ # newvalues - newMeant
59
+ delta2 = [
60
+ np.subtract(v.flatten(), [m] * len(v.flatten()))
61
+ for v, m in zip(new_values, mean)
62
+ ]
63
+
64
+ m2 += np.array([np.sum(d * d2) for d, d2 in zip(delta, delta2)])
65
+
66
+ return (count, mean, m2)
67
+
68
+
69
+ def finalize_iterative_stats(
70
+ count: NDArray, mean: NDArray, m2: NDArray
71
+ ) -> tuple[NDArray, NDArray]:
72
+ """Finalize the mean and variance computation.
73
+
74
+ Parameters
75
+ ----------
76
+ count : NDArray
77
+ Number of elements in the array.
78
+ mean : NDArray
79
+ Mean of the array.
80
+ m2 : NDArray
81
+ Variance of the array.
82
+
83
+ Returns
84
+ -------
85
+ tuple[NDArray, NDArray]
86
+ Final mean and standard deviation.
87
+ """
88
+ std = np.array([np.sqrt(m / c) for m, c in zip(m2, count)])
89
+ if any(c < 2 for c in count):
90
+ return np.full(mean.shape, np.nan), np.full(std.shape, np.nan)
91
+ else:
92
+ return mean, std
93
+
94
+
95
+ class WelfordStatistics:
96
+ """Compute Welford statistics iteratively.
97
+
98
+ The Welford algorithm is used to compute the mean and variance of an array
99
+ iteratively. Based on the implementation from:
100
+ https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
101
+ """
102
+
103
+ def update(self, array: NDArray, sample_idx: int) -> None:
104
+ """Update the Welford statistics.
105
+
106
+ Parameters
107
+ ----------
108
+ array : NDArray
109
+ Input array.
110
+ sample_idx : int
111
+ Current sample number.
112
+ """
113
+ self.sample_idx = sample_idx
114
+ sample_channels = np.array(np.split(array, array.shape[1], axis=1))
115
+
116
+ # Initialize the statistics
117
+ if self.sample_idx == 0:
118
+ # Compute the mean and standard deviation
119
+ self.mean, _ = compute_normalization_stats(array)
120
+ # Initialize the count and m2 with zero-valued arrays of shape (C,)
121
+ self.count, self.mean, self.m2 = update_iterative_stats(
122
+ count=np.zeros(array.shape[1]),
123
+ mean=self.mean,
124
+ m2=np.zeros(array.shape[1]),
125
+ new_values=sample_channels,
126
+ )
127
+ else:
128
+ # Update the statistics
129
+ self.count, self.mean, self.m2 = update_iterative_stats(
130
+ count=self.count, mean=self.mean, m2=self.m2, new_values=sample_channels
131
+ )
132
+
133
+ self.sample_idx += 1
134
+
135
+ def finalize(self) -> tuple[NDArray, NDArray]:
136
+ """Finalize the Welford statistics.
137
+
138
+ Returns
139
+ -------
140
+ tuple or numpy arrays
141
+ Final mean and standard deviation.
142
+ """
143
+ return finalize_iterative_stats(self.count, self.mean, self.m2)
144
+
145
+
146
+ # from multiprocessing import Value
147
+ # from typing import tuple
148
+
149
+ # import numpy as np
150
+
151
+
152
+ # class RunningStats:
153
+ # """Calculates running mean and std."""
154
+
155
+ # def __init__(self) -> None:
156
+ # self.reset()
157
+
158
+ # def reset(self) -> None:
159
+ # """Reset the running stats."""
160
+ # self.avg_mean = Value("d", 0)
161
+ # self.avg_std = Value("d", 0)
162
+ # self.m2 = Value("d", 0)
163
+ # self.count = Value("i", 0)
164
+
165
+ # def init(self, mean: float, std: float) -> None:
166
+ # """Initialize running stats."""
167
+ # with self.avg_mean.get_lock():
168
+ # self.avg_mean.value += mean
169
+ # with self.avg_std.get_lock():
170
+ # self.avg_std.value = std
171
+
172
+ # def compute_std(self) -> tuple[float, float]:
173
+ # """Compute std."""
174
+ # if self.count.value >= 2:
175
+ # self.avg_std.value = np.sqrt(self.m2.value / self.count.value)
176
+
177
+ # def update(self, value: float) -> None:
178
+ # """Update running stats."""
179
+ # with self.count.get_lock():
180
+ # self.count.value += 1
181
+ # delta = value - self.avg_mean.value
182
+ # with self.avg_mean.get_lock():
183
+ # self.avg_mean.value += delta / self.count.value
184
+ # delta2 = value - self.avg_mean.value
185
+ # with self.m2.get_lock():
186
+ # self.m2.value += delta * delta2