dataeval 0.61.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +18 -0
- dataeval/_internal/detectors/__init__.py +0 -0
- dataeval/_internal/detectors/clusterer.py +469 -0
- dataeval/_internal/detectors/drift/__init__.py +0 -0
- dataeval/_internal/detectors/drift/base.py +265 -0
- dataeval/_internal/detectors/drift/cvm.py +97 -0
- dataeval/_internal/detectors/drift/ks.py +100 -0
- dataeval/_internal/detectors/drift/mmd.py +166 -0
- dataeval/_internal/detectors/drift/torch.py +310 -0
- dataeval/_internal/detectors/drift/uncertainty.py +149 -0
- dataeval/_internal/detectors/duplicates.py +49 -0
- dataeval/_internal/detectors/linter.py +78 -0
- dataeval/_internal/detectors/ood/__init__.py +0 -0
- dataeval/_internal/detectors/ood/ae.py +77 -0
- dataeval/_internal/detectors/ood/aegmm.py +69 -0
- dataeval/_internal/detectors/ood/base.py +199 -0
- dataeval/_internal/detectors/ood/llr.py +284 -0
- dataeval/_internal/detectors/ood/vae.py +86 -0
- dataeval/_internal/detectors/ood/vaegmm.py +79 -0
- dataeval/_internal/flags.py +47 -0
- dataeval/_internal/metrics/__init__.py +0 -0
- dataeval/_internal/metrics/base.py +92 -0
- dataeval/_internal/metrics/ber.py +124 -0
- dataeval/_internal/metrics/coverage.py +80 -0
- dataeval/_internal/metrics/divergence.py +94 -0
- dataeval/_internal/metrics/hash.py +79 -0
- dataeval/_internal/metrics/parity.py +180 -0
- dataeval/_internal/metrics/stats.py +332 -0
- dataeval/_internal/metrics/uap.py +45 -0
- dataeval/_internal/metrics/utils.py +158 -0
- dataeval/_internal/models/__init__.py +0 -0
- dataeval/_internal/models/pytorch/__init__.py +0 -0
- dataeval/_internal/models/pytorch/autoencoder.py +202 -0
- dataeval/_internal/models/pytorch/blocks.py +46 -0
- dataeval/_internal/models/pytorch/utils.py +67 -0
- dataeval/_internal/models/tensorflow/__init__.py +0 -0
- dataeval/_internal/models/tensorflow/autoencoder.py +317 -0
- dataeval/_internal/models/tensorflow/gmm.py +115 -0
- dataeval/_internal/models/tensorflow/losses.py +107 -0
- dataeval/_internal/models/tensorflow/pixelcnn.py +1106 -0
- dataeval/_internal/models/tensorflow/trainer.py +102 -0
- dataeval/_internal/models/tensorflow/utils.py +254 -0
- dataeval/_internal/workflows/sufficiency.py +555 -0
- dataeval/detectors/__init__.py +29 -0
- dataeval/flags/__init__.py +3 -0
- dataeval/metrics/__init__.py +7 -0
- dataeval/models/__init__.py +15 -0
- dataeval/models/tensorflow/__init__.py +6 -0
- dataeval/models/torch/__init__.py +8 -0
- dataeval/py.typed +0 -0
- dataeval/workflows/__init__.py +8 -0
- dataeval-0.61.0.dist-info/LICENSE.txt +21 -0
- dataeval-0.61.0.dist-info/METADATA +114 -0
- dataeval-0.61.0.dist-info/RECORD +55 -0
- dataeval-0.61.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,180 @@
|
|
1
|
+
import warnings
|
2
|
+
from typing import Optional, Tuple
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
import scipy
|
6
|
+
|
7
|
+
|
8
|
+
class Parity:
|
9
|
+
"""
|
10
|
+
Class for evaluating statistics of observed and expected class labels, including:
|
11
|
+
|
12
|
+
- Chi Squared test for statistical independence between expected and observed labels
|
13
|
+
|
14
|
+
Parameters
|
15
|
+
----------
|
16
|
+
expected_labels : np.ndarray
|
17
|
+
List of class labels in the expected dataset
|
18
|
+
observed_labels : np.ndarray
|
19
|
+
List of class labels in the observed dataset
|
20
|
+
num_classes : Optional[int]
|
21
|
+
The number of unique classes in the datasets. If this is not specified, it will
|
22
|
+
be inferred from the set of unique labels in expected_labels and observed_labels
|
23
|
+
"""
|
24
|
+
|
25
|
+
def __init__(self, expected_labels: np.ndarray, observed_labels: np.ndarray, num_classes: Optional[int] = None):
|
26
|
+
self.set_labels(expected_labels, observed_labels, num_classes)
|
27
|
+
|
28
|
+
def set_labels(self, expected_labels: np.ndarray, observed_labels: np.ndarray, num_classes: Optional[int] = None):
|
29
|
+
"""
|
30
|
+
Calculates the label distributions for expected and observed labels
|
31
|
+
and performs validation on the results.
|
32
|
+
|
33
|
+
Parameters
|
34
|
+
----------
|
35
|
+
expected_labels : np.ndarray
|
36
|
+
List of class labels in the expected dataset
|
37
|
+
observed_labels : np.ndarray
|
38
|
+
List of class labels in the observed dataset
|
39
|
+
num_classes : Optional[int]
|
40
|
+
The number of unique classes in the datasets. If this is not specified, it will
|
41
|
+
be inferred from the set of unique labels in expected_labels and observed_labels
|
42
|
+
|
43
|
+
Raises
|
44
|
+
------
|
45
|
+
ValueError
|
46
|
+
If x is empty
|
47
|
+
"""
|
48
|
+
self.num_classes = num_classes
|
49
|
+
|
50
|
+
# Calculate
|
51
|
+
observed_dist = self._calculate_label_dist(observed_labels)
|
52
|
+
expected_dist = self._calculate_label_dist(expected_labels)
|
53
|
+
|
54
|
+
# Validate
|
55
|
+
self._validate_dist(observed_dist, "observed")
|
56
|
+
|
57
|
+
# Normalize
|
58
|
+
expected_dist = self._normalize_expected_dist(expected_dist, observed_dist)
|
59
|
+
|
60
|
+
# Validate normalized expected distribution
|
61
|
+
self._validate_dist(expected_dist, f"expected for {np.sum(observed_dist)} observations")
|
62
|
+
self._validate_class_balance(expected_dist, observed_dist)
|
63
|
+
|
64
|
+
self._observed_dist = observed_dist
|
65
|
+
self._expected_dist = expected_dist
|
66
|
+
|
67
|
+
def _normalize_expected_dist(self, expected_dist: np.ndarray, observed_dist: np.ndarray) -> np.ndarray:
|
68
|
+
exp_sum = np.sum(expected_dist)
|
69
|
+
obs_sum = np.sum(observed_dist)
|
70
|
+
|
71
|
+
if exp_sum == 0:
|
72
|
+
raise ValueError(
|
73
|
+
f"Expected label distribution {expected_dist} is all zeros. "
|
74
|
+
"Ensure that Parity.expected_dist is set to a list "
|
75
|
+
"with at least one nonzero element"
|
76
|
+
)
|
77
|
+
|
78
|
+
# Renormalize expected distribution to have the same total number of labels as the observed dataset
|
79
|
+
if exp_sum != obs_sum:
|
80
|
+
expected_dist = expected_dist * obs_sum / exp_sum
|
81
|
+
|
82
|
+
return expected_dist
|
83
|
+
|
84
|
+
def _calculate_label_dist(self, labels: np.ndarray) -> np.ndarray:
|
85
|
+
"""
|
86
|
+
Calculate the class frequencies associated with a dataset
|
87
|
+
|
88
|
+
Parameters
|
89
|
+
----------
|
90
|
+
labels : np.ndarray
|
91
|
+
List of class labels in a dataset
|
92
|
+
|
93
|
+
Returns
|
94
|
+
-------
|
95
|
+
label_dist : np.ndarray
|
96
|
+
Array representing label distributions
|
97
|
+
"""
|
98
|
+
label_dist = np.bincount(labels, minlength=(self.num_classes if self.num_classes else 0))
|
99
|
+
return label_dist
|
100
|
+
|
101
|
+
def _validate_class_balance(self, expected_dist: np.ndarray, observed_dist: np.ndarray):
|
102
|
+
"""
|
103
|
+
Check if the numbers of unique classes in the datasets are unequal
|
104
|
+
|
105
|
+
Parameters
|
106
|
+
----------
|
107
|
+
expected_dist : np.ndarray
|
108
|
+
Array representing expected label distributions
|
109
|
+
observed_dist : np.ndarray
|
110
|
+
Array representing observed label distributions
|
111
|
+
|
112
|
+
Raises
|
113
|
+
------
|
114
|
+
ValueError
|
115
|
+
When exp_ld and obs_ld do not have the same number of classes
|
116
|
+
"""
|
117
|
+
exp_n_cls = len(expected_dist)
|
118
|
+
obs_n_cls = len(observed_dist)
|
119
|
+
if exp_n_cls != obs_n_cls:
|
120
|
+
raise ValueError(
|
121
|
+
f"Found {obs_n_cls} unique classes in observed label distribution, "
|
122
|
+
f"but found {exp_n_cls} unique classes in expected label distribution,"
|
123
|
+
"This can happen when some class ids have zero instances in one dataset but "
|
124
|
+
"not in the other. When initializing Parity, "
|
125
|
+
"try setting the num_classes parameter to the known number of unique class ids, "
|
126
|
+
"so that classes with zero instances are still included in the distributions."
|
127
|
+
)
|
128
|
+
|
129
|
+
def _validate_dist(self, label_dist: np.ndarray, label_name: str):
|
130
|
+
"""
|
131
|
+
Verifies that the given label distribution has labels and checks if
|
132
|
+
any labels have frequencies less than 5.
|
133
|
+
|
134
|
+
Parameters
|
135
|
+
----------
|
136
|
+
label_dist : np.ndarray
|
137
|
+
Array representing label distributions
|
138
|
+
|
139
|
+
Raises
|
140
|
+
------
|
141
|
+
ValueError
|
142
|
+
If label_dist is empty
|
143
|
+
Warning
|
144
|
+
If any elements of label_dist are less than 5
|
145
|
+
"""
|
146
|
+
if not len(label_dist):
|
147
|
+
raise ValueError(f"No labels found in the {label_name} dataset")
|
148
|
+
if np.any(label_dist < 5):
|
149
|
+
warnings.warn(
|
150
|
+
f"Labels {np.where(label_dist<5)[0]} in {label_name}"
|
151
|
+
" dataset have frequencies less than 5. This may lead"
|
152
|
+
" to invalid chi-squared evaluation."
|
153
|
+
)
|
154
|
+
warnings.warn(
|
155
|
+
f"Labels {np.where(label_dist<5)[0]} in {label_name}"
|
156
|
+
" dataset have frequencies less than 5. This may lead"
|
157
|
+
" to invalid chi-squared evaluation."
|
158
|
+
)
|
159
|
+
|
160
|
+
def evaluate(self) -> Tuple[np.float64, np.float64]:
|
161
|
+
"""
|
162
|
+
Perform a one-way chi-squared test between observation frequencies and expected frequencies that
|
163
|
+
tests the null hypothesis that the observed data has the expected frequencies.
|
164
|
+
|
165
|
+
This function acts as an interface to the scipy.stats.chisquare method, which is documented at
|
166
|
+
https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.chisquare.html
|
167
|
+
https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.chisquare.html
|
168
|
+
|
169
|
+
Returns
|
170
|
+
-------
|
171
|
+
np.float64
|
172
|
+
chi-squared value of the test
|
173
|
+
np.float64
|
174
|
+
p-value of the test
|
175
|
+
"""
|
176
|
+
cs_result = scipy.stats.chisquare(f_obs=self._observed_dist, f_exp=self._expected_dist)
|
177
|
+
|
178
|
+
chisquared = cs_result.statistic
|
179
|
+
p_value = cs_result.pvalue
|
180
|
+
return chisquared, p_value
|
@@ -0,0 +1,332 @@
|
|
1
|
+
from enum import Flag
|
2
|
+
from typing import Any, Callable, Dict, Generic, Iterable, List, Optional, Sequence, TypeVar, Union
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
from scipy.stats import entropy, kurtosis, skew
|
6
|
+
|
7
|
+
from dataeval._internal.flags import ImageHash, ImageProperty, ImageStatistics, ImageStatsFlags, ImageVisuals
|
8
|
+
from dataeval._internal.metrics.base import MetricMixin
|
9
|
+
from dataeval._internal.metrics.hash import pchash, xxhash
|
10
|
+
from dataeval._internal.metrics.utils import edge_filter, get_bitdepth, normalize_image_shape, rescale
|
11
|
+
|
12
|
+
QUARTILES = (0, 25, 50, 75, 100)
|
13
|
+
|
14
|
+
TBatch = TypeVar("TBatch", bound=Sequence)
|
15
|
+
TFlag = TypeVar("TFlag", bound=Flag)
|
16
|
+
|
17
|
+
|
18
|
+
class BaseStatsMetric(MetricMixin, Generic[TBatch, TFlag]):
|
19
|
+
def __init__(self, flags: TFlag):
|
20
|
+
self.flags = flags
|
21
|
+
self.results = []
|
22
|
+
|
23
|
+
def update(self, preds: TBatch, targets=None) -> None:
|
24
|
+
"""
|
25
|
+
Updates internal metric cache for later calculation
|
26
|
+
|
27
|
+
Parameters
|
28
|
+
----------
|
29
|
+
batch : Sequence
|
30
|
+
Sequence of images to be processed
|
31
|
+
"""
|
32
|
+
|
33
|
+
def compute(self) -> Dict[str, Any]:
|
34
|
+
"""
|
35
|
+
Computes the specified measures on the cached values
|
36
|
+
|
37
|
+
Returns
|
38
|
+
-------
|
39
|
+
Dict[str, Any]
|
40
|
+
Dictionary results of the specified measures
|
41
|
+
"""
|
42
|
+
return {stat: [result[stat] for result in self.results] for stat in self.results[0]}
|
43
|
+
|
44
|
+
def reset(self) -> None:
|
45
|
+
"""
|
46
|
+
Resets the internal metric cache
|
47
|
+
"""
|
48
|
+
self.results = []
|
49
|
+
|
50
|
+
def _map(self, func_map: Dict[Flag, Callable]) -> Dict[str, Any]:
|
51
|
+
"""Calculates the measures for each flag if it is selected."""
|
52
|
+
results = {}
|
53
|
+
for flag, func in func_map.items():
|
54
|
+
if not flag.name:
|
55
|
+
raise ValueError("Provided flag to set value does not have a name.")
|
56
|
+
if flag & self.flags:
|
57
|
+
results[flag.name.lower()] = func()
|
58
|
+
return results
|
59
|
+
|
60
|
+
def _keys(self) -> List[str]:
|
61
|
+
"""Returns the list of measures to be calculated."""
|
62
|
+
flags = (
|
63
|
+
self.flags
|
64
|
+
if isinstance(self.flags, Iterable) # py3.11
|
65
|
+
else [flag for flag in list(self.flags.__class__) if flag & self.flags]
|
66
|
+
)
|
67
|
+
return [flag.name.lower() for flag in flags if flag.name is not None]
|
68
|
+
|
69
|
+
|
70
|
+
class ImageHashMetric(BaseStatsMetric):
|
71
|
+
"""
|
72
|
+
Hashes images using the specified algorithms
|
73
|
+
|
74
|
+
Parameters
|
75
|
+
----------
|
76
|
+
flags : ImageHash
|
77
|
+
Algorithm(s) to calculate a hash as hex digest
|
78
|
+
"""
|
79
|
+
|
80
|
+
def __init__(self, flags: ImageHash = ImageHash.ALL):
|
81
|
+
super().__init__(flags)
|
82
|
+
|
83
|
+
def update(self, preds: Iterable[np.ndarray], targets=None) -> None:
|
84
|
+
for data in preds:
|
85
|
+
results = self._map(
|
86
|
+
{
|
87
|
+
ImageHash.XXHASH: lambda: xxhash(data),
|
88
|
+
ImageHash.PCHASH: lambda: pchash(data),
|
89
|
+
}
|
90
|
+
)
|
91
|
+
self.results.append(results)
|
92
|
+
|
93
|
+
|
94
|
+
class ImagePropertyMetric(BaseStatsMetric):
|
95
|
+
"""
|
96
|
+
Calculates specified image properties
|
97
|
+
|
98
|
+
Parameters
|
99
|
+
----------
|
100
|
+
flags: ImageProperty
|
101
|
+
Property(ies) to calculate for each image
|
102
|
+
"""
|
103
|
+
|
104
|
+
def __init__(self, flags: ImageProperty = ImageProperty.ALL):
|
105
|
+
super().__init__(flags)
|
106
|
+
|
107
|
+
def update(self, preds: Iterable[np.ndarray], targets=None) -> None:
|
108
|
+
for data in preds:
|
109
|
+
results = self._map(
|
110
|
+
{
|
111
|
+
ImageProperty.WIDTH: lambda: np.int32(data.shape[-1]),
|
112
|
+
ImageProperty.HEIGHT: lambda: np.int32(data.shape[-2]),
|
113
|
+
ImageProperty.SIZE: lambda: np.int32(data.shape[-1] * data.shape[-2]),
|
114
|
+
ImageProperty.ASPECT_RATIO: lambda: data.shape[-1] / np.int32(data.shape[-2]),
|
115
|
+
ImageProperty.CHANNELS: lambda: data.shape[-3],
|
116
|
+
ImageProperty.DEPTH: lambda: get_bitdepth(data).depth,
|
117
|
+
}
|
118
|
+
)
|
119
|
+
self.results.append(results)
|
120
|
+
|
121
|
+
|
122
|
+
class ImageVisualsMetric(BaseStatsMetric):
|
123
|
+
"""
|
124
|
+
Calculates specified visual image properties
|
125
|
+
|
126
|
+
Parameters
|
127
|
+
----------
|
128
|
+
flags: ImageVisuals
|
129
|
+
Property(ies) to calculate for each image
|
130
|
+
"""
|
131
|
+
|
132
|
+
def __init__(self, flags: ImageVisuals = ImageVisuals.ALL):
|
133
|
+
super().__init__(flags)
|
134
|
+
|
135
|
+
def update(self, preds: Iterable[np.ndarray], targets=None) -> None:
|
136
|
+
for data in preds:
|
137
|
+
results = self._map(
|
138
|
+
{
|
139
|
+
ImageVisuals.BRIGHTNESS: lambda: np.mean(rescale(data)),
|
140
|
+
ImageVisuals.BLURRINESS: lambda: np.std(edge_filter(np.mean(data, axis=0))),
|
141
|
+
ImageVisuals.MISSING: lambda: np.sum(np.isnan(data)),
|
142
|
+
ImageVisuals.ZERO: lambda: np.int32(np.count_nonzero(data == 0)),
|
143
|
+
}
|
144
|
+
)
|
145
|
+
self.results.append(results)
|
146
|
+
|
147
|
+
|
148
|
+
class ImageStatisticsMetric(BaseStatsMetric):
|
149
|
+
"""
|
150
|
+
Calculates descriptive statistics for each image
|
151
|
+
|
152
|
+
Parameters
|
153
|
+
----------
|
154
|
+
flags: ImageStatistics
|
155
|
+
Statistic(s) to calculate for each image
|
156
|
+
"""
|
157
|
+
|
158
|
+
def __init__(self, flags: ImageStatistics = ImageStatistics.ALL):
|
159
|
+
super().__init__(flags)
|
160
|
+
|
161
|
+
def update(self, preds: Iterable[np.ndarray], targets=None) -> None:
|
162
|
+
for data in preds:
|
163
|
+
scaled = rescale(data)
|
164
|
+
if (ImageStatistics.HISTOGRAM | ImageStatistics.ENTROPY) & self.flags:
|
165
|
+
hist = np.histogram(scaled, bins=256, range=(0, 1))[0]
|
166
|
+
|
167
|
+
results = self._map(
|
168
|
+
{
|
169
|
+
ImageStatistics.MEAN: lambda: np.mean(scaled),
|
170
|
+
ImageStatistics.STD: lambda: np.std(scaled),
|
171
|
+
ImageStatistics.VAR: lambda: np.var(scaled),
|
172
|
+
ImageStatistics.SKEW: lambda: np.float32(skew(scaled.ravel())),
|
173
|
+
ImageStatistics.KURTOSIS: lambda: np.float32(kurtosis(scaled.ravel())),
|
174
|
+
ImageStatistics.PERCENTILES: lambda: np.percentile(scaled, q=QUARTILES),
|
175
|
+
ImageStatistics.HISTOGRAM: lambda: hist,
|
176
|
+
ImageStatistics.ENTROPY: lambda: np.float32(entropy(hist)),
|
177
|
+
}
|
178
|
+
)
|
179
|
+
self.results.append(results)
|
180
|
+
|
181
|
+
|
182
|
+
class ChannelStatisticsMetric(BaseStatsMetric):
|
183
|
+
"""
|
184
|
+
Calculates descriptive statistics for each image per channel
|
185
|
+
|
186
|
+
Parameters
|
187
|
+
----------
|
188
|
+
flags: ImageStatistics
|
189
|
+
Statistic(s) to calculate for each image per channel
|
190
|
+
"""
|
191
|
+
|
192
|
+
def __init__(self, flags: ImageStatistics = ImageStatistics.ALL):
|
193
|
+
super().__init__(flags)
|
194
|
+
|
195
|
+
def update(self, preds: Iterable[np.ndarray], targets=None) -> None:
|
196
|
+
for data in preds:
|
197
|
+
scaled = rescale(data)
|
198
|
+
flattened = scaled.reshape(data.shape[0], -1)
|
199
|
+
|
200
|
+
if (ImageStatistics.HISTOGRAM | ImageStatistics.ENTROPY) & self.flags:
|
201
|
+
hist = np.apply_along_axis(lambda x: np.histogram(x, bins=256, range=(0, 1))[0], 1, flattened)
|
202
|
+
|
203
|
+
results = self._map(
|
204
|
+
{
|
205
|
+
ImageStatistics.MEAN: lambda: np.mean(flattened, axis=1),
|
206
|
+
ImageStatistics.STD: lambda: np.std(flattened, axis=1),
|
207
|
+
ImageStatistics.VAR: lambda: np.var(flattened, axis=1),
|
208
|
+
ImageStatistics.SKEW: lambda: skew(flattened, axis=1),
|
209
|
+
ImageStatistics.KURTOSIS: lambda: kurtosis(flattened, axis=1),
|
210
|
+
ImageStatistics.PERCENTILES: lambda: np.percentile(flattened, q=QUARTILES, axis=1).T,
|
211
|
+
ImageStatistics.HISTOGRAM: lambda: hist,
|
212
|
+
ImageStatistics.ENTROPY: lambda: entropy(hist, axis=1),
|
213
|
+
}
|
214
|
+
)
|
215
|
+
self.results.append(results)
|
216
|
+
|
217
|
+
|
218
|
+
class BaseAggregateMetric(BaseStatsMetric, Generic[TFlag]):
|
219
|
+
FLAG_METRIC_MAP: Dict[type, type]
|
220
|
+
DEFAULT_FLAGS: Sequence[TFlag]
|
221
|
+
|
222
|
+
def __init__(self, flags: Optional[Union[TFlag, Sequence[TFlag]]] = None):
|
223
|
+
flag_dict = {}
|
224
|
+
for flag in flags if isinstance(flags, Sequence) else self.DEFAULT_FLAGS if not flags else [flags]:
|
225
|
+
flag_dict[type(flag)] = flag_dict.setdefault(type(flag), type(flag)(0)) | flag
|
226
|
+
self._metrics_dict = {
|
227
|
+
metric: []
|
228
|
+
for metric in (
|
229
|
+
self.FLAG_METRIC_MAP[flag_class](flag) for flag_class, flag in flag_dict.items() if flag.value != 0
|
230
|
+
)
|
231
|
+
}
|
232
|
+
|
233
|
+
|
234
|
+
class ImageStats(BaseAggregateMetric):
|
235
|
+
"""
|
236
|
+
Calculates various image property statistics
|
237
|
+
|
238
|
+
Parameters
|
239
|
+
----------
|
240
|
+
flags: [ImageHash | ImageProperty | ImageStatistics | ImageVisuals], default None
|
241
|
+
Metric(s) to calculate for each image per channel - calculates all metrics if None
|
242
|
+
"""
|
243
|
+
|
244
|
+
FLAG_METRIC_MAP = {
|
245
|
+
ImageHash: ImageHashMetric,
|
246
|
+
ImageProperty: ImagePropertyMetric,
|
247
|
+
ImageStatistics: ImageStatisticsMetric,
|
248
|
+
ImageVisuals: ImageVisualsMetric,
|
249
|
+
}
|
250
|
+
DEFAULT_FLAGS = [ImageHash.ALL, ImageProperty.ALL, ImageStatistics.ALL, ImageVisuals.ALL]
|
251
|
+
|
252
|
+
def __init__(self, flags: Optional[Union[ImageStatsFlags, Sequence[ImageStatsFlags]]] = None):
|
253
|
+
super().__init__(flags)
|
254
|
+
self._length = 0
|
255
|
+
|
256
|
+
def update(self, preds: Iterable[np.ndarray], targets=None) -> None:
|
257
|
+
for image in preds:
|
258
|
+
self._length += 1
|
259
|
+
img = normalize_image_shape(image)
|
260
|
+
for metric in self._metrics_dict:
|
261
|
+
metric.update([img])
|
262
|
+
|
263
|
+
def compute(self) -> Dict[str, Any]:
|
264
|
+
for metric in self._metrics_dict:
|
265
|
+
self._metrics_dict[metric] = metric.results
|
266
|
+
|
267
|
+
stats = {}
|
268
|
+
for metric, results in self._metrics_dict.items():
|
269
|
+
for i, result in enumerate(results):
|
270
|
+
for stat in metric._keys():
|
271
|
+
value = result[stat]
|
272
|
+
if not isinstance(value, (np.ndarray, np.generic)):
|
273
|
+
if stat not in stats:
|
274
|
+
stats[stat] = []
|
275
|
+
stats[stat].append(result[stat])
|
276
|
+
else:
|
277
|
+
if stat not in stats:
|
278
|
+
shape = () if np.isscalar(result[stat]) else result[stat].shape
|
279
|
+
stats[stat] = np.empty((self._length,) + shape)
|
280
|
+
stats[stat][i] = result[stat]
|
281
|
+
return stats
|
282
|
+
|
283
|
+
def reset(self):
|
284
|
+
self._length = 0
|
285
|
+
for metric in self._metrics_dict:
|
286
|
+
metric.reset()
|
287
|
+
self._metrics_dict[metric] = []
|
288
|
+
|
289
|
+
|
290
|
+
class ChannelStats(BaseAggregateMetric):
|
291
|
+
FLAG_METRIC_MAP = {ImageStatistics: ChannelStatisticsMetric}
|
292
|
+
DEFAULT_FLAGS = [ImageStatistics.ALL]
|
293
|
+
IDX_MAP = "idx_map"
|
294
|
+
|
295
|
+
def __init__(self, flags: Optional[ImageStatistics] = None) -> None:
|
296
|
+
super().__init__(flags)
|
297
|
+
|
298
|
+
def update(self, preds: Iterable[np.ndarray], targets=None) -> None:
|
299
|
+
for image in preds:
|
300
|
+
img = normalize_image_shape(image)
|
301
|
+
for metric in self._metrics_dict:
|
302
|
+
metric.update([img])
|
303
|
+
|
304
|
+
for metric in self._metrics_dict:
|
305
|
+
self._metrics_dict[metric] = metric.results
|
306
|
+
|
307
|
+
def compute(self) -> Dict[str, Any]:
|
308
|
+
# Aggregate all metrics into a single dictionary
|
309
|
+
stats = {}
|
310
|
+
channel_stats = set()
|
311
|
+
for metric, results in self._metrics_dict.items():
|
312
|
+
for i, result in enumerate(results):
|
313
|
+
for stat in metric._keys():
|
314
|
+
channel_stats.update(metric._keys())
|
315
|
+
channels = result[stat].shape[0]
|
316
|
+
stats.setdefault(self.IDX_MAP, {}).setdefault(channels, {})[i] = None
|
317
|
+
stats.setdefault(stat, {}).setdefault(channels, []).append(result[stat])
|
318
|
+
|
319
|
+
# Concatenate list of channel statistics numpy
|
320
|
+
for stat in channel_stats:
|
321
|
+
for channel in stats[stat]:
|
322
|
+
stats[stat][channel] = np.array(stats[stat][channel]).T
|
323
|
+
|
324
|
+
for channel in stats[self.IDX_MAP]:
|
325
|
+
stats[self.IDX_MAP][channel] = list(stats[self.IDX_MAP][channel].keys())
|
326
|
+
|
327
|
+
return stats
|
328
|
+
|
329
|
+
def reset(self) -> None:
|
330
|
+
for metric in self._metrics_dict:
|
331
|
+
metric.reset()
|
332
|
+
self._metrics_dict[metric] = []
|
@@ -0,0 +1,45 @@
|
|
1
|
+
"""
|
2
|
+
This module contains the implementation of the
|
3
|
+
FR Test Statistic based estimate for the upperbound
|
4
|
+
average precision using empirical mean precision
|
5
|
+
"""
|
6
|
+
|
7
|
+
from typing import Dict
|
8
|
+
|
9
|
+
import numpy as np
|
10
|
+
from sklearn.metrics import average_precision_score
|
11
|
+
|
12
|
+
from dataeval._internal.metrics.base import EvaluateMixin
|
13
|
+
|
14
|
+
|
15
|
+
class UAP(EvaluateMixin):
|
16
|
+
"""
|
17
|
+
FR Test Statistic based estimate of the empirical mean precision
|
18
|
+
|
19
|
+
Parameters
|
20
|
+
----------
|
21
|
+
labels : np.ndarray
|
22
|
+
A numpy array of n_samples of class labels with M unique classes.
|
23
|
+
|
24
|
+
scores : np.ndarray
|
25
|
+
A 2D array of class probabilities per image
|
26
|
+
"""
|
27
|
+
|
28
|
+
def __init__(self, labels: np.ndarray, scores: np.ndarray) -> None:
|
29
|
+
self.labels = labels
|
30
|
+
self.scores = scores
|
31
|
+
|
32
|
+
def evaluate(self) -> Dict[str, float]:
|
33
|
+
"""
|
34
|
+
Returns
|
35
|
+
-------
|
36
|
+
Dict[str, float]
|
37
|
+
uap : The empirical mean precision estimate
|
38
|
+
|
39
|
+
Raises
|
40
|
+
------
|
41
|
+
ValueError
|
42
|
+
If unique classes M < 2
|
43
|
+
"""
|
44
|
+
uap = float(average_precision_score(self.labels, self.scores, average="weighted"))
|
45
|
+
return {"uap": uap}
|