dataeval 0.64.0__py3-none-any.whl → 0.66.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. dataeval/__init__.py +13 -9
  2. dataeval/_internal/detectors/clusterer.py +63 -49
  3. dataeval/_internal/detectors/drift/base.py +248 -51
  4. dataeval/_internal/detectors/drift/cvm.py +28 -26
  5. dataeval/_internal/detectors/drift/ks.py +31 -28
  6. dataeval/_internal/detectors/drift/mmd.py +62 -42
  7. dataeval/_internal/detectors/drift/torch.py +69 -60
  8. dataeval/_internal/detectors/drift/uncertainty.py +32 -32
  9. dataeval/_internal/detectors/duplicates.py +67 -31
  10. dataeval/_internal/detectors/ood/ae.py +15 -29
  11. dataeval/_internal/detectors/ood/aegmm.py +33 -27
  12. dataeval/_internal/detectors/ood/base.py +86 -47
  13. dataeval/_internal/detectors/ood/llr.py +34 -31
  14. dataeval/_internal/detectors/ood/vae.py +32 -31
  15. dataeval/_internal/detectors/ood/vaegmm.py +34 -28
  16. dataeval/_internal/detectors/{linter.py → outliers.py} +60 -38
  17. dataeval/_internal/flags.py +44 -21
  18. dataeval/_internal/interop.py +5 -3
  19. dataeval/_internal/metrics/balance.py +42 -5
  20. dataeval/_internal/metrics/ber.py +11 -8
  21. dataeval/_internal/metrics/coverage.py +15 -8
  22. dataeval/_internal/metrics/divergence.py +41 -7
  23. dataeval/_internal/metrics/diversity.py +57 -19
  24. dataeval/_internal/metrics/parity.py +141 -66
  25. dataeval/_internal/metrics/stats.py +330 -313
  26. dataeval/_internal/metrics/uap.py +33 -4
  27. dataeval/_internal/metrics/utils.py +79 -40
  28. dataeval/_internal/models/pytorch/autoencoder.py +127 -22
  29. dataeval/_internal/models/tensorflow/autoencoder.py +33 -30
  30. dataeval/_internal/models/tensorflow/gmm.py +4 -2
  31. dataeval/_internal/models/tensorflow/losses.py +17 -13
  32. dataeval/_internal/models/tensorflow/pixelcnn.py +19 -18
  33. dataeval/_internal/models/tensorflow/trainer.py +10 -7
  34. dataeval/_internal/models/tensorflow/utils.py +23 -20
  35. dataeval/_internal/output.py +85 -0
  36. dataeval/_internal/utils.py +5 -3
  37. dataeval/_internal/workflows/sufficiency.py +122 -121
  38. dataeval/detectors/__init__.py +6 -25
  39. dataeval/detectors/drift/__init__.py +16 -0
  40. dataeval/detectors/drift/kernels/__init__.py +6 -0
  41. dataeval/detectors/drift/updates/__init__.py +3 -0
  42. dataeval/detectors/linters/__init__.py +5 -0
  43. dataeval/detectors/ood/__init__.py +11 -0
  44. dataeval/flags/__init__.py +2 -2
  45. dataeval/metrics/__init__.py +2 -26
  46. dataeval/metrics/bias/__init__.py +14 -0
  47. dataeval/metrics/estimators/__init__.py +9 -0
  48. dataeval/metrics/stats/__init__.py +6 -0
  49. dataeval/tensorflow/__init__.py +3 -0
  50. dataeval/tensorflow/loss/__init__.py +3 -0
  51. dataeval/tensorflow/models/__init__.py +5 -0
  52. dataeval/tensorflow/recon/__init__.py +3 -0
  53. dataeval/torch/__init__.py +3 -0
  54. dataeval/{models/torch → torch/models}/__init__.py +1 -2
  55. dataeval/torch/trainer/__init__.py +3 -0
  56. dataeval/utils/__init__.py +3 -6
  57. dataeval/workflows/__init__.py +2 -4
  58. {dataeval-0.64.0.dist-info → dataeval-0.66.0.dist-info}/METADATA +1 -1
  59. dataeval-0.66.0.dist-info/RECORD +72 -0
  60. dataeval/_internal/metrics/base.py +0 -10
  61. dataeval/models/__init__.py +0 -15
  62. dataeval/models/tensorflow/__init__.py +0 -6
  63. dataeval-0.64.0.dist-info/RECORD +0 -60
  64. {dataeval-0.64.0.dist-info → dataeval-0.66.0.dist-info}/LICENSE.txt +0 -0
  65. {dataeval-0.64.0.dist-info → dataeval-0.66.0.dist-info}/WHEEL +0 -0
@@ -1,345 +1,362 @@
1
- from abc import abstractmethod
2
- from enum import Flag
3
- from typing import Any, Callable, Dict, Generic, Iterable, List, Optional, Sequence, TypeVar, Union
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Callable, Iterable
4
5
 
5
6
  import numpy as np
6
- from numpy.typing import ArrayLike
7
+ from numpy.typing import ArrayLike, NDArray
7
8
  from scipy.stats import entropy, kurtosis, skew
8
9
 
9
- from dataeval._internal.flags import ImageHash, ImageProperty, ImageStatistics, ImageStatsFlags, ImageVisuals
10
+ from dataeval._internal.flags import ImageStat, to_distinct, verify_supported
10
11
  from dataeval._internal.interop import to_numpy_iter
11
- from dataeval._internal.metrics.base import EvaluateMixin
12
12
  from dataeval._internal.metrics.utils import edge_filter, get_bitdepth, normalize_image_shape, pchash, rescale, xxhash
13
+ from dataeval._internal.output import OutputMetadata, populate_defaults, set_metadata
13
14
 
14
- QUARTILES = (0, 25, 50, 75, 100)
15
+ CH_IDX_MAP = "ch_idx_map"
15
16
 
16
- TBatch = TypeVar("TBatch", bound=Sequence[ArrayLike])
17
- TFlag = TypeVar("TFlag", bound=Flag)
18
-
19
-
20
- class BaseStatsMetric(EvaluateMixin, Generic[TBatch, TFlag]):
21
- def __init__(self, flags: TFlag):
22
- self.flags = flags
23
- self.results = []
24
-
25
- @abstractmethod
26
- def update(self, images: TBatch) -> None:
27
- """
28
- Updates internal metric cache for later calculation
29
-
30
- Parameters
31
- ----------
32
- batch : Sequence
33
- Sequence of images to be processed
34
- """
35
-
36
- def compute(self) -> Dict[str, Any]:
37
- """
38
- Computes the specified measures on the cached values
39
-
40
- Returns
41
- -------
42
- Dict[str, Any]
43
- Dictionary results of the specified measures
44
- """
45
- return {stat: [result[stat] for result in self.results] for stat in self.results[0]}
46
-
47
- def reset(self) -> None:
48
- """
49
- Resets the internal metric cache
50
- """
51
- self.results = []
52
-
53
- def _map(self, func_map: Dict[Flag, Callable]) -> Dict[str, Any]:
54
- """Calculates the measures for each flag if it is selected."""
55
- results = {}
56
- for flag, func in func_map.items():
57
- if not flag.name:
58
- raise ValueError("Provided flag to set value does not have a name.")
59
- if flag & self.flags:
60
- results[flag.name.lower()] = func()
61
- return results
62
-
63
- def _keys(self) -> List[str]:
64
- """Returns the list of measures to be calculated."""
65
- flags = (
66
- self.flags
67
- if isinstance(self.flags, Iterable) # py3.11
68
- else [flag for flag in list(self.flags.__class__) if flag & self.flags]
69
- )
70
- return [flag.name.lower() for flag in flags if flag.name is not None]
71
-
72
- def evaluate(self, images: TBatch) -> Dict[str, Any]:
73
- """Calculate metric results given a single batch of images"""
74
- if self.results:
75
- raise RuntimeError("Call reset before calling evaluate")
76
-
77
- self.update(images)
78
- results = self.compute()
79
- self.reset()
80
- return results
81
-
82
-
83
- class ImageHashMetric(BaseStatsMetric):
84
- """
85
- Hashes images using the specified algorithms
86
17
 
87
- Parameters
18
+ @dataclass(frozen=True)
19
+ class StatsOutput(OutputMetadata):
20
+ """
21
+ Attributes
88
22
  ----------
89
- flags : ImageHash
90
- Algorithm(s) to calculate a hash as hex digest
23
+ xxhash : List[str]
24
+ xxHash hash of the images as a hex string
25
+ pchash : List[str]
26
+ Perception hash of the images as a hex string
27
+ width: NDArray[np.uint16]
28
+ Width of the images in pixels
29
+ height: NDArray[np.uint16]
30
+ Height of the images in pixels
31
+ channels: NDArray[np.uint8]
32
+ Channel count of the images in pixels
33
+ size: NDArray[np.uint32]
34
+ Size of the images in pixels
35
+ aspect_ratio: NDArray[np.float16]
36
+ Aspect ratio of the images (width/height)
37
+ depth: NDArray[np.uint8]
38
+ Color depth of the images in bits
39
+ brightness: NDArray[np.float16]
40
+ Brightness of the images
41
+ blurriness: NDArray[np.float16]
42
+ Blurriness of the images
43
+ missing: NDArray[np.float16]
44
+ Percentage of the images with missing pixels
45
+ zero: NDArray[np.float16]
46
+ Percentage of the images with zero value pixels
47
+ mean: NDArray[np.float16]
48
+ Mean of the pixel values of the images
49
+ std: NDArray[np.float16]
50
+ Standard deviation of the pixel values of the images
51
+ var: NDArray[np.float16]
52
+ Variance of the pixel values of the images
53
+ skew: NDArray[np.float16]
54
+ Skew of the pixel values of the images
55
+ kurtosis: NDArray[np.float16]
56
+ Kurtosis of the pixel values of the images
57
+ percentiles: NDArray[np.float16]
58
+ Percentiles of the pixel values of the images with quartiles of (0, 25, 50, 75, 100)
59
+ histogram: NDArray[np.uint32]
60
+ Histogram of the pixel values of the images across 256 bins scaled between 0 and 1
61
+ entropy: NDArray[np.float16]
62
+ Entropy of the pixel values of the images
63
+ ch_idx_map: Dict[int, List[int]]
64
+ Per-channel mapping of indices for each metric
91
65
  """
92
66
 
93
- def __init__(self, flags: ImageHash = ImageHash.ALL):
94
- super().__init__(flags)
95
-
96
- def update(self, images: Iterable[ArrayLike]) -> None:
97
- for image in to_numpy_iter(images):
98
- results = self._map(
99
- {
100
- ImageHash.XXHASH: lambda: xxhash(image),
101
- ImageHash.PCHASH: lambda: pchash(image),
102
- }
103
- )
104
- self.results.append(results)
67
+ xxhash: list[str]
68
+ pchash: list[str]
69
+ width: NDArray[np.uint16]
70
+ height: NDArray[np.uint16]
71
+ channels: NDArray[np.uint8]
72
+ size: NDArray[np.uint32]
73
+ aspect_ratio: NDArray[np.float16]
74
+ depth: NDArray[np.uint8]
75
+ brightness: NDArray[np.float16]
76
+ blurriness: NDArray[np.float16]
77
+ missing: NDArray[np.float16]
78
+ zero: NDArray[np.float16]
79
+ mean: NDArray[np.float16]
80
+ std: NDArray[np.float16]
81
+ var: NDArray[np.float16]
82
+ skew: NDArray[np.float16]
83
+ kurtosis: NDArray[np.float16]
84
+ percentiles: NDArray[np.float16]
85
+ histogram: NDArray[np.uint32]
86
+ entropy: NDArray[np.float16]
87
+ ch_idx_map: dict[int, list[int]]
88
+
89
+ def dict(self):
90
+ return {k: v for k, v in self.__dict__.items() if not k.startswith("_") and len(v) > 0}
105
91
 
106
92
 
107
- class ImagePropertyMetric(BaseStatsMetric):
108
- """
109
- Calculates specified image properties
93
+ QUARTILES = (0, 25, 50, 75, 100)
110
94
 
111
- Parameters
112
- ----------
113
- flags: ImageProperty
114
- Property(ies) to calculate for each image
95
+ IMAGESTATS_FN_MAP: dict[ImageStat, Callable[[NDArray], Any]] = {
96
+ ImageStat.XXHASH: lambda x: xxhash(x),
97
+ ImageStat.PCHASH: lambda x: pchash(x),
98
+ ImageStat.WIDTH: lambda x: np.uint16(x.shape[-1]),
99
+ ImageStat.HEIGHT: lambda x: np.uint16(x.shape[-2]),
100
+ ImageStat.CHANNELS: lambda x: np.uint8(x.shape[-3]),
101
+ ImageStat.SIZE: lambda x: np.uint32(np.prod(x.shape[-2:])),
102
+ ImageStat.ASPECT_RATIO: lambda x: np.float16(x.shape[-1] / x.shape[-2]),
103
+ ImageStat.DEPTH: lambda x: np.uint8(get_bitdepth(x).depth),
104
+ ImageStat.BRIGHTNESS: lambda x: np.float16(np.mean(x)),
105
+ ImageStat.BLURRINESS: lambda x: np.float16(np.std(edge_filter(np.mean(x, axis=0)))),
106
+ ImageStat.MISSING: lambda x: np.float16(np.sum(np.isnan(x)) / np.prod(x.shape[-2:])),
107
+ ImageStat.ZERO: lambda x: np.float16(np.count_nonzero(x == 0) / np.prod(x.shape[-2:])),
108
+ ImageStat.MEAN: lambda x: np.float16(np.mean(x)),
109
+ ImageStat.STD: lambda x: np.float16(np.std(x)),
110
+ ImageStat.VAR: lambda x: np.float16(np.var(x)),
111
+ ImageStat.SKEW: lambda x: np.float16(skew(x.ravel())),
112
+ ImageStat.KURTOSIS: lambda x: np.float16(kurtosis(x.ravel())),
113
+ ImageStat.PERCENTILES: lambda x: np.float16(np.percentile(x, q=QUARTILES)),
114
+ ImageStat.HISTOGRAM: lambda x: np.uint32(np.histogram(x, 256, (0, 1))[0]),
115
+ ImageStat.ENTROPY: lambda x: np.float16(entropy(x)),
116
+ }
117
+
118
+ CHANNELSTATS_FN_MAP: dict[ImageStat, Callable[[NDArray], Any]] = {
119
+ ImageStat.MEAN: lambda x: np.float16(np.mean(x, axis=1)),
120
+ ImageStat.STD: lambda x: np.float16(np.std(x, axis=1)),
121
+ ImageStat.VAR: lambda x: np.float16(np.var(x, axis=1)),
122
+ ImageStat.SKEW: lambda x: np.float16(skew(x, axis=1)),
123
+ ImageStat.KURTOSIS: lambda x: np.float16(kurtosis(x, axis=1)),
124
+ ImageStat.PERCENTILES: lambda x: np.float16(np.percentile(x, q=QUARTILES, axis=1).T),
125
+ ImageStat.HISTOGRAM: lambda x: np.uint32(np.apply_along_axis(lambda y: np.histogram(y, 256, (0, 1))[0], 1, x)),
126
+ ImageStat.ENTROPY: lambda x: np.float16(entropy(x, axis=1)),
127
+ }
128
+
129
+
130
+ def run_stats(
131
+ images: Iterable[ArrayLike],
132
+ flags: ImageStat,
133
+ fn_map: dict[ImageStat, Callable[[NDArray], Any]],
134
+ flatten: bool,
135
+ ):
115
136
  """
137
+ Compute specified statistics on a set of images.
116
138
 
117
- def __init__(self, flags: ImageProperty = ImageProperty.ALL):
118
- super().__init__(flags)
119
-
120
- def update(self, images: Iterable[ArrayLike]) -> None:
121
- for image in to_numpy_iter(images):
122
- results = self._map(
123
- {
124
- ImageProperty.WIDTH: lambda: np.int32(image.shape[-1]),
125
- ImageProperty.HEIGHT: lambda: np.int32(image.shape[-2]),
126
- ImageProperty.SIZE: lambda: np.int32(image.shape[-1] * image.shape[-2]),
127
- ImageProperty.ASPECT_RATIO: lambda: image.shape[-1] / np.int32(image.shape[-2]),
128
- ImageProperty.CHANNELS: lambda: image.shape[-3],
129
- ImageProperty.DEPTH: lambda: get_bitdepth(image).depth,
130
- }
131
- )
132
- self.results.append(results)
133
-
134
-
135
- class ImageVisualsMetric(BaseStatsMetric):
136
- """
137
- Calculates specified visual image properties
139
+ This function applies a set of statistical operations to each image in the input iterable,
140
+ based on the specified flags. The function dynamically determines which statistics to apply
141
+ using a flag system and a corresponding function map. It also supports optional image
142
+ flattening for pixel-wise calculations.
138
143
 
139
144
  Parameters
140
145
  ----------
141
- flags: ImageVisuals
142
- Property(ies) to calculate for each image
146
+ images : ArrayLike
147
+ An iterable of images (e.g., list of arrays), where each image is represented as an
148
+ array-like structure (e.g., NumPy arrays).
149
+ flags : ImageStat
150
+ A bitwise flag or set of flags specifying the statistics to compute for each image.
151
+ These flags determine which functions in `fn_map` to apply.
152
+ fn_map : dict[ImageStat, Callable]
153
+ A dictionary mapping `ImageStat` flags to functions that compute the corresponding statistics.
154
+ Each function accepts a NumPy array (representing an image or rescaled pixel data) and returns a result.
155
+ flatten : bool
156
+ If True, the image is flattened into a 2D array for pixel-wise operations. Otherwise, the
157
+ original image dimensions are preserved.
158
+
159
+ Returns
160
+ -------
161
+ list[dict[str, NDArray]]
162
+ A list of dictionaries, where each dictionary contains the computed statistics for an image.
163
+ The dictionary keys correspond to the names of the statistics, and the values are NumPy arrays
164
+ with the results of the computations.
165
+
166
+ Raises
167
+ ------
168
+ ValueError
169
+ If unsupported flags are provided that are not present in `fn_map`.
170
+
171
+ Notes
172
+ -----
173
+ - The function performs image normalization (rescaling the image values)
174
+ before applying some of the statistics.
175
+ - Pixel-level statistics (e.g., brightness, entropy) are computed after
176
+ rescaling and, optionally, flattening the images.
177
+ - For statistics like histograms and entropy, intermediate results may
178
+ be reused to avoid redundant computation.
143
179
  """
144
-
145
- def __init__(self, flags: ImageVisuals = ImageVisuals.ALL):
146
- super().__init__(flags)
147
-
148
- def update(self, images: Iterable[ArrayLike]) -> None:
149
- for image in to_numpy_iter(images):
150
- results = self._map(
151
- {
152
- ImageVisuals.BRIGHTNESS: lambda: np.mean(rescale(image)),
153
- ImageVisuals.BLURRINESS: lambda: np.std(edge_filter(np.mean(image, axis=0))),
154
- ImageVisuals.MISSING: lambda: np.sum(np.isnan(image)),
155
- ImageVisuals.ZERO: lambda: np.int32(np.count_nonzero(image == 0)),
156
- }
157
- )
158
- self.results.append(results)
159
-
160
-
161
- class ImageStatisticsMetric(BaseStatsMetric):
180
+ verify_supported(flags, fn_map)
181
+ flag_dict = to_distinct(flags)
182
+
183
+ results_list: list[dict[str, NDArray]] = []
184
+ for image in to_numpy_iter(images):
185
+ normalized = normalize_image_shape(image)
186
+ scaled = None
187
+ hist = None
188
+ output: dict[str, NDArray] = {}
189
+ for flag, stat in flag_dict.items():
190
+ if flag & (ImageStat.ALL_PIXELSTATS | ImageStat.BRIGHTNESS):
191
+ if scaled is None:
192
+ scaled = rescale(normalized).reshape(image.shape[0], -1) if flatten else rescale(normalized)
193
+ if flag & (ImageStat.HISTOGRAM | ImageStat.ENTROPY):
194
+ if hist is None:
195
+ hist = fn_map[ImageStat.HISTOGRAM](scaled)
196
+ output[stat] = hist if flag & ImageStat.HISTOGRAM else fn_map[flag](hist)
197
+ else:
198
+ output[stat] = fn_map[flag](scaled)
199
+ else:
200
+ output[stat] = fn_map[flag](normalized)
201
+ results_list.append(output)
202
+ return results_list
203
+
204
+
205
+ @set_metadata("dataeval.metrics")
206
+ def imagestats(images: Iterable[ArrayLike], flags: ImageStat = ImageStat.ALL_STATS) -> StatsOutput:
162
207
  """
163
- Calculates descriptive statistics for each image
208
+ Calculates image and pixel statistics for each image
164
209
 
165
- Parameters
166
- ----------
167
- flags: ImageStatistics
168
- Statistic(s) to calculate for each image
169
- """
170
-
171
- def __init__(self, flags: ImageStatistics = ImageStatistics.ALL):
172
- super().__init__(flags)
173
-
174
- def update(self, images: Iterable[ArrayLike]) -> None:
175
- for image in to_numpy_iter(images):
176
- scaled = rescale(image)
177
- if (ImageStatistics.HISTOGRAM | ImageStatistics.ENTROPY) & self.flags:
178
- hist = np.histogram(scaled, bins=256, range=(0, 1))[0]
179
-
180
- results = self._map(
181
- {
182
- ImageStatistics.MEAN: lambda: np.mean(scaled),
183
- ImageStatistics.STD: lambda: np.std(scaled),
184
- ImageStatistics.VAR: lambda: np.var(scaled),
185
- ImageStatistics.SKEW: lambda: np.float32(skew(scaled.ravel())),
186
- ImageStatistics.KURTOSIS: lambda: np.float32(kurtosis(scaled.ravel())),
187
- ImageStatistics.PERCENTILES: lambda: np.percentile(scaled, q=QUARTILES),
188
- ImageStatistics.HISTOGRAM: lambda: hist,
189
- ImageStatistics.ENTROPY: lambda: np.float32(entropy(hist)),
190
- }
191
- )
192
- self.results.append(results)
193
-
194
-
195
- class ChannelStatisticsMetric(BaseStatsMetric):
196
- """
197
- Calculates descriptive statistics for each image per channel
210
+ This function computes various statistical metrics (e.g., mean, standard deviation, entropy)
211
+ on the images as a whole, based on the specified flags. It supports multiple types of statistics
212
+ that can be selected using the `flags` argument.
198
213
 
199
214
  Parameters
200
215
  ----------
201
- flags: ImageStatistics
202
- Statistic(s) to calculate for each image per channel
216
+ images : ArrayLike
217
+ Images to run statistical tests on
218
+ flags : ImageStat, default ImageStat.ALL_STATS
219
+ Metric(s) to calculate for each image. The default flag ``ImageStat.ALL_STATS``
220
+ computes all available statistics.
221
+
222
+ Returns
223
+ -------
224
+ StatsOutput
225
+ A dictionary-like object containing the computed statistics for each image. The keys correspond
226
+ to the names of the statistics (e.g., 'mean', 'std'), and the values are lists of results for
227
+ each image or numpy arrays when the results are multi-dimensional.
228
+
229
+ Notes
230
+ -----
231
+ - All metrics in the ImageStat.ALL_PIXELSTATS flag are scaled based on the perceived bit depth
232
+ (which is derived from the largest pixel value) to allow for better comparison
233
+ between images stored in different formats and different resolutions.
234
+ - ImageStat.ZERO and ImageStat.MISSING are presented as a percentage of total pixel counts
235
+
236
+ Examples
237
+ --------
238
+ Calculating the statistics on the images, whose shape is (C, H, W)
239
+
240
+ >>> results = imagestats(images, flags=ImageStat.MEAN | ImageStat.ALL_VISUALS)
241
+ >>> print(results.mean)
242
+ [0.16650391 0.52050781 0.05471802 0.07702637 0.09875488 0.12188721
243
+ 0.14440918 0.16711426 0.18859863 0.21264648 0.2355957 0.25854492
244
+ 0.27978516 0.3046875 0.32788086 0.35131836 0.37255859 0.39819336
245
+ 0.42163086 0.4453125 0.46630859 0.49267578 0.51660156 0.54052734
246
+ 0.56152344 0.58837891 0.61230469 0.63671875 0.65771484 0.68505859
247
+ 0.70947266 0.73388672 0.75488281 0.78271484 0.80712891 0.83203125
248
+ 0.85302734 0.88134766 0.90625 0.93115234]
249
+ >>> print(results.zero)
250
+ [0.12561035 0. 0. 0. 0.11730957 0.
251
+ 0. 0. 0.10986328 0. 0. 0.
252
+ 0.10266113 0. 0. 0. 0.09570312 0.
253
+ 0. 0. 0.08898926 0. 0. 0.
254
+ 0.08251953 0. 0. 0. 0.07629395 0.
255
+ 0. 0. 0.0703125 0. 0. 0.
256
+ 0.0645752 0. 0. 0. ]
203
257
  """
204
-
205
- def __init__(self, flags: ImageStatistics = ImageStatistics.ALL):
206
- super().__init__(flags)
207
-
208
- def update(self, images: Iterable[ArrayLike]) -> None:
209
- for image in to_numpy_iter(images):
210
- scaled = rescale(image)
211
- flattened = scaled.reshape(image.shape[0], -1)
212
-
213
- if (ImageStatistics.HISTOGRAM | ImageStatistics.ENTROPY) & self.flags:
214
- hist = np.apply_along_axis(lambda x: np.histogram(x, bins=256, range=(0, 1))[0], 1, flattened)
215
-
216
- results = self._map(
217
- {
218
- ImageStatistics.MEAN: lambda: np.mean(flattened, axis=1),
219
- ImageStatistics.STD: lambda: np.std(flattened, axis=1),
220
- ImageStatistics.VAR: lambda: np.var(flattened, axis=1),
221
- ImageStatistics.SKEW: lambda: skew(flattened, axis=1),
222
- ImageStatistics.KURTOSIS: lambda: kurtosis(flattened, axis=1),
223
- ImageStatistics.PERCENTILES: lambda: np.percentile(flattened, q=QUARTILES, axis=1).T,
224
- ImageStatistics.HISTOGRAM: lambda: hist,
225
- ImageStatistics.ENTROPY: lambda: entropy(hist, axis=1),
226
- }
227
- )
228
- self.results.append(results)
229
-
230
-
231
- class BaseAggregateMetric(BaseStatsMetric, Generic[TFlag]):
232
- FLAG_METRIC_MAP: Dict[type, type]
233
- DEFAULT_FLAGS: Sequence[TFlag]
234
-
235
- def __init__(self, flags: Optional[Union[TFlag, Sequence[TFlag]]] = None):
236
- flag_dict = {}
237
- for flag in flags if isinstance(flags, Sequence) else self.DEFAULT_FLAGS if not flags else [flags]:
238
- flag_dict[type(flag)] = flag_dict.setdefault(type(flag), type(flag)(0)) | flag
239
- self._metrics_dict = {
240
- metric: []
241
- for metric in (
242
- self.FLAG_METRIC_MAP[flag_class](flag) for flag_class, flag in flag_dict.items() if flag.value != 0
243
- )
244
- }
245
-
246
-
247
- class ImageStats(BaseAggregateMetric):
258
+ stats = run_stats(images, flags, IMAGESTATS_FN_MAP, False)
259
+ output = {}
260
+ length = len(stats)
261
+ for i, results in enumerate(stats):
262
+ for stat, result in results.items():
263
+ if not isinstance(result, (np.ndarray, np.generic)):
264
+ output.setdefault(stat, []).append(result)
265
+ else:
266
+ shape = () if np.isscalar(result) else result.shape
267
+ output.setdefault(stat, np.empty((length,) + shape))[i] = result
268
+ return StatsOutput(**populate_defaults(output, StatsOutput))
269
+
270
+
271
+ @set_metadata("dataeval.metrics")
272
+ def channelstats(images: Iterable[ArrayLike], flags=ImageStat.ALL_PIXELSTATS) -> StatsOutput:
248
273
  """
249
- Calculates various image property statistics
274
+ Calculates pixel statistics for each image per channel
275
+
276
+ This function computes pixel-level statistics (e.g., mean, variance, etc.) on a per-channel basis
277
+ for each image. The statistics can be selected using the `flags` argument, and the results will
278
+ be grouped by the number of channels (e.g., RGB channels) in each image.
250
279
 
251
280
  Parameters
252
281
  ----------
253
- flags: [ImageHash | ImageProperty | ImageStatistics | ImageVisuals], default None
254
- Metric(s) to calculate for each image per channel - calculates all metrics if None
282
+ images : ArrayLike
283
+ Images to run statistical tests on
284
+ flags: ImageStat, default ImageStat.ALL_PIXELSTATS
285
+ Metric(s) to calculate for each image per channel.
286
+ Only flags within the ``ImageStat.ALL_PIXELSTATS`` category are supported.
287
+
288
+ Returns
289
+ -------
290
+ StatsOutput
291
+ A dictionary-like object containing the computed statistics for each image per channel. The keys
292
+ correspond to the names of the statistics (e.g., 'mean', 'variance'), and the values are numpy arrays
293
+ with results for each channel of each image.
294
+
295
+ Notes
296
+ -----
297
+ - All metrics in the ImageStat.ALL_PIXELSTATS flag are scaled based on the perceived bit depth
298
+ (which is derived from the largest pixel value) to allow for better comparison
299
+ between images stored in different formats and different resolutions.
300
+
301
+ Examples
302
+ --------
303
+ Calculating the statistics on a per channel basis for images, whose shape is (N, C, H, W)
304
+
305
+ >>> results = channelstats(images, flags=ImageStat.MEAN | ImageStat.VAR)
306
+ >>> print(results.mean)
307
+ {3: array([[0.01617, 0.5303 , 0.06525, 0.09735, 0.1295 , 0.1616 , 0.1937 ,
308
+ 0.2258 , 0.2578 , 0.29 , 0.322 , 0.3542 , 0.3865 , 0.4185 ,
309
+ 0.4507 , 0.4827 , 0.5146 , 0.547 , 0.579 , 0.6113 , 0.643 ,
310
+ 0.6753 , 0.7075 , 0.7397 , 0.7715 , 0.8037 , 0.836 , 0.868 ,
311
+ 0.9004 , 0.932 ],
312
+ [0.04828, 0.562 , 0.06726, 0.09937, 0.1315 , 0.1636 , 0.1957 ,
313
+ 0.2278 , 0.26 , 0.292 , 0.3242 , 0.3562 , 0.3884 , 0.4204 ,
314
+ 0.4526 , 0.4846 , 0.5166 , 0.549 , 0.581 , 0.6133 , 0.6455 ,
315
+ 0.6772 , 0.7095 , 0.7417 , 0.774 , 0.8057 , 0.838 , 0.87 ,
316
+ 0.9023 , 0.934 ],
317
+ [0.0804 , 0.594 , 0.0693 , 0.1014 , 0.1334 , 0.1656 , 0.1978 ,
318
+ 0.2299 , 0.262 , 0.294 , 0.3262 , 0.3584 , 0.3904 , 0.4226 ,
319
+ 0.4546 , 0.4868 , 0.519 , 0.551 , 0.583 , 0.615 , 0.6475 ,
320
+ 0.679 , 0.7114 , 0.7437 , 0.776 , 0.808 , 0.84 , 0.872 ,
321
+ 0.9043 , 0.9365 ]], dtype=float16)}
322
+ >>> print(results.var)
323
+ {3: array([[0.00010103, 0.01077 , 0.0001621 , 0.0003605 , 0.0006375 ,
324
+ 0.000993 , 0.001427 , 0.001939 , 0.00253 , 0.003199 ,
325
+ 0.003944 , 0.004772 , 0.005676 , 0.006657 , 0.007717 ,
326
+ 0.00886 , 0.01008 , 0.01137 , 0.01275 , 0.0142 ,
327
+ 0.01573 , 0.01733 , 0.01903 , 0.0208 , 0.02264 ,
328
+ 0.02457 , 0.02657 , 0.02864 , 0.0308 , 0.03305 ],
329
+ [0.0001798 , 0.0121 , 0.0001721 , 0.0003753 , 0.0006566 ,
330
+ 0.001017 , 0.001455 , 0.001972 , 0.002565 , 0.003239 ,
331
+ 0.00399 , 0.00482 , 0.00573 , 0.006714 , 0.007782 ,
332
+ 0.00893 , 0.01015 , 0.011444 , 0.012825 , 0.01428 ,
333
+ 0.01581 , 0.01743 , 0.01912 , 0.02089 , 0.02274 ,
334
+ 0.02466 , 0.02667 , 0.02875 , 0.03091 , 0.03314 ],
335
+ [0.000337 , 0.0135 , 0.0001824 , 0.0003903 , 0.0006766 ,
336
+ 0.00104 , 0.001484 , 0.002005 , 0.002604 , 0.00328 ,
337
+ 0.004036 , 0.00487 , 0.005783 , 0.006775 , 0.00784 ,
338
+ 0.00899 , 0.010216 , 0.01152 , 0.0129 , 0.01436 ,
339
+ 0.0159 , 0.01752 , 0.01921 , 0.02098 , 0.02283 ,
340
+ 0.02477 , 0.02676 , 0.02885 , 0.03102 , 0.03326 ]],
341
+ dtype=float16)}
255
342
  """
256
-
257
- FLAG_METRIC_MAP = {
258
- ImageHash: ImageHashMetric,
259
- ImageProperty: ImagePropertyMetric,
260
- ImageStatistics: ImageStatisticsMetric,
261
- ImageVisuals: ImageVisualsMetric,
262
- }
263
- DEFAULT_FLAGS = [ImageHash.ALL, ImageProperty.ALL, ImageStatistics.ALL, ImageVisuals.ALL]
264
-
265
- def __init__(self, flags: Optional[Union[ImageStatsFlags, Sequence[ImageStatsFlags]]] = None):
266
- super().__init__(flags)
267
- self._length = 0
268
-
269
- def update(self, images: Iterable[ArrayLike]) -> None:
270
- for image in to_numpy_iter(images):
271
- self._length += 1
272
- img = normalize_image_shape(image)
273
- for metric in self._metrics_dict:
274
- metric.update([img])
275
-
276
- def compute(self) -> Dict[str, Any]:
277
- for metric in self._metrics_dict:
278
- self._metrics_dict[metric] = metric.results
279
-
280
- stats = {}
281
- for metric, results in self._metrics_dict.items():
282
- for i, result in enumerate(results):
283
- for stat in metric._keys():
284
- value = result[stat]
285
- if not isinstance(value, (np.ndarray, np.generic)):
286
- if stat not in stats:
287
- stats[stat] = []
288
- stats[stat].append(result[stat])
289
- else:
290
- if stat not in stats:
291
- shape = () if np.isscalar(result[stat]) else result[stat].shape
292
- stats[stat] = np.empty((self._length,) + shape)
293
- stats[stat][i] = result[stat]
294
- return stats
295
-
296
- def reset(self):
297
- self._length = 0
298
- for metric in self._metrics_dict:
299
- metric.reset()
300
- self._metrics_dict[metric] = []
301
-
302
-
303
- class ChannelStats(BaseAggregateMetric):
304
- FLAG_METRIC_MAP = {ImageStatistics: ChannelStatisticsMetric}
305
- DEFAULT_FLAGS = [ImageStatistics.ALL]
306
- IDX_MAP = "idx_map"
307
-
308
- def __init__(self, flags: Optional[ImageStatistics] = None) -> None:
309
- super().__init__(flags)
310
-
311
- def update(self, images: Iterable[ArrayLike]) -> None:
312
- for image in to_numpy_iter(images):
313
- img = normalize_image_shape(image)
314
- for metric in self._metrics_dict:
315
- metric.update([img])
316
-
317
- for metric in self._metrics_dict:
318
- self._metrics_dict[metric] = metric.results
319
-
320
- def compute(self) -> Dict[str, Any]:
321
- # Aggregate all metrics into a single dictionary
322
- stats = {}
323
- channel_stats = set()
324
- for metric, results in self._metrics_dict.items():
325
- for i, result in enumerate(results):
326
- for stat in metric._keys():
327
- channel_stats.update(metric._keys())
328
- channels = result[stat].shape[0]
329
- stats.setdefault(self.IDX_MAP, {}).setdefault(channels, {})[i] = None
330
- stats.setdefault(stat, {}).setdefault(channels, []).append(result[stat])
331
-
332
- # Concatenate list of channel statistics numpy
333
- for stat in channel_stats:
334
- for channel in stats[stat]:
335
- stats[stat][channel] = np.array(stats[stat][channel]).T
336
-
337
- for channel in stats[self.IDX_MAP]:
338
- stats[self.IDX_MAP][channel] = list(stats[self.IDX_MAP][channel].keys())
339
-
340
- return stats
341
-
342
- def reset(self) -> None:
343
- for metric in self._metrics_dict:
344
- metric.reset()
345
- self._metrics_dict[metric] = []
343
+ stats = run_stats(images, flags, CHANNELSTATS_FN_MAP, True)
344
+
345
+ output = {}
346
+ for i, results in enumerate(stats):
347
+ for stat, result in results.items():
348
+ channels = result.shape[0]
349
+ output.setdefault(stat, {}).setdefault(channels, []).append(result)
350
+ output.setdefault(CH_IDX_MAP, {}).setdefault(channels, {})[i] = None
351
+
352
+ # Concatenate list of channel statistics numpy
353
+ for stat in output:
354
+ if stat == CH_IDX_MAP:
355
+ continue
356
+ for channel in output[stat]:
357
+ output[stat][channel] = np.array(output[stat][channel]).T
358
+
359
+ for channel in output[CH_IDX_MAP]:
360
+ output[CH_IDX_MAP][channel] = list(output[CH_IDX_MAP][channel].keys())
361
+
362
+ return StatsOutput(**populate_defaults(output, StatsOutput))