dataeval 0.64.0__py3-none-any.whl → 0.65.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. dataeval/__init__.py +2 -2
  2. dataeval/_internal/detectors/clusterer.py +46 -34
  3. dataeval/_internal/detectors/drift/base.py +52 -35
  4. dataeval/_internal/detectors/drift/cvm.py +4 -4
  5. dataeval/_internal/detectors/drift/ks.py +6 -6
  6. dataeval/_internal/detectors/drift/mmd.py +35 -16
  7. dataeval/_internal/detectors/drift/torch.py +6 -5
  8. dataeval/_internal/detectors/drift/uncertainty.py +7 -7
  9. dataeval/_internal/detectors/duplicates.py +55 -29
  10. dataeval/_internal/detectors/linter.py +40 -24
  11. dataeval/_internal/detectors/ood/base.py +36 -15
  12. dataeval/_internal/detectors/ood/llr.py +7 -7
  13. dataeval/_internal/flags.py +42 -21
  14. dataeval/_internal/interop.py +2 -2
  15. dataeval/_internal/metrics/balance.py +10 -2
  16. dataeval/_internal/metrics/ber.py +6 -5
  17. dataeval/_internal/metrics/coverage.py +15 -8
  18. dataeval/_internal/metrics/divergence.py +41 -7
  19. dataeval/_internal/metrics/diversity.py +17 -12
  20. dataeval/_internal/metrics/parity.py +30 -43
  21. dataeval/_internal/metrics/stats.py +196 -317
  22. dataeval/_internal/metrics/uap.py +5 -2
  23. dataeval/_internal/metrics/utils.py +70 -33
  24. dataeval/_internal/models/tensorflow/losses.py +3 -3
  25. dataeval/_internal/models/tensorflow/trainer.py +3 -2
  26. dataeval/_internal/models/tensorflow/utils.py +4 -3
  27. dataeval/_internal/output.py +82 -0
  28. dataeval/_internal/workflows/sufficiency.py +96 -107
  29. dataeval/flags/__init__.py +2 -2
  30. dataeval/metrics/__init__.py +3 -3
  31. {dataeval-0.64.0.dist-info → dataeval-0.65.0.dist-info}/METADATA +1 -1
  32. dataeval-0.65.0.dist-info/RECORD +60 -0
  33. dataeval/_internal/metrics/base.py +0 -10
  34. dataeval-0.64.0.dist-info/RECORD +0 -60
  35. {dataeval-0.64.0.dist-info → dataeval-0.65.0.dist-info}/LICENSE.txt +0 -0
  36. {dataeval-0.64.0.dist-info → dataeval-0.65.0.dist-info}/WHEEL +0 -0
@@ -1,345 +1,224 @@
1
- from abc import abstractmethod
2
- from enum import Flag
3
- from typing import Any, Callable, Dict, Generic, Iterable, List, Optional, Sequence, TypeVar, Union
1
+ from dataclasses import dataclass
2
+ from typing import Any, Callable, Dict, Iterable, List
4
3
 
5
4
  import numpy as np
6
- from numpy.typing import ArrayLike
5
+ from numpy.typing import ArrayLike, NDArray
7
6
  from scipy.stats import entropy, kurtosis, skew
8
7
 
9
- from dataeval._internal.flags import ImageHash, ImageProperty, ImageStatistics, ImageStatsFlags, ImageVisuals
8
+ from dataeval._internal.flags import ImageStat, to_distinct, verify_supported
10
9
  from dataeval._internal.interop import to_numpy_iter
11
- from dataeval._internal.metrics.base import EvaluateMixin
12
10
  from dataeval._internal.metrics.utils import edge_filter, get_bitdepth, normalize_image_shape, pchash, rescale, xxhash
11
+ from dataeval._internal.output import OutputMetadata, populate_defaults, set_metadata
13
12
 
14
- QUARTILES = (0, 25, 50, 75, 100)
15
-
16
- TBatch = TypeVar("TBatch", bound=Sequence[ArrayLike])
17
- TFlag = TypeVar("TFlag", bound=Flag)
18
-
19
-
20
- class BaseStatsMetric(EvaluateMixin, Generic[TBatch, TFlag]):
21
- def __init__(self, flags: TFlag):
22
- self.flags = flags
23
- self.results = []
24
-
25
- @abstractmethod
26
- def update(self, images: TBatch) -> None:
27
- """
28
- Updates internal metric cache for later calculation
29
-
30
- Parameters
31
- ----------
32
- batch : Sequence
33
- Sequence of images to be processed
34
- """
35
-
36
- def compute(self) -> Dict[str, Any]:
37
- """
38
- Computes the specified measures on the cached values
39
-
40
- Returns
41
- -------
42
- Dict[str, Any]
43
- Dictionary results of the specified measures
44
- """
45
- return {stat: [result[stat] for result in self.results] for stat in self.results[0]}
46
-
47
- def reset(self) -> None:
48
- """
49
- Resets the internal metric cache
50
- """
51
- self.results = []
52
-
53
- def _map(self, func_map: Dict[Flag, Callable]) -> Dict[str, Any]:
54
- """Calculates the measures for each flag if it is selected."""
55
- results = {}
56
- for flag, func in func_map.items():
57
- if not flag.name:
58
- raise ValueError("Provided flag to set value does not have a name.")
59
- if flag & self.flags:
60
- results[flag.name.lower()] = func()
61
- return results
62
-
63
- def _keys(self) -> List[str]:
64
- """Returns the list of measures to be calculated."""
65
- flags = (
66
- self.flags
67
- if isinstance(self.flags, Iterable) # py3.11
68
- else [flag for flag in list(self.flags.__class__) if flag & self.flags]
69
- )
70
- return [flag.name.lower() for flag in flags if flag.name is not None]
71
-
72
- def evaluate(self, images: TBatch) -> Dict[str, Any]:
73
- """Calculate metric results given a single batch of images"""
74
- if self.results:
75
- raise RuntimeError("Call reset before calling evaluate")
76
-
77
- self.update(images)
78
- results = self.compute()
79
- self.reset()
80
- return results
81
-
82
-
83
- class ImageHashMetric(BaseStatsMetric):
84
- """
85
- Hashes images using the specified algorithms
86
-
87
- Parameters
88
- ----------
89
- flags : ImageHash
90
- Algorithm(s) to calculate a hash as hex digest
91
- """
92
-
93
- def __init__(self, flags: ImageHash = ImageHash.ALL):
94
- super().__init__(flags)
95
-
96
- def update(self, images: Iterable[ArrayLike]) -> None:
97
- for image in to_numpy_iter(images):
98
- results = self._map(
99
- {
100
- ImageHash.XXHASH: lambda: xxhash(image),
101
- ImageHash.PCHASH: lambda: pchash(image),
102
- }
103
- )
104
- self.results.append(results)
13
+ CH_IDX_MAP = "ch_idx_map"
105
14
 
106
15
 
107
- class ImagePropertyMetric(BaseStatsMetric):
16
+ @dataclass(frozen=True)
17
+ class StatsOutput(OutputMetadata):
108
18
  """
109
- Calculates specified image properties
110
-
111
- Parameters
19
+ Attributes
112
20
  ----------
113
- flags: ImageProperty
114
- Property(ies) to calculate for each image
21
+ xxhash : List[str]
22
+ xxHash hash of the images as a hex string
23
+ pchash : List[str]
24
+ Perception hash of the images as a hex string
25
+ width: NDArray[np.uint16]
26
+ Width of the images in pixels
27
+ height: NDArray[np.uint16]
28
+ Height of the images in pixels
29
+ channels: NDArray[np.uint8]
30
+ Channel count of the images in pixels
31
+ size: NDArray[np.uint32]
32
+ Size of the images in pixels
33
+ aspect_ratio: NDArray[np.float16]
34
+ Aspect ratio of the images (width/height)
35
+ depth: NDArray[np.uint8]
36
+ Color depth of the images in bits
37
+ brightness: NDArray[np.float16]
38
+ Brightness of the images
39
+ blurriness: NDArray[np.float16]
40
+ Blurriness of the images
41
+ missing: NDArray[np.float16]
42
+ Percentage of the images with missing pixels
43
+ zero: NDArray[np.float16]
44
+ Percentage of the images with zero value pixels
45
+ mean: NDArray[np.float16]
46
+ Mean of the pixel values of the images
47
+ std: NDArray[np.float16]
48
+ Standard deviation of the pixel values of the images
49
+ var: NDArray[np.float16]
50
+ Variance of the pixel values of the images
51
+ skew: NDArray[np.float16]
52
+ Skew of the pixel values of the images
53
+ kurtosis: NDArray[np.float16]
54
+ Kurtosis of the pixel values of the images
55
+ percentiles: NDArray[np.float16]
56
+ Percentiles of the pixel values of the images with quartiles of (0, 25, 50, 75, 100)
57
+ histogram: NDArray[np.uint32]
58
+ Histogram of the pixel values of the images across 256 bins scaled between 0 and 1
59
+ entropy: NDArray[np.float16]
60
+ Entropy of the pixel values of the images
61
+ ch_idx_map: Dict[int, List[int]]
62
+ Per-channel mapping of indices for each metric
115
63
  """
116
64
 
117
- def __init__(self, flags: ImageProperty = ImageProperty.ALL):
118
- super().__init__(flags)
65
+ xxhash: List[str]
66
+ pchash: List[str]
67
+ width: NDArray[np.uint16]
68
+ height: NDArray[np.uint16]
69
+ channels: NDArray[np.uint8]
70
+ size: NDArray[np.uint32]
71
+ aspect_ratio: NDArray[np.float16]
72
+ depth: NDArray[np.uint8]
73
+ brightness: NDArray[np.float16]
74
+ blurriness: NDArray[np.float16]
75
+ missing: NDArray[np.float16]
76
+ zero: NDArray[np.float16]
77
+ mean: NDArray[np.float16]
78
+ std: NDArray[np.float16]
79
+ var: NDArray[np.float16]
80
+ skew: NDArray[np.float16]
81
+ kurtosis: NDArray[np.float16]
82
+ percentiles: NDArray[np.float16]
83
+ histogram: NDArray[np.uint32]
84
+ entropy: NDArray[np.float16]
85
+ ch_idx_map: Dict[int, List[int]]
86
+
87
+ def dict(self):
88
+ return {k: v for k, v in self.__dict__.items() if not k.startswith("_") and len(v) > 0}
119
89
 
120
- def update(self, images: Iterable[ArrayLike]) -> None:
121
- for image in to_numpy_iter(images):
122
- results = self._map(
123
- {
124
- ImageProperty.WIDTH: lambda: np.int32(image.shape[-1]),
125
- ImageProperty.HEIGHT: lambda: np.int32(image.shape[-2]),
126
- ImageProperty.SIZE: lambda: np.int32(image.shape[-1] * image.shape[-2]),
127
- ImageProperty.ASPECT_RATIO: lambda: image.shape[-1] / np.int32(image.shape[-2]),
128
- ImageProperty.CHANNELS: lambda: image.shape[-3],
129
- ImageProperty.DEPTH: lambda: get_bitdepth(image).depth,
130
- }
131
- )
132
- self.results.append(results)
133
-
134
-
135
- class ImageVisualsMetric(BaseStatsMetric):
136
- """
137
- Calculates specified visual image properties
138
-
139
- Parameters
140
- ----------
141
- flags: ImageVisuals
142
- Property(ies) to calculate for each image
143
- """
144
-
145
- def __init__(self, flags: ImageVisuals = ImageVisuals.ALL):
146
- super().__init__(flags)
147
-
148
- def update(self, images: Iterable[ArrayLike]) -> None:
149
- for image in to_numpy_iter(images):
150
- results = self._map(
151
- {
152
- ImageVisuals.BRIGHTNESS: lambda: np.mean(rescale(image)),
153
- ImageVisuals.BLURRINESS: lambda: np.std(edge_filter(np.mean(image, axis=0))),
154
- ImageVisuals.MISSING: lambda: np.sum(np.isnan(image)),
155
- ImageVisuals.ZERO: lambda: np.int32(np.count_nonzero(image == 0)),
156
- }
157
- )
158
- self.results.append(results)
159
90
 
91
+ QUARTILES = (0, 25, 50, 75, 100)
160
92
 
161
- class ImageStatisticsMetric(BaseStatsMetric):
93
+ IMAGESTATS_FN_MAP: Dict[ImageStat, Callable[[NDArray], Any]] = {
94
+ ImageStat.XXHASH: lambda x: xxhash(x),
95
+ ImageStat.PCHASH: lambda x: pchash(x),
96
+ ImageStat.WIDTH: lambda x: np.uint16(x.shape[-1]),
97
+ ImageStat.HEIGHT: lambda x: np.uint16(x.shape[-2]),
98
+ ImageStat.CHANNELS: lambda x: np.uint8(x.shape[-3]),
99
+ ImageStat.SIZE: lambda x: np.uint32(np.prod(x.shape[-2:])),
100
+ ImageStat.ASPECT_RATIO: lambda x: np.float16(x.shape[-1] / x.shape[-2]),
101
+ ImageStat.DEPTH: lambda x: np.uint8(get_bitdepth(x).depth),
102
+ ImageStat.BRIGHTNESS: lambda x: np.float16(np.mean(x)),
103
+ ImageStat.BLURRINESS: lambda x: np.float16(np.std(edge_filter(np.mean(x, axis=0)))),
104
+ ImageStat.MISSING: lambda x: np.float16(np.sum(np.isnan(x)) / np.prod(x.shape[-2:])),
105
+ ImageStat.ZERO: lambda x: np.float16(np.count_nonzero(x == 0) / np.prod(x.shape[-2:])),
106
+ ImageStat.MEAN: lambda x: np.float16(np.mean(x)),
107
+ ImageStat.STD: lambda x: np.float16(np.std(x)),
108
+ ImageStat.VAR: lambda x: np.float16(np.var(x)),
109
+ ImageStat.SKEW: lambda x: np.float16(skew(x.ravel())),
110
+ ImageStat.KURTOSIS: lambda x: np.float16(kurtosis(x.ravel())),
111
+ ImageStat.PERCENTILES: lambda x: np.float16(np.percentile(x, q=QUARTILES)),
112
+ ImageStat.HISTOGRAM: lambda x: np.uint32(np.histogram(x, 256, (0, 1))[0]),
113
+ ImageStat.ENTROPY: lambda x: np.float16(entropy(x)),
114
+ }
115
+
116
+ CHANNELSTATS_FN_MAP: Dict[ImageStat, Callable[[NDArray], Any]] = {
117
+ ImageStat.MEAN: lambda x: np.float16(np.mean(x, axis=1)),
118
+ ImageStat.STD: lambda x: np.float16(np.std(x, axis=1)),
119
+ ImageStat.VAR: lambda x: np.float16(np.var(x, axis=1)),
120
+ ImageStat.SKEW: lambda x: np.float16(skew(x, axis=1)),
121
+ ImageStat.KURTOSIS: lambda x: np.float16(kurtosis(x, axis=1)),
122
+ ImageStat.PERCENTILES: lambda x: np.float16(np.percentile(x, q=QUARTILES, axis=1).T),
123
+ ImageStat.HISTOGRAM: lambda x: np.uint32(np.apply_along_axis(lambda y: np.histogram(y, 256, (0, 1))[0], 1, x)),
124
+ ImageStat.ENTROPY: lambda x: np.float16(entropy(x, axis=1)),
125
+ }
126
+
127
+
128
+ def run_stats(
129
+ images: Iterable[ArrayLike],
130
+ flags: ImageStat,
131
+ fn_map: Dict[ImageStat, Callable[[NDArray], Any]],
132
+ flatten: bool,
133
+ ):
134
+ verify_supported(flags, fn_map)
135
+ flag_dict = to_distinct(flags)
136
+
137
+ results_list: List[Dict[str, NDArray]] = []
138
+ for image in to_numpy_iter(images):
139
+ normalized = normalize_image_shape(image)
140
+ scaled = None
141
+ hist = None
142
+ output: Dict[str, NDArray] = {}
143
+ for flag, stat in flag_dict.items():
144
+ if flag & (ImageStat.ALL_PIXELSTATS | ImageStat.BRIGHTNESS):
145
+ if scaled is None:
146
+ scaled = rescale(normalized).reshape(image.shape[0], -1) if flatten else rescale(normalized)
147
+ if flag & (ImageStat.HISTOGRAM | ImageStat.ENTROPY):
148
+ if hist is None:
149
+ hist = fn_map[ImageStat.HISTOGRAM](scaled)
150
+ output[stat] = hist if flag & ImageStat.HISTOGRAM else fn_map[flag](hist)
151
+ else:
152
+ output[stat] = fn_map[flag](scaled)
153
+ else:
154
+ output[stat] = fn_map[flag](normalized)
155
+ results_list.append(output)
156
+ return results_list
157
+
158
+
159
+ @set_metadata("dataeval.metrics")
160
+ def imagestats(images: Iterable[ArrayLike], flags: ImageStat = ImageStat.ALL_STATS) -> StatsOutput:
162
161
  """
163
- Calculates descriptive statistics for each image
162
+ Calculates image and pixel statistics for each image
164
163
 
165
164
  Parameters
166
165
  ----------
167
- flags: ImageStatistics
168
- Statistic(s) to calculate for each image
166
+ images : Iterable[ArrayLike]
167
+ Images to run statistical tests on
168
+ flags : ImageStat, default ImageStat.ALL_STATS
169
+ Metric(s) to calculate for each image
170
+
171
+ Returns
172
+ -------
173
+ Dict[str, Any]
169
174
  """
170
-
171
- def __init__(self, flags: ImageStatistics = ImageStatistics.ALL):
172
- super().__init__(flags)
173
-
174
- def update(self, images: Iterable[ArrayLike]) -> None:
175
- for image in to_numpy_iter(images):
176
- scaled = rescale(image)
177
- if (ImageStatistics.HISTOGRAM | ImageStatistics.ENTROPY) & self.flags:
178
- hist = np.histogram(scaled, bins=256, range=(0, 1))[0]
179
-
180
- results = self._map(
181
- {
182
- ImageStatistics.MEAN: lambda: np.mean(scaled),
183
- ImageStatistics.STD: lambda: np.std(scaled),
184
- ImageStatistics.VAR: lambda: np.var(scaled),
185
- ImageStatistics.SKEW: lambda: np.float32(skew(scaled.ravel())),
186
- ImageStatistics.KURTOSIS: lambda: np.float32(kurtosis(scaled.ravel())),
187
- ImageStatistics.PERCENTILES: lambda: np.percentile(scaled, q=QUARTILES),
188
- ImageStatistics.HISTOGRAM: lambda: hist,
189
- ImageStatistics.ENTROPY: lambda: np.float32(entropy(hist)),
190
- }
191
- )
192
- self.results.append(results)
193
-
194
-
195
- class ChannelStatisticsMetric(BaseStatsMetric):
175
+ stats = run_stats(images, flags, IMAGESTATS_FN_MAP, False)
176
+ output = {}
177
+ length = len(stats)
178
+ for i, results in enumerate(stats):
179
+ for stat, result in results.items():
180
+ if not isinstance(result, (np.ndarray, np.generic)):
181
+ output.setdefault(stat, []).append(result)
182
+ else:
183
+ shape = () if np.isscalar(result) else result.shape
184
+ output.setdefault(stat, np.empty((length,) + shape))[i] = result
185
+ return StatsOutput(**populate_defaults(output, StatsOutput))
186
+
187
+
188
+ @set_metadata("dataeval.metrics")
189
+ def channelstats(images: Iterable[ArrayLike], flags=ImageStat.ALL_PIXELSTATS) -> StatsOutput:
196
190
  """
197
- Calculates descriptive statistics for each image per channel
191
+ Calculates pixel statistics for each image per channel
198
192
 
199
193
  Parameters
200
194
  ----------
201
- flags: ImageStatistics
195
+ images : Iterable[ArrayLike]
196
+ Images to run statistical tests on
197
+ flags: ImageStat, default ImageStat.ALL_PIXELSTATS
202
198
  Statistic(s) to calculate for each image per channel
203
- """
204
-
205
- def __init__(self, flags: ImageStatistics = ImageStatistics.ALL):
206
- super().__init__(flags)
207
-
208
- def update(self, images: Iterable[ArrayLike]) -> None:
209
- for image in to_numpy_iter(images):
210
- scaled = rescale(image)
211
- flattened = scaled.reshape(image.shape[0], -1)
212
-
213
- if (ImageStatistics.HISTOGRAM | ImageStatistics.ENTROPY) & self.flags:
214
- hist = np.apply_along_axis(lambda x: np.histogram(x, bins=256, range=(0, 1))[0], 1, flattened)
199
+ Only flags in the ImageStat.ALL_PIXELSTATS category are supported
215
200
 
216
- results = self._map(
217
- {
218
- ImageStatistics.MEAN: lambda: np.mean(flattened, axis=1),
219
- ImageStatistics.STD: lambda: np.std(flattened, axis=1),
220
- ImageStatistics.VAR: lambda: np.var(flattened, axis=1),
221
- ImageStatistics.SKEW: lambda: skew(flattened, axis=1),
222
- ImageStatistics.KURTOSIS: lambda: kurtosis(flattened, axis=1),
223
- ImageStatistics.PERCENTILES: lambda: np.percentile(flattened, q=QUARTILES, axis=1).T,
224
- ImageStatistics.HISTOGRAM: lambda: hist,
225
- ImageStatistics.ENTROPY: lambda: entropy(hist, axis=1),
226
- }
227
- )
228
- self.results.append(results)
229
-
230
-
231
- class BaseAggregateMetric(BaseStatsMetric, Generic[TFlag]):
232
- FLAG_METRIC_MAP: Dict[type, type]
233
- DEFAULT_FLAGS: Sequence[TFlag]
234
-
235
- def __init__(self, flags: Optional[Union[TFlag, Sequence[TFlag]]] = None):
236
- flag_dict = {}
237
- for flag in flags if isinstance(flags, Sequence) else self.DEFAULT_FLAGS if not flags else [flags]:
238
- flag_dict[type(flag)] = flag_dict.setdefault(type(flag), type(flag)(0)) | flag
239
- self._metrics_dict = {
240
- metric: []
241
- for metric in (
242
- self.FLAG_METRIC_MAP[flag_class](flag) for flag_class, flag in flag_dict.items() if flag.value != 0
243
- )
244
- }
245
-
246
-
247
- class ImageStats(BaseAggregateMetric):
248
- """
249
- Calculates various image property statistics
250
-
251
- Parameters
252
- ----------
253
- flags: [ImageHash | ImageProperty | ImageStatistics | ImageVisuals], default None
254
- Metric(s) to calculate for each image per channel - calculates all metrics if None
201
+ Returns
202
+ -------
203
+ Dict[str, Any]
255
204
  """
256
-
257
- FLAG_METRIC_MAP = {
258
- ImageHash: ImageHashMetric,
259
- ImageProperty: ImagePropertyMetric,
260
- ImageStatistics: ImageStatisticsMetric,
261
- ImageVisuals: ImageVisualsMetric,
262
- }
263
- DEFAULT_FLAGS = [ImageHash.ALL, ImageProperty.ALL, ImageStatistics.ALL, ImageVisuals.ALL]
264
-
265
- def __init__(self, flags: Optional[Union[ImageStatsFlags, Sequence[ImageStatsFlags]]] = None):
266
- super().__init__(flags)
267
- self._length = 0
268
-
269
- def update(self, images: Iterable[ArrayLike]) -> None:
270
- for image in to_numpy_iter(images):
271
- self._length += 1
272
- img = normalize_image_shape(image)
273
- for metric in self._metrics_dict:
274
- metric.update([img])
275
-
276
- def compute(self) -> Dict[str, Any]:
277
- for metric in self._metrics_dict:
278
- self._metrics_dict[metric] = metric.results
279
-
280
- stats = {}
281
- for metric, results in self._metrics_dict.items():
282
- for i, result in enumerate(results):
283
- for stat in metric._keys():
284
- value = result[stat]
285
- if not isinstance(value, (np.ndarray, np.generic)):
286
- if stat not in stats:
287
- stats[stat] = []
288
- stats[stat].append(result[stat])
289
- else:
290
- if stat not in stats:
291
- shape = () if np.isscalar(result[stat]) else result[stat].shape
292
- stats[stat] = np.empty((self._length,) + shape)
293
- stats[stat][i] = result[stat]
294
- return stats
295
-
296
- def reset(self):
297
- self._length = 0
298
- for metric in self._metrics_dict:
299
- metric.reset()
300
- self._metrics_dict[metric] = []
301
-
302
-
303
- class ChannelStats(BaseAggregateMetric):
304
- FLAG_METRIC_MAP = {ImageStatistics: ChannelStatisticsMetric}
305
- DEFAULT_FLAGS = [ImageStatistics.ALL]
306
- IDX_MAP = "idx_map"
307
-
308
- def __init__(self, flags: Optional[ImageStatistics] = None) -> None:
309
- super().__init__(flags)
310
-
311
- def update(self, images: Iterable[ArrayLike]) -> None:
312
- for image in to_numpy_iter(images):
313
- img = normalize_image_shape(image)
314
- for metric in self._metrics_dict:
315
- metric.update([img])
316
-
317
- for metric in self._metrics_dict:
318
- self._metrics_dict[metric] = metric.results
319
-
320
- def compute(self) -> Dict[str, Any]:
321
- # Aggregate all metrics into a single dictionary
322
- stats = {}
323
- channel_stats = set()
324
- for metric, results in self._metrics_dict.items():
325
- for i, result in enumerate(results):
326
- for stat in metric._keys():
327
- channel_stats.update(metric._keys())
328
- channels = result[stat].shape[0]
329
- stats.setdefault(self.IDX_MAP, {}).setdefault(channels, {})[i] = None
330
- stats.setdefault(stat, {}).setdefault(channels, []).append(result[stat])
331
-
332
- # Concatenate list of channel statistics numpy
333
- for stat in channel_stats:
334
- for channel in stats[stat]:
335
- stats[stat][channel] = np.array(stats[stat][channel]).T
336
-
337
- for channel in stats[self.IDX_MAP]:
338
- stats[self.IDX_MAP][channel] = list(stats[self.IDX_MAP][channel].keys())
339
-
340
- return stats
341
-
342
- def reset(self) -> None:
343
- for metric in self._metrics_dict:
344
- metric.reset()
345
- self._metrics_dict[metric] = []
205
+ stats = run_stats(images, flags, CHANNELSTATS_FN_MAP, True)
206
+
207
+ output = {}
208
+ for i, results in enumerate(stats):
209
+ for stat, result in results.items():
210
+ channels = result.shape[0]
211
+ output.setdefault(stat, {}).setdefault(channels, []).append(result)
212
+ output.setdefault(CH_IDX_MAP, {}).setdefault(channels, {})[i] = None
213
+
214
+ # Concatenate list of channel statistics numpy
215
+ for stat in output:
216
+ if stat == CH_IDX_MAP:
217
+ continue
218
+ for channel in output[stat]:
219
+ output[stat][channel] = np.array(output[stat][channel]).T
220
+
221
+ for channel in output[CH_IDX_MAP]:
222
+ output[CH_IDX_MAP][channel] = list(output[CH_IDX_MAP][channel].keys())
223
+
224
+ return StatsOutput(**populate_defaults(output, StatsOutput))
@@ -4,15 +4,17 @@ FR Test Statistic based estimate for the upperbound
4
4
  average precision using empirical mean precision
5
5
  """
6
6
 
7
- from typing import NamedTuple
7
+ from dataclasses import dataclass
8
8
 
9
9
  from numpy.typing import ArrayLike
10
10
  from sklearn.metrics import average_precision_score
11
11
 
12
12
  from dataeval._internal.interop import to_numpy
13
+ from dataeval._internal.output import OutputMetadata, set_metadata
13
14
 
14
15
 
15
- class UAPOutput(NamedTuple):
16
+ @dataclass(frozen=True)
17
+ class UAPOutput(OutputMetadata):
16
18
  """
17
19
  Attributes
18
20
  ----------
@@ -23,6 +25,7 @@ class UAPOutput(NamedTuple):
23
25
  uap: float
24
26
 
25
27
 
28
+ @set_metadata("dataeval.metrics")
26
29
  def uap(labels: ArrayLike, scores: ArrayLike) -> UAPOutput:
27
30
  """
28
31
  FR Test Statistic based estimate of the empirical mean precision for