dataeval 0.61.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. dataeval/__init__.py +18 -0
  2. dataeval/_internal/detectors/__init__.py +0 -0
  3. dataeval/_internal/detectors/clusterer.py +469 -0
  4. dataeval/_internal/detectors/drift/__init__.py +0 -0
  5. dataeval/_internal/detectors/drift/base.py +265 -0
  6. dataeval/_internal/detectors/drift/cvm.py +97 -0
  7. dataeval/_internal/detectors/drift/ks.py +100 -0
  8. dataeval/_internal/detectors/drift/mmd.py +166 -0
  9. dataeval/_internal/detectors/drift/torch.py +310 -0
  10. dataeval/_internal/detectors/drift/uncertainty.py +149 -0
  11. dataeval/_internal/detectors/duplicates.py +49 -0
  12. dataeval/_internal/detectors/linter.py +78 -0
  13. dataeval/_internal/detectors/ood/__init__.py +0 -0
  14. dataeval/_internal/detectors/ood/ae.py +77 -0
  15. dataeval/_internal/detectors/ood/aegmm.py +69 -0
  16. dataeval/_internal/detectors/ood/base.py +199 -0
  17. dataeval/_internal/detectors/ood/llr.py +284 -0
  18. dataeval/_internal/detectors/ood/vae.py +86 -0
  19. dataeval/_internal/detectors/ood/vaegmm.py +79 -0
  20. dataeval/_internal/flags.py +47 -0
  21. dataeval/_internal/metrics/__init__.py +0 -0
  22. dataeval/_internal/metrics/base.py +92 -0
  23. dataeval/_internal/metrics/ber.py +124 -0
  24. dataeval/_internal/metrics/coverage.py +80 -0
  25. dataeval/_internal/metrics/divergence.py +94 -0
  26. dataeval/_internal/metrics/hash.py +79 -0
  27. dataeval/_internal/metrics/parity.py +180 -0
  28. dataeval/_internal/metrics/stats.py +332 -0
  29. dataeval/_internal/metrics/uap.py +45 -0
  30. dataeval/_internal/metrics/utils.py +158 -0
  31. dataeval/_internal/models/__init__.py +0 -0
  32. dataeval/_internal/models/pytorch/__init__.py +0 -0
  33. dataeval/_internal/models/pytorch/autoencoder.py +202 -0
  34. dataeval/_internal/models/pytorch/blocks.py +46 -0
  35. dataeval/_internal/models/pytorch/utils.py +67 -0
  36. dataeval/_internal/models/tensorflow/__init__.py +0 -0
  37. dataeval/_internal/models/tensorflow/autoencoder.py +317 -0
  38. dataeval/_internal/models/tensorflow/gmm.py +115 -0
  39. dataeval/_internal/models/tensorflow/losses.py +107 -0
  40. dataeval/_internal/models/tensorflow/pixelcnn.py +1106 -0
  41. dataeval/_internal/models/tensorflow/trainer.py +102 -0
  42. dataeval/_internal/models/tensorflow/utils.py +254 -0
  43. dataeval/_internal/workflows/sufficiency.py +555 -0
  44. dataeval/detectors/__init__.py +29 -0
  45. dataeval/flags/__init__.py +3 -0
  46. dataeval/metrics/__init__.py +7 -0
  47. dataeval/models/__init__.py +15 -0
  48. dataeval/models/tensorflow/__init__.py +6 -0
  49. dataeval/models/torch/__init__.py +8 -0
  50. dataeval/py.typed +0 -0
  51. dataeval/workflows/__init__.py +8 -0
  52. dataeval-0.61.0.dist-info/LICENSE.txt +21 -0
  53. dataeval-0.61.0.dist-info/METADATA +114 -0
  54. dataeval-0.61.0.dist-info/RECORD +55 -0
  55. dataeval-0.61.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,180 @@
1
+ import warnings
2
+ from typing import Optional, Tuple
3
+
4
+ import numpy as np
5
+ import scipy
6
+
7
+
8
+ class Parity:
9
+ """
10
+ Class for evaluating statistics of observed and expected class labels, including:
11
+
12
+ - Chi Squared test for statistical independence between expected and observed labels
13
+
14
+ Parameters
15
+ ----------
16
+ expected_labels : np.ndarray
17
+ List of class labels in the expected dataset
18
+ observed_labels : np.ndarray
19
+ List of class labels in the observed dataset
20
+ num_classes : Optional[int]
21
+ The number of unique classes in the datasets. If this is not specified, it will
22
+ be inferred from the set of unique labels in expected_labels and observed_labels
23
+ """
24
+
25
+ def __init__(self, expected_labels: np.ndarray, observed_labels: np.ndarray, num_classes: Optional[int] = None):
26
+ self.set_labels(expected_labels, observed_labels, num_classes)
27
+
28
+ def set_labels(self, expected_labels: np.ndarray, observed_labels: np.ndarray, num_classes: Optional[int] = None):
29
+ """
30
+ Calculates the label distributions for expected and observed labels
31
+ and performs validation on the results.
32
+
33
+ Parameters
34
+ ----------
35
+ expected_labels : np.ndarray
36
+ List of class labels in the expected dataset
37
+ observed_labels : np.ndarray
38
+ List of class labels in the observed dataset
39
+ num_classes : Optional[int]
40
+ The number of unique classes in the datasets. If this is not specified, it will
41
+ be inferred from the set of unique labels in expected_labels and observed_labels
42
+
43
+ Raises
44
+ ------
45
+ ValueError
46
+ If x is empty
47
+ """
48
+ self.num_classes = num_classes
49
+
50
+ # Calculate
51
+ observed_dist = self._calculate_label_dist(observed_labels)
52
+ expected_dist = self._calculate_label_dist(expected_labels)
53
+
54
+ # Validate
55
+ self._validate_dist(observed_dist, "observed")
56
+
57
+ # Normalize
58
+ expected_dist = self._normalize_expected_dist(expected_dist, observed_dist)
59
+
60
+ # Validate normalized expected distribution
61
+ self._validate_dist(expected_dist, f"expected for {np.sum(observed_dist)} observations")
62
+ self._validate_class_balance(expected_dist, observed_dist)
63
+
64
+ self._observed_dist = observed_dist
65
+ self._expected_dist = expected_dist
66
+
67
+ def _normalize_expected_dist(self, expected_dist: np.ndarray, observed_dist: np.ndarray) -> np.ndarray:
68
+ exp_sum = np.sum(expected_dist)
69
+ obs_sum = np.sum(observed_dist)
70
+
71
+ if exp_sum == 0:
72
+ raise ValueError(
73
+ f"Expected label distribution {expected_dist} is all zeros. "
74
+ "Ensure that Parity.expected_dist is set to a list "
75
+ "with at least one nonzero element"
76
+ )
77
+
78
+ # Renormalize expected distribution to have the same total number of labels as the observed dataset
79
+ if exp_sum != obs_sum:
80
+ expected_dist = expected_dist * obs_sum / exp_sum
81
+
82
+ return expected_dist
83
+
84
+ def _calculate_label_dist(self, labels: np.ndarray) -> np.ndarray:
85
+ """
86
+ Calculate the class frequencies associated with a dataset
87
+
88
+ Parameters
89
+ ----------
90
+ labels : np.ndarray
91
+ List of class labels in a dataset
92
+
93
+ Returns
94
+ -------
95
+ label_dist : np.ndarray
96
+ Array representing label distributions
97
+ """
98
+ label_dist = np.bincount(labels, minlength=(self.num_classes if self.num_classes else 0))
99
+ return label_dist
100
+
101
+ def _validate_class_balance(self, expected_dist: np.ndarray, observed_dist: np.ndarray):
102
+ """
103
+ Check if the numbers of unique classes in the datasets are unequal
104
+
105
+ Parameters
106
+ ----------
107
+ expected_dist : np.ndarray
108
+ Array representing expected label distributions
109
+ observed_dist : np.ndarray
110
+ Array representing observed label distributions
111
+
112
+ Raises
113
+ ------
114
+ ValueError
115
+ When exp_ld and obs_ld do not have the same number of classes
116
+ """
117
+ exp_n_cls = len(expected_dist)
118
+ obs_n_cls = len(observed_dist)
119
+ if exp_n_cls != obs_n_cls:
120
+ raise ValueError(
121
+ f"Found {obs_n_cls} unique classes in observed label distribution, "
122
+ f"but found {exp_n_cls} unique classes in expected label distribution,"
123
+ "This can happen when some class ids have zero instances in one dataset but "
124
+ "not in the other. When initializing Parity, "
125
+ "try setting the num_classes parameter to the known number of unique class ids, "
126
+ "so that classes with zero instances are still included in the distributions."
127
+ )
128
+
129
+ def _validate_dist(self, label_dist: np.ndarray, label_name: str):
130
+ """
131
+ Verifies that the given label distribution has labels and checks if
132
+ any labels have frequencies less than 5.
133
+
134
+ Parameters
135
+ ----------
136
+ label_dist : np.ndarray
137
+ Array representing label distributions
138
+
139
+ Raises
140
+ ------
141
+ ValueError
142
+ If label_dist is empty
143
+ Warning
144
+ If any elements of label_dist are less than 5
145
+ """
146
+ if not len(label_dist):
147
+ raise ValueError(f"No labels found in the {label_name} dataset")
148
+ if np.any(label_dist < 5):
149
+ warnings.warn(
150
+ f"Labels {np.where(label_dist<5)[0]} in {label_name}"
151
+ " dataset have frequencies less than 5. This may lead"
152
+ " to invalid chi-squared evaluation."
153
+ )
154
+ warnings.warn(
155
+ f"Labels {np.where(label_dist<5)[0]} in {label_name}"
156
+ " dataset have frequencies less than 5. This may lead"
157
+ " to invalid chi-squared evaluation."
158
+ )
159
+
160
+ def evaluate(self) -> Tuple[np.float64, np.float64]:
161
+ """
162
+ Perform a one-way chi-squared test between observation frequencies and expected frequencies that
163
+ tests the null hypothesis that the observed data has the expected frequencies.
164
+
165
+ This function acts as an interface to the scipy.stats.chisquare method, which is documented at
166
+ https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.chisquare.html
167
+ https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.chisquare.html
168
+
169
+ Returns
170
+ -------
171
+ np.float64
172
+ chi-squared value of the test
173
+ np.float64
174
+ p-value of the test
175
+ """
176
+ cs_result = scipy.stats.chisquare(f_obs=self._observed_dist, f_exp=self._expected_dist)
177
+
178
+ chisquared = cs_result.statistic
179
+ p_value = cs_result.pvalue
180
+ return chisquared, p_value
@@ -0,0 +1,332 @@
1
+ from enum import Flag
2
+ from typing import Any, Callable, Dict, Generic, Iterable, List, Optional, Sequence, TypeVar, Union
3
+
4
+ import numpy as np
5
+ from scipy.stats import entropy, kurtosis, skew
6
+
7
+ from dataeval._internal.flags import ImageHash, ImageProperty, ImageStatistics, ImageStatsFlags, ImageVisuals
8
+ from dataeval._internal.metrics.base import MetricMixin
9
+ from dataeval._internal.metrics.hash import pchash, xxhash
10
+ from dataeval._internal.metrics.utils import edge_filter, get_bitdepth, normalize_image_shape, rescale
11
+
12
+ QUARTILES = (0, 25, 50, 75, 100)
13
+
14
+ TBatch = TypeVar("TBatch", bound=Sequence)
15
+ TFlag = TypeVar("TFlag", bound=Flag)
16
+
17
+
18
+ class BaseStatsMetric(MetricMixin, Generic[TBatch, TFlag]):
19
+ def __init__(self, flags: TFlag):
20
+ self.flags = flags
21
+ self.results = []
22
+
23
+ def update(self, preds: TBatch, targets=None) -> None:
24
+ """
25
+ Updates internal metric cache for later calculation
26
+
27
+ Parameters
28
+ ----------
29
+ batch : Sequence
30
+ Sequence of images to be processed
31
+ """
32
+
33
+ def compute(self) -> Dict[str, Any]:
34
+ """
35
+ Computes the specified measures on the cached values
36
+
37
+ Returns
38
+ -------
39
+ Dict[str, Any]
40
+ Dictionary results of the specified measures
41
+ """
42
+ return {stat: [result[stat] for result in self.results] for stat in self.results[0]}
43
+
44
+ def reset(self) -> None:
45
+ """
46
+ Resets the internal metric cache
47
+ """
48
+ self.results = []
49
+
50
+ def _map(self, func_map: Dict[Flag, Callable]) -> Dict[str, Any]:
51
+ """Calculates the measures for each flag if it is selected."""
52
+ results = {}
53
+ for flag, func in func_map.items():
54
+ if not flag.name:
55
+ raise ValueError("Provided flag to set value does not have a name.")
56
+ if flag & self.flags:
57
+ results[flag.name.lower()] = func()
58
+ return results
59
+
60
+ def _keys(self) -> List[str]:
61
+ """Returns the list of measures to be calculated."""
62
+ flags = (
63
+ self.flags
64
+ if isinstance(self.flags, Iterable) # py3.11
65
+ else [flag for flag in list(self.flags.__class__) if flag & self.flags]
66
+ )
67
+ return [flag.name.lower() for flag in flags if flag.name is not None]
68
+
69
+
70
+ class ImageHashMetric(BaseStatsMetric):
71
+ """
72
+ Hashes images using the specified algorithms
73
+
74
+ Parameters
75
+ ----------
76
+ flags : ImageHash
77
+ Algorithm(s) to calculate a hash as hex digest
78
+ """
79
+
80
+ def __init__(self, flags: ImageHash = ImageHash.ALL):
81
+ super().__init__(flags)
82
+
83
+ def update(self, preds: Iterable[np.ndarray], targets=None) -> None:
84
+ for data in preds:
85
+ results = self._map(
86
+ {
87
+ ImageHash.XXHASH: lambda: xxhash(data),
88
+ ImageHash.PCHASH: lambda: pchash(data),
89
+ }
90
+ )
91
+ self.results.append(results)
92
+
93
+
94
+ class ImagePropertyMetric(BaseStatsMetric):
95
+ """
96
+ Calculates specified image properties
97
+
98
+ Parameters
99
+ ----------
100
+ flags: ImageProperty
101
+ Property(ies) to calculate for each image
102
+ """
103
+
104
+ def __init__(self, flags: ImageProperty = ImageProperty.ALL):
105
+ super().__init__(flags)
106
+
107
+ def update(self, preds: Iterable[np.ndarray], targets=None) -> None:
108
+ for data in preds:
109
+ results = self._map(
110
+ {
111
+ ImageProperty.WIDTH: lambda: np.int32(data.shape[-1]),
112
+ ImageProperty.HEIGHT: lambda: np.int32(data.shape[-2]),
113
+ ImageProperty.SIZE: lambda: np.int32(data.shape[-1] * data.shape[-2]),
114
+ ImageProperty.ASPECT_RATIO: lambda: data.shape[-1] / np.int32(data.shape[-2]),
115
+ ImageProperty.CHANNELS: lambda: data.shape[-3],
116
+ ImageProperty.DEPTH: lambda: get_bitdepth(data).depth,
117
+ }
118
+ )
119
+ self.results.append(results)
120
+
121
+
122
+ class ImageVisualsMetric(BaseStatsMetric):
123
+ """
124
+ Calculates specified visual image properties
125
+
126
+ Parameters
127
+ ----------
128
+ flags: ImageVisuals
129
+ Property(ies) to calculate for each image
130
+ """
131
+
132
+ def __init__(self, flags: ImageVisuals = ImageVisuals.ALL):
133
+ super().__init__(flags)
134
+
135
+ def update(self, preds: Iterable[np.ndarray], targets=None) -> None:
136
+ for data in preds:
137
+ results = self._map(
138
+ {
139
+ ImageVisuals.BRIGHTNESS: lambda: np.mean(rescale(data)),
140
+ ImageVisuals.BLURRINESS: lambda: np.std(edge_filter(np.mean(data, axis=0))),
141
+ ImageVisuals.MISSING: lambda: np.sum(np.isnan(data)),
142
+ ImageVisuals.ZERO: lambda: np.int32(np.count_nonzero(data == 0)),
143
+ }
144
+ )
145
+ self.results.append(results)
146
+
147
+
148
+ class ImageStatisticsMetric(BaseStatsMetric):
149
+ """
150
+ Calculates descriptive statistics for each image
151
+
152
+ Parameters
153
+ ----------
154
+ flags: ImageStatistics
155
+ Statistic(s) to calculate for each image
156
+ """
157
+
158
+ def __init__(self, flags: ImageStatistics = ImageStatistics.ALL):
159
+ super().__init__(flags)
160
+
161
+ def update(self, preds: Iterable[np.ndarray], targets=None) -> None:
162
+ for data in preds:
163
+ scaled = rescale(data)
164
+ if (ImageStatistics.HISTOGRAM | ImageStatistics.ENTROPY) & self.flags:
165
+ hist = np.histogram(scaled, bins=256, range=(0, 1))[0]
166
+
167
+ results = self._map(
168
+ {
169
+ ImageStatistics.MEAN: lambda: np.mean(scaled),
170
+ ImageStatistics.STD: lambda: np.std(scaled),
171
+ ImageStatistics.VAR: lambda: np.var(scaled),
172
+ ImageStatistics.SKEW: lambda: np.float32(skew(scaled.ravel())),
173
+ ImageStatistics.KURTOSIS: lambda: np.float32(kurtosis(scaled.ravel())),
174
+ ImageStatistics.PERCENTILES: lambda: np.percentile(scaled, q=QUARTILES),
175
+ ImageStatistics.HISTOGRAM: lambda: hist,
176
+ ImageStatistics.ENTROPY: lambda: np.float32(entropy(hist)),
177
+ }
178
+ )
179
+ self.results.append(results)
180
+
181
+
182
+ class ChannelStatisticsMetric(BaseStatsMetric):
183
+ """
184
+ Calculates descriptive statistics for each image per channel
185
+
186
+ Parameters
187
+ ----------
188
+ flags: ImageStatistics
189
+ Statistic(s) to calculate for each image per channel
190
+ """
191
+
192
+ def __init__(self, flags: ImageStatistics = ImageStatistics.ALL):
193
+ super().__init__(flags)
194
+
195
+ def update(self, preds: Iterable[np.ndarray], targets=None) -> None:
196
+ for data in preds:
197
+ scaled = rescale(data)
198
+ flattened = scaled.reshape(data.shape[0], -1)
199
+
200
+ if (ImageStatistics.HISTOGRAM | ImageStatistics.ENTROPY) & self.flags:
201
+ hist = np.apply_along_axis(lambda x: np.histogram(x, bins=256, range=(0, 1))[0], 1, flattened)
202
+
203
+ results = self._map(
204
+ {
205
+ ImageStatistics.MEAN: lambda: np.mean(flattened, axis=1),
206
+ ImageStatistics.STD: lambda: np.std(flattened, axis=1),
207
+ ImageStatistics.VAR: lambda: np.var(flattened, axis=1),
208
+ ImageStatistics.SKEW: lambda: skew(flattened, axis=1),
209
+ ImageStatistics.KURTOSIS: lambda: kurtosis(flattened, axis=1),
210
+ ImageStatistics.PERCENTILES: lambda: np.percentile(flattened, q=QUARTILES, axis=1).T,
211
+ ImageStatistics.HISTOGRAM: lambda: hist,
212
+ ImageStatistics.ENTROPY: lambda: entropy(hist, axis=1),
213
+ }
214
+ )
215
+ self.results.append(results)
216
+
217
+
218
+ class BaseAggregateMetric(BaseStatsMetric, Generic[TFlag]):
219
+ FLAG_METRIC_MAP: Dict[type, type]
220
+ DEFAULT_FLAGS: Sequence[TFlag]
221
+
222
+ def __init__(self, flags: Optional[Union[TFlag, Sequence[TFlag]]] = None):
223
+ flag_dict = {}
224
+ for flag in flags if isinstance(flags, Sequence) else self.DEFAULT_FLAGS if not flags else [flags]:
225
+ flag_dict[type(flag)] = flag_dict.setdefault(type(flag), type(flag)(0)) | flag
226
+ self._metrics_dict = {
227
+ metric: []
228
+ for metric in (
229
+ self.FLAG_METRIC_MAP[flag_class](flag) for flag_class, flag in flag_dict.items() if flag.value != 0
230
+ )
231
+ }
232
+
233
+
234
+ class ImageStats(BaseAggregateMetric):
235
+ """
236
+ Calculates various image property statistics
237
+
238
+ Parameters
239
+ ----------
240
+ flags: [ImageHash | ImageProperty | ImageStatistics | ImageVisuals], default None
241
+ Metric(s) to calculate for each image per channel - calculates all metrics if None
242
+ """
243
+
244
+ FLAG_METRIC_MAP = {
245
+ ImageHash: ImageHashMetric,
246
+ ImageProperty: ImagePropertyMetric,
247
+ ImageStatistics: ImageStatisticsMetric,
248
+ ImageVisuals: ImageVisualsMetric,
249
+ }
250
+ DEFAULT_FLAGS = [ImageHash.ALL, ImageProperty.ALL, ImageStatistics.ALL, ImageVisuals.ALL]
251
+
252
+ def __init__(self, flags: Optional[Union[ImageStatsFlags, Sequence[ImageStatsFlags]]] = None):
253
+ super().__init__(flags)
254
+ self._length = 0
255
+
256
+ def update(self, preds: Iterable[np.ndarray], targets=None) -> None:
257
+ for image in preds:
258
+ self._length += 1
259
+ img = normalize_image_shape(image)
260
+ for metric in self._metrics_dict:
261
+ metric.update([img])
262
+
263
+ def compute(self) -> Dict[str, Any]:
264
+ for metric in self._metrics_dict:
265
+ self._metrics_dict[metric] = metric.results
266
+
267
+ stats = {}
268
+ for metric, results in self._metrics_dict.items():
269
+ for i, result in enumerate(results):
270
+ for stat in metric._keys():
271
+ value = result[stat]
272
+ if not isinstance(value, (np.ndarray, np.generic)):
273
+ if stat not in stats:
274
+ stats[stat] = []
275
+ stats[stat].append(result[stat])
276
+ else:
277
+ if stat not in stats:
278
+ shape = () if np.isscalar(result[stat]) else result[stat].shape
279
+ stats[stat] = np.empty((self._length,) + shape)
280
+ stats[stat][i] = result[stat]
281
+ return stats
282
+
283
+ def reset(self):
284
+ self._length = 0
285
+ for metric in self._metrics_dict:
286
+ metric.reset()
287
+ self._metrics_dict[metric] = []
288
+
289
+
290
+ class ChannelStats(BaseAggregateMetric):
291
+ FLAG_METRIC_MAP = {ImageStatistics: ChannelStatisticsMetric}
292
+ DEFAULT_FLAGS = [ImageStatistics.ALL]
293
+ IDX_MAP = "idx_map"
294
+
295
+ def __init__(self, flags: Optional[ImageStatistics] = None) -> None:
296
+ super().__init__(flags)
297
+
298
+ def update(self, preds: Iterable[np.ndarray], targets=None) -> None:
299
+ for image in preds:
300
+ img = normalize_image_shape(image)
301
+ for metric in self._metrics_dict:
302
+ metric.update([img])
303
+
304
+ for metric in self._metrics_dict:
305
+ self._metrics_dict[metric] = metric.results
306
+
307
+ def compute(self) -> Dict[str, Any]:
308
+ # Aggregate all metrics into a single dictionary
309
+ stats = {}
310
+ channel_stats = set()
311
+ for metric, results in self._metrics_dict.items():
312
+ for i, result in enumerate(results):
313
+ for stat in metric._keys():
314
+ channel_stats.update(metric._keys())
315
+ channels = result[stat].shape[0]
316
+ stats.setdefault(self.IDX_MAP, {}).setdefault(channels, {})[i] = None
317
+ stats.setdefault(stat, {}).setdefault(channels, []).append(result[stat])
318
+
319
+ # Concatenate list of channel statistics numpy
320
+ for stat in channel_stats:
321
+ for channel in stats[stat]:
322
+ stats[stat][channel] = np.array(stats[stat][channel]).T
323
+
324
+ for channel in stats[self.IDX_MAP]:
325
+ stats[self.IDX_MAP][channel] = list(stats[self.IDX_MAP][channel].keys())
326
+
327
+ return stats
328
+
329
+ def reset(self) -> None:
330
+ for metric in self._metrics_dict:
331
+ metric.reset()
332
+ self._metrics_dict[metric] = []
@@ -0,0 +1,45 @@
1
+ """
2
+ This module contains the implementation of the
3
+ FR Test Statistic based estimate for the upperbound
4
+ average precision using empirical mean precision
5
+ """
6
+
7
+ from typing import Dict
8
+
9
+ import numpy as np
10
+ from sklearn.metrics import average_precision_score
11
+
12
+ from dataeval._internal.metrics.base import EvaluateMixin
13
+
14
+
15
+ class UAP(EvaluateMixin):
16
+ """
17
+ FR Test Statistic based estimate of the empirical mean precision
18
+
19
+ Parameters
20
+ ----------
21
+ labels : np.ndarray
22
+ A numpy array of n_samples of class labels with M unique classes.
23
+
24
+ scores : np.ndarray
25
+ A 2D array of class probabilities per image
26
+ """
27
+
28
+ def __init__(self, labels: np.ndarray, scores: np.ndarray) -> None:
29
+ self.labels = labels
30
+ self.scores = scores
31
+
32
+ def evaluate(self) -> Dict[str, float]:
33
+ """
34
+ Returns
35
+ -------
36
+ Dict[str, float]
37
+ uap : The empirical mean precision estimate
38
+
39
+ Raises
40
+ ------
41
+ ValueError
42
+ If unique classes M < 2
43
+ """
44
+ uap = float(average_precision_score(self.labels, self.scores, average="weighted"))
45
+ return {"uap": uap}