dataeval 0.74.2__py3-none-any.whl → 0.76.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +27 -23
- dataeval/detectors/__init__.py +2 -2
- dataeval/detectors/drift/__init__.py +14 -12
- dataeval/detectors/drift/base.py +3 -3
- dataeval/detectors/drift/cvm.py +1 -1
- dataeval/detectors/drift/ks.py +3 -2
- dataeval/detectors/drift/mmd.py +9 -7
- dataeval/detectors/drift/torch.py +12 -12
- dataeval/detectors/drift/uncertainty.py +5 -4
- dataeval/detectors/drift/updates.py +1 -1
- dataeval/detectors/linters/__init__.py +4 -4
- dataeval/detectors/linters/clusterer.py +5 -9
- dataeval/detectors/linters/duplicates.py +10 -14
- dataeval/detectors/linters/outliers.py +100 -5
- dataeval/detectors/ood/__init__.py +4 -11
- dataeval/detectors/ood/{ae_torch.py → ae.py} +6 -4
- dataeval/detectors/ood/base.py +47 -160
- dataeval/detectors/ood/metadata_ks_compare.py +34 -42
- dataeval/detectors/ood/metadata_least_likely.py +3 -3
- dataeval/detectors/ood/metadata_ood_mi.py +6 -5
- dataeval/detectors/ood/mixin.py +146 -0
- dataeval/detectors/ood/output.py +63 -0
- dataeval/interop.py +7 -6
- dataeval/{logging.py → log.py} +2 -0
- dataeval/metrics/__init__.py +3 -3
- dataeval/metrics/bias/__init__.py +10 -13
- dataeval/metrics/bias/balance.py +13 -11
- dataeval/metrics/bias/coverage.py +53 -5
- dataeval/metrics/bias/diversity.py +56 -24
- dataeval/metrics/bias/parity.py +20 -17
- dataeval/metrics/estimators/__init__.py +2 -2
- dataeval/metrics/estimators/ber.py +7 -4
- dataeval/metrics/estimators/divergence.py +4 -4
- dataeval/metrics/estimators/uap.py +4 -4
- dataeval/metrics/stats/__init__.py +19 -19
- dataeval/metrics/stats/base.py +28 -12
- dataeval/metrics/stats/boxratiostats.py +13 -14
- dataeval/metrics/stats/datasetstats.py +49 -20
- dataeval/metrics/stats/dimensionstats.py +8 -8
- dataeval/metrics/stats/hashstats.py +14 -10
- dataeval/metrics/stats/labelstats.py +94 -11
- dataeval/metrics/stats/pixelstats.py +11 -14
- dataeval/metrics/stats/visualstats.py +10 -13
- dataeval/output.py +23 -14
- dataeval/utils/__init__.py +5 -14
- dataeval/utils/dataset/__init__.py +7 -0
- dataeval/utils/{torch → dataset}/datasets.py +2 -0
- dataeval/utils/dataset/read.py +63 -0
- dataeval/utils/{split_dataset.py → dataset/split.py} +38 -30
- dataeval/utils/image.py +2 -2
- dataeval/utils/metadata.py +317 -14
- dataeval/{metrics/bias/metadata_utils.py → utils/plot.py} +91 -71
- dataeval/utils/torch/__init__.py +2 -17
- dataeval/utils/torch/gmm.py +29 -6
- dataeval/utils/torch/{utils.py → internal.py} +82 -58
- dataeval/utils/torch/models.py +10 -8
- dataeval/utils/torch/trainer.py +6 -85
- dataeval/workflows/__init__.py +2 -5
- dataeval/workflows/sufficiency.py +18 -8
- {dataeval-0.74.2.dist-info → dataeval-0.76.0.dist-info}/LICENSE.txt +2 -2
- dataeval-0.76.0.dist-info/METADATA +137 -0
- dataeval-0.76.0.dist-info/RECORD +67 -0
- dataeval/detectors/ood/base_torch.py +0 -109
- dataeval/metrics/bias/metadata_preprocessing.py +0 -285
- dataeval/utils/gmm.py +0 -26
- dataeval-0.74.2.dist-info/METADATA +0 -120
- dataeval-0.74.2.dist-info/RECORD +0 -66
- {dataeval-0.74.2.dist-info → dataeval-0.76.0.dist-info}/WHEEL +0 -0
dataeval/__init__.py
CHANGED
@@ -1,36 +1,40 @@
|
|
1
|
-
|
1
|
+
"""
|
2
|
+
DataEval provides a simple interface to characterize image data and its impact on model performance
|
3
|
+
across classification and object-detection tasks. It also provides capabilities to select and curate
|
4
|
+
datasets to test and train performant, robust, unbiased and reliable AI models and monitor for data
|
5
|
+
shifts that impact performance of deployed models.
|
6
|
+
"""
|
7
|
+
|
8
|
+
from __future__ import annotations
|
9
|
+
|
10
|
+
__all__ = ["detectors", "log", "metrics", "utils", "workflows"]
|
11
|
+
__version__ = "0.76.0"
|
2
12
|
|
3
13
|
import logging
|
4
|
-
|
14
|
+
|
15
|
+
from dataeval import detectors, metrics, utils, workflows
|
5
16
|
|
6
17
|
logging.getLogger(__name__).addHandler(logging.NullHandler())
|
7
18
|
|
8
19
|
|
9
|
-
def
|
20
|
+
def log(level: int = logging.DEBUG, handler: logging.Handler | None = None) -> None:
|
10
21
|
"""
|
11
|
-
Helper for quickly adding a StreamHandler to the logger. Useful for
|
12
|
-
|
22
|
+
Helper for quickly adding a StreamHandler to the logger. Useful for debugging.
|
23
|
+
|
24
|
+
Parameters
|
25
|
+
----------
|
26
|
+
level : int, default logging.DEBUG(10)
|
27
|
+
Set the logging level for the logger.
|
28
|
+
handler : logging.Handler, optional
|
29
|
+
Sets the logging handler for the logger if provided, otherwise logger will be
|
30
|
+
provided with a StreamHandler.
|
13
31
|
"""
|
14
32
|
import logging
|
15
33
|
|
16
34
|
logger = logging.getLogger(__name__)
|
17
|
-
handler
|
18
|
-
|
35
|
+
if handler is None:
|
36
|
+
handler = logging.StreamHandler() if handler is None else handler
|
37
|
+
handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s"))
|
19
38
|
logger.addHandler(handler)
|
20
39
|
logger.setLevel(level)
|
21
|
-
logger.debug("Added
|
22
|
-
|
23
|
-
|
24
|
-
_IS_TORCH_AVAILABLE = find_spec("torch") is not None
|
25
|
-
_IS_TORCHVISION_AVAILABLE = find_spec("torchvision") is not None
|
26
|
-
|
27
|
-
del find_spec
|
28
|
-
|
29
|
-
from dataeval import detectors, metrics # noqa: E402
|
30
|
-
|
31
|
-
__all__ = ["log_stderr", "detectors", "metrics"]
|
32
|
-
|
33
|
-
if _IS_TORCH_AVAILABLE:
|
34
|
-
from dataeval import utils, workflows
|
35
|
-
|
36
|
-
__all__ += ["utils", "workflows"]
|
40
|
+
logger.debug(f"Added logging handler {handler} to logger: {__name__}")
|
dataeval/detectors/__init__.py
CHANGED
@@ -2,19 +2,21 @@
|
|
2
2
|
:term:`Drift` detectors identify if the statistical properties of the data has changed.
|
3
3
|
"""
|
4
4
|
|
5
|
-
|
5
|
+
__all__ = [
|
6
|
+
"DriftCVM",
|
7
|
+
"DriftKS",
|
8
|
+
"DriftMMD",
|
9
|
+
"DriftMMDOutput",
|
10
|
+
"DriftOutput",
|
11
|
+
"DriftUncertainty",
|
12
|
+
"preprocess_drift",
|
13
|
+
"updates",
|
14
|
+
]
|
15
|
+
|
6
16
|
from dataeval.detectors.drift import updates
|
7
17
|
from dataeval.detectors.drift.base import DriftOutput
|
8
18
|
from dataeval.detectors.drift.cvm import DriftCVM
|
9
19
|
from dataeval.detectors.drift.ks import DriftKS
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
if _IS_TORCH_AVAILABLE:
|
14
|
-
from dataeval.detectors.drift.mmd import DriftMMD, DriftMMDOutput
|
15
|
-
from dataeval.detectors.drift.torch import preprocess_drift
|
16
|
-
from dataeval.detectors.drift.uncertainty import DriftUncertainty
|
17
|
-
|
18
|
-
__all__ += ["DriftMMD", "DriftMMDOutput", "DriftUncertainty", "preprocess_drift"]
|
19
|
-
|
20
|
-
del _IS_TORCH_AVAILABLE
|
20
|
+
from dataeval.detectors.drift.mmd import DriftMMD, DriftMMDOutput
|
21
|
+
from dataeval.detectors.drift.torch import preprocess_drift
|
22
|
+
from dataeval.detectors.drift.uncertainty import DriftUncertainty
|
dataeval/detectors/drift/base.py
CHANGED
@@ -8,7 +8,7 @@ Licensed under Apache Software License (Apache 2.0)
|
|
8
8
|
|
9
9
|
from __future__ import annotations
|
10
10
|
|
11
|
-
__all__ = [
|
11
|
+
__all__ = []
|
12
12
|
|
13
13
|
from abc import ABC, abstractmethod
|
14
14
|
from dataclasses import dataclass
|
@@ -45,7 +45,7 @@ class UpdateStrategy(ABC):
|
|
45
45
|
@dataclass(frozen=True)
|
46
46
|
class DriftBaseOutput(Output):
|
47
47
|
"""
|
48
|
-
Base output class for Drift
|
48
|
+
Base output class for Drift Detector classes
|
49
49
|
|
50
50
|
Attributes
|
51
51
|
----------
|
@@ -64,7 +64,7 @@ class DriftBaseOutput(Output):
|
|
64
64
|
@dataclass(frozen=True)
|
65
65
|
class DriftOutput(DriftBaseOutput):
|
66
66
|
"""
|
67
|
-
Output class for :class:`DriftCVM`, :class:`DriftKS`, and :class:`DriftUncertainty` drift detectors
|
67
|
+
Output class for :class:`DriftCVM`, :class:`DriftKS`, and :class:`DriftUncertainty` drift detectors.
|
68
68
|
|
69
69
|
Attributes
|
70
70
|
----------
|
dataeval/detectors/drift/cvm.py
CHANGED
dataeval/detectors/drift/ks.py
CHANGED
@@ -8,7 +8,7 @@ Licensed under Apache Software License (Apache 2.0)
|
|
8
8
|
|
9
9
|
from __future__ import annotations
|
10
10
|
|
11
|
-
__all__ = [
|
11
|
+
__all__ = []
|
12
12
|
|
13
13
|
from typing import Callable, Literal
|
14
14
|
|
@@ -22,7 +22,8 @@ from dataeval.interop import to_numpy
|
|
22
22
|
|
23
23
|
class DriftKS(BaseDriftUnivariate):
|
24
24
|
"""
|
25
|
-
:term:`Drift` detector employing the Kolmogorov-Smirnov (KS)
|
25
|
+
:term:`Drift` detector employing the :term:`Kolmogorov-Smirnov (KS) \
|
26
|
+
distribution<Kolmogorov-Smirnov (K-S) test>` test.
|
26
27
|
|
27
28
|
The KS test detects changes in the maximum distance between two data
|
28
29
|
distributions with Bonferroni or :term:`False Discovery Rate (FDR)` correction
|
dataeval/detectors/drift/mmd.py
CHANGED
@@ -8,7 +8,7 @@ Licensed under Apache Software License (Apache 2.0)
|
|
8
8
|
|
9
9
|
from __future__ import annotations
|
10
10
|
|
11
|
-
__all__ = [
|
11
|
+
__all__ = []
|
12
12
|
|
13
13
|
from dataclasses import dataclass
|
14
14
|
from typing import Callable
|
@@ -17,15 +17,16 @@ import torch
|
|
17
17
|
from numpy.typing import ArrayLike
|
18
18
|
|
19
19
|
from dataeval.detectors.drift.base import BaseDrift, DriftBaseOutput, UpdateStrategy, preprocess_x, update_x_ref
|
20
|
-
from dataeval.detectors.drift.torch import
|
20
|
+
from dataeval.detectors.drift.torch import GaussianRBF, mmd2_from_kernel_matrix
|
21
21
|
from dataeval.interop import as_numpy
|
22
22
|
from dataeval.output import set_metadata
|
23
|
+
from dataeval.utils.torch.internal import get_device
|
23
24
|
|
24
25
|
|
25
26
|
@dataclass(frozen=True)
|
26
27
|
class DriftMMDOutput(DriftBaseOutput):
|
27
28
|
"""
|
28
|
-
Output class for :class:`DriftMMD` :term:`drift<Drift>` detector
|
29
|
+
Output class for :class:`DriftMMD` :term:`drift<Drift>` detector.
|
29
30
|
|
30
31
|
Attributes
|
31
32
|
----------
|
@@ -50,7 +51,8 @@ class DriftMMDOutput(DriftBaseOutput):
|
|
50
51
|
|
51
52
|
class DriftMMD(BaseDrift):
|
52
53
|
"""
|
53
|
-
:term:`Maximum Mean Discrepancy (MMD) Drift Detection` algorithm
|
54
|
+
:term:`Maximum Mean Discrepancy (MMD) Drift Detection` algorithm \
|
55
|
+
using a permutation test.
|
54
56
|
|
55
57
|
Parameters
|
56
58
|
----------
|
@@ -109,7 +111,7 @@ class DriftMMD(BaseDrift):
|
|
109
111
|
|
110
112
|
# initialize kernel
|
111
113
|
sigma_tensor = torch.from_numpy(as_numpy(sigma)).to(self.device) if sigma is not None else None
|
112
|
-
self._kernel =
|
114
|
+
self._kernel = GaussianRBF(sigma_tensor).to(self.device)
|
113
115
|
|
114
116
|
# compute kernel matrix for the reference data
|
115
117
|
if self._infer_sigma or isinstance(sigma_tensor, torch.Tensor):
|
@@ -150,9 +152,9 @@ class DriftMMD(BaseDrift):
|
|
150
152
|
n = x.shape[0]
|
151
153
|
kernel_mat = self._kernel_matrix(x_ref, torch.from_numpy(x).to(self.device))
|
152
154
|
kernel_mat = kernel_mat - torch.diag(kernel_mat.diag()) # zero diagonal
|
153
|
-
mmd2 =
|
155
|
+
mmd2 = mmd2_from_kernel_matrix(kernel_mat, n, permute=False, zero_diag=False)
|
154
156
|
mmd2_permuted = torch.Tensor(
|
155
|
-
[
|
157
|
+
[mmd2_from_kernel_matrix(kernel_mat, n, permute=True, zero_diag=False) for _ in range(self.n_permutations)]
|
156
158
|
)
|
157
159
|
mmd2, mmd2_permuted = mmd2.detach().cpu(), mmd2_permuted.detach().cpu()
|
158
160
|
p_val = (mmd2 <= mmd2_permuted).float().mean()
|
@@ -17,10 +17,10 @@ import torch
|
|
17
17
|
import torch.nn as nn
|
18
18
|
from numpy.typing import NDArray
|
19
19
|
|
20
|
-
from dataeval.utils.torch.
|
20
|
+
from dataeval.utils.torch.internal import get_device, predict_batch
|
21
21
|
|
22
22
|
|
23
|
-
def
|
23
|
+
def mmd2_from_kernel_matrix(
|
24
24
|
kernel_mat: torch.Tensor, m: int, permute: bool = False, zero_diag: bool = True
|
25
25
|
) -> torch.Tensor:
|
26
26
|
"""
|
@@ -127,7 +127,7 @@ def _squared_pairwise_distance(
|
|
127
127
|
|
128
128
|
def sigma_median(x: torch.Tensor, y: torch.Tensor, dist: torch.Tensor) -> torch.Tensor:
|
129
129
|
"""
|
130
|
-
Bandwidth estimation using the median heuristic
|
130
|
+
Bandwidth estimation using the median heuristic `Gretton2012`
|
131
131
|
|
132
132
|
Parameters
|
133
133
|
----------
|
@@ -151,7 +151,7 @@ def sigma_median(x: torch.Tensor, y: torch.Tensor, dist: torch.Tensor) -> torch.
|
|
151
151
|
return sigma
|
152
152
|
|
153
153
|
|
154
|
-
class
|
154
|
+
class GaussianRBF(nn.Module):
|
155
155
|
"""
|
156
156
|
Gaussian RBF kernel: k(x,y) = exp(-(1/(2*sigma^2)||x-y||^2).
|
157
157
|
|
@@ -179,18 +179,18 @@ class _GaussianRBF(nn.Module):
|
|
179
179
|
) -> None:
|
180
180
|
super().__init__()
|
181
181
|
init_sigma_fn = sigma_median if init_sigma_fn is None else init_sigma_fn
|
182
|
-
self.config = {
|
182
|
+
self.config: dict[str, Any] = {
|
183
183
|
"sigma": sigma,
|
184
184
|
"trainable": trainable,
|
185
185
|
"init_sigma_fn": init_sigma_fn,
|
186
186
|
}
|
187
187
|
if sigma is None:
|
188
|
-
self.log_sigma = nn.Parameter(torch.empty(1), requires_grad=trainable)
|
189
|
-
self.init_required = True
|
188
|
+
self.log_sigma: nn.Parameter = nn.Parameter(torch.empty(1), requires_grad=trainable)
|
189
|
+
self.init_required: bool = True
|
190
190
|
else:
|
191
191
|
sigma = sigma.reshape(-1) # [Ns,]
|
192
|
-
self.log_sigma = nn.Parameter(sigma.log(), requires_grad=trainable)
|
193
|
-
self.init_required = False
|
192
|
+
self.log_sigma: nn.Parameter = nn.Parameter(sigma.log(), requires_grad=trainable)
|
193
|
+
self.init_required: bool = False
|
194
194
|
self.init_sigma_fn = init_sigma_fn
|
195
195
|
self.trainable = trainable
|
196
196
|
|
@@ -200,8 +200,8 @@ class _GaussianRBF(nn.Module):
|
|
200
200
|
|
201
201
|
def forward(
|
202
202
|
self,
|
203
|
-
x: np.ndarray | torch.Tensor,
|
204
|
-
y: np.ndarray | torch.Tensor,
|
203
|
+
x: np.ndarray[Any, Any] | torch.Tensor,
|
204
|
+
y: np.ndarray[Any, Any] | torch.Tensor,
|
205
205
|
infer_sigma: bool = False,
|
206
206
|
) -> torch.Tensor:
|
207
207
|
x, y = torch.as_tensor(x), torch.as_tensor(y)
|
@@ -213,7 +213,7 @@ class _GaussianRBF(nn.Module):
|
|
213
213
|
sigma = self.init_sigma_fn(x, y, dist)
|
214
214
|
with torch.no_grad():
|
215
215
|
self.log_sigma.copy_(sigma.log().clone())
|
216
|
-
self.init_required = False
|
216
|
+
self.init_required: bool = False
|
217
217
|
|
218
218
|
gamma = 1.0 / (2.0 * self.sigma**2) # [Ns,]
|
219
219
|
# TODO: do matrix multiplication after all?
|
@@ -8,7 +8,7 @@ Licensed under Apache Software License (Apache 2.0)
|
|
8
8
|
|
9
9
|
from __future__ import annotations
|
10
10
|
|
11
|
-
__all__ = [
|
11
|
+
__all__ = []
|
12
12
|
|
13
13
|
from functools import partial
|
14
14
|
from typing import Callable, Literal
|
@@ -20,7 +20,8 @@ from scipy.stats import entropy
|
|
20
20
|
|
21
21
|
from dataeval.detectors.drift.base import DriftOutput, UpdateStrategy
|
22
22
|
from dataeval.detectors.drift.ks import DriftKS
|
23
|
-
from dataeval.detectors.drift.torch import
|
23
|
+
from dataeval.detectors.drift.torch import preprocess_drift
|
24
|
+
from dataeval.utils.torch.internal import get_device
|
24
25
|
|
25
26
|
|
26
27
|
def classifier_uncertainty(
|
@@ -65,8 +66,8 @@ def classifier_uncertainty(
|
|
65
66
|
|
66
67
|
class DriftUncertainty:
|
67
68
|
"""
|
68
|
-
Test for a change in the number of instances falling into regions on which
|
69
|
-
|
69
|
+
Test for a change in the number of instances falling into regions on which \
|
70
|
+
the model is uncertain.
|
70
71
|
|
71
72
|
Performs a K-S test on prediction entropies.
|
72
73
|
|
@@ -2,10 +2,6 @@
|
|
2
2
|
Linters help identify potential issues in training and test data and are an important aspect of data cleaning.
|
3
3
|
"""
|
4
4
|
|
5
|
-
from dataeval.detectors.linters.clusterer import Clusterer, ClustererOutput
|
6
|
-
from dataeval.detectors.linters.duplicates import Duplicates, DuplicatesOutput
|
7
|
-
from dataeval.detectors.linters.outliers import Outliers, OutliersOutput
|
8
|
-
|
9
5
|
__all__ = [
|
10
6
|
"Clusterer",
|
11
7
|
"ClustererOutput",
|
@@ -14,3 +10,7 @@ __all__ = [
|
|
14
10
|
"Outliers",
|
15
11
|
"OutliersOutput",
|
16
12
|
]
|
13
|
+
|
14
|
+
from dataeval.detectors.linters.clusterer import Clusterer, ClustererOutput
|
15
|
+
from dataeval.detectors.linters.duplicates import Duplicates, DuplicatesOutput
|
16
|
+
from dataeval.detectors.linters.outliers import Outliers, OutliersOutput
|
@@ -1,6 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
__all__ = [
|
3
|
+
__all__ = []
|
4
4
|
|
5
5
|
from dataclasses import dataclass
|
6
6
|
from typing import Any, Iterable, NamedTuple, cast
|
@@ -18,7 +18,7 @@ from dataeval.utils.shared import flatten
|
|
18
18
|
@dataclass(frozen=True)
|
19
19
|
class ClustererOutput(Output):
|
20
20
|
"""
|
21
|
-
Output class for :class:`Clusterer` lint detector
|
21
|
+
Output class for :class:`Clusterer` lint detector.
|
22
22
|
|
23
23
|
Attributes
|
24
24
|
----------
|
@@ -131,7 +131,8 @@ class _ClusterMergeEntry:
|
|
131
131
|
|
132
132
|
class Clusterer:
|
133
133
|
"""
|
134
|
-
Uses hierarchical clustering to flag dataset properties of interest like
|
134
|
+
Uses hierarchical clustering to flag dataset properties of interest like outliers \
|
135
|
+
and :term:`duplicates<Duplicates>`.
|
135
136
|
|
136
137
|
Parameters
|
137
138
|
----------
|
@@ -147,12 +148,6 @@ class Clusterer:
|
|
147
148
|
----
|
148
149
|
The Clusterer works best when the length of the feature dimension, P, is less than 500.
|
149
150
|
If flattening a CxHxW image results in a dimension larger than 500, then it is recommended to reduce the dimensions.
|
150
|
-
|
151
|
-
Example
|
152
|
-
-------
|
153
|
-
Initialize the Clusterer class:
|
154
|
-
|
155
|
-
>>> cluster = Clusterer(dataset)
|
156
151
|
"""
|
157
152
|
|
158
153
|
def __init__(self, dataset: ArrayLike) -> None:
|
@@ -506,6 +501,7 @@ class Clusterer:
|
|
506
501
|
|
507
502
|
Example
|
508
503
|
-------
|
504
|
+
>>> cluster = Clusterer(clusterer_images)
|
509
505
|
>>> cluster.evaluate()
|
510
506
|
ClustererOutput(outliers=[18, 21, 34, 35, 45], potential_outliers=[13, 15, 42], duplicates=[[9, 24], [23, 48]], potential_duplicates=[[1, 11]])
|
511
507
|
""" # noqa: E501
|
@@ -1,6 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
__all__ = [
|
3
|
+
__all__ = []
|
4
4
|
|
5
5
|
from dataclasses import dataclass
|
6
6
|
from typing import Generic, Iterable, Sequence, TypeVar, overload
|
@@ -19,7 +19,7 @@ TIndexCollection = TypeVar("TIndexCollection", DuplicateGroup, DatasetDuplicateG
|
|
19
19
|
@dataclass(frozen=True)
|
20
20
|
class DuplicatesOutput(Generic[TIndexCollection], Output):
|
21
21
|
"""
|
22
|
-
Output class for :class:`Duplicates` lint detector
|
22
|
+
Output class for :class:`Duplicates` lint detector.
|
23
23
|
|
24
24
|
Attributes
|
25
25
|
----------
|
@@ -39,8 +39,8 @@ class DuplicatesOutput(Generic[TIndexCollection], Output):
|
|
39
39
|
|
40
40
|
class Duplicates:
|
41
41
|
"""
|
42
|
-
Finds the duplicate images in a dataset using xxhash for exact
|
43
|
-
and pchash for near duplicates
|
42
|
+
Finds the duplicate images in a dataset using xxhash for exact \
|
43
|
+
:term:`duplicates<Duplicates>` and pchash for near duplicates.
|
44
44
|
|
45
45
|
Attributes
|
46
46
|
----------
|
@@ -51,13 +51,6 @@ class Duplicates:
|
|
51
51
|
----------
|
52
52
|
only_exact : bool, default False
|
53
53
|
Only inspect the dataset for exact image matches
|
54
|
-
|
55
|
-
Example
|
56
|
-
-------
|
57
|
-
Initialize the Duplicates class:
|
58
|
-
|
59
|
-
>>> all_dupes = Duplicates()
|
60
|
-
>>> exact_dupes = Duplicates(only_exact=True)
|
61
54
|
"""
|
62
55
|
|
63
56
|
def __init__(self, only_exact: bool = False) -> None:
|
@@ -73,7 +66,8 @@ class Duplicates:
|
|
73
66
|
if not self.only_exact:
|
74
67
|
near_dict: dict[int, list] = {}
|
75
68
|
for i, value in enumerate(stats["pchash"]):
|
76
|
-
|
69
|
+
if value:
|
70
|
+
near_dict.setdefault(value, []).append(i)
|
77
71
|
near = [sorted(v) for v in near_dict.values() if len(v) > 1 and not any(set(v).issubset(x) for x in exact)]
|
78
72
|
else:
|
79
73
|
near = []
|
@@ -98,7 +92,7 @@ class Duplicates:
|
|
98
92
|
|
99
93
|
Parameters
|
100
94
|
----------
|
101
|
-
|
95
|
+
hashes : HashStatsOutput | Sequence[HashStatsOutput]
|
102
96
|
The output(s) from a hashstats analysis
|
103
97
|
|
104
98
|
Returns
|
@@ -112,6 +106,7 @@ class Duplicates:
|
|
112
106
|
|
113
107
|
Example
|
114
108
|
-------
|
109
|
+
>>> exact_dupes = Duplicates(only_exact=True)
|
115
110
|
>>> exact_dupes.from_stats([hashes1, hashes2])
|
116
111
|
DuplicatesOutput(exact=[{0: [3, 20]}, {0: [16], 1: [12]}], near=[])
|
117
112
|
"""
|
@@ -159,7 +154,8 @@ class Duplicates:
|
|
159
154
|
|
160
155
|
Example
|
161
156
|
-------
|
162
|
-
>>> all_dupes
|
157
|
+
>>> all_dupes = Duplicates()
|
158
|
+
>>> all_dupes.evaluate(duplicate_images)
|
163
159
|
DuplicatesOutput(exact=[[3, 20], [16, 37]], near=[[3, 20, 22], [12, 18], [13, 36], [14, 31], [17, 27], [19, 38, 47]])
|
164
160
|
""" # noqa: E501
|
165
161
|
self.stats = hashstats(data)
|
@@ -1,7 +1,8 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
__all__ = [
|
3
|
+
__all__ = []
|
4
4
|
|
5
|
+
# import contextlib
|
5
6
|
from dataclasses import dataclass
|
6
7
|
from typing import Generic, Iterable, Literal, Sequence, TypeVar, Union, overload
|
7
8
|
|
@@ -12,19 +13,78 @@ from dataeval.detectors.linters.merged_stats import combine_stats, get_dataset_s
|
|
12
13
|
from dataeval.metrics.stats.base import BOX_COUNT, SOURCE_INDEX
|
13
14
|
from dataeval.metrics.stats.datasetstats import DatasetStatsOutput, datasetstats
|
14
15
|
from dataeval.metrics.stats.dimensionstats import DimensionStatsOutput
|
16
|
+
from dataeval.metrics.stats.labelstats import LabelStatsOutput
|
15
17
|
from dataeval.metrics.stats.pixelstats import PixelStatsOutput
|
16
18
|
from dataeval.metrics.stats.visualstats import VisualStatsOutput
|
17
19
|
from dataeval.output import Output, set_metadata
|
18
20
|
|
21
|
+
# with contextlib.suppress(ImportError):
|
22
|
+
# import pandas as pd
|
23
|
+
|
24
|
+
|
19
25
|
IndexIssueMap = dict[int, dict[str, float]]
|
20
26
|
OutlierStatsOutput = Union[DimensionStatsOutput, PixelStatsOutput, VisualStatsOutput]
|
21
27
|
TIndexIssueMap = TypeVar("TIndexIssueMap", IndexIssueMap, list[IndexIssueMap])
|
22
28
|
|
23
29
|
|
30
|
+
def _reorganize_by_class_and_metric(result, lstats):
|
31
|
+
"""Flip result from grouping by image to grouping by class and metric"""
|
32
|
+
metrics = {}
|
33
|
+
class_wise = {label: {} for label in lstats.image_indices_per_label}
|
34
|
+
|
35
|
+
# Group metrics and calculate class-wise counts
|
36
|
+
for img, group in result.items():
|
37
|
+
for extreme in group:
|
38
|
+
metrics.setdefault(extreme, []).append(img)
|
39
|
+
for label, images in lstats.image_indices_per_label.items():
|
40
|
+
if img in images:
|
41
|
+
class_wise[label][extreme] = class_wise[label].get(extreme, 0) + 1
|
42
|
+
|
43
|
+
return metrics, class_wise
|
44
|
+
|
45
|
+
|
46
|
+
def _create_table(metrics, class_wise):
|
47
|
+
"""Create table for displaying the results"""
|
48
|
+
max_class_length = max(len(str(label)) for label in class_wise) + 2
|
49
|
+
max_total = max(len(metrics[group]) for group in metrics) + 2
|
50
|
+
|
51
|
+
table_header = " | ".join(
|
52
|
+
[f"{'Class':>{max_class_length}}"]
|
53
|
+
+ [f"{group:^{max(5, len(str(group))) + 2}}" for group in sorted(metrics.keys())]
|
54
|
+
+ [f"{'Total':<{max_total}}"]
|
55
|
+
)
|
56
|
+
table_rows = []
|
57
|
+
|
58
|
+
for class_cat, results in class_wise.items():
|
59
|
+
table_value = [f"{class_cat:>{max_class_length}}"]
|
60
|
+
total = 0
|
61
|
+
for group in sorted(metrics.keys()):
|
62
|
+
count = results.get(group, 0)
|
63
|
+
table_value.append(f"{count:^{max(5, len(str(group))) + 2}}")
|
64
|
+
total += count
|
65
|
+
table_value.append(f"{total:^{max_total}}")
|
66
|
+
table_rows.append(" | ".join(table_value))
|
67
|
+
|
68
|
+
table = [table_header] + table_rows
|
69
|
+
return table
|
70
|
+
|
71
|
+
|
72
|
+
# def _create_pandas_dataframe(class_wise):
|
73
|
+
# """Create data for pandas dataframe"""
|
74
|
+
# data = []
|
75
|
+
# for label, metrics_dict in class_wise.items():
|
76
|
+
# row = {"Class": label}
|
77
|
+
# total = sum(metrics_dict.values())
|
78
|
+
# row.update(metrics_dict) # Add metric counts
|
79
|
+
# row["Total"] = total
|
80
|
+
# data.append(row)
|
81
|
+
# return data
|
82
|
+
|
83
|
+
|
24
84
|
@dataclass(frozen=True)
|
25
85
|
class OutliersOutput(Generic[TIndexIssueMap], Output):
|
26
86
|
"""
|
27
|
-
Output class for :class:`Outliers` lint detector
|
87
|
+
Output class for :class:`Outliers` lint detector.
|
28
88
|
|
29
89
|
Attributes
|
30
90
|
----------
|
@@ -45,6 +105,39 @@ class OutliersOutput(Generic[TIndexIssueMap], Output):
|
|
45
105
|
else:
|
46
106
|
return sum(len(d) for d in self.issues)
|
47
107
|
|
108
|
+
def to_table(self, labelstats: LabelStatsOutput) -> str:
|
109
|
+
if isinstance(self.issues, dict):
|
110
|
+
metrics, classwise = _reorganize_by_class_and_metric(self.issues, labelstats)
|
111
|
+
listed_table = _create_table(metrics, classwise)
|
112
|
+
table = "\n".join(listed_table)
|
113
|
+
else:
|
114
|
+
outertable = []
|
115
|
+
for d in self.issues:
|
116
|
+
metrics, classwise = _reorganize_by_class_and_metric(d, labelstats)
|
117
|
+
listed_table = _create_table(metrics, classwise)
|
118
|
+
str_table = "\n".join(listed_table)
|
119
|
+
outertable.append(str_table)
|
120
|
+
table = "\n\n".join(outertable)
|
121
|
+
return table
|
122
|
+
|
123
|
+
# def to_dataframe(self, labelstats: LabelStatsOutput) -> pd.DataFrame:
|
124
|
+
# import pandas as pd
|
125
|
+
|
126
|
+
# if isinstance(self.issues, dict):
|
127
|
+
# _, classwise = _reorganize_by_class_and_metric(self.issues, labelstats)
|
128
|
+
# data = _create_pandas_dataframe(classwise)
|
129
|
+
# df = pd.DataFrame(data)
|
130
|
+
# else:
|
131
|
+
# df_list = []
|
132
|
+
# for i, d in enumerate(self.issues):
|
133
|
+
# _, classwise = _reorganize_by_class_and_metric(d, labelstats)
|
134
|
+
# data = _create_pandas_dataframe(classwise)
|
135
|
+
# single_df = pd.DataFrame(data)
|
136
|
+
# single_df["Dataset"] = i
|
137
|
+
# df_list.append(single_df)
|
138
|
+
# df = pd.concat(df_list)
|
139
|
+
# return df
|
140
|
+
|
48
141
|
|
49
142
|
def _get_outlier_mask(
|
50
143
|
values: NDArray, method: Literal["zscore", "modzscore", "iqr"], threshold: float | None
|
@@ -71,7 +164,7 @@ def _get_outlier_mask(
|
|
71
164
|
|
72
165
|
class Outliers:
|
73
166
|
r"""
|
74
|
-
Calculates statistical
|
167
|
+
Calculates statistical outliers of a dataset using various statistical tests applied to each image.
|
75
168
|
|
76
169
|
Parameters
|
77
170
|
----------
|
@@ -164,7 +257,7 @@ class Outliers:
|
|
164
257
|
self, stats: OutlierStatsOutput | DatasetStatsOutput | Sequence[OutlierStatsOutput]
|
165
258
|
) -> OutliersOutput[IndexIssueMap] | OutliersOutput[list[IndexIssueMap]]:
|
166
259
|
"""
|
167
|
-
Returns indices of Outliers with the issues identified for each
|
260
|
+
Returns indices of Outliers with the issues identified for each.
|
168
261
|
|
169
262
|
Parameters
|
170
263
|
----------
|
@@ -188,6 +281,7 @@ class Outliers:
|
|
188
281
|
-------
|
189
282
|
Evaluate the dataset:
|
190
283
|
|
284
|
+
>>> outliers = Outliers(outlier_method="zscore", outlier_threshold=3.5)
|
191
285
|
>>> results = outliers.from_stats([stats1, stats2])
|
192
286
|
>>> len(results)
|
193
287
|
2
|
@@ -248,7 +342,8 @@ class Outliers:
|
|
248
342
|
-------
|
249
343
|
Evaluate the dataset:
|
250
344
|
|
251
|
-
>>>
|
345
|
+
>>> outliers = Outliers(outlier_method="zscore", outlier_threshold=3.5)
|
346
|
+
>>> results = outliers.evaluate(outlier_images)
|
252
347
|
>>> list(results.issues)
|
253
348
|
[10, 12]
|
254
349
|
>>> results.issues[10]
|
@@ -1,15 +1,8 @@
|
|
1
1
|
"""
|
2
|
-
Out-of-distribution (OOD)
|
2
|
+
Out-of-distribution (OOD) detectors identify data that is different from the data used to train a particular model.
|
3
3
|
"""
|
4
4
|
|
5
|
-
|
6
|
-
from dataeval.detectors.ood.base import OODOutput, OODScoreOutput
|
5
|
+
__all__ = ["OODOutput", "OODScoreOutput", "OOD_AE"]
|
7
6
|
|
8
|
-
|
9
|
-
|
10
|
-
if _IS_TORCH_AVAILABLE:
|
11
|
-
from dataeval.detectors.ood.ae_torch import OOD_AE
|
12
|
-
|
13
|
-
__all__ += ["OOD_AE"]
|
14
|
-
|
15
|
-
del _IS_TORCH_AVAILABLE
|
7
|
+
from dataeval.detectors.ood.ae import OOD_AE
|
8
|
+
from dataeval.detectors.ood.output import OODOutput, OODScoreOutput
|