dataeval 0.63.0__py3-none-any.whl → 0.65.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. dataeval/__init__.py +4 -4
  2. dataeval/_internal/detectors/clusterer.py +47 -34
  3. dataeval/_internal/detectors/drift/base.py +53 -35
  4. dataeval/_internal/detectors/drift/cvm.py +5 -4
  5. dataeval/_internal/detectors/drift/ks.py +7 -6
  6. dataeval/_internal/detectors/drift/mmd.py +39 -19
  7. dataeval/_internal/detectors/drift/torch.py +6 -5
  8. dataeval/_internal/detectors/drift/uncertainty.py +7 -8
  9. dataeval/_internal/detectors/duplicates.py +57 -30
  10. dataeval/_internal/detectors/linter.py +40 -24
  11. dataeval/_internal/detectors/ood/ae.py +2 -1
  12. dataeval/_internal/detectors/ood/aegmm.py +2 -1
  13. dataeval/_internal/detectors/ood/base.py +37 -15
  14. dataeval/_internal/detectors/ood/llr.py +9 -8
  15. dataeval/_internal/detectors/ood/vae.py +2 -1
  16. dataeval/_internal/detectors/ood/vaegmm.py +2 -1
  17. dataeval/_internal/flags.py +42 -21
  18. dataeval/_internal/interop.py +3 -12
  19. dataeval/_internal/metrics/balance.py +188 -0
  20. dataeval/_internal/metrics/ber.py +123 -48
  21. dataeval/_internal/metrics/coverage.py +90 -74
  22. dataeval/_internal/metrics/divergence.py +101 -67
  23. dataeval/_internal/metrics/diversity.py +211 -0
  24. dataeval/_internal/metrics/parity.py +287 -155
  25. dataeval/_internal/metrics/stats.py +198 -317
  26. dataeval/_internal/metrics/uap.py +40 -29
  27. dataeval/_internal/metrics/utils.py +430 -0
  28. dataeval/_internal/models/tensorflow/losses.py +3 -3
  29. dataeval/_internal/models/tensorflow/trainer.py +3 -2
  30. dataeval/_internal/models/tensorflow/utils.py +4 -3
  31. dataeval/_internal/output.py +82 -0
  32. dataeval/_internal/utils.py +64 -0
  33. dataeval/_internal/workflows/sufficiency.py +96 -107
  34. dataeval/flags/__init__.py +2 -2
  35. dataeval/metrics/__init__.py +26 -7
  36. dataeval/utils/__init__.py +9 -0
  37. {dataeval-0.63.0.dist-info → dataeval-0.65.0.dist-info}/METADATA +1 -1
  38. dataeval-0.65.0.dist-info/RECORD +60 -0
  39. dataeval/_internal/functional/__init__.py +0 -0
  40. dataeval/_internal/functional/ber.py +0 -63
  41. dataeval/_internal/functional/coverage.py +0 -75
  42. dataeval/_internal/functional/divergence.py +0 -16
  43. dataeval/_internal/functional/hash.py +0 -79
  44. dataeval/_internal/functional/metadata.py +0 -136
  45. dataeval/_internal/functional/metadataparity.py +0 -190
  46. dataeval/_internal/functional/uap.py +0 -6
  47. dataeval/_internal/functional/utils.py +0 -158
  48. dataeval/_internal/maite/__init__.py +0 -0
  49. dataeval/_internal/maite/utils.py +0 -30
  50. dataeval/_internal/metrics/base.py +0 -92
  51. dataeval/_internal/metrics/metadata.py +0 -610
  52. dataeval/_internal/metrics/metadataparity.py +0 -67
  53. dataeval-0.63.0.dist-info/RECORD +0 -68
  54. {dataeval-0.63.0.dist-info → dataeval-0.65.0.dist-info}/LICENSE.txt +0 -0
  55. {dataeval-0.63.0.dist-info → dataeval-0.65.0.dist-info}/WHEEL +0 -0
@@ -1,343 +1,224 @@
1
- from enum import Flag
2
- from typing import Any, Callable, Dict, Generic, Iterable, List, Optional, Sequence, TypeVar, Union
1
+ from dataclasses import dataclass
2
+ from typing import Any, Callable, Dict, Iterable, List
3
3
 
4
4
  import numpy as np
5
+ from numpy.typing import ArrayLike, NDArray
5
6
  from scipy.stats import entropy, kurtosis, skew
6
7
 
7
- from dataeval._internal.flags import ImageHash, ImageProperty, ImageStatistics, ImageStatsFlags, ImageVisuals
8
- from dataeval._internal.functional.hash import pchash, xxhash
9
- from dataeval._internal.functional.utils import edge_filter, get_bitdepth, normalize_image_shape, rescale
10
- from dataeval._internal.interop import ArrayLike, to_numpy_iter
11
- from dataeval._internal.metrics.base import EvaluateMixin, MetricMixin
8
+ from dataeval._internal.flags import ImageStat, to_distinct, verify_supported
9
+ from dataeval._internal.interop import to_numpy_iter
10
+ from dataeval._internal.metrics.utils import edge_filter, get_bitdepth, normalize_image_shape, pchash, rescale, xxhash
11
+ from dataeval._internal.output import OutputMetadata, populate_defaults, set_metadata
12
12
 
13
- QUARTILES = (0, 25, 50, 75, 100)
14
-
15
- TBatch = TypeVar("TBatch", bound=Sequence[ArrayLike])
16
- TFlag = TypeVar("TFlag", bound=Flag)
17
-
18
-
19
- class BaseStatsMetric(EvaluateMixin, MetricMixin, Generic[TBatch, TFlag]):
20
- def __init__(self, flags: TFlag):
21
- self.flags = flags
22
- self.results = []
23
-
24
- def update(self, images: TBatch) -> None:
25
- """
26
- Updates internal metric cache for later calculation
27
-
28
- Parameters
29
- ----------
30
- batch : Sequence
31
- Sequence of images to be processed
32
- """
33
-
34
- def compute(self) -> Dict[str, Any]:
35
- """
36
- Computes the specified measures on the cached values
37
-
38
- Returns
39
- -------
40
- Dict[str, Any]
41
- Dictionary results of the specified measures
42
- """
43
- return {stat: [result[stat] for result in self.results] for stat in self.results[0]}
44
-
45
- def reset(self) -> None:
46
- """
47
- Resets the internal metric cache
48
- """
49
- self.results = []
50
-
51
- def _map(self, func_map: Dict[Flag, Callable]) -> Dict[str, Any]:
52
- """Calculates the measures for each flag if it is selected."""
53
- results = {}
54
- for flag, func in func_map.items():
55
- if not flag.name:
56
- raise ValueError("Provided flag to set value does not have a name.")
57
- if flag & self.flags:
58
- results[flag.name.lower()] = func()
59
- return results
60
-
61
- def _keys(self) -> List[str]:
62
- """Returns the list of measures to be calculated."""
63
- flags = (
64
- self.flags
65
- if isinstance(self.flags, Iterable) # py3.11
66
- else [flag for flag in list(self.flags.__class__) if flag & self.flags]
67
- )
68
- return [flag.name.lower() for flag in flags if flag.name is not None]
69
-
70
- def evaluate(self, images: TBatch) -> Dict[str, Any]:
71
- """Calculate metric results given a single batch of images"""
72
- if self.results:
73
- raise RuntimeError("Call reset before calling evaluate")
74
-
75
- self.update(images)
76
- results = self.compute()
77
- self.reset()
78
- return results
79
-
80
-
81
- class ImageHashMetric(BaseStatsMetric):
82
- """
83
- Hashes images using the specified algorithms
84
-
85
- Parameters
86
- ----------
87
- flags : ImageHash
88
- Algorithm(s) to calculate a hash as hex digest
89
- """
90
-
91
- def __init__(self, flags: ImageHash = ImageHash.ALL):
92
- super().__init__(flags)
93
-
94
- def update(self, images: Iterable[ArrayLike]) -> None:
95
- for image in to_numpy_iter(images):
96
- results = self._map(
97
- {
98
- ImageHash.XXHASH: lambda: xxhash(image),
99
- ImageHash.PCHASH: lambda: pchash(image),
100
- }
101
- )
102
- self.results.append(results)
13
+ CH_IDX_MAP = "ch_idx_map"
103
14
 
104
15
 
105
- class ImagePropertyMetric(BaseStatsMetric):
16
+ @dataclass(frozen=True)
17
+ class StatsOutput(OutputMetadata):
106
18
  """
107
- Calculates specified image properties
108
-
109
- Parameters
19
+ Attributes
110
20
  ----------
111
- flags: ImageProperty
112
- Property(ies) to calculate for each image
21
+ xxhash : List[str]
22
+ xxHash hash of the images as a hex string
23
+ pchash : List[str]
24
+ Perception hash of the images as a hex string
25
+ width: NDArray[np.uint16]
26
+ Width of the images in pixels
27
+ height: NDArray[np.uint16]
28
+ Height of the images in pixels
29
+ channels: NDArray[np.uint8]
30
+ Channel count of the images in pixels
31
+ size: NDArray[np.uint32]
32
+ Size of the images in pixels
33
+ aspect_ratio: NDArray[np.float16]
34
+ Aspect ratio of the images (width/height)
35
+ depth: NDArray[np.uint8]
36
+ Color depth of the images in bits
37
+ brightness: NDArray[np.float16]
38
+ Brightness of the images
39
+ blurriness: NDArray[np.float16]
40
+ Blurriness of the images
41
+ missing: NDArray[np.float16]
42
+ Percentage of the images with missing pixels
43
+ zero: NDArray[np.float16]
44
+ Percentage of the images with zero value pixels
45
+ mean: NDArray[np.float16]
46
+ Mean of the pixel values of the images
47
+ std: NDArray[np.float16]
48
+ Standard deviation of the pixel values of the images
49
+ var: NDArray[np.float16]
50
+ Variance of the pixel values of the images
51
+ skew: NDArray[np.float16]
52
+ Skew of the pixel values of the images
53
+ kurtosis: NDArray[np.float16]
54
+ Kurtosis of the pixel values of the images
55
+ percentiles: NDArray[np.float16]
56
+ Percentiles of the pixel values of the images with quartiles of (0, 25, 50, 75, 100)
57
+ histogram: NDArray[np.uint32]
58
+ Histogram of the pixel values of the images across 256 bins scaled between 0 and 1
59
+ entropy: NDArray[np.float16]
60
+ Entropy of the pixel values of the images
61
+ ch_idx_map: Dict[int, List[int]]
62
+ Per-channel mapping of indices for each metric
113
63
  """
114
64
 
115
- def __init__(self, flags: ImageProperty = ImageProperty.ALL):
116
- super().__init__(flags)
65
+ xxhash: List[str]
66
+ pchash: List[str]
67
+ width: NDArray[np.uint16]
68
+ height: NDArray[np.uint16]
69
+ channels: NDArray[np.uint8]
70
+ size: NDArray[np.uint32]
71
+ aspect_ratio: NDArray[np.float16]
72
+ depth: NDArray[np.uint8]
73
+ brightness: NDArray[np.float16]
74
+ blurriness: NDArray[np.float16]
75
+ missing: NDArray[np.float16]
76
+ zero: NDArray[np.float16]
77
+ mean: NDArray[np.float16]
78
+ std: NDArray[np.float16]
79
+ var: NDArray[np.float16]
80
+ skew: NDArray[np.float16]
81
+ kurtosis: NDArray[np.float16]
82
+ percentiles: NDArray[np.float16]
83
+ histogram: NDArray[np.uint32]
84
+ entropy: NDArray[np.float16]
85
+ ch_idx_map: Dict[int, List[int]]
86
+
87
+ def dict(self):
88
+ return {k: v for k, v in self.__dict__.items() if not k.startswith("_") and len(v) > 0}
117
89
 
118
- def update(self, images: Iterable[ArrayLike]) -> None:
119
- for image in to_numpy_iter(images):
120
- results = self._map(
121
- {
122
- ImageProperty.WIDTH: lambda: np.int32(image.shape[-1]),
123
- ImageProperty.HEIGHT: lambda: np.int32(image.shape[-2]),
124
- ImageProperty.SIZE: lambda: np.int32(image.shape[-1] * image.shape[-2]),
125
- ImageProperty.ASPECT_RATIO: lambda: image.shape[-1] / np.int32(image.shape[-2]),
126
- ImageProperty.CHANNELS: lambda: image.shape[-3],
127
- ImageProperty.DEPTH: lambda: get_bitdepth(image).depth,
128
- }
129
- )
130
- self.results.append(results)
131
-
132
-
133
- class ImageVisualsMetric(BaseStatsMetric):
134
- """
135
- Calculates specified visual image properties
136
-
137
- Parameters
138
- ----------
139
- flags: ImageVisuals
140
- Property(ies) to calculate for each image
141
- """
142
-
143
- def __init__(self, flags: ImageVisuals = ImageVisuals.ALL):
144
- super().__init__(flags)
145
-
146
- def update(self, images: Iterable[ArrayLike]) -> None:
147
- for image in to_numpy_iter(images):
148
- results = self._map(
149
- {
150
- ImageVisuals.BRIGHTNESS: lambda: np.mean(rescale(image)),
151
- ImageVisuals.BLURRINESS: lambda: np.std(edge_filter(np.mean(image, axis=0))),
152
- ImageVisuals.MISSING: lambda: np.sum(np.isnan(image)),
153
- ImageVisuals.ZERO: lambda: np.int32(np.count_nonzero(image == 0)),
154
- }
155
- )
156
- self.results.append(results)
157
90
 
91
+ QUARTILES = (0, 25, 50, 75, 100)
158
92
 
159
- class ImageStatisticsMetric(BaseStatsMetric):
93
+ IMAGESTATS_FN_MAP: Dict[ImageStat, Callable[[NDArray], Any]] = {
94
+ ImageStat.XXHASH: lambda x: xxhash(x),
95
+ ImageStat.PCHASH: lambda x: pchash(x),
96
+ ImageStat.WIDTH: lambda x: np.uint16(x.shape[-1]),
97
+ ImageStat.HEIGHT: lambda x: np.uint16(x.shape[-2]),
98
+ ImageStat.CHANNELS: lambda x: np.uint8(x.shape[-3]),
99
+ ImageStat.SIZE: lambda x: np.uint32(np.prod(x.shape[-2:])),
100
+ ImageStat.ASPECT_RATIO: lambda x: np.float16(x.shape[-1] / x.shape[-2]),
101
+ ImageStat.DEPTH: lambda x: np.uint8(get_bitdepth(x).depth),
102
+ ImageStat.BRIGHTNESS: lambda x: np.float16(np.mean(x)),
103
+ ImageStat.BLURRINESS: lambda x: np.float16(np.std(edge_filter(np.mean(x, axis=0)))),
104
+ ImageStat.MISSING: lambda x: np.float16(np.sum(np.isnan(x)) / np.prod(x.shape[-2:])),
105
+ ImageStat.ZERO: lambda x: np.float16(np.count_nonzero(x == 0) / np.prod(x.shape[-2:])),
106
+ ImageStat.MEAN: lambda x: np.float16(np.mean(x)),
107
+ ImageStat.STD: lambda x: np.float16(np.std(x)),
108
+ ImageStat.VAR: lambda x: np.float16(np.var(x)),
109
+ ImageStat.SKEW: lambda x: np.float16(skew(x.ravel())),
110
+ ImageStat.KURTOSIS: lambda x: np.float16(kurtosis(x.ravel())),
111
+ ImageStat.PERCENTILES: lambda x: np.float16(np.percentile(x, q=QUARTILES)),
112
+ ImageStat.HISTOGRAM: lambda x: np.uint32(np.histogram(x, 256, (0, 1))[0]),
113
+ ImageStat.ENTROPY: lambda x: np.float16(entropy(x)),
114
+ }
115
+
116
+ CHANNELSTATS_FN_MAP: Dict[ImageStat, Callable[[NDArray], Any]] = {
117
+ ImageStat.MEAN: lambda x: np.float16(np.mean(x, axis=1)),
118
+ ImageStat.STD: lambda x: np.float16(np.std(x, axis=1)),
119
+ ImageStat.VAR: lambda x: np.float16(np.var(x, axis=1)),
120
+ ImageStat.SKEW: lambda x: np.float16(skew(x, axis=1)),
121
+ ImageStat.KURTOSIS: lambda x: np.float16(kurtosis(x, axis=1)),
122
+ ImageStat.PERCENTILES: lambda x: np.float16(np.percentile(x, q=QUARTILES, axis=1).T),
123
+ ImageStat.HISTOGRAM: lambda x: np.uint32(np.apply_along_axis(lambda y: np.histogram(y, 256, (0, 1))[0], 1, x)),
124
+ ImageStat.ENTROPY: lambda x: np.float16(entropy(x, axis=1)),
125
+ }
126
+
127
+
128
+ def run_stats(
129
+ images: Iterable[ArrayLike],
130
+ flags: ImageStat,
131
+ fn_map: Dict[ImageStat, Callable[[NDArray], Any]],
132
+ flatten: bool,
133
+ ):
134
+ verify_supported(flags, fn_map)
135
+ flag_dict = to_distinct(flags)
136
+
137
+ results_list: List[Dict[str, NDArray]] = []
138
+ for image in to_numpy_iter(images):
139
+ normalized = normalize_image_shape(image)
140
+ scaled = None
141
+ hist = None
142
+ output: Dict[str, NDArray] = {}
143
+ for flag, stat in flag_dict.items():
144
+ if flag & (ImageStat.ALL_PIXELSTATS | ImageStat.BRIGHTNESS):
145
+ if scaled is None:
146
+ scaled = rescale(normalized).reshape(image.shape[0], -1) if flatten else rescale(normalized)
147
+ if flag & (ImageStat.HISTOGRAM | ImageStat.ENTROPY):
148
+ if hist is None:
149
+ hist = fn_map[ImageStat.HISTOGRAM](scaled)
150
+ output[stat] = hist if flag & ImageStat.HISTOGRAM else fn_map[flag](hist)
151
+ else:
152
+ output[stat] = fn_map[flag](scaled)
153
+ else:
154
+ output[stat] = fn_map[flag](normalized)
155
+ results_list.append(output)
156
+ return results_list
157
+
158
+
159
+ @set_metadata("dataeval.metrics")
160
+ def imagestats(images: Iterable[ArrayLike], flags: ImageStat = ImageStat.ALL_STATS) -> StatsOutput:
160
161
  """
161
- Calculates descriptive statistics for each image
162
+ Calculates image and pixel statistics for each image
162
163
 
163
164
  Parameters
164
165
  ----------
165
- flags: ImageStatistics
166
- Statistic(s) to calculate for each image
166
+ images : Iterable[ArrayLike]
167
+ Images to run statistical tests on
168
+ flags : ImageStat, default ImageStat.ALL_STATS
169
+ Metric(s) to calculate for each image
170
+
171
+ Returns
172
+ -------
173
+ Dict[str, Any]
167
174
  """
168
-
169
- def __init__(self, flags: ImageStatistics = ImageStatistics.ALL):
170
- super().__init__(flags)
171
-
172
- def update(self, images: Iterable[ArrayLike]) -> None:
173
- for image in to_numpy_iter(images):
174
- scaled = rescale(image)
175
- if (ImageStatistics.HISTOGRAM | ImageStatistics.ENTROPY) & self.flags:
176
- hist = np.histogram(scaled, bins=256, range=(0, 1))[0]
177
-
178
- results = self._map(
179
- {
180
- ImageStatistics.MEAN: lambda: np.mean(scaled),
181
- ImageStatistics.STD: lambda: np.std(scaled),
182
- ImageStatistics.VAR: lambda: np.var(scaled),
183
- ImageStatistics.SKEW: lambda: np.float32(skew(scaled.ravel())),
184
- ImageStatistics.KURTOSIS: lambda: np.float32(kurtosis(scaled.ravel())),
185
- ImageStatistics.PERCENTILES: lambda: np.percentile(scaled, q=QUARTILES),
186
- ImageStatistics.HISTOGRAM: lambda: hist,
187
- ImageStatistics.ENTROPY: lambda: np.float32(entropy(hist)),
188
- }
189
- )
190
- self.results.append(results)
191
-
192
-
193
- class ChannelStatisticsMetric(BaseStatsMetric):
175
+ stats = run_stats(images, flags, IMAGESTATS_FN_MAP, False)
176
+ output = {}
177
+ length = len(stats)
178
+ for i, results in enumerate(stats):
179
+ for stat, result in results.items():
180
+ if not isinstance(result, (np.ndarray, np.generic)):
181
+ output.setdefault(stat, []).append(result)
182
+ else:
183
+ shape = () if np.isscalar(result) else result.shape
184
+ output.setdefault(stat, np.empty((length,) + shape))[i] = result
185
+ return StatsOutput(**populate_defaults(output, StatsOutput))
186
+
187
+
188
+ @set_metadata("dataeval.metrics")
189
+ def channelstats(images: Iterable[ArrayLike], flags=ImageStat.ALL_PIXELSTATS) -> StatsOutput:
194
190
  """
195
- Calculates descriptive statistics for each image per channel
191
+ Calculates pixel statistics for each image per channel
196
192
 
197
193
  Parameters
198
194
  ----------
199
- flags: ImageStatistics
195
+ images : Iterable[ArrayLike]
196
+ Images to run statistical tests on
197
+ flags: ImageStat, default ImageStat.ALL_PIXELSTATS
200
198
  Statistic(s) to calculate for each image per channel
201
- """
202
-
203
- def __init__(self, flags: ImageStatistics = ImageStatistics.ALL):
204
- super().__init__(flags)
205
-
206
- def update(self, images: Iterable[ArrayLike]) -> None:
207
- for image in to_numpy_iter(images):
208
- scaled = rescale(image)
209
- flattened = scaled.reshape(image.shape[0], -1)
210
-
211
- if (ImageStatistics.HISTOGRAM | ImageStatistics.ENTROPY) & self.flags:
212
- hist = np.apply_along_axis(lambda x: np.histogram(x, bins=256, range=(0, 1))[0], 1, flattened)
199
+ Only flags in the ImageStat.ALL_PIXELSTATS category are supported
213
200
 
214
- results = self._map(
215
- {
216
- ImageStatistics.MEAN: lambda: np.mean(flattened, axis=1),
217
- ImageStatistics.STD: lambda: np.std(flattened, axis=1),
218
- ImageStatistics.VAR: lambda: np.var(flattened, axis=1),
219
- ImageStatistics.SKEW: lambda: skew(flattened, axis=1),
220
- ImageStatistics.KURTOSIS: lambda: kurtosis(flattened, axis=1),
221
- ImageStatistics.PERCENTILES: lambda: np.percentile(flattened, q=QUARTILES, axis=1).T,
222
- ImageStatistics.HISTOGRAM: lambda: hist,
223
- ImageStatistics.ENTROPY: lambda: entropy(hist, axis=1),
224
- }
225
- )
226
- self.results.append(results)
227
-
228
-
229
- class BaseAggregateMetric(BaseStatsMetric, Generic[TFlag]):
230
- FLAG_METRIC_MAP: Dict[type, type]
231
- DEFAULT_FLAGS: Sequence[TFlag]
232
-
233
- def __init__(self, flags: Optional[Union[TFlag, Sequence[TFlag]]] = None):
234
- flag_dict = {}
235
- for flag in flags if isinstance(flags, Sequence) else self.DEFAULT_FLAGS if not flags else [flags]:
236
- flag_dict[type(flag)] = flag_dict.setdefault(type(flag), type(flag)(0)) | flag
237
- self._metrics_dict = {
238
- metric: []
239
- for metric in (
240
- self.FLAG_METRIC_MAP[flag_class](flag) for flag_class, flag in flag_dict.items() if flag.value != 0
241
- )
242
- }
243
-
244
-
245
- class ImageStats(BaseAggregateMetric):
246
- """
247
- Calculates various image property statistics
248
-
249
- Parameters
250
- ----------
251
- flags: [ImageHash | ImageProperty | ImageStatistics | ImageVisuals], default None
252
- Metric(s) to calculate for each image per channel - calculates all metrics if None
201
+ Returns
202
+ -------
203
+ Dict[str, Any]
253
204
  """
254
-
255
- FLAG_METRIC_MAP = {
256
- ImageHash: ImageHashMetric,
257
- ImageProperty: ImagePropertyMetric,
258
- ImageStatistics: ImageStatisticsMetric,
259
- ImageVisuals: ImageVisualsMetric,
260
- }
261
- DEFAULT_FLAGS = [ImageHash.ALL, ImageProperty.ALL, ImageStatistics.ALL, ImageVisuals.ALL]
262
-
263
- def __init__(self, flags: Optional[Union[ImageStatsFlags, Sequence[ImageStatsFlags]]] = None):
264
- super().__init__(flags)
265
- self._length = 0
266
-
267
- def update(self, images: Iterable[ArrayLike]) -> None:
268
- for image in to_numpy_iter(images):
269
- self._length += 1
270
- img = normalize_image_shape(image)
271
- for metric in self._metrics_dict:
272
- metric.update([img])
273
-
274
- def compute(self) -> Dict[str, Any]:
275
- for metric in self._metrics_dict:
276
- self._metrics_dict[metric] = metric.results
277
-
278
- stats = {}
279
- for metric, results in self._metrics_dict.items():
280
- for i, result in enumerate(results):
281
- for stat in metric._keys():
282
- value = result[stat]
283
- if not isinstance(value, (np.ndarray, np.generic)):
284
- if stat not in stats:
285
- stats[stat] = []
286
- stats[stat].append(result[stat])
287
- else:
288
- if stat not in stats:
289
- shape = () if np.isscalar(result[stat]) else result[stat].shape
290
- stats[stat] = np.empty((self._length,) + shape)
291
- stats[stat][i] = result[stat]
292
- return stats
293
-
294
- def reset(self):
295
- self._length = 0
296
- for metric in self._metrics_dict:
297
- metric.reset()
298
- self._metrics_dict[metric] = []
299
-
300
-
301
- class ChannelStats(BaseAggregateMetric):
302
- FLAG_METRIC_MAP = {ImageStatistics: ChannelStatisticsMetric}
303
- DEFAULT_FLAGS = [ImageStatistics.ALL]
304
- IDX_MAP = "idx_map"
305
-
306
- def __init__(self, flags: Optional[ImageStatistics] = None) -> None:
307
- super().__init__(flags)
308
-
309
- def update(self, images: Iterable[ArrayLike]) -> None:
310
- for image in to_numpy_iter(images):
311
- img = normalize_image_shape(image)
312
- for metric in self._metrics_dict:
313
- metric.update([img])
314
-
315
- for metric in self._metrics_dict:
316
- self._metrics_dict[metric] = metric.results
317
-
318
- def compute(self) -> Dict[str, Any]:
319
- # Aggregate all metrics into a single dictionary
320
- stats = {}
321
- channel_stats = set()
322
- for metric, results in self._metrics_dict.items():
323
- for i, result in enumerate(results):
324
- for stat in metric._keys():
325
- channel_stats.update(metric._keys())
326
- channels = result[stat].shape[0]
327
- stats.setdefault(self.IDX_MAP, {}).setdefault(channels, {})[i] = None
328
- stats.setdefault(stat, {}).setdefault(channels, []).append(result[stat])
329
-
330
- # Concatenate list of channel statistics numpy
331
- for stat in channel_stats:
332
- for channel in stats[stat]:
333
- stats[stat][channel] = np.array(stats[stat][channel]).T
334
-
335
- for channel in stats[self.IDX_MAP]:
336
- stats[self.IDX_MAP][channel] = list(stats[self.IDX_MAP][channel].keys())
337
-
338
- return stats
339
-
340
- def reset(self) -> None:
341
- for metric in self._metrics_dict:
342
- metric.reset()
343
- self._metrics_dict[metric] = []
205
+ stats = run_stats(images, flags, CHANNELSTATS_FN_MAP, True)
206
+
207
+ output = {}
208
+ for i, results in enumerate(stats):
209
+ for stat, result in results.items():
210
+ channels = result.shape[0]
211
+ output.setdefault(stat, {}).setdefault(channels, []).append(result)
212
+ output.setdefault(CH_IDX_MAP, {}).setdefault(channels, {})[i] = None
213
+
214
+ # Concatenate list of channel statistics numpy
215
+ for stat in output:
216
+ if stat == CH_IDX_MAP:
217
+ continue
218
+ for channel in output[stat]:
219
+ output[stat][channel] = np.array(output[stat][channel]).T
220
+
221
+ for channel in output[CH_IDX_MAP]:
222
+ output[CH_IDX_MAP][channel] = list(output[CH_IDX_MAP][channel].keys())
223
+
224
+ return StatsOutput(**populate_defaults(output, StatsOutput))
@@ -4,39 +4,50 @@ FR Test Statistic based estimate for the upperbound
4
4
  average precision using empirical mean precision
5
5
  """
6
6
 
7
- from typing import Dict
7
+ from dataclasses import dataclass
8
8
 
9
- from dataeval._internal.functional.uap import uap
10
- from dataeval._internal.interop import ArrayLike, to_numpy
11
- from dataeval._internal.metrics.base import EvaluateMixin
9
+ from numpy.typing import ArrayLike
10
+ from sklearn.metrics import average_precision_score
12
11
 
12
+ from dataeval._internal.interop import to_numpy
13
+ from dataeval._internal.output import OutputMetadata, set_metadata
13
14
 
14
- class UAP(EvaluateMixin):
15
+
16
+ @dataclass(frozen=True)
17
+ class UAPOutput(OutputMetadata):
18
+ """
19
+ Attributes
20
+ ----------
21
+ uap : float
22
+ The empirical mean precision estimate
15
23
  """
16
- FR Test Statistic based estimate of the empirical mean precision
17
24
 
25
+ uap: float
26
+
27
+
28
+ @set_metadata("dataeval.metrics")
29
+ def uap(labels: ArrayLike, scores: ArrayLike) -> UAPOutput:
30
+ """
31
+ FR Test Statistic based estimate of the empirical mean precision for
32
+ the upperbound average precision
33
+
34
+ Parameters
35
+ ----------
36
+ labels : ArrayLike
37
+ A numpy array of n_samples of class labels with M unique classes.
38
+ scores : ArrayLike
39
+ A 2D array of class probabilities per image
40
+
41
+ Returns
42
+ -------
43
+ Dict[str, float]
44
+ uap : The empirical mean precision estimate
45
+
46
+ Raises
47
+ ------
48
+ ValueError
49
+ If unique classes M < 2
18
50
  """
19
51
 
20
- def evaluate(self, labels: ArrayLike, scores: ArrayLike) -> Dict[str, float]:
21
- """
22
- Estimates the upperbound average precision
23
-
24
- Parameters
25
- ----------
26
- labels : ArrayLike
27
- A numpy array of n_samples of class labels with M unique classes.
28
- scores : ArrayLike
29
- A 2D array of class probabilities per image
30
-
31
- Returns
32
- -------
33
- Dict[str, float]
34
- uap : The empirical mean precision estimate
35
-
36
- Raises
37
- ------
38
- ValueError
39
- If unique classes M < 2
40
- """
41
-
42
- return {"uap": uap(to_numpy(labels), to_numpy(scores))}
52
+ precision = float(average_precision_score(to_numpy(labels), to_numpy(scores), average="weighted"))
53
+ return UAPOutput(precision)