careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (155) hide show
  1. careamics/__init__.py +6 -1
  2. careamics/careamist.py +729 -0
  3. careamics/config/__init__.py +39 -0
  4. careamics/config/architectures/__init__.py +17 -0
  5. careamics/config/architectures/architecture_model.py +37 -0
  6. careamics/config/architectures/custom_model.py +162 -0
  7. careamics/config/architectures/lvae_model.py +174 -0
  8. careamics/config/architectures/register_model.py +103 -0
  9. careamics/config/architectures/unet_model.py +118 -0
  10. careamics/config/callback_model.py +123 -0
  11. careamics/config/configuration_factory.py +583 -0
  12. careamics/config/configuration_model.py +604 -0
  13. careamics/config/data_model.py +527 -0
  14. careamics/config/fcn_algorithm_model.py +147 -0
  15. careamics/config/inference_model.py +239 -0
  16. careamics/config/likelihood_model.py +43 -0
  17. careamics/config/nm_model.py +101 -0
  18. careamics/config/optimizer_models.py +187 -0
  19. careamics/config/references/__init__.py +45 -0
  20. careamics/config/references/algorithm_descriptions.py +132 -0
  21. careamics/config/references/references.py +39 -0
  22. careamics/config/support/__init__.py +31 -0
  23. careamics/config/support/supported_activations.py +27 -0
  24. careamics/config/support/supported_algorithms.py +33 -0
  25. careamics/config/support/supported_architectures.py +17 -0
  26. careamics/config/support/supported_data.py +109 -0
  27. careamics/config/support/supported_loggers.py +10 -0
  28. careamics/config/support/supported_losses.py +29 -0
  29. careamics/config/support/supported_optimizers.py +57 -0
  30. careamics/config/support/supported_pixel_manipulations.py +15 -0
  31. careamics/config/support/supported_struct_axis.py +21 -0
  32. careamics/config/support/supported_transforms.py +11 -0
  33. careamics/config/tile_information.py +65 -0
  34. careamics/config/training_model.py +72 -0
  35. careamics/config/transformations/__init__.py +15 -0
  36. careamics/config/transformations/n2v_manipulate_model.py +64 -0
  37. careamics/config/transformations/normalize_model.py +60 -0
  38. careamics/config/transformations/transform_model.py +45 -0
  39. careamics/config/transformations/xy_flip_model.py +43 -0
  40. careamics/config/transformations/xy_random_rotate90_model.py +35 -0
  41. careamics/config/vae_algorithm_model.py +171 -0
  42. careamics/config/validators/__init__.py +5 -0
  43. careamics/config/validators/validator_utils.py +101 -0
  44. careamics/conftest.py +39 -0
  45. careamics/dataset/__init__.py +17 -0
  46. careamics/dataset/dataset_utils/__init__.py +19 -0
  47. careamics/dataset/dataset_utils/dataset_utils.py +101 -0
  48. careamics/dataset/dataset_utils/file_utils.py +141 -0
  49. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  50. careamics/dataset/dataset_utils/running_stats.py +186 -0
  51. careamics/dataset/in_memory_dataset.py +310 -0
  52. careamics/dataset/in_memory_pred_dataset.py +88 -0
  53. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  54. careamics/dataset/iterable_dataset.py +295 -0
  55. careamics/dataset/iterable_pred_dataset.py +122 -0
  56. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  57. careamics/dataset/patching/__init__.py +1 -0
  58. careamics/dataset/patching/patching.py +299 -0
  59. careamics/dataset/patching/random_patching.py +201 -0
  60. careamics/dataset/patching/sequential_patching.py +212 -0
  61. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  62. careamics/dataset/tiling/__init__.py +10 -0
  63. careamics/dataset/tiling/collate_tiles.py +33 -0
  64. careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
  65. careamics/dataset/tiling/tiled_patching.py +164 -0
  66. careamics/dataset/zarr_dataset.py +151 -0
  67. careamics/file_io/__init__.py +15 -0
  68. careamics/file_io/read/__init__.py +12 -0
  69. careamics/file_io/read/get_func.py +56 -0
  70. careamics/file_io/read/tiff.py +58 -0
  71. careamics/file_io/read/zarr.py +60 -0
  72. careamics/file_io/write/__init__.py +15 -0
  73. careamics/file_io/write/get_func.py +63 -0
  74. careamics/file_io/write/tiff.py +40 -0
  75. careamics/lightning/__init__.py +18 -0
  76. careamics/lightning/callbacks/__init__.py +11 -0
  77. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  78. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  79. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  80. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
  81. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
  82. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  83. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  84. careamics/lightning/lightning_module.py +632 -0
  85. careamics/lightning/predict_data_module.py +333 -0
  86. careamics/lightning/train_data_module.py +680 -0
  87. careamics/losses/__init__.py +15 -0
  88. careamics/losses/fcn/__init__.py +1 -0
  89. careamics/losses/fcn/losses.py +98 -0
  90. careamics/losses/loss_factory.py +155 -0
  91. careamics/losses/lvae/__init__.py +1 -0
  92. careamics/losses/lvae/loss_utils.py +83 -0
  93. careamics/losses/lvae/losses.py +445 -0
  94. careamics/lvae_training/__init__.py +0 -0
  95. careamics/lvae_training/dataset/__init__.py +0 -0
  96. careamics/lvae_training/dataset/data_utils.py +701 -0
  97. careamics/lvae_training/dataset/lc_dataset.py +259 -0
  98. careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
  99. careamics/lvae_training/dataset/vae_data_config.py +179 -0
  100. careamics/lvae_training/dataset/vae_dataset.py +1054 -0
  101. careamics/lvae_training/eval_utils.py +905 -0
  102. careamics/lvae_training/get_config.py +84 -0
  103. careamics/lvae_training/lightning_module.py +701 -0
  104. careamics/lvae_training/metrics.py +214 -0
  105. careamics/lvae_training/train_lvae.py +342 -0
  106. careamics/lvae_training/train_utils.py +121 -0
  107. careamics/model_io/__init__.py +7 -0
  108. careamics/model_io/bioimage/__init__.py +11 -0
  109. careamics/model_io/bioimage/_readme_factory.py +121 -0
  110. careamics/model_io/bioimage/bioimage_utils.py +52 -0
  111. careamics/model_io/bioimage/model_description.py +327 -0
  112. careamics/model_io/bmz_io.py +246 -0
  113. careamics/model_io/model_io_utils.py +95 -0
  114. careamics/models/__init__.py +5 -0
  115. careamics/models/activation.py +39 -0
  116. careamics/models/layers.py +493 -0
  117. careamics/models/lvae/__init__.py +3 -0
  118. careamics/models/lvae/layers.py +1998 -0
  119. careamics/models/lvae/likelihoods.py +364 -0
  120. careamics/models/lvae/lvae.py +901 -0
  121. careamics/models/lvae/noise_models.py +541 -0
  122. careamics/models/lvae/utils.py +395 -0
  123. careamics/models/model_factory.py +67 -0
  124. careamics/models/unet.py +443 -0
  125. careamics/prediction_utils/__init__.py +10 -0
  126. careamics/prediction_utils/lvae_prediction.py +158 -0
  127. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  128. careamics/prediction_utils/prediction_outputs.py +135 -0
  129. careamics/prediction_utils/stitch_prediction.py +112 -0
  130. careamics/transforms/__init__.py +20 -0
  131. careamics/transforms/compose.py +107 -0
  132. careamics/transforms/n2v_manipulate.py +146 -0
  133. careamics/transforms/normalize.py +243 -0
  134. careamics/transforms/pixel_manipulation.py +407 -0
  135. careamics/transforms/struct_mask_parameters.py +20 -0
  136. careamics/transforms/transform.py +24 -0
  137. careamics/transforms/tta.py +88 -0
  138. careamics/transforms/xy_flip.py +123 -0
  139. careamics/transforms/xy_random_rotate90.py +101 -0
  140. careamics/utils/__init__.py +19 -0
  141. careamics/utils/autocorrelation.py +40 -0
  142. careamics/utils/base_enum.py +60 -0
  143. careamics/utils/context.py +66 -0
  144. careamics/utils/logging.py +322 -0
  145. careamics/utils/metrics.py +188 -0
  146. careamics/utils/path_utils.py +26 -0
  147. careamics/utils/ram.py +15 -0
  148. careamics/utils/receptive_field.py +108 -0
  149. careamics/utils/torch_utils.py +127 -0
  150. careamics-0.0.3.dist-info/METADATA +78 -0
  151. careamics-0.0.3.dist-info/RECORD +154 -0
  152. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
  153. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
  154. careamics-0.0.1.dist-info/METADATA +0 -46
  155. careamics-0.0.1.dist-info/RECORD +0 -6
@@ -0,0 +1,101 @@
1
+ """Dataset utilities."""
2
+
3
+ from typing import List, Tuple
4
+
5
+ import numpy as np
6
+
7
+ from careamics.utils.logging import get_logger
8
+
9
+ logger = get_logger(__name__)
10
+
11
+
12
+ def _get_shape_order(
13
+ shape_in: Tuple[int, ...], axes_in: str, ref_axes: str = "STCZYX"
14
+ ) -> Tuple[Tuple[int, ...], str, List[int]]:
15
+ """
16
+ Compute a new shape for the array based on the reference axes.
17
+
18
+ Parameters
19
+ ----------
20
+ shape_in : Tuple[int, ...]
21
+ Input shape.
22
+ axes_in : str
23
+ Input axes.
24
+ ref_axes : str
25
+ Reference axes.
26
+
27
+ Returns
28
+ -------
29
+ Tuple[Tuple[int, ...], str, List[int]]
30
+ New shape, new axes, indices of axes in the new axes order.
31
+ """
32
+ indices = [axes_in.find(k) for k in ref_axes]
33
+
34
+ # remove all non-existing axes (index == -1)
35
+ new_indices = list(filter(lambda k: k != -1, indices))
36
+
37
+ # find axes order and get new shape
38
+ new_axes = [axes_in[ind] for ind in new_indices]
39
+ new_shape = tuple([shape_in[ind] for ind in new_indices])
40
+
41
+ return new_shape, "".join(new_axes), new_indices
42
+
43
+
44
+ def reshape_array(x: np.ndarray, axes: str) -> np.ndarray:
45
+ """Reshape the data to (S, C, (Z), Y, X) by moving axes.
46
+
47
+ If the data has both S and T axes, the two axes will be merged. A singleton
48
+ dimension is added if there are no C axis.
49
+
50
+ Parameters
51
+ ----------
52
+ x : np.ndarray
53
+ Input array.
54
+ axes : str
55
+ Description of axes in format `STCZYX`.
56
+
57
+ Returns
58
+ -------
59
+ np.ndarray
60
+ Reshaped array with shape (S, C, (Z), Y, X).
61
+ """
62
+ _x = x
63
+ _axes = axes
64
+
65
+ # sanity checks
66
+ if len(_axes) != len(_x.shape):
67
+ raise ValueError(
68
+ f"Incompatible data shape ({_x.shape}) and axes ({_axes}). Are the axes "
69
+ f"correct?"
70
+ )
71
+
72
+ # get new x shape
73
+ new_x_shape, new_axes, indices = _get_shape_order(_x.shape, _axes)
74
+
75
+ # if S is not in the list of axes, then add a singleton S
76
+ if "S" not in new_axes:
77
+ new_axes = "S" + new_axes
78
+ _x = _x[np.newaxis, ...]
79
+ new_x_shape = (1,) + new_x_shape
80
+
81
+ # need to change the array of indices
82
+ indices = [0] + [1 + i for i in indices]
83
+
84
+ # reshape by moving axes
85
+ destination = list(range(len(indices)))
86
+ _x = np.moveaxis(_x, indices, destination)
87
+
88
+ # remove T if necessary
89
+ if "T" in new_axes:
90
+ new_x_shape = (-1,) + new_x_shape[2:] # remove T and S
91
+ new_axes = new_axes.replace("T", "")
92
+
93
+ # reshape S and T together
94
+ _x = _x.reshape(new_x_shape)
95
+
96
+ # add channel
97
+ if "C" not in new_axes:
98
+ # Add channel axis after S
99
+ _x = np.expand_dims(_x, new_axes.index("S") + 1)
100
+
101
+ return _x
@@ -0,0 +1,141 @@
1
+ """File utilities."""
2
+
3
+ from fnmatch import fnmatch
4
+ from pathlib import Path
5
+ from typing import List, Union
6
+
7
+ import numpy as np
8
+
9
+ from careamics.config.support import SupportedData
10
+ from careamics.utils.logging import get_logger
11
+
12
+ logger = get_logger(__name__)
13
+
14
+
15
+ def get_files_size(files: List[Path]) -> float:
16
+ """Get files size in MB.
17
+
18
+ Parameters
19
+ ----------
20
+ files : List[Path]
21
+ List of files.
22
+
23
+ Returns
24
+ -------
25
+ float
26
+ Total size of the files in MB.
27
+ """
28
+ return np.sum([f.stat().st_size / 1024**2 for f in files])
29
+
30
+
31
+ def list_files(
32
+ data_path: Union[str, Path],
33
+ data_type: Union[str, SupportedData],
34
+ extension_filter: str = "",
35
+ ) -> List[Path]:
36
+ """List recursively files in `data_path` and return a sorted list.
37
+
38
+ If `data_path` is a file, its name is validated against the `data_type` using
39
+ `fnmatch`, and the method returns `data_path` itself.
40
+
41
+ By default, if `data_type` is equal to `custom`, all files will be listed. To
42
+ further filter the files, use `extension_filter`.
43
+
44
+ `extension_filter` must be compatible with `fnmatch` and `Path.rglob`, e.g. "*.npy"
45
+ or "*.czi".
46
+
47
+ Parameters
48
+ ----------
49
+ data_path : Union[str, Path]
50
+ Path to the folder containing the data.
51
+ data_type : Union[str, SupportedData]
52
+ One of the supported data type (e.g. tif, custom).
53
+ extension_filter : str, optional
54
+ Extension filter, by default "".
55
+
56
+ Returns
57
+ -------
58
+ List[Path]
59
+ List of pathlib.Path objects.
60
+
61
+ Raises
62
+ ------
63
+ FileNotFoundError
64
+ If the data path does not exist.
65
+ ValueError
66
+ If the data path is empty or no files with the extension were found.
67
+ ValueError
68
+ If the file does not match the requested extension.
69
+ """
70
+ # convert to Path
71
+ data_path = Path(data_path)
72
+
73
+ # raise error if does not exists
74
+ if not data_path.exists():
75
+ raise FileNotFoundError(f"Data path {data_path} does not exist.")
76
+
77
+ # get extension compatible with fnmatch and rglob search
78
+ extension = SupportedData.get_extension_pattern(data_type)
79
+
80
+ if data_type == SupportedData.CUSTOM and extension_filter != "":
81
+ extension = extension_filter
82
+
83
+ # search recurively
84
+ if data_path.is_dir():
85
+ # search recursively the path for files with the extension
86
+ files = sorted(data_path.rglob(extension))
87
+ else:
88
+ # raise error if it has the wrong extension
89
+ if not fnmatch(str(data_path.absolute()), extension):
90
+ raise ValueError(
91
+ f"File {data_path} does not match the requested extension "
92
+ f'"{extension}".'
93
+ )
94
+
95
+ # save in list
96
+ files = [data_path]
97
+
98
+ # raise error if no files were found
99
+ if len(files) == 0:
100
+ raise ValueError(
101
+ f'Data path {data_path} is empty or files with extension "{extension}" '
102
+ f"were not found."
103
+ )
104
+
105
+ return files
106
+
107
+
108
+ def validate_source_target_files(src_files: List[Path], tar_files: List[Path]) -> None:
109
+ """
110
+ Validate source and target path lists.
111
+
112
+ The two lists should have the same number of files, and the filenames should match.
113
+
114
+ Parameters
115
+ ----------
116
+ src_files : List[Path]
117
+ List of source files.
118
+ tar_files : List[Path]
119
+ List of target files.
120
+
121
+ Raises
122
+ ------
123
+ ValueError
124
+ If the number of files in source and target folders is not the same.
125
+ ValueError
126
+ If some filenames in Train and target folders are not the same.
127
+ """
128
+ # check equal length
129
+ if len(src_files) != len(tar_files):
130
+ raise ValueError(
131
+ f"The number of source files ({len(src_files)}) is not equal to the number "
132
+ f"of target files ({len(tar_files)})."
133
+ )
134
+
135
+ # check identical names
136
+ src_names = {f.name for f in src_files}
137
+ tar_names = {f.name for f in tar_files}
138
+ difference = src_names.symmetric_difference(tar_names)
139
+
140
+ if len(difference) > 0:
141
+ raise ValueError(f"Source and target files have different names: {difference}.")
@@ -0,0 +1,83 @@
1
+ """Function to iterate over files."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Callable, Generator, Optional, Union
7
+
8
+ from numpy.typing import NDArray
9
+ from torch.utils.data import get_worker_info
10
+
11
+ from careamics.config import DataConfig, InferenceConfig
12
+ from careamics.file_io.read import read_tiff
13
+ from careamics.utils.logging import get_logger
14
+
15
+ from .dataset_utils import reshape_array
16
+
17
+ logger = get_logger(__name__)
18
+
19
+
20
+ def iterate_over_files(
21
+ data_config: Union[DataConfig, InferenceConfig],
22
+ data_files: list[Path],
23
+ target_files: Optional[list[Path]] = None,
24
+ read_source_func: Callable = read_tiff,
25
+ ) -> Generator[tuple[NDArray, Optional[NDArray]], None, None]:
26
+ """Iterate over data source and yield whole reshaped images.
27
+
28
+ Parameters
29
+ ----------
30
+ data_config : CAREamics DataConfig or InferenceConfig
31
+ Configuration.
32
+ data_files : list of pathlib.Path
33
+ List of data files.
34
+ target_files : list of pathlib.Path, optional
35
+ List of target files, by default None.
36
+ read_source_func : Callable, optional
37
+ Function to read the source, by default read_tiff.
38
+
39
+ Yields
40
+ ------
41
+ NDArray
42
+ Image.
43
+ """
44
+ # When num_workers > 0, each worker process will have a different copy of the
45
+ # dataset object
46
+ # Configuring each copy independently to avoid having duplicate data returned
47
+ # from the workers
48
+ worker_info = get_worker_info()
49
+ worker_id = worker_info.id if worker_info is not None else 0
50
+ num_workers = worker_info.num_workers if worker_info is not None else 1
51
+
52
+ # iterate over the files
53
+ for i, filename in enumerate(data_files):
54
+ # retrieve file corresponding to the worker id
55
+ if i % num_workers == worker_id:
56
+ try:
57
+ # read data
58
+ sample = read_source_func(filename, data_config.axes)
59
+
60
+ # reshape array
61
+ reshaped_sample = reshape_array(sample, data_config.axes)
62
+
63
+ # read target, if available
64
+ if target_files is not None:
65
+ if filename.name != target_files[i].name:
66
+ raise ValueError(
67
+ f"File {filename} does not match target file "
68
+ f"{target_files[i]}. Have you passed sorted "
69
+ f"arrays?"
70
+ )
71
+
72
+ # read target
73
+ target = read_source_func(target_files[i], data_config.axes)
74
+
75
+ # reshape target
76
+ reshaped_target = reshape_array(target, data_config.axes)
77
+
78
+ yield reshaped_sample, reshaped_target
79
+ else:
80
+ yield reshaped_sample, None
81
+
82
+ except Exception as e:
83
+ logger.error(f"Error reading file {filename}: {e}")
@@ -0,0 +1,186 @@
1
+ """Computing data statistics."""
2
+
3
+ import numpy as np
4
+ from numpy.typing import NDArray
5
+
6
+
7
+ def compute_normalization_stats(image: NDArray) -> tuple[NDArray, NDArray]:
8
+ """
9
+ Compute mean and standard deviation of an array.
10
+
11
+ Expected input shape is (S, C, (Z), Y, X). The mean and standard deviation are
12
+ computed per channel.
13
+
14
+ Parameters
15
+ ----------
16
+ image : NDArray
17
+ Input array.
18
+
19
+ Returns
20
+ -------
21
+ tuple of (list of floats, list of floats)
22
+ Lists of mean and standard deviation values per channel.
23
+ """
24
+ # Define the list of axes excluding the channel axis
25
+ axes = tuple(np.delete(np.arange(image.ndim), 1))
26
+ return np.mean(image, axis=axes), np.std(image, axis=axes)
27
+
28
+
29
+ def update_iterative_stats(
30
+ count: NDArray, mean: NDArray, m2: NDArray, new_values: NDArray
31
+ ) -> tuple[NDArray, NDArray, NDArray]:
32
+ """Update the mean and variance of an array iteratively.
33
+
34
+ Parameters
35
+ ----------
36
+ count : NDArray
37
+ Number of elements in the array.
38
+ mean : NDArray
39
+ Mean of the array.
40
+ m2 : NDArray
41
+ Variance of the array.
42
+ new_values : NDArray
43
+ New values to add to the mean and variance.
44
+
45
+ Returns
46
+ -------
47
+ tuple[NDArray, NDArray, NDArray]
48
+ Updated count, mean, and variance.
49
+ """
50
+ count += np.array([np.prod(channel.shape) for channel in new_values])
51
+ # newvalues - oldMean
52
+ delta = [
53
+ np.subtract(v.flatten(), [m] * len(v.flatten()))
54
+ for v, m in zip(new_values, mean)
55
+ ]
56
+
57
+ mean += np.array([np.sum(d / c) for d, c in zip(delta, count)])
58
+ # newvalues - newMeant
59
+ delta2 = [
60
+ np.subtract(v.flatten(), [m] * len(v.flatten()))
61
+ for v, m in zip(new_values, mean)
62
+ ]
63
+
64
+ m2 += np.array([np.sum(d * d2) for d, d2 in zip(delta, delta2)])
65
+
66
+ return (count, mean, m2)
67
+
68
+
69
+ def finalize_iterative_stats(
70
+ count: NDArray, mean: NDArray, m2: NDArray
71
+ ) -> tuple[NDArray, NDArray]:
72
+ """Finalize the mean and variance computation.
73
+
74
+ Parameters
75
+ ----------
76
+ count : NDArray
77
+ Number of elements in the array.
78
+ mean : NDArray
79
+ Mean of the array.
80
+ m2 : NDArray
81
+ Variance of the array.
82
+
83
+ Returns
84
+ -------
85
+ tuple[NDArray, NDArray]
86
+ Final mean and standard deviation.
87
+ """
88
+ std = np.array([np.sqrt(m / c) for m, c in zip(m2, count)])
89
+ if any(c < 2 for c in count):
90
+ return np.full(mean.shape, np.nan), np.full(std.shape, np.nan)
91
+ else:
92
+ return mean, std
93
+
94
+
95
+ class WelfordStatistics:
96
+ """Compute Welford statistics iteratively.
97
+
98
+ The Welford algorithm is used to compute the mean and variance of an array
99
+ iteratively. Based on the implementation from:
100
+ https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
101
+ """
102
+
103
+ def update(self, array: NDArray, sample_idx: int) -> None:
104
+ """Update the Welford statistics.
105
+
106
+ Parameters
107
+ ----------
108
+ array : NDArray
109
+ Input array.
110
+ sample_idx : int
111
+ Current sample number.
112
+ """
113
+ self.sample_idx = sample_idx
114
+ sample_channels = np.array(np.split(array, array.shape[1], axis=1))
115
+
116
+ # Initialize the statistics
117
+ if self.sample_idx == 0:
118
+ # Compute the mean and standard deviation
119
+ self.mean, _ = compute_normalization_stats(array)
120
+ # Initialize the count and m2 with zero-valued arrays of shape (C,)
121
+ self.count, self.mean, self.m2 = update_iterative_stats(
122
+ count=np.zeros(array.shape[1]),
123
+ mean=self.mean,
124
+ m2=np.zeros(array.shape[1]),
125
+ new_values=sample_channels,
126
+ )
127
+ else:
128
+ # Update the statistics
129
+ self.count, self.mean, self.m2 = update_iterative_stats(
130
+ count=self.count, mean=self.mean, m2=self.m2, new_values=sample_channels
131
+ )
132
+
133
+ self.sample_idx += 1
134
+
135
+ def finalize(self) -> tuple[NDArray, NDArray]:
136
+ """Finalize the Welford statistics.
137
+
138
+ Returns
139
+ -------
140
+ tuple or numpy arrays
141
+ Final mean and standard deviation.
142
+ """
143
+ return finalize_iterative_stats(self.count, self.mean, self.m2)
144
+
145
+
146
+ # from multiprocessing import Value
147
+ # from typing import tuple
148
+
149
+ # import numpy as np
150
+
151
+
152
+ # class RunningStats:
153
+ # """Calculates running mean and std."""
154
+
155
+ # def __init__(self) -> None:
156
+ # self.reset()
157
+
158
+ # def reset(self) -> None:
159
+ # """Reset the running stats."""
160
+ # self.avg_mean = Value("d", 0)
161
+ # self.avg_std = Value("d", 0)
162
+ # self.m2 = Value("d", 0)
163
+ # self.count = Value("i", 0)
164
+
165
+ # def init(self, mean: float, std: float) -> None:
166
+ # """Initialize running stats."""
167
+ # with self.avg_mean.get_lock():
168
+ # self.avg_mean.value += mean
169
+ # with self.avg_std.get_lock():
170
+ # self.avg_std.value = std
171
+
172
+ # def compute_std(self) -> tuple[float, float]:
173
+ # """Compute std."""
174
+ # if self.count.value >= 2:
175
+ # self.avg_std.value = np.sqrt(self.m2.value / self.count.value)
176
+
177
+ # def update(self, value: float) -> None:
178
+ # """Update running stats."""
179
+ # with self.count.get_lock():
180
+ # self.count.value += 1
181
+ # delta = value - self.avg_mean.value
182
+ # with self.avg_mean.get_lock():
183
+ # self.avg_mean.value += delta / self.count.value
184
+ # delta2 = value - self.avg_mean.value
185
+ # with self.m2.get_lock():
186
+ # self.m2.value += delta * delta2