careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (91) hide show
  1. careamics/__init__.py +1 -14
  2. careamics/careamist.py +212 -294
  3. careamics/config/__init__.py +0 -3
  4. careamics/config/algorithm_model.py +8 -15
  5. careamics/config/architectures/architecture_model.py +1 -0
  6. careamics/config/architectures/custom_model.py +5 -3
  7. careamics/config/architectures/unet_model.py +19 -0
  8. careamics/config/architectures/vae_model.py +1 -0
  9. careamics/config/callback_model.py +76 -34
  10. careamics/config/configuration_factory.py +18 -98
  11. careamics/config/configuration_model.py +23 -18
  12. careamics/config/data_model.py +103 -54
  13. careamics/config/inference_model.py +41 -19
  14. careamics/config/optimizer_models.py +13 -7
  15. careamics/config/support/supported_data.py +29 -4
  16. careamics/config/support/supported_transforms.py +0 -1
  17. careamics/config/tile_information.py +36 -58
  18. careamics/config/training_model.py +5 -1
  19. careamics/config/transformations/normalize_model.py +32 -4
  20. careamics/config/validators/validator_utils.py +1 -1
  21. careamics/dataset/__init__.py +12 -1
  22. careamics/dataset/dataset_utils/__init__.py +8 -7
  23. careamics/dataset/dataset_utils/file_utils.py +2 -2
  24. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  25. careamics/dataset/dataset_utils/running_stats.py +186 -0
  26. careamics/dataset/in_memory_dataset.py +84 -173
  27. careamics/dataset/in_memory_pred_dataset.py +88 -0
  28. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  29. careamics/dataset/iterable_dataset.py +97 -250
  30. careamics/dataset/iterable_pred_dataset.py +122 -0
  31. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  32. careamics/dataset/patching/patching.py +97 -52
  33. careamics/dataset/patching/random_patching.py +9 -4
  34. careamics/dataset/patching/validate_patch_dimension.py +5 -3
  35. careamics/dataset/tiling/__init__.py +10 -0
  36. careamics/dataset/tiling/collate_tiles.py +33 -0
  37. careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
  38. careamics/file_io/__init__.py +7 -0
  39. careamics/file_io/read/__init__.py +11 -0
  40. careamics/file_io/read/get_func.py +56 -0
  41. careamics/{dataset/dataset_utils/read_tiff.py → file_io/read/tiff.py} +3 -10
  42. careamics/file_io/write/__init__.py +9 -0
  43. careamics/file_io/write/get_func.py +59 -0
  44. careamics/file_io/write/tiff.py +39 -0
  45. careamics/lightning/__init__.py +17 -0
  46. careamics/{lightning_module.py → lightning/lightning_module.py} +69 -92
  47. careamics/{lightning_prediction_datamodule.py → lightning/predict_data_module.py} +120 -178
  48. careamics/{lightning_datamodule.py → lightning/train_data_module.py} +135 -220
  49. careamics/lvae_training/__init__.py +0 -0
  50. careamics/lvae_training/data_modules.py +1220 -0
  51. careamics/lvae_training/data_utils.py +618 -0
  52. careamics/lvae_training/eval_utils.py +905 -0
  53. careamics/lvae_training/get_config.py +84 -0
  54. careamics/lvae_training/lightning_module.py +701 -0
  55. careamics/lvae_training/metrics.py +214 -0
  56. careamics/lvae_training/train_lvae.py +339 -0
  57. careamics/lvae_training/train_utils.py +121 -0
  58. careamics/model_io/bioimage/model_description.py +40 -32
  59. careamics/model_io/bmz_io.py +2 -2
  60. careamics/model_io/model_io_utils.py +6 -3
  61. careamics/models/lvae/__init__.py +0 -0
  62. careamics/models/lvae/layers.py +1998 -0
  63. careamics/models/lvae/likelihoods.py +312 -0
  64. careamics/models/lvae/lvae.py +985 -0
  65. careamics/models/lvae/noise_models.py +409 -0
  66. careamics/models/lvae/utils.py +395 -0
  67. careamics/prediction_utils/__init__.py +10 -0
  68. careamics/prediction_utils/prediction_outputs.py +137 -0
  69. careamics/prediction_utils/stitch_prediction.py +103 -0
  70. careamics/transforms/n2v_manipulate.py +3 -1
  71. careamics/transforms/normalize.py +139 -68
  72. careamics/transforms/pixel_manipulation.py +33 -9
  73. careamics/transforms/tta.py +43 -29
  74. careamics/utils/__init__.py +2 -0
  75. careamics/utils/autocorrelation.py +40 -0
  76. careamics/utils/ram.py +2 -2
  77. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/METADATA +7 -6
  78. careamics-0.1.0rc8.dist-info/RECORD +135 -0
  79. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/WHEEL +1 -1
  80. careamics/config/configuration_example.py +0 -89
  81. careamics/dataset/dataset_utils/read_utils.py +0 -27
  82. careamics/lightning_prediction_loop.py +0 -118
  83. careamics/prediction/__init__.py +0 -7
  84. careamics/prediction/stitch_prediction.py +0 -70
  85. careamics/utils/running_stats.py +0 -43
  86. careamics-0.1.0rc6.dist-info/RECORD +0 -107
  87. /careamics/{dataset/dataset_utils/read_zarr.py → file_io/read/zarr.py} +0 -0
  88. /careamics/{callbacks → lightning/callbacks}/__init__.py +0 -0
  89. /careamics/{callbacks → lightning/callbacks}/hyperparameters_callback.py +0 -0
  90. /careamics/{callbacks → lightning/callbacks}/progress_bar_callback.py +0 -0
  91. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/licenses/LICENSE +0 -0
@@ -2,9 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from typing import Optional, Tuple
6
-
7
- from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
5
+ from pydantic import BaseModel, ConfigDict, field_validator
8
6
 
9
7
 
10
8
  class TileInformation(BaseModel):
@@ -13,30 +11,43 @@ class TileInformation(BaseModel):
13
11
 
14
12
  This model is used to represent the information required to stitch back a tile into
15
13
  a larger image. It is used throughout the prediction pipeline of CAREamics.
14
+
15
+ Array shape should be (C)(Z)YX, where C and Z are optional dimensions, and must not
16
+ contain singleton dimensions.
16
17
  """
17
18
 
18
19
  model_config = ConfigDict(validate_default=True)
19
20
 
20
- array_shape: Tuple[int, ...]
21
- tiled: bool = False
21
+ array_shape: tuple[int, ...]
22
+ """Shape of the original (untiled) array."""
23
+
22
24
  last_tile: bool = False
23
- overlap_crop_coords: Optional[Tuple[Tuple[int, ...], ...]] = Field(default=None)
24
- stitch_coords: Optional[Tuple[Tuple[int, ...], ...]] = Field(default=None)
25
+ """Whether this tile is the last one of the array."""
26
+
27
+ overlap_crop_coords: tuple[tuple[int, ...], ...]
28
+ """Inner coordinates of the tile where to crop the prediction in order to stitch
29
+ it back into the original image."""
30
+
31
+ stitch_coords: tuple[tuple[int, ...], ...]
32
+ """Coordinates in the original image where to stitch the cropped tile back."""
33
+
34
+ sample_id: int
35
+ """Sample ID of the tile."""
25
36
 
26
37
  @field_validator("array_shape")
27
38
  @classmethod
28
- def no_singleton_dimensions(cls, v: Tuple[int, ...]):
39
+ def no_singleton_dimensions(cls, v: tuple[int, ...]):
29
40
  """
30
41
  Check that the array shape does not have any singleton dimensions.
31
42
 
32
43
  Parameters
33
44
  ----------
34
- v : Tuple[int, ...]
45
+ v : tuple of int
35
46
  Array shape to check.
36
47
 
37
48
  Returns
38
49
  -------
39
- Tuple[int, ...]
50
+ tuple of int
40
51
  The array shape if it does not contain singleton dimensions.
41
52
 
42
53
  Raises
@@ -48,59 +59,26 @@ class TileInformation(BaseModel):
48
59
  raise ValueError("Array shape must not contain singleton dimensions.")
49
60
  return v
50
61
 
51
- @field_validator("last_tile")
52
- @classmethod
53
- def only_if_tiled(cls, v: bool, values: ValidationInfo):
54
- """
55
- Check that the last tile flag is only set if tiling is enabled.
62
+ def __eq__(self, other_tile: object):
63
+ """Check if two tile information objects are equal.
56
64
 
57
65
  Parameters
58
66
  ----------
59
- v : bool
60
- Last tile flag.
61
- values : ValidationInfo
62
- Validation information.
67
+ other_tile : object
68
+ Tile information object to compare with.
63
69
 
64
70
  Returns
65
71
  -------
66
72
  bool
67
- The last tile flag.
68
- """
69
- if not values.data["tiled"]:
70
- return False
71
- return v
72
-
73
- @field_validator("overlap_crop_coords", "stitch_coords")
74
- @classmethod
75
- def mandatory_if_tiled(
76
- cls, v: Optional[Tuple[int, ...]], values: ValidationInfo
77
- ) -> Optional[Tuple[int, ...]]:
73
+ Whether the two tile information objects are equal.
78
74
  """
79
- Check that the coordinates are not `None` if tiling is enabled.
80
-
81
- The method also return `None` if tiling is not enabled.
82
-
83
- Parameters
84
- ----------
85
- v : Optional[Tuple[int, ...]]
86
- Coordinates to check.
87
- values : ValidationInfo
88
- Validation information.
89
-
90
- Returns
91
- -------
92
- Optional[Tuple[int, ...]]
93
- The coordinates if tiling is enabled, otherwise `None`.
94
-
95
- Raises
96
- ------
97
- ValueError
98
- If the coordinates are `None` and tiling is enabled.
99
- """
100
- if values.data["tiled"]:
101
- if v is None:
102
- raise ValueError("Value must be specified if tiling is enabled.")
103
-
104
- return v
105
- else:
106
- return None
75
+ if not isinstance(other_tile, TileInformation):
76
+ return NotImplemented
77
+
78
+ return (
79
+ self.array_shape == other_tile.array_shape
80
+ and self.last_tile == other_tile.last_tile
81
+ and self.overlap_crop_coords == other_tile.overlap_crop_coords
82
+ and self.stitch_coords == other_tile.stitch_coords
83
+ and self.sample_id == other_tile.sample_id
84
+ )
@@ -35,15 +35,19 @@ class TrainingConfig(BaseModel):
35
35
  )
36
36
 
37
37
  num_epochs: int = Field(default=20, ge=1)
38
+ """Number of epochs, greater than 0."""
38
39
 
39
40
  logger: Optional[Literal["wandb", "tensorboard"]] = None
41
+ """Logger to use during training. If None, no logger will be used. Available
42
+ loggers are defined in SupportedLogger."""
40
43
 
41
44
  checkpoint_callback: CheckpointModel = CheckpointModel()
45
+ """Checkpoint callback configuration."""
42
46
 
43
47
  early_stopping_callback: Optional[EarlyStoppingModel] = Field(
44
48
  default=None, validate_default=True
45
49
  )
46
- # precision: Literal["64", "32", "16", "bf16"] = 32
50
+ """Early stopping callback configuration."""
47
51
 
48
52
  def __str__(self) -> str:
49
53
  """Pretty string reprensenting the configuration.
@@ -1,8 +1,9 @@
1
1
  """Pydantic model for the Normalize transform."""
2
2
 
3
- from typing import Literal
3
+ from typing import Literal, Optional
4
4
 
5
- from pydantic import ConfigDict, Field
5
+ from pydantic import ConfigDict, Field, model_validator
6
+ from typing_extensions import Self
6
7
 
7
8
  from .transform_model import TransformModel
8
9
 
@@ -28,5 +29,32 @@ class NormalizeModel(TransformModel):
28
29
  )
29
30
 
30
31
  name: Literal["Normalize"] = "Normalize"
31
- mean: float = Field(default=0.485) # albumentations defaults
32
- std: float = Field(default=0.229)
32
+ image_means: list = Field(..., min_length=0, max_length=32)
33
+ image_stds: list = Field(..., min_length=0, max_length=32)
34
+ target_means: Optional[list] = Field(default=None, min_length=0, max_length=32)
35
+ target_stds: Optional[list] = Field(default=None, min_length=0, max_length=32)
36
+
37
+ @model_validator(mode="after")
38
+ def validate_means_stds(self: Self) -> Self:
39
+ """Validate that the means and stds have the same length.
40
+
41
+ Returns
42
+ -------
43
+ Self
44
+ The instance of the model.
45
+ """
46
+ if len(self.image_means) != len(self.image_stds):
47
+ raise ValueError("The number of image means and stds must be the same.")
48
+
49
+ if (self.target_means is None) != (self.target_stds is None):
50
+ raise ValueError(
51
+ "Both target means and stds must be provided together, or bot None."
52
+ )
53
+
54
+ if self.target_means is not None and self.target_stds is not None:
55
+ if len(self.target_means) != len(self.target_stds):
56
+ raise ValueError(
57
+ "The number of target means and stds must be the same."
58
+ )
59
+
60
+ return self
@@ -72,7 +72,7 @@ def value_ge_than_8_power_of_2(
72
72
  If the value is not a power of 2.
73
73
  """
74
74
  if value < 8:
75
- raise ValueError(f"Value must be non-zero positive (got {value}).")
75
+ raise ValueError(f"Value must be greater than 8 (got {value}).")
76
76
 
77
77
  if (value & (value - 1)) != 0:
78
78
  raise ValueError(f"Value must be a power of 2 (got {value}).")
@@ -1,6 +1,17 @@
1
1
  """Dataset module."""
2
2
 
3
- __all__ = ["InMemoryDataset", "PathIterableDataset"]
3
+ __all__ = [
4
+ "InMemoryDataset",
5
+ "InMemoryPredDataset",
6
+ "InMemoryTiledPredDataset",
7
+ "PathIterableDataset",
8
+ "IterableTiledPredDataset",
9
+ "IterablePredDataset",
10
+ ]
4
11
 
5
12
  from .in_memory_dataset import InMemoryDataset
13
+ from .in_memory_pred_dataset import InMemoryPredDataset
14
+ from .in_memory_tiled_pred_dataset import InMemoryTiledPredDataset
6
15
  from .iterable_dataset import PathIterableDataset
16
+ from .iterable_pred_dataset import IterablePredDataset
17
+ from .iterable_tiled_pred_dataset import IterableTiledPredDataset
@@ -2,17 +2,18 @@
2
2
 
3
3
  __all__ = [
4
4
  "reshape_array",
5
+ "compute_normalization_stats",
5
6
  "get_files_size",
6
7
  "list_files",
7
8
  "validate_source_target_files",
8
- "read_tiff",
9
- "get_read_func",
10
- "read_zarr",
9
+ "iterate_over_files",
10
+ "WelfordStatistics",
11
11
  ]
12
12
 
13
13
 
14
- from .dataset_utils import reshape_array
14
+ from .dataset_utils import (
15
+ reshape_array,
16
+ )
15
17
  from .file_utils import get_files_size, list_files, validate_source_target_files
16
- from .read_tiff import read_tiff
17
- from .read_utils import get_read_func
18
- from .read_zarr import read_zarr
18
+ from .iterate_over_files import iterate_over_files
19
+ from .running_stats import WelfordStatistics, compute_normalization_stats
@@ -33,7 +33,7 @@ def list_files(
33
33
  data_type: Union[str, SupportedData],
34
34
  extension_filter: str = "",
35
35
  ) -> List[Path]:
36
- """Create a recursive list of files in `data_path`.
36
+ """List recursively files in `data_path` and return a sorted list.
37
37
 
38
38
  If `data_path` is a file, its name is validated against the `data_type` using
39
39
  `fnmatch`, and the method returns `data_path` itself.
@@ -75,7 +75,7 @@ def list_files(
75
75
  raise FileNotFoundError(f"Data path {data_path} does not exist.")
76
76
 
77
77
  # get extension compatible with fnmatch and rglob search
78
- extension = SupportedData.get_extension(data_type)
78
+ extension = SupportedData.get_extension_pattern(data_type)
79
79
 
80
80
  if data_type == SupportedData.CUSTOM and extension_filter != "":
81
81
  extension = extension_filter
@@ -0,0 +1,83 @@
1
+ """Function to iterate over files."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Callable, Generator, Optional, Union
7
+
8
+ from numpy.typing import NDArray
9
+ from torch.utils.data import get_worker_info
10
+
11
+ from careamics.config import DataConfig, InferenceConfig
12
+ from careamics.file_io.read import read_tiff
13
+ from careamics.utils.logging import get_logger
14
+
15
+ from .dataset_utils import reshape_array
16
+
17
+ logger = get_logger(__name__)
18
+
19
+
20
+ def iterate_over_files(
21
+ data_config: Union[DataConfig, InferenceConfig],
22
+ data_files: list[Path],
23
+ target_files: Optional[list[Path]] = None,
24
+ read_source_func: Callable = read_tiff,
25
+ ) -> Generator[tuple[NDArray, Optional[NDArray]], None, None]:
26
+ """Iterate over data source and yield whole reshaped images.
27
+
28
+ Parameters
29
+ ----------
30
+ data_config : CAREamics DataConfig or InferenceConfig
31
+ Configuration.
32
+ data_files : list of pathlib.Path
33
+ List of data files.
34
+ target_files : list of pathlib.Path, optional
35
+ List of target files, by default None.
36
+ read_source_func : Callable, optional
37
+ Function to read the source, by default read_tiff.
38
+
39
+ Yields
40
+ ------
41
+ NDArray
42
+ Image.
43
+ """
44
+ # When num_workers > 0, each worker process will have a different copy of the
45
+ # dataset object
46
+ # Configuring each copy independently to avoid having duplicate data returned
47
+ # from the workers
48
+ worker_info = get_worker_info()
49
+ worker_id = worker_info.id if worker_info is not None else 0
50
+ num_workers = worker_info.num_workers if worker_info is not None else 1
51
+
52
+ # iterate over the files
53
+ for i, filename in enumerate(data_files):
54
+ # retrieve file corresponding to the worker id
55
+ if i % num_workers == worker_id:
56
+ try:
57
+ # read data
58
+ sample = read_source_func(filename, data_config.axes)
59
+
60
+ # reshape array
61
+ reshaped_sample = reshape_array(sample, data_config.axes)
62
+
63
+ # read target, if available
64
+ if target_files is not None:
65
+ if filename.name != target_files[i].name:
66
+ raise ValueError(
67
+ f"File {filename} does not match target file "
68
+ f"{target_files[i]}. Have you passed sorted "
69
+ f"arrays?"
70
+ )
71
+
72
+ # read target
73
+ target = read_source_func(target_files[i], data_config.axes)
74
+
75
+ # reshape target
76
+ reshaped_target = reshape_array(target, data_config.axes)
77
+
78
+ yield reshaped_sample, reshaped_target
79
+ else:
80
+ yield reshaped_sample, None
81
+
82
+ except Exception as e:
83
+ logger.error(f"Error reading file {filename}: {e}")
@@ -0,0 +1,186 @@
1
+ """Computing data statistics."""
2
+
3
+ import numpy as np
4
+ from numpy.typing import NDArray
5
+
6
+
7
+ def compute_normalization_stats(image: NDArray) -> tuple[NDArray, NDArray]:
8
+ """
9
+ Compute mean and standard deviation of an array.
10
+
11
+ Expected input shape is (S, C, (Z), Y, X). The mean and standard deviation are
12
+ computed per channel.
13
+
14
+ Parameters
15
+ ----------
16
+ image : NDArray
17
+ Input array.
18
+
19
+ Returns
20
+ -------
21
+ tuple of (list of floats, list of floats)
22
+ Lists of mean and standard deviation values per channel.
23
+ """
24
+ # Define the list of axes excluding the channel axis
25
+ axes = tuple(np.delete(np.arange(image.ndim), 1))
26
+ return np.mean(image, axis=axes), np.std(image, axis=axes)
27
+
28
+
29
+ def update_iterative_stats(
30
+ count: NDArray, mean: NDArray, m2: NDArray, new_values: NDArray
31
+ ) -> tuple[NDArray, NDArray, NDArray]:
32
+ """Update the mean and variance of an array iteratively.
33
+
34
+ Parameters
35
+ ----------
36
+ count : NDArray
37
+ Number of elements in the array.
38
+ mean : NDArray
39
+ Mean of the array.
40
+ m2 : NDArray
41
+ Variance of the array.
42
+ new_values : NDArray
43
+ New values to add to the mean and variance.
44
+
45
+ Returns
46
+ -------
47
+ tuple[NDArray, NDArray, NDArray]
48
+ Updated count, mean, and variance.
49
+ """
50
+ count += np.array([np.prod(channel.shape) for channel in new_values])
51
+ # newvalues - oldMean
52
+ delta = [
53
+ np.subtract(v.flatten(), [m] * len(v.flatten()))
54
+ for v, m in zip(new_values, mean)
55
+ ]
56
+
57
+ mean += np.array([np.sum(d / c) for d, c in zip(delta, count)])
58
+ # newvalues - newMeant
59
+ delta2 = [
60
+ np.subtract(v.flatten(), [m] * len(v.flatten()))
61
+ for v, m in zip(new_values, mean)
62
+ ]
63
+
64
+ m2 += np.array([np.sum(d * d2) for d, d2 in zip(delta, delta2)])
65
+
66
+ return (count, mean, m2)
67
+
68
+
69
+ def finalize_iterative_stats(
70
+ count: NDArray, mean: NDArray, m2: NDArray
71
+ ) -> tuple[NDArray, NDArray]:
72
+ """Finalize the mean and variance computation.
73
+
74
+ Parameters
75
+ ----------
76
+ count : NDArray
77
+ Number of elements in the array.
78
+ mean : NDArray
79
+ Mean of the array.
80
+ m2 : NDArray
81
+ Variance of the array.
82
+
83
+ Returns
84
+ -------
85
+ tuple[NDArray, NDArray]
86
+ Final mean and standard deviation.
87
+ """
88
+ std = np.array([np.sqrt(m / c) for m, c in zip(m2, count)])
89
+ if any(c < 2 for c in count):
90
+ return np.full(mean.shape, np.nan), np.full(std.shape, np.nan)
91
+ else:
92
+ return mean, std
93
+
94
+
95
+ class WelfordStatistics:
96
+ """Compute Welford statistics iteratively.
97
+
98
+ The Welford algorithm is used to compute the mean and variance of an array
99
+ iteratively. Based on the implementation from:
100
+ https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
101
+ """
102
+
103
+ def update(self, array: NDArray, sample_idx: int) -> None:
104
+ """Update the Welford statistics.
105
+
106
+ Parameters
107
+ ----------
108
+ array : NDArray
109
+ Input array.
110
+ sample_idx : int
111
+ Current sample number.
112
+ """
113
+ self.sample_idx = sample_idx
114
+ sample_channels = np.array(np.split(array, array.shape[1], axis=1))
115
+
116
+ # Initialize the statistics
117
+ if self.sample_idx == 0:
118
+ # Compute the mean and standard deviation
119
+ self.mean, _ = compute_normalization_stats(array)
120
+ # Initialize the count and m2 with zero-valued arrays of shape (C,)
121
+ self.count, self.mean, self.m2 = update_iterative_stats(
122
+ count=np.zeros(array.shape[1]),
123
+ mean=self.mean,
124
+ m2=np.zeros(array.shape[1]),
125
+ new_values=sample_channels,
126
+ )
127
+ else:
128
+ # Update the statistics
129
+ self.count, self.mean, self.m2 = update_iterative_stats(
130
+ count=self.count, mean=self.mean, m2=self.m2, new_values=sample_channels
131
+ )
132
+
133
+ self.sample_idx += 1
134
+
135
+ def finalize(self) -> tuple[NDArray, NDArray]:
136
+ """Finalize the Welford statistics.
137
+
138
+ Returns
139
+ -------
140
+ tuple or numpy arrays
141
+ Final mean and standard deviation.
142
+ """
143
+ return finalize_iterative_stats(self.count, self.mean, self.m2)
144
+
145
+
146
+ # from multiprocessing import Value
147
+ # from typing import tuple
148
+
149
+ # import numpy as np
150
+
151
+
152
+ # class RunningStats:
153
+ # """Calculates running mean and std."""
154
+
155
+ # def __init__(self) -> None:
156
+ # self.reset()
157
+
158
+ # def reset(self) -> None:
159
+ # """Reset the running stats."""
160
+ # self.avg_mean = Value("d", 0)
161
+ # self.avg_std = Value("d", 0)
162
+ # self.m2 = Value("d", 0)
163
+ # self.count = Value("i", 0)
164
+
165
+ # def init(self, mean: float, std: float) -> None:
166
+ # """Initialize running stats."""
167
+ # with self.avg_mean.get_lock():
168
+ # self.avg_mean.value += mean
169
+ # with self.avg_std.get_lock():
170
+ # self.avg_std.value = std
171
+
172
+ # def compute_std(self) -> tuple[float, float]:
173
+ # """Compute std."""
174
+ # if self.count.value >= 2:
175
+ # self.avg_std.value = np.sqrt(self.m2.value / self.count.value)
176
+
177
+ # def update(self, value: float) -> None:
178
+ # """Update running stats."""
179
+ # with self.count.get_lock():
180
+ # self.count.value += 1
181
+ # delta = value - self.avg_mean.value
182
+ # with self.avg_mean.get_lock():
183
+ # self.avg_mean.value += delta / self.count.value
184
+ # delta2 = value - self.avg_mean.value
185
+ # with self.m2.get_lock():
186
+ # self.m2.value += delta * delta2