dataeval 0.76.1__py3-none-any.whl → 0.81.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +3 -3
- dataeval/{output.py → _output.py} +14 -0
- dataeval/config.py +77 -0
- dataeval/detectors/__init__.py +1 -1
- dataeval/detectors/drift/__init__.py +6 -6
- dataeval/detectors/drift/{base.py → _base.py} +41 -30
- dataeval/detectors/drift/{cvm.py → _cvm.py} +21 -28
- dataeval/detectors/drift/{ks.py → _ks.py} +20 -26
- dataeval/detectors/drift/{mmd.py → _mmd.py} +33 -19
- dataeval/detectors/drift/{torch.py → _torch.py} +2 -1
- dataeval/detectors/drift/{uncertainty.py → _uncertainty.py} +23 -7
- dataeval/detectors/drift/updates.py +1 -1
- dataeval/detectors/linters/__init__.py +0 -3
- dataeval/detectors/linters/duplicates.py +17 -8
- dataeval/detectors/linters/outliers.py +23 -14
- dataeval/detectors/ood/ae.py +29 -8
- dataeval/detectors/ood/base.py +5 -4
- dataeval/detectors/ood/metadata_ks_compare.py +1 -1
- dataeval/detectors/ood/mixin.py +20 -5
- dataeval/detectors/ood/output.py +1 -1
- dataeval/detectors/ood/vae.py +73 -0
- dataeval/metadata/__init__.py +5 -0
- dataeval/metadata/_ood.py +238 -0
- dataeval/metrics/__init__.py +1 -1
- dataeval/metrics/bias/__init__.py +5 -4
- dataeval/metrics/bias/{balance.py → _balance.py} +67 -17
- dataeval/metrics/bias/{coverage.py → _coverage.py} +41 -35
- dataeval/metrics/bias/{diversity.py → _diversity.py} +17 -12
- dataeval/metrics/bias/{parity.py → _parity.py} +89 -61
- dataeval/metrics/estimators/__init__.py +14 -4
- dataeval/metrics/estimators/{ber.py → _ber.py} +42 -11
- dataeval/metrics/estimators/_clusterer.py +104 -0
- dataeval/metrics/estimators/{divergence.py → _divergence.py} +18 -13
- dataeval/metrics/estimators/{uap.py → _uap.py} +4 -4
- dataeval/metrics/stats/__init__.py +7 -7
- dataeval/metrics/stats/{base.py → _base.py} +52 -16
- dataeval/metrics/stats/{boxratiostats.py → _boxratiostats.py} +6 -9
- dataeval/metrics/stats/{datasetstats.py → _datasetstats.py} +10 -14
- dataeval/metrics/stats/{dimensionstats.py → _dimensionstats.py} +6 -5
- dataeval/metrics/stats/{hashstats.py → _hashstats.py} +6 -6
- dataeval/metrics/stats/{labelstats.py → _labelstats.py} +4 -4
- dataeval/metrics/stats/{pixelstats.py → _pixelstats.py} +5 -4
- dataeval/metrics/stats/{visualstats.py → _visualstats.py} +9 -8
- dataeval/typing.py +54 -0
- dataeval/utils/__init__.py +2 -2
- dataeval/utils/_array.py +169 -0
- dataeval/utils/_bin.py +199 -0
- dataeval/utils/_clusterer.py +144 -0
- dataeval/utils/_fast_mst.py +189 -0
- dataeval/utils/{image.py → _image.py} +6 -4
- dataeval/utils/_method.py +18 -0
- dataeval/utils/{shared.py → _mst.py} +3 -65
- dataeval/utils/{plot.py → _plot.py} +4 -4
- dataeval/utils/data/__init__.py +22 -0
- dataeval/utils/data/_embeddings.py +105 -0
- dataeval/utils/data/_images.py +65 -0
- dataeval/utils/data/_metadata.py +352 -0
- dataeval/utils/data/_selection.py +119 -0
- dataeval/utils/{dataset/split.py → data/_split.py} +13 -14
- dataeval/utils/data/_targets.py +73 -0
- dataeval/utils/data/_types.py +58 -0
- dataeval/utils/data/collate.py +103 -0
- dataeval/utils/data/datasets/__init__.py +17 -0
- dataeval/utils/data/datasets/_base.py +254 -0
- dataeval/utils/data/datasets/_cifar10.py +134 -0
- dataeval/utils/data/datasets/_fileio.py +168 -0
- dataeval/utils/data/datasets/_milco.py +153 -0
- dataeval/utils/data/datasets/_mixin.py +56 -0
- dataeval/utils/data/datasets/_mnist.py +183 -0
- dataeval/utils/data/datasets/_ships.py +123 -0
- dataeval/utils/data/datasets/_voc.py +352 -0
- dataeval/utils/data/selections/__init__.py +15 -0
- dataeval/utils/data/selections/_classfilter.py +60 -0
- dataeval/utils/data/selections/_indices.py +26 -0
- dataeval/utils/data/selections/_limit.py +26 -0
- dataeval/utils/data/selections/_reverse.py +18 -0
- dataeval/utils/data/selections/_shuffle.py +29 -0
- dataeval/utils/metadata.py +51 -376
- dataeval/utils/torch/{gmm.py → _gmm.py} +4 -2
- dataeval/utils/torch/{internal.py → _internal.py} +21 -51
- dataeval/utils/torch/models.py +43 -2
- dataeval/workflows/sufficiency.py +10 -9
- {dataeval-0.76.1.dist-info → dataeval-0.81.0.dist-info}/METADATA +4 -1
- dataeval-0.81.0.dist-info/RECORD +94 -0
- dataeval/detectors/linters/clusterer.py +0 -512
- dataeval/detectors/linters/merged_stats.py +0 -49
- dataeval/detectors/ood/metadata_least_likely.py +0 -119
- dataeval/interop.py +0 -69
- dataeval/utils/dataset/__init__.py +0 -7
- dataeval/utils/dataset/datasets.py +0 -412
- dataeval/utils/dataset/read.py +0 -63
- dataeval-0.76.1.dist-info/RECORD +0 -67
- /dataeval/{log.py → _log.py} +0 -0
- /dataeval/utils/torch/{blocks.py → _blocks.py} +0 -0
- {dataeval-0.76.1.dist-info → dataeval-0.81.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.76.1.dist-info → dataeval-0.81.0.dist-info}/WHEEL +0 -0
@@ -3,13 +3,14 @@ from __future__ import annotations
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
5
|
from dataclasses import dataclass
|
6
|
-
from typing import Generic, Iterable, Sequence, TypeVar, overload
|
6
|
+
from typing import Any, Generic, Iterable, Sequence, TypeVar, overload
|
7
7
|
|
8
|
-
from
|
8
|
+
from torch.utils.data import Dataset
|
9
9
|
|
10
|
-
from dataeval.
|
11
|
-
from dataeval.metrics.stats.
|
12
|
-
from dataeval.
|
10
|
+
from dataeval._output import Output, set_metadata
|
11
|
+
from dataeval.metrics.stats._base import combine_stats, get_dataset_step_from_idx
|
12
|
+
from dataeval.metrics.stats._hashstats import HashStatsOutput, hashstats
|
13
|
+
from dataeval.typing import ArrayLike
|
13
14
|
|
14
15
|
DuplicateGroup = list[int]
|
15
16
|
DatasetDuplicateGroupMap = dict[int, DuplicateGroup]
|
@@ -19,7 +20,7 @@ TIndexCollection = TypeVar("TIndexCollection", DuplicateGroup, DatasetDuplicateG
|
|
19
20
|
@dataclass(frozen=True)
|
20
21
|
class DuplicatesOutput(Generic[TIndexCollection], Output):
|
21
22
|
"""
|
22
|
-
Output class for :class
|
23
|
+
Output class for :class:`.Duplicates` lint detector.
|
23
24
|
|
24
25
|
Attributes
|
25
26
|
----------
|
@@ -133,8 +134,15 @@ class Duplicates:
|
|
133
134
|
|
134
135
|
return DuplicatesOutput(**duplicates)
|
135
136
|
|
137
|
+
@overload
|
138
|
+
def evaluate(self, data: Iterable[ArrayLike]) -> DuplicatesOutput[DuplicateGroup]: ...
|
139
|
+
@overload
|
140
|
+
def evaluate(self, data: Dataset[tuple[ArrayLike, Any, dict[str, Any]]]) -> DuplicatesOutput[DuplicateGroup]: ...
|
141
|
+
|
136
142
|
@set_metadata(state=["only_exact"])
|
137
|
-
def evaluate(
|
143
|
+
def evaluate(
|
144
|
+
self, data: Iterable[ArrayLike] | Dataset[tuple[ArrayLike, Any, dict[str, Any]]]
|
145
|
+
) -> DuplicatesOutput[DuplicateGroup]:
|
138
146
|
"""
|
139
147
|
Returns duplicate image indices for both exact matches and near matches
|
140
148
|
|
@@ -158,6 +166,7 @@ class Duplicates:
|
|
158
166
|
>>> all_dupes.evaluate(duplicate_images)
|
159
167
|
DuplicatesOutput(exact=[[3, 20], [16, 37]], near=[[3, 20, 22], [12, 18], [13, 36], [14, 31], [17, 27], [19, 38, 47]])
|
160
168
|
""" # noqa: E501
|
161
|
-
|
169
|
+
images = (d[0] for d in data) if isinstance(data, Dataset) else data
|
170
|
+
self.stats = hashstats(images)
|
162
171
|
duplicates = self._get_duplicates(self.stats.dict())
|
163
172
|
return DuplicatesOutput(**duplicates)
|
@@ -4,19 +4,20 @@ __all__ = []
|
|
4
4
|
|
5
5
|
import contextlib
|
6
6
|
from dataclasses import dataclass
|
7
|
-
from typing import Generic, Iterable, Literal, Sequence, TypeVar, Union, overload
|
7
|
+
from typing import Any, Generic, Iterable, Literal, Sequence, TypeVar, Union, overload
|
8
8
|
|
9
9
|
import numpy as np
|
10
|
-
from numpy.typing import
|
11
|
-
|
12
|
-
|
13
|
-
from dataeval.
|
14
|
-
from dataeval.metrics.stats.
|
15
|
-
from dataeval.metrics.stats.
|
16
|
-
from dataeval.metrics.stats.
|
17
|
-
from dataeval.metrics.stats.
|
18
|
-
from dataeval.metrics.stats.
|
19
|
-
from dataeval.
|
10
|
+
from numpy.typing import NDArray
|
11
|
+
from torch.utils.data import Dataset
|
12
|
+
|
13
|
+
from dataeval._output import Output, set_metadata
|
14
|
+
from dataeval.metrics.stats._base import BOX_COUNT, SOURCE_INDEX, combine_stats, get_dataset_step_from_idx
|
15
|
+
from dataeval.metrics.stats._datasetstats import DatasetStatsOutput, datasetstats
|
16
|
+
from dataeval.metrics.stats._dimensionstats import DimensionStatsOutput
|
17
|
+
from dataeval.metrics.stats._labelstats import LabelStatsOutput
|
18
|
+
from dataeval.metrics.stats._pixelstats import PixelStatsOutput
|
19
|
+
from dataeval.metrics.stats._visualstats import VisualStatsOutput
|
20
|
+
from dataeval.typing import ArrayLike
|
20
21
|
|
21
22
|
with contextlib.suppress(ImportError):
|
22
23
|
import pandas as pd
|
@@ -84,7 +85,7 @@ def _create_pandas_dataframe(class_wise):
|
|
84
85
|
@dataclass(frozen=True)
|
85
86
|
class OutliersOutput(Generic[TIndexIssueMap], Output):
|
86
87
|
"""
|
87
|
-
Output class for :class
|
88
|
+
Output class for :class:`.Outliers` lint detector.
|
88
89
|
|
89
90
|
Attributes
|
90
91
|
----------
|
@@ -322,8 +323,15 @@ class Outliers:
|
|
322
323
|
|
323
324
|
return OutliersOutput(output_list)
|
324
325
|
|
326
|
+
@overload
|
327
|
+
def evaluate(self, data: Iterable[ArrayLike]) -> OutliersOutput[IndexIssueMap]: ...
|
328
|
+
@overload
|
329
|
+
def evaluate(self, data: Dataset[tuple[ArrayLike, Any, dict[str, Any]]]) -> OutliersOutput[IndexIssueMap]: ...
|
330
|
+
|
325
331
|
@set_metadata(state=["use_dimension", "use_pixel", "use_visual", "outlier_method", "outlier_threshold"])
|
326
|
-
def evaluate(
|
332
|
+
def evaluate(
|
333
|
+
self, data: Iterable[ArrayLike] | Dataset[tuple[ArrayLike, Any, dict[str, Any]]]
|
334
|
+
) -> OutliersOutput[IndexIssueMap]:
|
327
335
|
"""
|
328
336
|
Returns indices of Outliers with the issues identified for each
|
329
337
|
|
@@ -349,6 +357,7 @@ class Outliers:
|
|
349
357
|
>>> results.issues[10]
|
350
358
|
{'skew': -3.906, 'kurtosis': 13.266, 'entropy': 0.2128, 'contrast': 1.25, 'zeros': 0.05493}
|
351
359
|
"""
|
352
|
-
|
360
|
+
images = (d[0] for d in data) if isinstance(data, Dataset) else data
|
361
|
+
self.stats = datasetstats(images=images)
|
353
362
|
outliers = self._get_outliers(self.stats.dict())
|
354
363
|
return OutliersOutput(outliers)
|
dataeval/detectors/ood/ae.py
CHANGED
@@ -16,12 +16,12 @@ from typing import Callable
|
|
16
16
|
|
17
17
|
import numpy as np
|
18
18
|
import torch
|
19
|
-
from numpy.typing import
|
19
|
+
from numpy.typing import NDArray
|
20
20
|
|
21
21
|
from dataeval.detectors.ood.base import OODBase
|
22
22
|
from dataeval.detectors.ood.output import OODScoreOutput
|
23
|
-
from dataeval.
|
24
|
-
from dataeval.utils.torch.
|
23
|
+
from dataeval.typing import ArrayLike
|
24
|
+
from dataeval.utils.torch._internal import predict_batch
|
25
25
|
|
26
26
|
|
27
27
|
class OOD_AE(OODBase):
|
@@ -30,8 +30,31 @@ class OOD_AE(OODBase):
|
|
30
30
|
|
31
31
|
Parameters
|
32
32
|
----------
|
33
|
-
model :
|
34
|
-
An
|
33
|
+
model : torch.nn.Module
|
34
|
+
An autoencoder model to use for encoding and reconstruction of images
|
35
|
+
for detection of out-of-distribution samples.
|
36
|
+
device : str or torch.Device or None, default None
|
37
|
+
The device to use for the detector. None will default to the global
|
38
|
+
configuration selection if set, otherwise "cuda" then "cpu" by availability.
|
39
|
+
|
40
|
+
Example
|
41
|
+
-------
|
42
|
+
Perform out-of-distribution detection on test data.
|
43
|
+
|
44
|
+
>>> from dataeval.utils.torch.models import AE
|
45
|
+
|
46
|
+
>>> input_shape = train_images[0].shape
|
47
|
+
>>> ood = OOD_AE(AE(input_shape))
|
48
|
+
|
49
|
+
Train the autoencoder using the training data.
|
50
|
+
|
51
|
+
>>> ood.fit(train_images, threshold_perc=99, epochs=20)
|
52
|
+
|
53
|
+
Test for out-of-distribution samples on the test data.
|
54
|
+
|
55
|
+
>>> output = ood.predict(test_images)
|
56
|
+
>>> output.is_ood
|
57
|
+
array([ True, True, False, True, True, True, True, True])
|
35
58
|
"""
|
36
59
|
|
37
60
|
def __init__(self, model: torch.nn.Module, device: str | torch.device | None = None) -> None:
|
@@ -55,9 +78,7 @@ class OOD_AE(OODBase):
|
|
55
78
|
|
56
79
|
super().fit(x_ref, threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
|
57
80
|
|
58
|
-
def _score(self, X:
|
59
|
-
self._validate(X := as_numpy(X))
|
60
|
-
|
81
|
+
def _score(self, X: NDArray[np.float32], batch_size: int = int(1e10)) -> OODScoreOutput:
|
61
82
|
# reconstruct instances
|
62
83
|
X_recon = predict_batch(X, self.model, batch_size=batch_size)
|
63
84
|
|
dataeval/detectors/ood/base.py
CHANGED
@@ -13,12 +13,13 @@ __all__ = []
|
|
13
13
|
from typing import Callable, cast
|
14
14
|
|
15
15
|
import torch
|
16
|
-
from numpy.typing import ArrayLike
|
17
16
|
|
17
|
+
from dataeval.config import get_device
|
18
18
|
from dataeval.detectors.ood.mixin import OODBaseMixin, OODFitMixin, OODGMMMixin
|
19
|
-
from dataeval.
|
20
|
-
from dataeval.utils.
|
21
|
-
from dataeval.utils.torch.
|
19
|
+
from dataeval.typing import ArrayLike
|
20
|
+
from dataeval.utils._array import to_numpy
|
21
|
+
from dataeval.utils.torch._gmm import GaussianMixtureModelParams, gmm_params
|
22
|
+
from dataeval.utils.torch._internal import trainer
|
22
23
|
|
23
24
|
|
24
25
|
class OODBase(OODBaseMixin[torch.nn.Module], OODFitMixin[Callable[..., torch.nn.Module], torch.optim.Optimizer]):
|
@@ -11,7 +11,7 @@ from numpy.typing import NDArray
|
|
11
11
|
from scipy.stats import iqr, ks_2samp
|
12
12
|
from scipy.stats import wasserstein_distance as emd
|
13
13
|
|
14
|
-
from dataeval.
|
14
|
+
from dataeval._output import MappingOutput, set_metadata
|
15
15
|
|
16
16
|
|
17
17
|
class MetadataKSResult(NamedTuple):
|
dataeval/detectors/ood/mixin.py
CHANGED
@@ -8,10 +8,11 @@ from abc import ABC, abstractmethod
|
|
8
8
|
from typing import Callable, Generic, Literal, TypeVar
|
9
9
|
|
10
10
|
import numpy as np
|
11
|
-
from numpy.typing import
|
11
|
+
from numpy.typing import NDArray
|
12
12
|
|
13
|
-
from dataeval.
|
14
|
-
from dataeval.
|
13
|
+
from dataeval._output import set_metadata
|
14
|
+
from dataeval.typing import ArrayLike
|
15
|
+
from dataeval.utils._array import as_numpy, to_numpy
|
15
16
|
|
16
17
|
TGMMParams = TypeVar("TGMMParams")
|
17
18
|
|
@@ -73,6 +74,9 @@ class OODBaseMixin(Generic[TModel], ABC):
|
|
73
74
|
def _get_data_info(self, X: NDArray) -> tuple[tuple, type]:
|
74
75
|
if not isinstance(X, np.ndarray):
|
75
76
|
raise TypeError("Dataset should of type: `NDArray`.")
|
77
|
+
if np.min(X) < 0 or np.max(X) > 1:
|
78
|
+
raise ValueError("Embeddings must be on the unit interval [0-1].")
|
79
|
+
|
76
80
|
return X.shape[1:], X.dtype.type
|
77
81
|
|
78
82
|
def _validate(self, X: NDArray) -> None:
|
@@ -90,7 +94,7 @@ class OODBaseMixin(Generic[TModel], ABC):
|
|
90
94
|
self._validate(X)
|
91
95
|
|
92
96
|
@abstractmethod
|
93
|
-
def _score(self, X:
|
97
|
+
def _score(self, X: NDArray[np.float32], batch_size: int = int(1e10)) -> OODScoreOutput: ...
|
94
98
|
|
95
99
|
@set_metadata
|
96
100
|
def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
|
@@ -105,11 +109,17 @@ class OODBaseMixin(Generic[TModel], ABC):
|
|
105
109
|
Number of instances to process in each batch.
|
106
110
|
Use a smaller batch size if your dataset is large or if you encounter memory issues.
|
107
111
|
|
112
|
+
Raises
|
113
|
+
------
|
114
|
+
ValueError
|
115
|
+
X input data must be unit interval [0-1].
|
116
|
+
|
108
117
|
Returns
|
109
118
|
-------
|
110
119
|
OODScoreOutput
|
111
120
|
An object containing the instance-level and feature-level OOD scores.
|
112
121
|
"""
|
122
|
+
self._validate(X := as_numpy(X).astype(np.float32))
|
113
123
|
return self._score(X, batch_size)
|
114
124
|
|
115
125
|
def _threshold_score(self, ood_type: Literal["feature", "instance"] = "instance") -> np.floating:
|
@@ -134,12 +144,17 @@ class OODBaseMixin(Generic[TModel], ABC):
|
|
134
144
|
ood_type : "feature" | "instance", default "instance"
|
135
145
|
Predict out-of-distribution at the 'feature' or 'instance' level.
|
136
146
|
|
147
|
+
Raises
|
148
|
+
------
|
149
|
+
ValueError
|
150
|
+
X input data must be unit interval [0-1].
|
151
|
+
|
137
152
|
Returns
|
138
153
|
-------
|
139
154
|
Dictionary containing the outlier predictions for the selected level,
|
140
155
|
and the OOD scores for the data including both 'instance' and 'feature' (if present) level scores.
|
141
156
|
"""
|
142
|
-
self._validate_state(X := to_numpy(X))
|
157
|
+
self._validate_state(X := to_numpy(X).astype(np.float32))
|
143
158
|
# compute outlier scores
|
144
159
|
score = self.score(X, batch_size=batch_size)
|
145
160
|
ood_pred = score.get(ood_type) > self._threshold_score(ood_type)
|
dataeval/detectors/ood/output.py
CHANGED
@@ -0,0 +1,73 @@
|
|
1
|
+
"""
|
2
|
+
Adapted for Pytorch from
|
3
|
+
|
4
|
+
Source code derived from Alibi-Detect 0.11.4
|
5
|
+
https://github.com/SeldonIO/alibi-detect/tree/v0.11.4
|
6
|
+
|
7
|
+
Original code Copyright (c) 2023 Seldon Technologies Ltd
|
8
|
+
Licensed under Apache Software License (Apache 2.0)
|
9
|
+
"""
|
10
|
+
|
11
|
+
from __future__ import annotations
|
12
|
+
|
13
|
+
__all__ = []
|
14
|
+
|
15
|
+
from typing import Callable
|
16
|
+
|
17
|
+
import numpy as np
|
18
|
+
import torch
|
19
|
+
|
20
|
+
from dataeval.detectors.ood.base import OODBase
|
21
|
+
from dataeval.detectors.ood.output import OODScoreOutput
|
22
|
+
from dataeval.typing import ArrayLike
|
23
|
+
from dataeval.utils._array import as_numpy
|
24
|
+
from dataeval.utils.torch._internal import predict_batch
|
25
|
+
|
26
|
+
|
27
|
+
class OOD_VAE(OODBase):
|
28
|
+
"""
|
29
|
+
Autoencoder based out-of-distribution detector.
|
30
|
+
|
31
|
+
Parameters
|
32
|
+
----------
|
33
|
+
model : Autoencoder
|
34
|
+
An Autoencoder model.
|
35
|
+
"""
|
36
|
+
|
37
|
+
def __init__(self, model: torch.nn.Module, device: str | torch.device | None = None) -> None:
|
38
|
+
super().__init__(model, device)
|
39
|
+
|
40
|
+
def fit(
|
41
|
+
self,
|
42
|
+
x_ref: ArrayLike,
|
43
|
+
threshold_perc: float,
|
44
|
+
loss_fn: Callable[..., torch.nn.Module] | None = None,
|
45
|
+
optimizer: torch.optim.Optimizer | None = None,
|
46
|
+
epochs: int = 20,
|
47
|
+
batch_size: int = 64,
|
48
|
+
verbose: bool = False,
|
49
|
+
) -> None:
|
50
|
+
if loss_fn is None:
|
51
|
+
loss_fn = torch.nn.MSELoss()
|
52
|
+
|
53
|
+
if optimizer is None:
|
54
|
+
optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4)
|
55
|
+
|
56
|
+
super().fit(x_ref, threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
|
57
|
+
|
58
|
+
def _score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
|
59
|
+
self._validate(X := as_numpy(X))
|
60
|
+
|
61
|
+
# reconstruct instances
|
62
|
+
X_recon = predict_batch(X, self.model, batch_size=batch_size)[0] # don't need mu or logvar from model
|
63
|
+
|
64
|
+
# compute feature and instance level scores
|
65
|
+
fscore = np.power(X.reshape((len(X), -1)) - X_recon, 2)
|
66
|
+
# fscore_flat = fscore.reshape(fscore.shape[0], -1).copy()
|
67
|
+
# n_score_features = int(np.ceil(fscore_flat.shape[1]))
|
68
|
+
# sorted_fscore = np.sort(fscore_flat, axis=1)
|
69
|
+
# sorted_fscore_perc = sorted_fscore[:, -n_score_features:]
|
70
|
+
# iscore = np.mean(sorted_fscore_perc, axis=1)
|
71
|
+
iscore = np.sum(fscore, axis=1)
|
72
|
+
|
73
|
+
return OODScoreOutput(iscore, fscore)
|
@@ -0,0 +1,238 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
import warnings
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
from numpy.typing import NDArray
|
9
|
+
|
10
|
+
from dataeval.detectors.ood import OODOutput
|
11
|
+
from dataeval.utils.data import Metadata
|
12
|
+
|
13
|
+
|
14
|
+
def _validate_keys(keys1: list[str], keys2: list[str]) -> None:
|
15
|
+
"""
|
16
|
+
Raises error when two lists are not equivalent including ordering
|
17
|
+
|
18
|
+
Parameters
|
19
|
+
----------
|
20
|
+
keys1 : list of strings
|
21
|
+
List of strings to compare
|
22
|
+
keys2 : list of strings
|
23
|
+
List of strings to compare
|
24
|
+
|
25
|
+
Raises
|
26
|
+
------
|
27
|
+
ValueError
|
28
|
+
If lists do not have the same values, value counts, or ordering
|
29
|
+
"""
|
30
|
+
|
31
|
+
if keys1 != keys2:
|
32
|
+
raise ValueError(f"Metadata keys must be identical, got {keys1} and {keys2}")
|
33
|
+
|
34
|
+
|
35
|
+
def _validate_factors_and_data(factors: list[str], data: NDArray) -> None:
|
36
|
+
"""
|
37
|
+
Raises error when the number of factors and number of rows do not match
|
38
|
+
|
39
|
+
Parameters
|
40
|
+
----------
|
41
|
+
factors : list of strings
|
42
|
+
List of factor names of size N
|
43
|
+
data : NDArray
|
44
|
+
Array of values with shape (M, N)
|
45
|
+
|
46
|
+
Raises
|
47
|
+
------
|
48
|
+
ValueError
|
49
|
+
If the length of factors does not equal the length of the transposed data
|
50
|
+
"""
|
51
|
+
if len(factors) != len(data.T):
|
52
|
+
raise ValueError(f"Factors and data have mismatched lengths. Got {len(factors)} and {len(data.T)}")
|
53
|
+
|
54
|
+
|
55
|
+
def _combine_metadata(metadata_1: Metadata, metadata_2: Metadata) -> tuple[list[str], list[NDArray], list[NDArray]]:
|
56
|
+
"""
|
57
|
+
Combines the factor names and data arrays of metadata_1 and metadata_2 when the names
|
58
|
+
match exactly and data has the same number of columns (factors).
|
59
|
+
|
60
|
+
Parameters
|
61
|
+
----------
|
62
|
+
metadata_1 : Metadata
|
63
|
+
The set of factor names used as reference to determine the correct factor names and length of data
|
64
|
+
metadata_2 : Metadata
|
65
|
+
The compared set of factor names and data that must match metadata_1
|
66
|
+
|
67
|
+
Returns
|
68
|
+
-------
|
69
|
+
list[str]
|
70
|
+
The combined discrete and continuous factor names in that order.
|
71
|
+
list[NDArray]
|
72
|
+
Combined discrete and continuous data of metadata_1
|
73
|
+
list[NDArray]
|
74
|
+
Combined discrete and continuous data of metadata_2
|
75
|
+
|
76
|
+
Raises
|
77
|
+
------
|
78
|
+
ValueError
|
79
|
+
If keys do not match in metadata_1 and metadata_2
|
80
|
+
ValueError
|
81
|
+
If the length of keys do not match the length of the data
|
82
|
+
"""
|
83
|
+
factor_names: list[str] = []
|
84
|
+
m1_data: list[NDArray] = []
|
85
|
+
m2_data: list[NDArray] = []
|
86
|
+
|
87
|
+
# Both metadata must have the same number of factors (cols), but not necessarily samples (row)
|
88
|
+
if metadata_1.total_num_factors != metadata_2.total_num_factors:
|
89
|
+
raise ValueError(
|
90
|
+
f"Number of factors differs between metadata_1 ({metadata_1.total_num_factors}) "
|
91
|
+
f"and metadata_2 ({metadata_2.total_num_factors})"
|
92
|
+
)
|
93
|
+
|
94
|
+
# Validate and attach discrete data
|
95
|
+
if metadata_1.discrete_factor_names:
|
96
|
+
_validate_keys(metadata_1.discrete_factor_names, metadata_2.discrete_factor_names)
|
97
|
+
_validate_factors_and_data(metadata_1.discrete_factor_names, metadata_1.discrete_data)
|
98
|
+
|
99
|
+
factor_names.extend(metadata_1.discrete_factor_names)
|
100
|
+
m1_data.append(metadata_1.discrete_data)
|
101
|
+
m2_data.append(metadata_2.discrete_data)
|
102
|
+
|
103
|
+
# Validate and attach continuous data
|
104
|
+
if metadata_1.continuous_factor_names:
|
105
|
+
_validate_keys(metadata_1.continuous_factor_names, metadata_2.continuous_factor_names)
|
106
|
+
_validate_factors_and_data(metadata_1.continuous_factor_names, metadata_1.continuous_data)
|
107
|
+
|
108
|
+
factor_names.extend(metadata_1.continuous_factor_names)
|
109
|
+
m1_data.append(metadata_1.continuous_data)
|
110
|
+
m2_data.append(metadata_2.continuous_data)
|
111
|
+
|
112
|
+
# Turns list of discrete and continuous into one array
|
113
|
+
return factor_names, m1_data, m2_data
|
114
|
+
|
115
|
+
|
116
|
+
def _calc_median_deviations(reference: NDArray, test: NDArray) -> NDArray:
|
117
|
+
"""
|
118
|
+
Calculates deviations of the test data from the median of the reference data
|
119
|
+
|
120
|
+
Parameters
|
121
|
+
----------
|
122
|
+
reference : NDArray
|
123
|
+
Reference values of shape (samples, factors)
|
124
|
+
test : NDArray
|
125
|
+
Incoming values where each sample's factors will be compared to the median of
|
126
|
+
the reference set corresponding factors
|
127
|
+
|
128
|
+
Returns
|
129
|
+
-------
|
130
|
+
NDArray
|
131
|
+
Scaled positive and negative deviations of the test data from the reference.
|
132
|
+
|
133
|
+
Note
|
134
|
+
----
|
135
|
+
All return values are in the range [0, pos_inf]
|
136
|
+
"""
|
137
|
+
|
138
|
+
# Take median over samples (rows)
|
139
|
+
ref_median = np.median(reference, axis=0) # (F, )
|
140
|
+
|
141
|
+
# Shift reference and test distributions by reference
|
142
|
+
ref_dev = reference - ref_median # (S, F) - F
|
143
|
+
test_dev = test - ref_median # (S_t, F) - F
|
144
|
+
|
145
|
+
# Separate positive and negative distributions
|
146
|
+
# Fills with nans to keep shape in both 1-D and N-D matrices
|
147
|
+
pdev = np.where(ref_dev > 0, ref_dev, np.nan) # (S, F)
|
148
|
+
ndev = np.where(ref_dev < 0, ref_dev, np.nan) # (S, F)
|
149
|
+
|
150
|
+
# Calculate middle of positive and negative distributions per feature
|
151
|
+
pscale = np.nanmedian(pdev, axis=0) # (F, )
|
152
|
+
nscale = np.abs(np.nanmedian(ndev, axis=0)) # (F, )
|
153
|
+
|
154
|
+
# Replace 0's for division. Negatives should not happen
|
155
|
+
pscale = np.where(pscale > 0, pscale, 1.0) # (F, )
|
156
|
+
nscale = np.where(nscale > 0, nscale, 1.0) # (F, )
|
157
|
+
|
158
|
+
# Scales positive values by positive scale and negative values by negative
|
159
|
+
return np.abs(np.where(test_dev >= 0, test_dev / pscale, test_dev / nscale)) # (S_t, F)
|
160
|
+
|
161
|
+
|
162
|
+
def most_deviated_factors(
|
163
|
+
metadata_1: Metadata,
|
164
|
+
metadata_2: Metadata,
|
165
|
+
ood: OODOutput,
|
166
|
+
) -> list[tuple[str, float]]:
|
167
|
+
"""
|
168
|
+
Determines greatest deviation in metadata features per out of distribution sample in metadata_2.
|
169
|
+
|
170
|
+
Parameters
|
171
|
+
----------
|
172
|
+
metadata_1 : Metadata
|
173
|
+
A reference set of Metadata containing factor names and samples
|
174
|
+
with discrete and/or continuous values per factor
|
175
|
+
metadata_2 : Metadata
|
176
|
+
The set of Metadata that is tested against the reference metadata.
|
177
|
+
This set must have the same number of features but does not require the same number of samples.
|
178
|
+
ood : OODOutput
|
179
|
+
A class output by the DataEval's OOD functions that contains which examples are OOD.
|
180
|
+
|
181
|
+
Returns
|
182
|
+
-------
|
183
|
+
list[tuple[str, float]]
|
184
|
+
An array of the factor name and deviation of the highest metadata deviation for each OOD example in metadata_2.
|
185
|
+
|
186
|
+
Notes
|
187
|
+
-----
|
188
|
+
1. Both :class:`.Metadata` inputs must have discrete and continuous data in the shape (samples, factors)
|
189
|
+
and have equivalent factor names and lengths
|
190
|
+
2. The flag at index `i` in :attr:`.OODOutput.is_ood` must correspond
|
191
|
+
directly to sample `i` of `metadata_2` being out-of-distribution from `metadata_1`
|
192
|
+
"""
|
193
|
+
|
194
|
+
ood_mask: NDArray[np.bool] = ood.is_ood
|
195
|
+
|
196
|
+
# No metadata correlated with out of distribution data
|
197
|
+
if not any(ood_mask):
|
198
|
+
return []
|
199
|
+
|
200
|
+
# Combines reference and test factor names and data if exists and match exactly
|
201
|
+
# shape -> (samples, factors)
|
202
|
+
factor_names, md_1, md_2 = _combine_metadata(
|
203
|
+
metadata_1=metadata_1,
|
204
|
+
metadata_2=metadata_2,
|
205
|
+
)
|
206
|
+
|
207
|
+
metadata_ref = np.hstack(md_1) if md_1 else np.array([])
|
208
|
+
metadata_tst = np.hstack(md_2) if md_2 else np.array([])
|
209
|
+
|
210
|
+
if len(metadata_ref) < 3:
|
211
|
+
warnings.warn(
|
212
|
+
f"At least 3 reference metadata samples are needed, got {len(metadata_ref)}",
|
213
|
+
UserWarning,
|
214
|
+
)
|
215
|
+
return []
|
216
|
+
|
217
|
+
if len(metadata_tst) != len(ood_mask):
|
218
|
+
raise ValueError(
|
219
|
+
f"ood and test metadata must have the same length, "
|
220
|
+
f"got {len(ood_mask)} and {len(metadata_tst)} respectively."
|
221
|
+
)
|
222
|
+
|
223
|
+
# Calculates deviations of all samples in m2_data
|
224
|
+
# from the median values of the corresponding index in m1_data
|
225
|
+
# Guaranteed for inputs to not be empty
|
226
|
+
deviations = _calc_median_deviations(metadata_ref, metadata_tst)
|
227
|
+
|
228
|
+
# Get most impactful factor deviation of each sample for ood samples only
|
229
|
+
deviation = np.max(deviations, axis=1)[ood_mask]
|
230
|
+
|
231
|
+
# Get indices of most impactful factors for ood samples only
|
232
|
+
max_factors = np.argmax(deviations, axis=1)[ood_mask]
|
233
|
+
|
234
|
+
# Get names of most impactful factors TODO: Find better way than np.dtype(<U4)
|
235
|
+
most_ood_factors = np.array(factor_names)[max_factors].tolist()
|
236
|
+
|
237
|
+
# List of tuples matching the factor name with its deviation
|
238
|
+
return [(factor, dev.item()) for factor, dev in zip(most_ood_factors, deviation)]
|
dataeval/metrics/__init__.py
CHANGED
@@ -7,6 +7,7 @@ __all__ = [
|
|
7
7
|
"BalanceOutput",
|
8
8
|
"CoverageOutput",
|
9
9
|
"DiversityOutput",
|
10
|
+
"LabelParityOutput",
|
10
11
|
"ParityOutput",
|
11
12
|
"balance",
|
12
13
|
"coverage",
|
@@ -15,7 +16,7 @@ __all__ = [
|
|
15
16
|
"parity",
|
16
17
|
]
|
17
18
|
|
18
|
-
from dataeval.metrics.bias.
|
19
|
-
from dataeval.metrics.bias.
|
20
|
-
from dataeval.metrics.bias.
|
21
|
-
from dataeval.metrics.bias.
|
19
|
+
from dataeval.metrics.bias._balance import BalanceOutput, balance
|
20
|
+
from dataeval.metrics.bias._coverage import CoverageOutput, coverage
|
21
|
+
from dataeval.metrics.bias._diversity import DiversityOutput, diversity
|
22
|
+
from dataeval.metrics.bias._parity import LabelParityOutput, ParityOutput, label_parity, parity
|