dataeval 0.82.1__py3-none-any.whl → 0.84.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +7 -2
- dataeval/config.py +13 -3
- dataeval/metadata/__init__.py +2 -2
- dataeval/metadata/_ood.py +144 -27
- dataeval/metrics/bias/__init__.py +11 -1
- dataeval/metrics/bias/_balance.py +3 -3
- dataeval/metrics/bias/_completeness.py +130 -0
- dataeval/metrics/estimators/_ber.py +2 -1
- dataeval/metrics/stats/_base.py +31 -36
- dataeval/metrics/stats/_dimensionstats.py +2 -2
- dataeval/metrics/stats/_hashstats.py +2 -2
- dataeval/metrics/stats/_imagestats.py +4 -4
- dataeval/metrics/stats/_labelstats.py +4 -45
- dataeval/metrics/stats/_pixelstats.py +2 -2
- dataeval/metrics/stats/_visualstats.py +2 -2
- dataeval/outputs/__init__.py +4 -2
- dataeval/outputs/_bias.py +31 -22
- dataeval/outputs/_metadata.py +7 -0
- dataeval/outputs/_stats.py +2 -3
- dataeval/typing.py +43 -12
- dataeval/utils/_array.py +26 -1
- dataeval/utils/_mst.py +1 -2
- dataeval/utils/data/_dataset.py +2 -0
- dataeval/utils/data/_embeddings.py +115 -32
- dataeval/utils/data/_images.py +38 -15
- dataeval/utils/data/_selection.py +7 -8
- dataeval/utils/data/_split.py +76 -129
- dataeval/utils/data/datasets/_base.py +4 -2
- dataeval/utils/data/datasets/_cifar10.py +17 -9
- dataeval/utils/data/datasets/_milco.py +18 -12
- dataeval/utils/data/datasets/_mnist.py +24 -8
- dataeval/utils/data/datasets/_ships.py +18 -8
- dataeval/utils/data/datasets/_types.py +1 -5
- dataeval/utils/data/datasets/_voc.py +47 -24
- dataeval/utils/data/selections/__init__.py +2 -0
- dataeval/utils/data/selections/_classfilter.py +1 -1
- dataeval/utils/data/selections/_prioritize.py +296 -0
- dataeval/utils/data/selections/_shuffle.py +13 -4
- dataeval/utils/metadata.py +1 -1
- dataeval/utils/torch/_gmm.py +3 -2
- {dataeval-0.82.1.dist-info → dataeval-0.84.0.dist-info}/METADATA +4 -4
- {dataeval-0.82.1.dist-info → dataeval-0.84.0.dist-info}/RECORD +44 -43
- dataeval/detectors/ood/metadata_ood_mi.py +0 -91
- {dataeval-0.82.1.dist-info → dataeval-0.84.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.82.1.dist-info → dataeval-0.84.0.dist-info}/WHEEL +0 -0
dataeval/__init__.py
CHANGED
@@ -8,7 +8,7 @@ shifts that impact performance of deployed models.
|
|
8
8
|
from __future__ import annotations
|
9
9
|
|
10
10
|
__all__ = ["config", "detectors", "log", "metrics", "typing", "utils", "workflows"]
|
11
|
-
__version__ = "0.
|
11
|
+
__version__ = "0.84.0"
|
12
12
|
|
13
13
|
import logging
|
14
14
|
|
@@ -34,7 +34,12 @@ def log(level: int = logging.DEBUG, handler: logging.Handler | None = None) -> N
|
|
34
34
|
logger = logging.getLogger(__name__)
|
35
35
|
if handler is None:
|
36
36
|
handler = logging.StreamHandler() if handler is None else handler
|
37
|
-
handler.setFormatter(
|
37
|
+
handler.setFormatter(
|
38
|
+
logging.Formatter(
|
39
|
+
"%(asctime)s %(levelname)-8s %(name)s.%(filename)s:%(lineno)s - %(funcName)10s() | %(message)s"
|
40
|
+
)
|
41
|
+
)
|
38
42
|
logger.addHandler(handler)
|
39
43
|
logger.setLevel(level)
|
44
|
+
logging.DEBUG
|
40
45
|
logger.debug(f"Added logging handler {handler} to logger: {__name__}")
|
dataeval/config.py
CHANGED
@@ -17,10 +17,18 @@ else:
|
|
17
17
|
import numpy as np
|
18
18
|
import torch
|
19
19
|
|
20
|
+
### GLOBALS ###
|
21
|
+
|
20
22
|
_device: torch.device | None = None
|
21
23
|
_processes: int | None = None
|
22
24
|
_seed: int | None = None
|
23
25
|
|
26
|
+
### CONSTS ###
|
27
|
+
|
28
|
+
EPSILON = 1e-10
|
29
|
+
|
30
|
+
### TYPES ###
|
31
|
+
|
24
32
|
DeviceLike: TypeAlias = Union[int, str, tuple[str, int], torch.device]
|
25
33
|
"""
|
26
34
|
Type alias for types that are acceptable for specifying a torch.device.
|
@@ -30,18 +38,20 @@ See Also
|
|
30
38
|
`torch.device <https://pytorch.org/docs/stable/tensor_attributes.html#torch.device>`_
|
31
39
|
"""
|
32
40
|
|
41
|
+
### FUNCS ###
|
42
|
+
|
33
43
|
|
34
44
|
def _todevice(device: DeviceLike) -> torch.device:
|
35
45
|
return torch.device(*device) if isinstance(device, tuple) else torch.device(device)
|
36
46
|
|
37
47
|
|
38
|
-
def set_device(device: DeviceLike) -> None:
|
48
|
+
def set_device(device: DeviceLike | None) -> None:
|
39
49
|
"""
|
40
50
|
Sets the default device to use when executing against a PyTorch backend.
|
41
51
|
|
42
52
|
Parameters
|
43
53
|
----------
|
44
|
-
device : DeviceLike
|
54
|
+
device : DeviceLike or None
|
45
55
|
The default device to use. See documentation for more information.
|
46
56
|
|
47
57
|
See Also
|
@@ -49,7 +59,7 @@ def set_device(device: DeviceLike) -> None:
|
|
49
59
|
`torch.device <https://pytorch.org/docs/stable/tensor_attributes.html#torch.device>`_
|
50
60
|
"""
|
51
61
|
global _device
|
52
|
-
_device = _todevice(device)
|
62
|
+
_device = None if device is None else _todevice(device)
|
53
63
|
|
54
64
|
|
55
65
|
def get_device(override: DeviceLike | None = None) -> torch.device:
|
dataeval/metadata/__init__.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
"""Explanatory functions using metadata and additional features such as ood or drift"""
|
2
2
|
|
3
|
-
__all__ = ["
|
3
|
+
__all__ = ["find_ood_predictors", "metadata_distance", "find_most_deviated_factors"]
|
4
4
|
|
5
5
|
from dataeval.metadata._distance import metadata_distance
|
6
|
-
from dataeval.metadata._ood import
|
6
|
+
from dataeval.metadata._ood import find_most_deviated_factors, find_ood_predictors
|
dataeval/metadata/_ood.py
CHANGED
@@ -6,14 +6,44 @@ import warnings
|
|
6
6
|
|
7
7
|
import numpy as np
|
8
8
|
from numpy.typing import NDArray
|
9
|
+
from sklearn.feature_selection import mutual_info_classif
|
9
10
|
|
11
|
+
from dataeval.config import get_seed
|
10
12
|
from dataeval.metadata._utils import _compare_keys, _validate_factors_and_data
|
11
|
-
from dataeval.outputs import MostDeviatedFactorsOutput, OODOutput
|
13
|
+
from dataeval.outputs import MostDeviatedFactorsOutput, OODOutput, OODPredictorOutput
|
12
14
|
from dataeval.outputs._base import set_metadata
|
13
15
|
from dataeval.utils.data import Metadata
|
14
16
|
|
15
17
|
|
16
|
-
def
|
18
|
+
def _combine_discrete_continuous(metadata: Metadata) -> tuple[list[str], NDArray[np.float64]]:
|
19
|
+
"""Combines the discrete and continuous data of a :class:`Metadata` object
|
20
|
+
|
21
|
+
Returns
|
22
|
+
-------
|
23
|
+
Tuple[list[str], NDArray]
|
24
|
+
The combined list of factors names and the combined discrete and continuous data
|
25
|
+
|
26
|
+
Note
|
27
|
+
----
|
28
|
+
Discrete and continuous data must have the same number of samples
|
29
|
+
"""
|
30
|
+
names = []
|
31
|
+
data = []
|
32
|
+
|
33
|
+
if metadata.discrete_factor_names and metadata.discrete_data.size != 0:
|
34
|
+
names.extend(metadata.discrete_factor_names)
|
35
|
+
data.append(metadata.discrete_data)
|
36
|
+
|
37
|
+
if metadata.continuous_factor_names and metadata.continuous_data.size != 0:
|
38
|
+
names.extend(metadata.continuous_factor_names)
|
39
|
+
data.append(metadata.continuous_data)
|
40
|
+
|
41
|
+
return names, np.hstack(data, dtype=np.float64) if data else np.array([], dtype=np.float64)
|
42
|
+
|
43
|
+
|
44
|
+
def _combine_metadata(
|
45
|
+
metadata_1: Metadata, metadata_2: Metadata
|
46
|
+
) -> tuple[list[str], list[NDArray[np.float64 | np.int64]], list[NDArray[np.int64 | np.float64]]]:
|
17
47
|
"""
|
18
48
|
Combines the factor names and data arrays of metadata_1 and metadata_2 when the names
|
19
49
|
match exactly and data has the same number of columns (factors).
|
@@ -42,8 +72,8 @@ def _combine_metadata(metadata_1: Metadata, metadata_2: Metadata) -> tuple[list[
|
|
42
72
|
If the length of keys do not match the length of the data
|
43
73
|
"""
|
44
74
|
factor_names: list[str] = []
|
45
|
-
m1_data: list[NDArray] = []
|
46
|
-
m2_data: list[NDArray] = []
|
75
|
+
m1_data: list[NDArray[np.int64 | np.float64]] = []
|
76
|
+
m2_data: list[NDArray[np.int64 | np.float64]] = []
|
47
77
|
|
48
78
|
# Both metadata must have the same number of factors (cols), but not necessarily samples (row)
|
49
79
|
if metadata_1.total_num_factors != metadata_2.total_num_factors:
|
@@ -121,36 +151,37 @@ def _calc_median_deviations(reference: NDArray, test: NDArray) -> NDArray:
|
|
121
151
|
|
122
152
|
|
123
153
|
@set_metadata
|
124
|
-
def
|
125
|
-
|
126
|
-
|
154
|
+
def find_most_deviated_factors(
|
155
|
+
metadata_ref: Metadata,
|
156
|
+
metadata_tst: Metadata,
|
127
157
|
ood: OODOutput,
|
128
158
|
) -> MostDeviatedFactorsOutput:
|
129
159
|
"""
|
130
|
-
|
160
|
+
Determine greatest deviation in metadata features per out of distribution sample in test metadata.
|
131
161
|
|
132
162
|
Parameters
|
133
163
|
----------
|
134
|
-
|
164
|
+
metadata_ref : Metadata
|
135
165
|
A reference set of Metadata containing factor names and samples
|
136
166
|
with discrete and/or continuous values per factor
|
137
|
-
|
167
|
+
metadata_tst : Metadata
|
138
168
|
The set of Metadata that is tested against the reference metadata.
|
139
169
|
This set must have the same number of features but does not require the same number of samples.
|
140
170
|
ood : OODOutput
|
141
|
-
A class output by
|
171
|
+
A class output by DataEval's OOD functions that contains which examples are OOD.
|
142
172
|
|
143
173
|
Returns
|
144
174
|
-------
|
145
|
-
|
146
|
-
An
|
175
|
+
MostDeviatedFactorsOutput
|
176
|
+
An output class containing the factor name and deviation of the highest metadata deviations for each
|
177
|
+
OOD example in the test metadata.
|
147
178
|
|
148
179
|
Notes
|
149
180
|
-----
|
150
181
|
1. Both :class:`.Metadata` inputs must have discrete and continuous data in the shape (samples, factors)
|
151
182
|
and have equivalent factor names and lengths
|
152
183
|
2. The flag at index `i` in :attr:`.OODOutput.is_ood` must correspond
|
153
|
-
directly to sample `i` of `
|
184
|
+
directly to sample `i` of `metadata_tst` being out-of-distribution from `metadata_ref`
|
154
185
|
|
155
186
|
Examples
|
156
187
|
--------
|
@@ -160,13 +191,13 @@ def most_deviated_factors(
|
|
160
191
|
All samples are out-of-distribution
|
161
192
|
|
162
193
|
>>> is_ood = OODOutput(np.array([True, True, True]), np.array([]), np.array([]))
|
163
|
-
>>>
|
194
|
+
>>> find_most_deviated_factors(metadata1, metadata2, is_ood)
|
164
195
|
MostDeviatedFactorsOutput([('time', 2.0), ('time', 2.592), ('time', 3.51)])
|
165
196
|
|
166
|
-
|
197
|
+
No samples are out-of-distribution
|
167
198
|
|
168
199
|
>>> is_ood = OODOutput(np.array([False, False, False]), np.array([]), np.array([]))
|
169
|
-
>>>
|
200
|
+
>>> find_most_deviated_factors(metadata1, metadata2, is_ood)
|
170
201
|
MostDeviatedFactorsOutput([])
|
171
202
|
"""
|
172
203
|
|
@@ -179,31 +210,30 @@ def most_deviated_factors(
|
|
179
210
|
# Combines reference and test factor names and data if exists and match exactly
|
180
211
|
# shape -> (samples, factors)
|
181
212
|
factor_names, md_1, md_2 = _combine_metadata(
|
182
|
-
metadata_1=
|
183
|
-
metadata_2=
|
213
|
+
metadata_1=metadata_ref,
|
214
|
+
metadata_2=metadata_tst,
|
184
215
|
)
|
185
216
|
|
186
217
|
# Stack discrete and continuous factors as separate factors. Must have equal sample counts
|
187
|
-
|
188
|
-
|
218
|
+
ref_data = np.hstack(md_1) if md_1 else np.array([]) # (S, Fd + Fc)
|
219
|
+
tst_data = np.hstack(md_2) if md_2 else np.array([]) # (S, Fd + Fc)
|
189
220
|
|
190
|
-
if len(
|
221
|
+
if len(ref_data) < 3:
|
191
222
|
warnings.warn(
|
192
|
-
f"At least 3 reference metadata samples are needed, got {len(
|
223
|
+
f"At least 3 reference metadata samples are needed, got {len(ref_data)}",
|
193
224
|
UserWarning,
|
194
225
|
)
|
195
226
|
return MostDeviatedFactorsOutput([])
|
196
227
|
|
197
|
-
if len(
|
228
|
+
if len(tst_data) != len(ood_mask):
|
198
229
|
raise ValueError(
|
199
|
-
f"ood and test metadata must have the same length, "
|
200
|
-
f"got {len(ood_mask)} and {len(metadata_tst)} respectively."
|
230
|
+
f"ood and test metadata must have the same length, got {len(ood_mask)} and {len(tst_data)} respectively."
|
201
231
|
)
|
202
232
|
|
203
233
|
# Calculates deviations of all samples in m2_data
|
204
234
|
# from the median values of the corresponding index in m1_data
|
205
235
|
# Guaranteed for inputs to not be empty
|
206
|
-
deviations = _calc_median_deviations(
|
236
|
+
deviations = _calc_median_deviations(ref_data, tst_data)
|
207
237
|
|
208
238
|
# Get most impactful factor deviation of each sample for ood samples only
|
209
239
|
deviation = np.max(deviations, axis=1)[ood_mask].astype(np.float16)
|
@@ -217,3 +247,90 @@ def most_deviated_factors(
|
|
217
247
|
# List of tuples matching the factor name with its deviation
|
218
248
|
|
219
249
|
return MostDeviatedFactorsOutput([(factor, dev) for factor, dev in zip(most_ood_factors, deviation)])
|
250
|
+
|
251
|
+
|
252
|
+
_NATS2BITS = 1.442695
|
253
|
+
"""
|
254
|
+
_NATS2BITS is the reciprocal of natural log of 2. If you have an information/entropy-type quantity measured in nats,
|
255
|
+
which is what many library functions return, multiply it by _NATS2BITS to get it in bits.
|
256
|
+
"""
|
257
|
+
|
258
|
+
|
259
|
+
def find_ood_predictors(
|
260
|
+
metadata: Metadata,
|
261
|
+
ood: OODOutput,
|
262
|
+
) -> OODPredictorOutput:
|
263
|
+
"""Computes mutual information between a set of metadata features and per sample out-of-distribution flags.
|
264
|
+
|
265
|
+
Given a set of metadata features per sample and a corresponding OODOutput that indicates whether a sample was
|
266
|
+
determined to be out of distribution, this function calculates the mutual information between each factor and being
|
267
|
+
out of distribution. In other words, it finds which metadata factors most likely correlate to an
|
268
|
+
out of distribution sample.
|
269
|
+
|
270
|
+
Note
|
271
|
+
----
|
272
|
+
A high mutual information between a factor and ood samples is an indication of correlation, but not causation.
|
273
|
+
Additional analysis should be done to determine how to handle factors with a high mutual information.
|
274
|
+
|
275
|
+
|
276
|
+
Parameters
|
277
|
+
----------
|
278
|
+
metadata : Metadata
|
279
|
+
A set of arrays of values, indexed by metadata feature names, with one value per data example per feature.
|
280
|
+
ood : OODOutput
|
281
|
+
A class output by DataEval's OOD functions that contains which examples are OOD.
|
282
|
+
|
283
|
+
Returns
|
284
|
+
-------
|
285
|
+
OODPredictorOutput
|
286
|
+
A dictionary with keys corresponding to metadata feature names, and values indicating the strength of
|
287
|
+
association between each named feature and the OOD flag, as mutual information measured in bits.
|
288
|
+
|
289
|
+
Examples
|
290
|
+
--------
|
291
|
+
>>> from dataeval.outputs import OODOutput
|
292
|
+
|
293
|
+
All samples are out-of-distribution
|
294
|
+
|
295
|
+
>>> is_ood = OODOutput(np.array([True, True, True]), np.array([]), np.array([]))
|
296
|
+
>>> find_ood_predictors(metadata1, is_ood)
|
297
|
+
OODPredictorOutput({'time': 8.008566032557951e-17, 'altitude': 8.008566032557951e-17})
|
298
|
+
|
299
|
+
No out-of-distribution samples
|
300
|
+
|
301
|
+
>> is_ood = OODOutput(np.array([False, False, False]), np.array([]), np.array([]))
|
302
|
+
>> find_ood_predictors(metadata1, is_ood)
|
303
|
+
OODPredictorOutput({})
|
304
|
+
"""
|
305
|
+
|
306
|
+
ood_mask: NDArray[np.bool] = ood.is_ood
|
307
|
+
|
308
|
+
discrete_features_count = len(metadata.discrete_factor_names)
|
309
|
+
factors, data = _combine_discrete_continuous(metadata) # (F, ), (S, F) => F = Fd + Fc
|
310
|
+
|
311
|
+
# No metadata correlated with out of distribution data, return 0.0 for all factors
|
312
|
+
if not any(ood_mask):
|
313
|
+
return OODPredictorOutput(dict.fromkeys(factors, 0.0))
|
314
|
+
|
315
|
+
if len(data) != len(ood_mask):
|
316
|
+
raise ValueError(
|
317
|
+
f"ood and metadata must have the same length, got {len(ood_mask)} and {len(data)} respectively."
|
318
|
+
)
|
319
|
+
|
320
|
+
# Calculate mean, std of each factor over all samples
|
321
|
+
scaled_data = (data - np.mean(data, axis=0)) / np.std(data, axis=0, ddof=1) # (S, F)
|
322
|
+
|
323
|
+
discrete_features = np.zeros_like(factors, dtype=np.bool)
|
324
|
+
discrete_features[:discrete_features_count] = True
|
325
|
+
|
326
|
+
mutual_info_values = (
|
327
|
+
mutual_info_classif(
|
328
|
+
X=scaled_data,
|
329
|
+
y=ood_mask,
|
330
|
+
discrete_features=discrete_features, # type: ignore -> sklearn issue - NDArray[bool] not of accepted type Union[ArrayLike, 'auto']
|
331
|
+
random_state=get_seed(),
|
332
|
+
)
|
333
|
+
* _NATS2BITS
|
334
|
+
)
|
335
|
+
|
336
|
+
return OODPredictorOutput({k: mutual_info_values[i] for i, k in enumerate(factors)})
|
@@ -6,10 +6,12 @@ representation which may impact model performance.
|
|
6
6
|
__all__ = [
|
7
7
|
"BalanceOutput",
|
8
8
|
"CoverageOutput",
|
9
|
+
"CompletenessOutput",
|
9
10
|
"DiversityOutput",
|
10
11
|
"LabelParityOutput",
|
11
12
|
"ParityOutput",
|
12
13
|
"balance",
|
14
|
+
"completeness",
|
13
15
|
"coverage",
|
14
16
|
"diversity",
|
15
17
|
"label_parity",
|
@@ -17,7 +19,15 @@ __all__ = [
|
|
17
19
|
]
|
18
20
|
|
19
21
|
from dataeval.metrics.bias._balance import balance
|
22
|
+
from dataeval.metrics.bias._completeness import completeness
|
20
23
|
from dataeval.metrics.bias._coverage import coverage
|
21
24
|
from dataeval.metrics.bias._diversity import diversity
|
22
25
|
from dataeval.metrics.bias._parity import label_parity, parity
|
23
|
-
from dataeval.outputs._bias import
|
26
|
+
from dataeval.outputs._bias import (
|
27
|
+
BalanceOutput,
|
28
|
+
CompletenessOutput,
|
29
|
+
CoverageOutput,
|
30
|
+
DiversityOutput,
|
31
|
+
LabelParityOutput,
|
32
|
+
ParityOutput,
|
33
|
+
)
|
@@ -8,7 +8,7 @@ import numpy as np
|
|
8
8
|
import scipy as sp
|
9
9
|
from sklearn.feature_selection import mutual_info_classif, mutual_info_regression
|
10
10
|
|
11
|
-
from dataeval.config import get_seed
|
11
|
+
from dataeval.config import EPSILON, get_seed
|
12
12
|
from dataeval.outputs import BalanceOutput
|
13
13
|
from dataeval.outputs._base import set_metadata
|
14
14
|
from dataeval.utils._bin import get_counts
|
@@ -128,7 +128,7 @@ def balance(
|
|
128
128
|
# Normalization via entropy
|
129
129
|
bin_cnts = get_counts(discretized_data)
|
130
130
|
ent_factor = sp.stats.entropy(bin_cnts, axis=0)
|
131
|
-
norm_factor = 0.5 * np.add.outer(ent_factor, ent_factor) +
|
131
|
+
norm_factor = 0.5 * np.add.outer(ent_factor, ent_factor) + EPSILON
|
132
132
|
|
133
133
|
# in principle MI should be symmetric, but it is not in practice.
|
134
134
|
nmi = 0.5 * (mi + mi.T) / norm_factor
|
@@ -157,7 +157,7 @@ def balance(
|
|
157
157
|
# Classwise normalization via entropy
|
158
158
|
classwise_bin_cnts = get_counts(tgt_bin)
|
159
159
|
ent_tgt_bin = sp.stats.entropy(classwise_bin_cnts, axis=0)
|
160
|
-
norm_factor = 0.5 * np.add.outer(ent_tgt_bin, ent_factor) +
|
160
|
+
norm_factor = 0.5 * np.add.outer(ent_tgt_bin, ent_factor) + EPSILON
|
161
161
|
classwise = classwise_mi / norm_factor
|
162
162
|
|
163
163
|
# Grabbing factor names for plotting function
|
@@ -0,0 +1,130 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import itertools
|
4
|
+
|
5
|
+
__all__ = []
|
6
|
+
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
|
10
|
+
from dataeval.config import EPSILON
|
11
|
+
from dataeval.outputs import CompletenessOutput
|
12
|
+
from dataeval.typing import ArrayLike
|
13
|
+
from dataeval.utils._array import ensure_embeddings
|
14
|
+
|
15
|
+
|
16
|
+
def completeness(embeddings: ArrayLike, quantiles: int) -> CompletenessOutput:
|
17
|
+
"""
|
18
|
+
Calculate the fraction of boxes in a grid defined by quantiles that
|
19
|
+
contain at least one data point.
|
20
|
+
Also returns the center coordinates of each empty box.
|
21
|
+
|
22
|
+
Parameters
|
23
|
+
----------
|
24
|
+
embeddings : ArrayLike
|
25
|
+
Embedded dataset (or other low-dimensional data) (nxp)
|
26
|
+
quantiles : int
|
27
|
+
number of quantile values to use for partitioning each dimension
|
28
|
+
e.g., 1 would create a grid of 2^p boxes, 2, 3^p etc..
|
29
|
+
|
30
|
+
Returns
|
31
|
+
-------
|
32
|
+
CompletenessOutput
|
33
|
+
- fraction_filled: float - Fraction of boxes that contain at least one
|
34
|
+
data point
|
35
|
+
- empty_box_centers: List[np.ndarray] - List of coordinates for centers of empty
|
36
|
+
boxes
|
37
|
+
|
38
|
+
Raises
|
39
|
+
------
|
40
|
+
ValueError
|
41
|
+
If embeddings are too high-dimensional (>10)
|
42
|
+
ValueError
|
43
|
+
If there are too many quantiles (>2)
|
44
|
+
ValueError
|
45
|
+
If embedding is invalid shape
|
46
|
+
|
47
|
+
Example
|
48
|
+
-------
|
49
|
+
>>> embs = np.array([[1, 0], [0, 1], [1, 1]])
|
50
|
+
>>> quantiles = 1
|
51
|
+
>>> result = completeness(embs, quantiles)
|
52
|
+
>>> result.fraction_filled
|
53
|
+
0.75
|
54
|
+
|
55
|
+
Reference
|
56
|
+
---------
|
57
|
+
This implementation is based on https://arxiv.org/abs/2002.03147.
|
58
|
+
|
59
|
+
[1] Byun, Taejoon, and Sanjai Rayadurgam. “Manifold for Machine Learning Assurance.”
|
60
|
+
Proceedings of the ACM/IEEE 42nd International Conference on Software Engineering
|
61
|
+
"""
|
62
|
+
# Ensure proper data format
|
63
|
+
embeddings = ensure_embeddings(embeddings, dtype=np.float64, unit_interval=False)
|
64
|
+
|
65
|
+
# Get data dimensions
|
66
|
+
n, p = embeddings.shape
|
67
|
+
if quantiles > 2 or quantiles <= 0:
|
68
|
+
raise ValueError(
|
69
|
+
f"Number of quantiles ({quantiles}) is greater than 2 or is nonpositive. \
|
70
|
+
The metric scales exponentially in this value. Please 1 or 2 quantiles."
|
71
|
+
)
|
72
|
+
if p > 10:
|
73
|
+
raise ValueError(
|
74
|
+
f"Dimension of embeddings ({p}) is greater than 10. \
|
75
|
+
The metric scales exponentially in this value. Please reduce the embedding dimension."
|
76
|
+
)
|
77
|
+
if n == 0 or p == 0:
|
78
|
+
raise ValueError("Your provided embeddings do not contain any data!")
|
79
|
+
# n+2 edges partition the embedding dimension (e.g. [0,0.5,1] for quantiles = 1)
|
80
|
+
quantile_vec = np.linspace(0, 1, quantiles + 2)
|
81
|
+
|
82
|
+
# Calculate the bin edges for each dimension based on quantiles
|
83
|
+
bin_edges = []
|
84
|
+
for dim in range(p):
|
85
|
+
# Calculate the quantile values for this feature
|
86
|
+
edges = np.array(np.quantile(embeddings[:, dim], quantile_vec))
|
87
|
+
# Make sure the last bin contains all the remaining points
|
88
|
+
edges[-1] += EPSILON
|
89
|
+
bin_edges.append(edges)
|
90
|
+
# Convert each data point into its corresponding grid cell indices
|
91
|
+
grid_indices = []
|
92
|
+
for dim in range(p):
|
93
|
+
# For each dimension, find which bin each data point belongs to
|
94
|
+
# Digitize is 1 indexed so we subtract 1
|
95
|
+
indices = np.digitize(embeddings[:, dim], bin_edges[dim]) - 1
|
96
|
+
grid_indices.append(indices)
|
97
|
+
|
98
|
+
# Make the rows the data point and the column the grid index
|
99
|
+
grid_coords = np.array(grid_indices).T
|
100
|
+
|
101
|
+
# Use set to find unique tuple of grid coordinates
|
102
|
+
occupied_cells = set(map(tuple, grid_coords))
|
103
|
+
|
104
|
+
# For the fraction
|
105
|
+
num_occupied_cells = len(occupied_cells)
|
106
|
+
|
107
|
+
# Calculate total possible cells in the grid
|
108
|
+
num_bins_per_dim = [len(edges) - 1 for edges in bin_edges]
|
109
|
+
total_possible_cells = np.prod(num_bins_per_dim)
|
110
|
+
|
111
|
+
# Generate all possible grid cells
|
112
|
+
all_cells = set(itertools.product(*[range(bins) for bins in num_bins_per_dim]))
|
113
|
+
|
114
|
+
# Find the empty cells (cells with no data points)
|
115
|
+
empty_cells = all_cells - occupied_cells
|
116
|
+
|
117
|
+
# Calculate center points of empty boxes
|
118
|
+
empty_box_centers = []
|
119
|
+
for cell in empty_cells:
|
120
|
+
center_coords = []
|
121
|
+
for dim, idx in enumerate(cell):
|
122
|
+
# Calculate center of the bin as midpoint between edges
|
123
|
+
center = (bin_edges[dim][idx] + bin_edges[dim][idx + 1]) / 2
|
124
|
+
center_coords.append(center)
|
125
|
+
empty_box_centers.append(np.array(center_coords))
|
126
|
+
|
127
|
+
# Calculate the fraction
|
128
|
+
fraction = float(num_occupied_cells / total_possible_cells)
|
129
|
+
empty_box_centers = np.array(empty_box_centers)
|
130
|
+
return CompletenessOutput(fraction, empty_box_centers)
|
@@ -19,6 +19,7 @@ from numpy.typing import NDArray
|
|
19
19
|
from scipy.sparse import coo_matrix
|
20
20
|
from scipy.stats import mode
|
21
21
|
|
22
|
+
from dataeval.config import EPSILON
|
22
23
|
from dataeval.outputs import BEROutput
|
23
24
|
from dataeval.outputs._base import set_metadata
|
24
25
|
from dataeval.typing import ArrayLike
|
@@ -82,7 +83,7 @@ def ber_knn(images: NDArray[np.float64], labels: NDArray[np.int_], k: int) -> tu
|
|
82
83
|
|
83
84
|
def knn_lowerbound(value: float, classes: int, k: int) -> float:
|
84
85
|
"""Several cases for computing the BER lower bound"""
|
85
|
-
if value <=
|
86
|
+
if value <= EPSILON:
|
86
87
|
return 0.0
|
87
88
|
|
88
89
|
if classes == 2 and k != 1:
|