dataeval 0.81.0__py3-none-any.whl → 0.82.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +1 -1
- dataeval/config.py +68 -11
- dataeval/detectors/drift/__init__.py +2 -2
- dataeval/detectors/drift/_base.py +8 -64
- dataeval/detectors/drift/_mmd.py +12 -38
- dataeval/detectors/drift/_torch.py +7 -7
- dataeval/detectors/drift/_uncertainty.py +6 -5
- dataeval/detectors/drift/updates.py +20 -3
- dataeval/detectors/linters/__init__.py +3 -2
- dataeval/detectors/linters/duplicates.py +14 -46
- dataeval/detectors/linters/outliers.py +25 -159
- dataeval/detectors/ood/__init__.py +1 -1
- dataeval/detectors/ood/ae.py +6 -5
- dataeval/detectors/ood/base.py +2 -2
- dataeval/detectors/ood/metadata_ood_mi.py +4 -6
- dataeval/detectors/ood/mixin.py +3 -4
- dataeval/detectors/ood/vae.py +3 -2
- dataeval/metadata/__init__.py +2 -1
- dataeval/metadata/_distance.py +134 -0
- dataeval/metadata/_ood.py +30 -49
- dataeval/metadata/_utils.py +44 -0
- dataeval/metrics/bias/__init__.py +5 -4
- dataeval/metrics/bias/_balance.py +17 -149
- dataeval/metrics/bias/_coverage.py +4 -106
- dataeval/metrics/bias/_diversity.py +12 -107
- dataeval/metrics/bias/_parity.py +7 -71
- dataeval/metrics/estimators/__init__.py +5 -4
- dataeval/metrics/estimators/_ber.py +2 -20
- dataeval/metrics/estimators/_clusterer.py +1 -61
- dataeval/metrics/estimators/_divergence.py +2 -19
- dataeval/metrics/estimators/_uap.py +2 -16
- dataeval/metrics/stats/__init__.py +15 -12
- dataeval/metrics/stats/_base.py +41 -128
- dataeval/metrics/stats/_boxratiostats.py +13 -13
- dataeval/metrics/stats/_dimensionstats.py +17 -58
- dataeval/metrics/stats/_hashstats.py +19 -35
- dataeval/metrics/stats/_imagestats.py +94 -0
- dataeval/metrics/stats/_labelstats.py +42 -121
- dataeval/metrics/stats/_pixelstats.py +19 -51
- dataeval/metrics/stats/_visualstats.py +19 -51
- dataeval/outputs/__init__.py +57 -0
- dataeval/outputs/_base.py +182 -0
- dataeval/outputs/_bias.py +381 -0
- dataeval/outputs/_drift.py +83 -0
- dataeval/outputs/_estimators.py +114 -0
- dataeval/outputs/_linters.py +186 -0
- dataeval/outputs/_metadata.py +54 -0
- dataeval/{detectors/ood/output.py → outputs/_ood.py} +22 -22
- dataeval/outputs/_stats.py +393 -0
- dataeval/outputs/_utils.py +44 -0
- dataeval/outputs/_workflows.py +364 -0
- dataeval/typing.py +187 -7
- dataeval/utils/_method.py +1 -5
- dataeval/utils/_plot.py +2 -2
- dataeval/utils/data/__init__.py +5 -1
- dataeval/utils/data/_dataset.py +217 -0
- dataeval/utils/data/_embeddings.py +12 -14
- dataeval/utils/data/_images.py +30 -27
- dataeval/utils/data/_metadata.py +28 -11
- dataeval/utils/data/_selection.py +25 -22
- dataeval/utils/data/_split.py +5 -29
- dataeval/utils/data/_targets.py +14 -2
- dataeval/utils/data/datasets/_base.py +5 -5
- dataeval/utils/data/datasets/_cifar10.py +1 -1
- dataeval/utils/data/datasets/_milco.py +1 -1
- dataeval/utils/data/datasets/_mnist.py +1 -1
- dataeval/utils/data/datasets/_ships.py +1 -1
- dataeval/utils/data/{_types.py → datasets/_types.py} +10 -16
- dataeval/utils/data/datasets/_voc.py +1 -1
- dataeval/utils/data/selections/_classfilter.py +4 -5
- dataeval/utils/data/selections/_indices.py +2 -2
- dataeval/utils/data/selections/_limit.py +2 -2
- dataeval/utils/data/selections/_reverse.py +2 -2
- dataeval/utils/data/selections/_shuffle.py +2 -2
- dataeval/utils/torch/_internal.py +5 -5
- dataeval/utils/torch/trainer.py +8 -8
- dataeval/workflows/__init__.py +2 -1
- dataeval/workflows/sufficiency.py +6 -342
- {dataeval-0.81.0.dist-info → dataeval-0.82.1.dist-info}/METADATA +2 -2
- dataeval-0.82.1.dist-info/RECORD +105 -0
- dataeval/_output.py +0 -137
- dataeval/detectors/ood/metadata_ks_compare.py +0 -129
- dataeval/metrics/stats/_datasetstats.py +0 -198
- dataeval-0.81.0.dist-info/RECORD +0 -94
- {dataeval-0.81.0.dist-info → dataeval-0.82.1.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.81.0.dist-info → dataeval-0.82.1.dist-info}/WHEEL +0 -0
dataeval/__init__.py
CHANGED
dataeval/config.py
CHANGED
@@ -4,36 +4,61 @@ Global configuration settings for DataEval.
|
|
4
4
|
|
5
5
|
from __future__ import annotations
|
6
6
|
|
7
|
-
__all__ = ["get_device", "set_device", "get_max_processes", "set_max_processes"]
|
7
|
+
__all__ = ["get_device", "set_device", "get_max_processes", "set_max_processes", "DeviceLike"]
|
8
8
|
|
9
|
+
import sys
|
10
|
+
from typing import Union
|
11
|
+
|
12
|
+
if sys.version_info >= (3, 10):
|
13
|
+
from typing import TypeAlias
|
14
|
+
else:
|
15
|
+
from typing_extensions import TypeAlias
|
16
|
+
|
17
|
+
import numpy as np
|
9
18
|
import torch
|
10
|
-
from torch import device
|
11
19
|
|
12
|
-
_device: device | None = None
|
20
|
+
_device: torch.device | None = None
|
13
21
|
_processes: int | None = None
|
22
|
+
_seed: int | None = None
|
23
|
+
|
24
|
+
DeviceLike: TypeAlias = Union[int, str, tuple[str, int], torch.device]
|
25
|
+
"""
|
26
|
+
Type alias for types that are acceptable for specifying a torch.device.
|
27
|
+
|
28
|
+
See Also
|
29
|
+
--------
|
30
|
+
`torch.device <https://pytorch.org/docs/stable/tensor_attributes.html#torch.device>`_
|
31
|
+
"""
|
32
|
+
|
14
33
|
|
34
|
+
def _todevice(device: DeviceLike) -> torch.device:
|
35
|
+
return torch.device(*device) if isinstance(device, tuple) else torch.device(device)
|
15
36
|
|
16
|
-
|
37
|
+
|
38
|
+
def set_device(device: DeviceLike) -> None:
|
17
39
|
"""
|
18
40
|
Sets the default device to use when executing against a PyTorch backend.
|
19
41
|
|
20
42
|
Parameters
|
21
43
|
----------
|
22
|
-
device :
|
23
|
-
The default device to use. See
|
24
|
-
|
44
|
+
device : DeviceLike
|
45
|
+
The default device to use. See documentation for more information.
|
46
|
+
|
47
|
+
See Also
|
48
|
+
--------
|
49
|
+
`torch.device <https://pytorch.org/docs/stable/tensor_attributes.html#torch.device>`_
|
25
50
|
"""
|
26
51
|
global _device
|
27
|
-
_device =
|
52
|
+
_device = _todevice(device)
|
28
53
|
|
29
54
|
|
30
|
-
def get_device(override:
|
55
|
+
def get_device(override: DeviceLike | None = None) -> torch.device:
|
31
56
|
"""
|
32
57
|
Returns the PyTorch device to use.
|
33
58
|
|
34
59
|
Parameters
|
35
60
|
----------
|
36
|
-
override :
|
61
|
+
override : DeviceLike or None, default None
|
37
62
|
The user specified override if provided, otherwise returns the default device.
|
38
63
|
|
39
64
|
Returns
|
@@ -44,7 +69,7 @@ def get_device(override: str | device | int | None = None) -> torch.device:
|
|
44
69
|
global _device
|
45
70
|
return torch.get_default_device() if _device is None else _device
|
46
71
|
else:
|
47
|
-
return
|
72
|
+
return _todevice(override)
|
48
73
|
|
49
74
|
|
50
75
|
def set_max_processes(processes: int | None) -> None:
|
@@ -75,3 +100,35 @@ def get_max_processes() -> int | None:
|
|
75
100
|
"""
|
76
101
|
global _processes
|
77
102
|
return _processes
|
103
|
+
|
104
|
+
|
105
|
+
def set_seed(seed: int | None, all_generators: bool = False) -> None:
|
106
|
+
"""
|
107
|
+
Sets the seed for use by classes that allow for a random state or seed.
|
108
|
+
|
109
|
+
Parameters
|
110
|
+
----------
|
111
|
+
seed : int or None
|
112
|
+
The seed to use.
|
113
|
+
all_generators : bool, default False
|
114
|
+
Whether to set the seed for all generators, including NumPy and PyTorch.
|
115
|
+
"""
|
116
|
+
global _seed
|
117
|
+
_seed = seed
|
118
|
+
|
119
|
+
if all_generators:
|
120
|
+
np.random.seed(seed)
|
121
|
+
torch.manual_seed(seed)
|
122
|
+
|
123
|
+
|
124
|
+
def get_seed() -> int | None:
|
125
|
+
"""
|
126
|
+
Returns the seed for random state or seed.
|
127
|
+
|
128
|
+
Returns
|
129
|
+
-------
|
130
|
+
int or None
|
131
|
+
The seed to use.
|
132
|
+
"""
|
133
|
+
global _seed
|
134
|
+
return _seed
|
@@ -14,9 +14,9 @@ __all__ = [
|
|
14
14
|
]
|
15
15
|
|
16
16
|
from dataeval.detectors.drift import updates
|
17
|
-
from dataeval.detectors.drift._base import DriftOutput
|
18
17
|
from dataeval.detectors.drift._cvm import DriftCVM
|
19
18
|
from dataeval.detectors.drift._ks import DriftKS
|
20
|
-
from dataeval.detectors.drift._mmd import DriftMMD
|
19
|
+
from dataeval.detectors.drift._mmd import DriftMMD
|
21
20
|
from dataeval.detectors.drift._torch import preprocess_drift
|
22
21
|
from dataeval.detectors.drift._uncertainty import DriftUncertainty
|
22
|
+
from dataeval.outputs._drift import DriftMMDOutput, DriftOutput
|
@@ -11,84 +11,28 @@ from __future__ import annotations
|
|
11
11
|
__all__ = []
|
12
12
|
|
13
13
|
import math
|
14
|
-
from abc import
|
15
|
-
from dataclasses import dataclass
|
14
|
+
from abc import abstractmethod
|
16
15
|
from functools import wraps
|
17
|
-
from typing import Any, Callable, Literal, TypeVar
|
16
|
+
from typing import Any, Callable, Literal, Protocol, TypeVar, runtime_checkable
|
18
17
|
|
19
18
|
import numpy as np
|
20
19
|
from numpy.typing import NDArray
|
21
20
|
|
22
|
-
from dataeval.
|
21
|
+
from dataeval.outputs import DriftOutput
|
22
|
+
from dataeval.outputs._base import set_metadata
|
23
23
|
from dataeval.typing import Array, ArrayLike
|
24
24
|
from dataeval.utils._array import as_numpy, to_numpy
|
25
25
|
|
26
26
|
R = TypeVar("R")
|
27
27
|
|
28
28
|
|
29
|
-
|
29
|
+
@runtime_checkable
|
30
|
+
class UpdateStrategy(Protocol):
|
30
31
|
"""
|
31
|
-
|
32
|
-
|
33
|
-
Parameters
|
34
|
-
----------
|
35
|
-
n : int
|
36
|
-
Update with last n instances seen by the detector.
|
37
|
-
"""
|
38
|
-
|
39
|
-
def __init__(self, n: int) -> None:
|
40
|
-
self.n = n
|
41
|
-
|
42
|
-
@abstractmethod
|
43
|
-
def __call__(self, x_ref: NDArray[Any], x: NDArray[Any], count: int) -> NDArray[Any]:
|
44
|
-
"""Abstract implementation of update strategy"""
|
45
|
-
|
46
|
-
|
47
|
-
@dataclass(frozen=True)
|
48
|
-
class DriftBaseOutput(Output):
|
49
|
-
"""
|
50
|
-
Base output class for Drift Detector classes
|
51
|
-
"""
|
52
|
-
|
53
|
-
drifted: bool
|
54
|
-
threshold: float
|
55
|
-
p_val: float
|
56
|
-
distance: float
|
57
|
-
|
58
|
-
|
59
|
-
@dataclass(frozen=True)
|
60
|
-
class DriftOutput(DriftBaseOutput):
|
61
|
-
"""
|
62
|
-
Output class for :class:`.DriftCVM`, :class:`.DriftKS`, and :class:`.DriftUncertainty` drift detectors.
|
63
|
-
|
64
|
-
Attributes
|
65
|
-
----------
|
66
|
-
drifted : bool
|
67
|
-
:term:`Drift` prediction for the images
|
68
|
-
threshold : float
|
69
|
-
Threshold after multivariate correction if needed
|
70
|
-
p_val : float
|
71
|
-
Instance-level p-value
|
72
|
-
distance : float
|
73
|
-
Instance-level distance
|
74
|
-
feature_drift : NDArray
|
75
|
-
Feature-level array of images detected to have drifted
|
76
|
-
feature_threshold : float
|
77
|
-
Feature-level threshold to determine drift
|
78
|
-
p_vals : NDArray
|
79
|
-
Feature-level p-values
|
80
|
-
distances : NDArray
|
81
|
-
Feature-level distances
|
32
|
+
Protocol for reference dataset update strategy for drift detectors
|
82
33
|
"""
|
83
34
|
|
84
|
-
|
85
|
-
# threshold: float
|
86
|
-
# p_val: float
|
87
|
-
# distance: float
|
88
|
-
feature_drift: NDArray[np.bool_]
|
89
|
-
feature_threshold: float
|
90
|
-
p_vals: NDArray[np.float32]
|
91
|
-
distances: NDArray[np.float32]
|
35
|
+
def __call__(self, x_ref: NDArray[Any], x: NDArray[Any], count: int) -> NDArray[Any]: ...
|
92
36
|
|
93
37
|
|
94
38
|
def update_x_ref(fn: Callable[..., R]) -> Callable[..., R]:
|
dataeval/detectors/drift/_mmd.py
CHANGED
@@ -10,44 +10,18 @@ from __future__ import annotations
|
|
10
10
|
|
11
11
|
__all__ = []
|
12
12
|
|
13
|
-
from dataclasses import dataclass
|
14
13
|
from typing import Callable
|
15
14
|
|
16
15
|
import torch
|
17
16
|
|
18
|
-
from dataeval.
|
19
|
-
from dataeval.
|
20
|
-
from dataeval.detectors.drift._base import BaseDrift, DriftBaseOutput, UpdateStrategy, preprocess_x, update_x_ref
|
17
|
+
from dataeval.config import DeviceLike, get_device
|
18
|
+
from dataeval.detectors.drift._base import BaseDrift, UpdateStrategy, preprocess_x, update_x_ref
|
21
19
|
from dataeval.detectors.drift._torch import GaussianRBF, mmd2_from_kernel_matrix
|
20
|
+
from dataeval.outputs import DriftMMDOutput
|
21
|
+
from dataeval.outputs._base import set_metadata
|
22
22
|
from dataeval.typing import ArrayLike
|
23
23
|
|
24
24
|
|
25
|
-
@dataclass(frozen=True)
|
26
|
-
class DriftMMDOutput(DriftBaseOutput):
|
27
|
-
"""
|
28
|
-
Output class for :class:`.DriftMMD` :term:`drift<Drift>` detector.
|
29
|
-
|
30
|
-
Attributes
|
31
|
-
----------
|
32
|
-
drifted : bool
|
33
|
-
Drift prediction for the images
|
34
|
-
threshold : float
|
35
|
-
:term:`P-Value` used for significance of the permutation test
|
36
|
-
p_val : float
|
37
|
-
P-value obtained from the permutation test
|
38
|
-
distance : float
|
39
|
-
MMD^2 between the reference and test set
|
40
|
-
distance_threshold : float
|
41
|
-
MMD^2 threshold above which drift is flagged
|
42
|
-
"""
|
43
|
-
|
44
|
-
# drifted: bool
|
45
|
-
# threshold: float
|
46
|
-
# p_val: float
|
47
|
-
# distance: float
|
48
|
-
distance_threshold: float
|
49
|
-
|
50
|
-
|
51
25
|
class DriftMMD(BaseDrift):
|
52
26
|
"""
|
53
27
|
:term:`Maximum Mean Discrepancy (MMD) Drift Detection` algorithm \
|
@@ -57,7 +31,7 @@ class DriftMMD(BaseDrift):
|
|
57
31
|
----------
|
58
32
|
x_ref : ArrayLike
|
59
33
|
Data used as reference distribution.
|
60
|
-
p_val : float
|
34
|
+
p_val : float or None, default 0.05
|
61
35
|
:term:`P-value` used for significance of the statistical test for each feature.
|
62
36
|
If the FDR correction method is used, this corresponds to the acceptable
|
63
37
|
q-value.
|
@@ -65,14 +39,14 @@ class DriftMMD(BaseDrift):
|
|
65
39
|
Whether the given reference data ``x_ref`` has been preprocessed yet.
|
66
40
|
If ``True``, only the test data ``x`` will be preprocessed at prediction time.
|
67
41
|
If ``False``, the reference data will also be preprocessed.
|
68
|
-
update_x_ref : UpdateStrategy
|
42
|
+
update_x_ref : UpdateStrategy or None, default None
|
69
43
|
Reference data can optionally be updated using an UpdateStrategy class. Update
|
70
44
|
using the last n instances seen by the detector with LastSeenUpdateStrategy
|
71
45
|
or via reservoir sampling with ReservoirSamplingUpdateStrategy.
|
72
|
-
preprocess_fn : Callable
|
46
|
+
preprocess_fn : Callable or None, default None
|
73
47
|
Function to preprocess the data before computing the data drift metrics.
|
74
48
|
Typically a :term:`dimensionality reduction<Dimensionality Reduction>` technique.
|
75
|
-
sigma : ArrayLike
|
49
|
+
sigma : ArrayLike or None, default None
|
76
50
|
Optionally set the internal GaussianRBF kernel bandwidth. Can also pass multiple
|
77
51
|
bandwidth values as an array. The kernel evaluation is then averaged over
|
78
52
|
those bandwidths.
|
@@ -80,9 +54,9 @@ class DriftMMD(BaseDrift):
|
|
80
54
|
Whether to already configure the kernel bandwidth from the reference data.
|
81
55
|
n_permutations : int, default 100
|
82
56
|
Number of permutations used in the permutation test.
|
83
|
-
device :
|
84
|
-
|
85
|
-
|
57
|
+
device : DeviceLike or None, default None
|
58
|
+
The hardware device to use if specified, otherwise uses the DataEval
|
59
|
+
default or torch default.
|
86
60
|
|
87
61
|
Example
|
88
62
|
-------
|
@@ -110,7 +84,7 @@ class DriftMMD(BaseDrift):
|
|
110
84
|
sigma: ArrayLike | None = None,
|
111
85
|
configure_kernel_from_x_ref: bool = True,
|
112
86
|
n_permutations: int = 100,
|
113
|
-
device:
|
87
|
+
device: DeviceLike | None = None,
|
114
88
|
) -> None:
|
115
89
|
super().__init__(x_ref, p_val, x_ref_preprocessed, update_x_ref, preprocess_fn)
|
116
90
|
|
@@ -17,7 +17,7 @@ import torch
|
|
17
17
|
import torch.nn as nn
|
18
18
|
from numpy.typing import NDArray
|
19
19
|
|
20
|
-
from dataeval.config import get_device
|
20
|
+
from dataeval.config import DeviceLike, get_device
|
21
21
|
from dataeval.utils.torch._internal import predict_batch
|
22
22
|
|
23
23
|
|
@@ -59,7 +59,7 @@ def mmd2_from_kernel_matrix(
|
|
59
59
|
def preprocess_drift(
|
60
60
|
x: NDArray[Any],
|
61
61
|
model: nn.Module,
|
62
|
-
device:
|
62
|
+
device: DeviceLike | None = None,
|
63
63
|
preprocess_batch_fn: Callable | None = None,
|
64
64
|
batch_size: int = int(1e10),
|
65
65
|
dtype: type[np.generic] | torch.dtype = np.float32,
|
@@ -73,15 +73,15 @@ def preprocess_drift(
|
|
73
73
|
Batch of instances.
|
74
74
|
model : nn.Module
|
75
75
|
Model used for preprocessing.
|
76
|
-
device :
|
77
|
-
|
78
|
-
|
79
|
-
preprocess_batch_fn : Callable
|
76
|
+
device : DeviceLike or None, default None
|
77
|
+
The hardware device to use if specified, otherwise uses the DataEval
|
78
|
+
default or torch default.
|
79
|
+
preprocess_batch_fn : Callable or None, default None
|
80
80
|
Optional batch preprocessing function. For example to convert a list of objects
|
81
81
|
to a batch which can be processed by the PyTorch model.
|
82
82
|
batch_size : int, default 1e10
|
83
83
|
Batch size used during prediction.
|
84
|
-
dtype : np.dtype
|
84
|
+
dtype : np.dtype or torch.dtype, default np.float32
|
85
85
|
Model output type, either a :term:`NumPy` or torch dtype, e.g. np.float32 or torch.float32.
|
86
86
|
|
87
87
|
Returns
|
@@ -19,9 +19,10 @@ from scipy.special import softmax
|
|
19
19
|
from scipy.stats import entropy
|
20
20
|
|
21
21
|
from dataeval.config import get_device
|
22
|
-
from dataeval.detectors.drift._base import
|
22
|
+
from dataeval.detectors.drift._base import UpdateStrategy
|
23
23
|
from dataeval.detectors.drift._ks import DriftKS
|
24
24
|
from dataeval.detectors.drift._torch import preprocess_drift
|
25
|
+
from dataeval.outputs import DriftOutput
|
25
26
|
from dataeval.typing import ArrayLike
|
26
27
|
|
27
28
|
|
@@ -84,20 +85,20 @@ class DriftUncertainty:
|
|
84
85
|
Whether the given reference data ``x_ref`` has been preprocessed yet.
|
85
86
|
If ``True``, only the test data ``x`` will be preprocessed at prediction time.
|
86
87
|
If ``False``, the reference data will also be preprocessed.
|
87
|
-
update_x_ref : UpdateStrategy
|
88
|
+
update_x_ref : UpdateStrategy or None, default None
|
88
89
|
Reference data can optionally be updated using an UpdateStrategy class. Update
|
89
90
|
using the last n instances seen by the detector with LastSeenUpdateStrategy
|
90
91
|
or via reservoir sampling with ReservoirSamplingUpdateStrategy.
|
91
|
-
preds_type : "probs"
|
92
|
+
preds_type : "probs" or "logits", default "probs"
|
92
93
|
Type of prediction output by the model. Options are 'probs' (in [0,1]) or
|
93
94
|
'logits' (in [-inf,inf]).
|
94
95
|
batch_size : int, default 32
|
95
96
|
Batch size used to evaluate model. Only relevant when backend has been
|
96
97
|
specified for batch prediction.
|
97
|
-
preprocess_batch_fn : Callable
|
98
|
+
preprocess_batch_fn : Callable or None, default None
|
98
99
|
Optional batch preprocessing function. For example to convert a list of
|
99
100
|
objects to a batch which can be processed by the model.
|
100
|
-
device :
|
101
|
+
device : DeviceLike or None, default None
|
101
102
|
Device type used. The default None tries to use the GPU and falls back on
|
102
103
|
CPU if needed. Can be specified by passing either 'cuda' or 'cpu'.
|
103
104
|
|
@@ -7,15 +7,32 @@ from __future__ import annotations
|
|
7
7
|
|
8
8
|
__all__ = ["LastSeenUpdate", "ReservoirSamplingUpdate"]
|
9
9
|
|
10
|
+
from abc import ABC, abstractmethod
|
10
11
|
from typing import Any
|
11
12
|
|
12
13
|
import numpy as np
|
13
14
|
from numpy.typing import NDArray
|
14
15
|
|
15
|
-
|
16
|
+
|
17
|
+
class BaseUpdateStrategy(ABC):
|
18
|
+
"""
|
19
|
+
Updates reference dataset for drift detector
|
20
|
+
|
21
|
+
Parameters
|
22
|
+
----------
|
23
|
+
n : int
|
24
|
+
Update with last n instances seen by the detector.
|
25
|
+
"""
|
26
|
+
|
27
|
+
def __init__(self, n: int) -> None:
|
28
|
+
self.n = n
|
29
|
+
|
30
|
+
@abstractmethod
|
31
|
+
def __call__(self, x_ref: NDArray[Any], x: NDArray[Any], count: int) -> NDArray[Any]:
|
32
|
+
"""Abstract implementation of update strategy"""
|
16
33
|
|
17
34
|
|
18
|
-
class LastSeenUpdate(
|
35
|
+
class LastSeenUpdate(BaseUpdateStrategy):
|
19
36
|
"""
|
20
37
|
Updates reference dataset for :term:`drift<Drift>` detector using last seen method.
|
21
38
|
|
@@ -30,7 +47,7 @@ class LastSeenUpdate(UpdateStrategy):
|
|
30
47
|
return x_updated[-self.n :]
|
31
48
|
|
32
49
|
|
33
|
-
class ReservoirSamplingUpdate(
|
50
|
+
class ReservoirSamplingUpdate(BaseUpdateStrategy):
|
34
51
|
"""
|
35
52
|
Updates reference dataset for :term:`drift<Drift>` detector using reservoir sampling method.
|
36
53
|
|
@@ -9,5 +9,6 @@ __all__ = [
|
|
9
9
|
"OutliersOutput",
|
10
10
|
]
|
11
11
|
|
12
|
-
from dataeval.detectors.linters.duplicates import Duplicates
|
13
|
-
from dataeval.detectors.linters.outliers import Outliers
|
12
|
+
from dataeval.detectors.linters.duplicates import Duplicates
|
13
|
+
from dataeval.detectors.linters.outliers import Outliers
|
14
|
+
from dataeval.outputs._linters import DuplicatesOutput, OutliersOutput
|
@@ -2,40 +2,15 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
-
from
|
6
|
-
from typing import Any, Generic, Iterable, Sequence, TypeVar, overload
|
5
|
+
from typing import Any, Sequence, overload
|
7
6
|
|
8
|
-
from
|
9
|
-
|
10
|
-
from dataeval._output import Output, set_metadata
|
7
|
+
from dataeval.metrics.stats import hashstats
|
11
8
|
from dataeval.metrics.stats._base import combine_stats, get_dataset_step_from_idx
|
12
|
-
from dataeval.
|
13
|
-
from dataeval.
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
TIndexCollection = TypeVar("TIndexCollection", DuplicateGroup, DatasetDuplicateGroupMap)
|
18
|
-
|
19
|
-
|
20
|
-
@dataclass(frozen=True)
|
21
|
-
class DuplicatesOutput(Generic[TIndexCollection], Output):
|
22
|
-
"""
|
23
|
-
Output class for :class:`.Duplicates` lint detector.
|
24
|
-
|
25
|
-
Attributes
|
26
|
-
----------
|
27
|
-
exact : list[list[int] | dict[int, list[int]]]
|
28
|
-
Indices of images that are exact matches
|
29
|
-
near: list[list[int] | dict[int, list[int]]]
|
30
|
-
Indices of images that are near matches
|
31
|
-
|
32
|
-
- For a single dataset, indices are returned as a list of index groups.
|
33
|
-
- For multiple datasets, indices are returned as dictionaries where the key is the
|
34
|
-
index of the dataset, and the value is the list index groups from that dataset.
|
35
|
-
"""
|
36
|
-
|
37
|
-
exact: list[TIndexCollection]
|
38
|
-
near: list[TIndexCollection]
|
9
|
+
from dataeval.outputs import DuplicatesOutput, HashStatsOutput
|
10
|
+
from dataeval.outputs._base import set_metadata
|
11
|
+
from dataeval.outputs._linters import DatasetDuplicateGroupMap, DuplicateGroup
|
12
|
+
from dataeval.typing import Array, Dataset
|
13
|
+
from dataeval.utils.data._images import Images
|
39
14
|
|
40
15
|
|
41
16
|
class Duplicates:
|
@@ -113,13 +88,13 @@ class Duplicates:
|
|
113
88
|
"""
|
114
89
|
|
115
90
|
if isinstance(hashes, HashStatsOutput):
|
116
|
-
return DuplicatesOutput(**self._get_duplicates(hashes.
|
91
|
+
return DuplicatesOutput(**self._get_duplicates(hashes.data()))
|
117
92
|
|
118
93
|
if not isinstance(hashes, Sequence):
|
119
94
|
raise TypeError("Invalid stats output type; only use output from hashstats.")
|
120
95
|
|
121
96
|
combined, dataset_steps = combine_stats(hashes)
|
122
|
-
duplicates = self._get_duplicates(combined.
|
97
|
+
duplicates = self._get_duplicates(combined.data())
|
123
98
|
|
124
99
|
# split up results from combined dataset into individual dataset buckets
|
125
100
|
for dup_type, dup_list in duplicates.items():
|
@@ -134,22 +109,15 @@ class Duplicates:
|
|
134
109
|
|
135
110
|
return DuplicatesOutput(**duplicates)
|
136
111
|
|
137
|
-
@overload
|
138
|
-
def evaluate(self, data: Iterable[ArrayLike]) -> DuplicatesOutput[DuplicateGroup]: ...
|
139
|
-
@overload
|
140
|
-
def evaluate(self, data: Dataset[tuple[ArrayLike, Any, dict[str, Any]]]) -> DuplicatesOutput[DuplicateGroup]: ...
|
141
|
-
|
142
112
|
@set_metadata(state=["only_exact"])
|
143
|
-
def evaluate(
|
144
|
-
self, data: Iterable[ArrayLike] | Dataset[tuple[ArrayLike, Any, dict[str, Any]]]
|
145
|
-
) -> DuplicatesOutput[DuplicateGroup]:
|
113
|
+
def evaluate(self, data: Dataset[Array] | Dataset[tuple[Array, Any, Any]]) -> DuplicatesOutput[DuplicateGroup]:
|
146
114
|
"""
|
147
115
|
Returns duplicate image indices for both exact matches and near matches
|
148
116
|
|
149
117
|
Parameters
|
150
118
|
----------
|
151
|
-
data : Iterable[
|
152
|
-
A dataset of images in an
|
119
|
+
data : Iterable[Array], shape - (N, C, H, W) | Dataset[tuple[Array, Any, Any]]
|
120
|
+
A dataset of images in an Array format or the output(s) from a hashstats analysis
|
153
121
|
|
154
122
|
Returns
|
155
123
|
-------
|
@@ -166,7 +134,7 @@ class Duplicates:
|
|
166
134
|
>>> all_dupes.evaluate(duplicate_images)
|
167
135
|
DuplicatesOutput(exact=[[3, 20], [16, 37]], near=[[3, 20, 22], [12, 18], [13, 36], [14, 31], [17, 27], [19, 38, 47]])
|
168
136
|
""" # noqa: E501
|
169
|
-
images = (
|
137
|
+
images = Images(data) if isinstance(data, Dataset) else data
|
170
138
|
self.stats = hashstats(images)
|
171
|
-
duplicates = self._get_duplicates(self.stats.
|
139
|
+
duplicates = self._get_duplicates(self.stats.data())
|
172
140
|
return DuplicatesOutput(**duplicates)
|