dataeval 0.73.1__py3-none-any.whl → 0.74.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +3 -9
- dataeval/detectors/__init__.py +2 -10
- dataeval/detectors/drift/base.py +3 -3
- dataeval/detectors/drift/mmd.py +1 -1
- dataeval/detectors/drift/torch.py +1 -101
- dataeval/detectors/linters/clusterer.py +3 -3
- dataeval/detectors/linters/duplicates.py +4 -4
- dataeval/detectors/linters/outliers.py +4 -4
- dataeval/detectors/ood/__init__.py +9 -9
- dataeval/detectors/ood/{ae.py → ae_torch.py} +22 -27
- dataeval/detectors/ood/base.py +63 -113
- dataeval/detectors/ood/base_torch.py +109 -0
- dataeval/detectors/ood/metadata_ks_compare.py +52 -14
- dataeval/interop.py +1 -1
- dataeval/metrics/bias/__init__.py +3 -0
- dataeval/metrics/bias/balance.py +73 -70
- dataeval/metrics/bias/coverage.py +4 -4
- dataeval/metrics/bias/diversity.py +67 -136
- dataeval/metrics/bias/metadata_preprocessing.py +285 -0
- dataeval/metrics/bias/metadata_utils.py +229 -0
- dataeval/metrics/bias/parity.py +51 -161
- dataeval/metrics/estimators/ber.py +3 -3
- dataeval/metrics/estimators/divergence.py +3 -3
- dataeval/metrics/estimators/uap.py +3 -3
- dataeval/metrics/stats/base.py +2 -2
- dataeval/metrics/stats/boxratiostats.py +1 -1
- dataeval/metrics/stats/datasetstats.py +6 -6
- dataeval/metrics/stats/dimensionstats.py +1 -1
- dataeval/metrics/stats/hashstats.py +1 -1
- dataeval/metrics/stats/labelstats.py +3 -3
- dataeval/metrics/stats/pixelstats.py +1 -1
- dataeval/metrics/stats/visualstats.py +1 -1
- dataeval/output.py +77 -53
- dataeval/utils/__init__.py +1 -7
- dataeval/utils/gmm.py +26 -0
- dataeval/utils/metadata.py +29 -9
- dataeval/utils/torch/gmm.py +98 -0
- dataeval/utils/torch/models.py +192 -0
- dataeval/utils/torch/trainer.py +84 -5
- dataeval/utils/torch/utils.py +107 -1
- dataeval/workflows/sufficiency.py +4 -4
- {dataeval-0.73.1.dist-info → dataeval-0.74.1.dist-info}/METADATA +3 -9
- dataeval-0.74.1.dist-info/RECORD +65 -0
- dataeval/detectors/ood/aegmm.py +0 -66
- dataeval/detectors/ood/llr.py +0 -302
- dataeval/detectors/ood/vae.py +0 -97
- dataeval/detectors/ood/vaegmm.py +0 -75
- dataeval/metrics/bias/metadata.py +0 -440
- dataeval/utils/lazy.py +0 -26
- dataeval/utils/tensorflow/__init__.py +0 -19
- dataeval/utils/tensorflow/_internal/gmm.py +0 -123
- dataeval/utils/tensorflow/_internal/loss.py +0 -121
- dataeval/utils/tensorflow/_internal/models.py +0 -1394
- dataeval/utils/tensorflow/_internal/trainer.py +0 -114
- dataeval/utils/tensorflow/_internal/utils.py +0 -256
- dataeval/utils/tensorflow/loss/__init__.py +0 -11
- dataeval-0.73.1.dist-info/RECORD +0 -73
- {dataeval-0.73.1.dist-info → dataeval-0.74.1.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.73.1.dist-info → dataeval-0.74.1.dist-info}/WHEEL +0 -0
dataeval/__init__.py
CHANGED
@@ -1,10 +1,9 @@
|
|
1
|
-
__version__ = "0.
|
1
|
+
__version__ = "0.74.1"
|
2
2
|
|
3
3
|
from importlib.util import find_spec
|
4
4
|
|
5
5
|
_IS_TORCH_AVAILABLE = find_spec("torch") is not None
|
6
6
|
_IS_TORCHVISION_AVAILABLE = find_spec("torchvision") is not None
|
7
|
-
_IS_TENSORFLOW_AVAILABLE = find_spec("tensorflow") is not None and find_spec("tensorflow_probability") is not None
|
8
7
|
|
9
8
|
del find_spec
|
10
9
|
|
@@ -13,11 +12,6 @@ from dataeval import detectors, metrics # noqa: E402
|
|
13
12
|
__all__ = ["detectors", "metrics"]
|
14
13
|
|
15
14
|
if _IS_TORCH_AVAILABLE:
|
16
|
-
from dataeval import workflows
|
15
|
+
from dataeval import utils, workflows
|
17
16
|
|
18
|
-
__all__ += ["workflows"]
|
19
|
-
|
20
|
-
if _IS_TENSORFLOW_AVAILABLE or _IS_TORCH_AVAILABLE:
|
21
|
-
from dataeval import utils
|
22
|
-
|
23
|
-
__all__ += ["utils"]
|
17
|
+
__all__ += ["utils", "workflows"]
|
dataeval/detectors/__init__.py
CHANGED
@@ -2,14 +2,6 @@
|
|
2
2
|
Detectors can determine if a dataset or individual images in a dataset are indicative of a specific issue.
|
3
3
|
"""
|
4
4
|
|
5
|
-
from dataeval import
|
6
|
-
from dataeval.detectors import drift, linters
|
5
|
+
from dataeval.detectors import drift, linters, ood
|
7
6
|
|
8
|
-
__all__ = ["drift", "linters"]
|
9
|
-
|
10
|
-
if _IS_TENSORFLOW_AVAILABLE:
|
11
|
-
from dataeval.detectors import ood
|
12
|
-
|
13
|
-
__all__ += ["ood"]
|
14
|
-
|
15
|
-
del _IS_TENSORFLOW_AVAILABLE
|
7
|
+
__all__ = ["drift", "linters", "ood"]
|
dataeval/detectors/drift/base.py
CHANGED
@@ -19,7 +19,7 @@ import numpy as np
|
|
19
19
|
from numpy.typing import ArrayLike, NDArray
|
20
20
|
|
21
21
|
from dataeval.interop import as_numpy
|
22
|
-
from dataeval.output import
|
22
|
+
from dataeval.output import Output, set_metadata
|
23
23
|
|
24
24
|
R = TypeVar("R")
|
25
25
|
|
@@ -43,7 +43,7 @@ class UpdateStrategy(ABC):
|
|
43
43
|
|
44
44
|
|
45
45
|
@dataclass(frozen=True)
|
46
|
-
class DriftBaseOutput(
|
46
|
+
class DriftBaseOutput(Output):
|
47
47
|
"""
|
48
48
|
Base output class for Drift detector classes
|
49
49
|
|
@@ -387,7 +387,7 @@ class BaseDriftUnivariate(BaseDrift):
|
|
387
387
|
else:
|
388
388
|
raise ValueError("`correction` needs to be either `bonferroni` or `fdr`.")
|
389
389
|
|
390
|
-
@set_metadata
|
390
|
+
@set_metadata
|
391
391
|
@preprocess_x
|
392
392
|
@update_x_ref
|
393
393
|
def predict(
|
dataeval/detectors/drift/mmd.py
CHANGED
@@ -161,7 +161,7 @@ class DriftMMD(BaseDrift):
|
|
161
161
|
distance_threshold = torch.sort(mmd2_permuted, descending=True).values[idx_threshold]
|
162
162
|
return p_val.numpy().item(), mmd2.numpy().item(), distance_threshold.numpy().item()
|
163
163
|
|
164
|
-
@set_metadata
|
164
|
+
@set_metadata
|
165
165
|
@preprocess_x
|
166
166
|
@update_x_ref
|
167
167
|
def predict(self, x: ArrayLike) -> DriftMMDOutput:
|
@@ -10,7 +10,6 @@ from __future__ import annotations
|
|
10
10
|
|
11
11
|
__all__ = []
|
12
12
|
|
13
|
-
from functools import partial
|
14
13
|
from typing import Any, Callable
|
15
14
|
|
16
15
|
import numpy as np
|
@@ -18,30 +17,7 @@ import torch
|
|
18
17
|
import torch.nn as nn
|
19
18
|
from numpy.typing import NDArray
|
20
19
|
|
21
|
-
|
22
|
-
def get_device(device: str | torch.device | None = None) -> torch.device:
|
23
|
-
"""
|
24
|
-
Instantiates a PyTorch device object.
|
25
|
-
|
26
|
-
Parameters
|
27
|
-
----------
|
28
|
-
device : str | torch.device | None, default None
|
29
|
-
Either ``None``, a str ('gpu' or 'cpu') indicating the device to choose, or an
|
30
|
-
already instantiated device object. If ``None``, the GPU is selected if it is
|
31
|
-
detected, otherwise the CPU is used as a fallback.
|
32
|
-
|
33
|
-
Returns
|
34
|
-
-------
|
35
|
-
The instantiated device object.
|
36
|
-
"""
|
37
|
-
if isinstance(device, torch.device): # Already a torch device
|
38
|
-
return device
|
39
|
-
else: # Instantiate device
|
40
|
-
if device is None or device.lower() in ["gpu", "cuda"]:
|
41
|
-
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
42
|
-
else:
|
43
|
-
torch_device = torch.device("cpu")
|
44
|
-
return torch_device
|
20
|
+
from dataeval.utils.torch.utils import get_device, predict_batch
|
45
21
|
|
46
22
|
|
47
23
|
def _mmd2_from_kernel_matrix(
|
@@ -79,82 +55,6 @@ def _mmd2_from_kernel_matrix(
|
|
79
55
|
return mmd2
|
80
56
|
|
81
57
|
|
82
|
-
def predict_batch(
|
83
|
-
x: NDArray[Any] | torch.Tensor,
|
84
|
-
model: Callable | nn.Module | nn.Sequential,
|
85
|
-
device: torch.device | None = None,
|
86
|
-
batch_size: int = int(1e10),
|
87
|
-
preprocess_fn: Callable | None = None,
|
88
|
-
dtype: type[np.generic] | torch.dtype = np.float32,
|
89
|
-
) -> NDArray[Any] | torch.Tensor | tuple[Any, ...]:
|
90
|
-
"""
|
91
|
-
Make batch predictions on a model.
|
92
|
-
|
93
|
-
Parameters
|
94
|
-
----------
|
95
|
-
x : np.ndarray | torch.Tensor
|
96
|
-
Batch of instances.
|
97
|
-
model : Callable | nn.Module | nn.Sequential
|
98
|
-
PyTorch model.
|
99
|
-
device : torch.device | None, default None
|
100
|
-
Device type used. The default None tries to use the GPU and falls back on CPU.
|
101
|
-
Can be specified by passing either torch.device('cuda') or torch.device('cpu').
|
102
|
-
batch_size : int, default 1e10
|
103
|
-
Batch size used during prediction.
|
104
|
-
preprocess_fn : Callable | None, default None
|
105
|
-
Optional preprocessing function for each batch.
|
106
|
-
dtype : np.dtype | torch.dtype, default np.float32
|
107
|
-
Model output type, either a :term:`NumPy` or torch dtype, e.g. np.float32 or torch.float32.
|
108
|
-
|
109
|
-
Returns
|
110
|
-
-------
|
111
|
-
NDArray | torch.Tensor | tuple
|
112
|
-
Numpy array, torch tensor or tuples of those with model outputs.
|
113
|
-
"""
|
114
|
-
device = get_device(device)
|
115
|
-
if isinstance(x, np.ndarray):
|
116
|
-
x = torch.from_numpy(x)
|
117
|
-
n = len(x)
|
118
|
-
n_minibatch = int(np.ceil(n / batch_size))
|
119
|
-
return_np = not isinstance(dtype, torch.dtype)
|
120
|
-
preds = []
|
121
|
-
with torch.no_grad():
|
122
|
-
for i in range(n_minibatch):
|
123
|
-
istart, istop = i * batch_size, min((i + 1) * batch_size, n)
|
124
|
-
x_batch = x[istart:istop]
|
125
|
-
if isinstance(preprocess_fn, Callable):
|
126
|
-
x_batch = preprocess_fn(x_batch)
|
127
|
-
preds_tmp = model(x_batch.to(device))
|
128
|
-
if isinstance(preds_tmp, (list, tuple)):
|
129
|
-
if len(preds) == 0: # init tuple with lists to store predictions
|
130
|
-
preds = tuple([] for _ in range(len(preds_tmp)))
|
131
|
-
for j, p in enumerate(preds_tmp):
|
132
|
-
if isinstance(p, torch.Tensor):
|
133
|
-
p = p.cpu()
|
134
|
-
preds[j].append(p if not return_np or isinstance(p, np.ndarray) else p.numpy())
|
135
|
-
elif isinstance(preds_tmp, (np.ndarray, torch.Tensor)):
|
136
|
-
if isinstance(preds_tmp, torch.Tensor):
|
137
|
-
preds_tmp = preds_tmp.cpu()
|
138
|
-
if isinstance(preds, tuple):
|
139
|
-
preds = list(preds)
|
140
|
-
preds.append(
|
141
|
-
preds_tmp
|
142
|
-
if not return_np or isinstance(preds_tmp, np.ndarray) # type: ignore
|
143
|
-
else preds_tmp.numpy()
|
144
|
-
)
|
145
|
-
else:
|
146
|
-
raise TypeError(
|
147
|
-
f"Model output type {type(preds_tmp)} not supported. The model \
|
148
|
-
output type needs to be one of list, tuple, NDArray or \
|
149
|
-
torch.Tensor."
|
150
|
-
)
|
151
|
-
concat = partial(np.concatenate, axis=0) if return_np else partial(torch.cat, dim=0)
|
152
|
-
out: tuple | np.ndarray | torch.Tensor = (
|
153
|
-
tuple(concat(p) for p in preds) if isinstance(preds, tuple) else concat(preds) # type: ignore
|
154
|
-
)
|
155
|
-
return out
|
156
|
-
|
157
|
-
|
158
58
|
def preprocess_drift(
|
159
59
|
x: NDArray[Any],
|
160
60
|
model: nn.Module,
|
@@ -11,12 +11,12 @@ from scipy.cluster.hierarchy import linkage
|
|
11
11
|
from scipy.spatial.distance import pdist, squareform
|
12
12
|
|
13
13
|
from dataeval.interop import to_numpy
|
14
|
-
from dataeval.output import
|
14
|
+
from dataeval.output import Output, set_metadata
|
15
15
|
from dataeval.utils.shared import flatten
|
16
16
|
|
17
17
|
|
18
18
|
@dataclass(frozen=True)
|
19
|
-
class ClustererOutput(
|
19
|
+
class ClustererOutput(Output):
|
20
20
|
"""
|
21
21
|
Output class for :class:`Clusterer` lint detector
|
22
22
|
|
@@ -495,7 +495,7 @@ class Clusterer:
|
|
495
495
|
return exact_dupes, near_dupes
|
496
496
|
|
497
497
|
# TODO: Move data input to evaluate from class
|
498
|
-
@set_metadata(["data"])
|
498
|
+
@set_metadata(state=["data"])
|
499
499
|
def evaluate(self) -> ClustererOutput:
|
500
500
|
"""Finds and flags indices of the data for Outliers and :term:`duplicates<Duplicates>`
|
501
501
|
|
@@ -9,7 +9,7 @@ from numpy.typing import ArrayLike
|
|
9
9
|
|
10
10
|
from dataeval.detectors.linters.merged_stats import combine_stats, get_dataset_step_from_idx
|
11
11
|
from dataeval.metrics.stats.hashstats import HashStatsOutput, hashstats
|
12
|
-
from dataeval.output import
|
12
|
+
from dataeval.output import Output, set_metadata
|
13
13
|
|
14
14
|
DuplicateGroup = list[int]
|
15
15
|
DatasetDuplicateGroupMap = dict[int, DuplicateGroup]
|
@@ -17,7 +17,7 @@ TIndexCollection = TypeVar("TIndexCollection", DuplicateGroup, DatasetDuplicateG
|
|
17
17
|
|
18
18
|
|
19
19
|
@dataclass(frozen=True)
|
20
|
-
class DuplicatesOutput(Generic[TIndexCollection],
|
20
|
+
class DuplicatesOutput(Generic[TIndexCollection], Output):
|
21
21
|
"""
|
22
22
|
Output class for :class:`Duplicates` lint detector
|
23
23
|
|
@@ -89,7 +89,7 @@ class Duplicates:
|
|
89
89
|
@overload
|
90
90
|
def from_stats(self, hashes: Sequence[HashStatsOutput]) -> DuplicatesOutput[DatasetDuplicateGroupMap]: ...
|
91
91
|
|
92
|
-
@set_metadata(["only_exact"])
|
92
|
+
@set_metadata(state=["only_exact"])
|
93
93
|
def from_stats(
|
94
94
|
self, hashes: HashStatsOutput | Sequence[HashStatsOutput]
|
95
95
|
) -> DuplicatesOutput[DuplicateGroup] | DuplicatesOutput[DatasetDuplicateGroupMap]:
|
@@ -138,7 +138,7 @@ class Duplicates:
|
|
138
138
|
|
139
139
|
return DuplicatesOutput(**duplicates)
|
140
140
|
|
141
|
-
@set_metadata(["only_exact"])
|
141
|
+
@set_metadata(state=["only_exact"])
|
142
142
|
def evaluate(self, data: Iterable[ArrayLike]) -> DuplicatesOutput[DuplicateGroup]:
|
143
143
|
"""
|
144
144
|
Returns duplicate image indices for both exact matches and near matches
|
@@ -14,7 +14,7 @@ from dataeval.metrics.stats.datasetstats import DatasetStatsOutput, datasetstats
|
|
14
14
|
from dataeval.metrics.stats.dimensionstats import DimensionStatsOutput
|
15
15
|
from dataeval.metrics.stats.pixelstats import PixelStatsOutput
|
16
16
|
from dataeval.metrics.stats.visualstats import VisualStatsOutput
|
17
|
-
from dataeval.output import
|
17
|
+
from dataeval.output import Output, set_metadata
|
18
18
|
|
19
19
|
IndexIssueMap = dict[int, dict[str, float]]
|
20
20
|
OutlierStatsOutput = Union[DimensionStatsOutput, PixelStatsOutput, VisualStatsOutput]
|
@@ -22,7 +22,7 @@ TIndexIssueMap = TypeVar("TIndexIssueMap", IndexIssueMap, list[IndexIssueMap])
|
|
22
22
|
|
23
23
|
|
24
24
|
@dataclass(frozen=True)
|
25
|
-
class OutliersOutput(Generic[TIndexIssueMap],
|
25
|
+
class OutliersOutput(Generic[TIndexIssueMap], Output):
|
26
26
|
"""
|
27
27
|
Output class for :class:`Outliers` lint detector
|
28
28
|
|
@@ -159,7 +159,7 @@ class Outliers:
|
|
159
159
|
@overload
|
160
160
|
def from_stats(self, stats: Sequence[OutlierStatsOutput]) -> OutliersOutput[list[IndexIssueMap]]: ...
|
161
161
|
|
162
|
-
@set_metadata(["outlier_method", "outlier_threshold"])
|
162
|
+
@set_metadata(state=["outlier_method", "outlier_threshold"])
|
163
163
|
def from_stats(
|
164
164
|
self, stats: OutlierStatsOutput | DatasetStatsOutput | Sequence[OutlierStatsOutput]
|
165
165
|
) -> OutliersOutput[IndexIssueMap] | OutliersOutput[list[IndexIssueMap]]:
|
@@ -228,7 +228,7 @@ class Outliers:
|
|
228
228
|
|
229
229
|
return OutliersOutput(output_list)
|
230
230
|
|
231
|
-
@set_metadata(["use_dimension", "use_pixel", "use_visual", "outlier_method", "outlier_threshold"])
|
231
|
+
@set_metadata(state=["use_dimension", "use_pixel", "use_visual", "outlier_method", "outlier_threshold"])
|
232
232
|
def evaluate(self, data: Iterable[ArrayLike]) -> OutliersOutput[IndexIssueMap]:
|
233
233
|
"""
|
234
234
|
Returns indices of Outliers with the issues identified for each
|
@@ -2,14 +2,14 @@
|
|
2
2
|
Out-of-distribution (OOD)` detectors identify data that is different from the data used to train a particular model.
|
3
3
|
"""
|
4
4
|
|
5
|
-
from dataeval import
|
5
|
+
from dataeval import _IS_TORCH_AVAILABLE
|
6
|
+
from dataeval.detectors.ood.base import OODOutput, OODScoreOutput
|
6
7
|
|
7
|
-
|
8
|
-
from dataeval.detectors.ood.ae import OOD_AE
|
9
|
-
from dataeval.detectors.ood.aegmm import OOD_AEGMM
|
10
|
-
from dataeval.detectors.ood.base import OODOutput, OODScoreOutput
|
11
|
-
from dataeval.detectors.ood.llr import OOD_LLR
|
12
|
-
from dataeval.detectors.ood.vae import OOD_VAE
|
13
|
-
from dataeval.detectors.ood.vaegmm import OOD_VAEGMM
|
8
|
+
__all__ = ["OODOutput", "OODScoreOutput"]
|
14
9
|
|
15
|
-
|
10
|
+
if _IS_TORCH_AVAILABLE:
|
11
|
+
from dataeval.detectors.ood.ae_torch import OOD_AE
|
12
|
+
|
13
|
+
__all__ += ["OOD_AE"]
|
14
|
+
|
15
|
+
del _IS_TORCH_AVAILABLE
|
@@ -1,4 +1,6 @@
|
|
1
1
|
"""
|
2
|
+
Adapted for Pytorch from
|
3
|
+
|
2
4
|
Source code derived from Alibi-Detect 0.11.4
|
3
5
|
https://github.com/SeldonIO/alibi-detect/tree/v0.11.4
|
4
6
|
|
@@ -8,55 +10,48 @@ Licensed under Apache Software License (Apache 2.0)
|
|
8
10
|
|
9
11
|
from __future__ import annotations
|
10
12
|
|
11
|
-
|
12
|
-
|
13
|
-
from typing import TYPE_CHECKING, Callable
|
13
|
+
from typing import Callable
|
14
14
|
|
15
15
|
import numpy as np
|
16
|
+
import torch
|
16
17
|
from numpy.typing import ArrayLike
|
17
18
|
|
18
|
-
from dataeval.detectors.ood.base import
|
19
|
+
from dataeval.detectors.ood.base import OODScoreOutput
|
20
|
+
from dataeval.detectors.ood.base_torch import OODBase
|
19
21
|
from dataeval.interop import as_numpy
|
20
|
-
from dataeval.utils.
|
21
|
-
from dataeval.utils.tensorflow._internal.utils import predict_batch
|
22
|
-
|
23
|
-
if TYPE_CHECKING:
|
24
|
-
import tensorflow as tf
|
25
|
-
import tf_keras as keras
|
26
|
-
|
27
|
-
import dataeval.utils.tensorflow._internal.models as tf_models
|
28
|
-
else:
|
29
|
-
tf = lazyload("tensorflow")
|
30
|
-
keras = lazyload("tf_keras")
|
31
|
-
tf_models = lazyload("dataeval.utils.tensorflow._internal.models")
|
22
|
+
from dataeval.utils.torch.utils import predict_batch
|
32
23
|
|
33
24
|
|
34
25
|
class OOD_AE(OODBase):
|
35
26
|
"""
|
36
|
-
Autoencoder
|
27
|
+
Autoencoder based out-of-distribution detector.
|
37
28
|
|
38
29
|
Parameters
|
39
30
|
----------
|
40
|
-
model :
|
41
|
-
|
31
|
+
model : AriaAutoencoder
|
32
|
+
An Autoencoder model.
|
42
33
|
"""
|
43
34
|
|
44
|
-
def __init__(self, model:
|
45
|
-
super().__init__(model)
|
35
|
+
def __init__(self, model: torch.nn.Module, device: str | torch.device | None = None) -> None:
|
36
|
+
super().__init__(model, device)
|
46
37
|
|
47
38
|
def fit(
|
48
39
|
self,
|
49
40
|
x_ref: ArrayLike,
|
50
|
-
threshold_perc: float
|
51
|
-
loss_fn: Callable[...,
|
52
|
-
optimizer:
|
41
|
+
threshold_perc: float,
|
42
|
+
loss_fn: Callable[..., torch.nn.Module] | None = None,
|
43
|
+
optimizer: torch.optim.Optimizer | None = None,
|
53
44
|
epochs: int = 20,
|
54
45
|
batch_size: int = 64,
|
55
|
-
verbose: bool =
|
46
|
+
verbose: bool = False,
|
56
47
|
) -> None:
|
57
48
|
if loss_fn is None:
|
58
|
-
loss_fn =
|
59
|
-
|
49
|
+
loss_fn = torch.nn.MSELoss()
|
50
|
+
|
51
|
+
if optimizer is None:
|
52
|
+
optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
|
53
|
+
|
54
|
+
super().fit(x_ref, threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
|
60
55
|
|
61
56
|
def _score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
|
62
57
|
self._validate(X := as_numpy(X))
|
dataeval/detectors/ood/base.py
CHANGED
@@ -12,27 +12,18 @@ __all__ = ["OODOutput", "OODScoreOutput"]
|
|
12
12
|
|
13
13
|
from abc import ABC, abstractmethod
|
14
14
|
from dataclasses import dataclass
|
15
|
-
from typing import
|
15
|
+
from typing import Callable, Generic, Literal, TypeVar
|
16
16
|
|
17
17
|
import numpy as np
|
18
18
|
from numpy.typing import ArrayLike, NDArray
|
19
19
|
|
20
20
|
from dataeval.interop import to_numpy
|
21
|
-
from dataeval.output import
|
22
|
-
from dataeval.utils.
|
23
|
-
from dataeval.utils.tensorflow._internal.gmm import GaussianMixtureModelParams, gmm_params
|
24
|
-
from dataeval.utils.tensorflow._internal.trainer import trainer
|
25
|
-
|
26
|
-
if TYPE_CHECKING:
|
27
|
-
import tensorflow as tf
|
28
|
-
import tf_keras as keras
|
29
|
-
else:
|
30
|
-
tf = lazyload("tensorflow")
|
31
|
-
keras = lazyload("tf_keras")
|
21
|
+
from dataeval.output import Output, set_metadata
|
22
|
+
from dataeval.utils.gmm import GaussianMixtureModelParams
|
32
23
|
|
33
24
|
|
34
25
|
@dataclass(frozen=True)
|
35
|
-
class OODOutput(
|
26
|
+
class OODOutput(Output):
|
36
27
|
"""
|
37
28
|
Output class for predictions from :class:`OOD_AE`, :class:`OOD_AEGMM`, :class:`OOD_LLR`,
|
38
29
|
:class:`OOD_VAE`, and :class:`OOD_VAEGMM` out-of-distribution detectors
|
@@ -53,7 +44,7 @@ class OODOutput(OutputMetadata):
|
|
53
44
|
|
54
45
|
|
55
46
|
@dataclass(frozen=True)
|
56
|
-
class OODScoreOutput(
|
47
|
+
class OODScoreOutput(Output):
|
57
48
|
"""
|
58
49
|
Output class for instance and feature scores from :class:`OOD_AE`, :class:`OOD_AEGMM`,
|
59
50
|
:class:`OOD_LLR`, :class:`OOD_VAE`, and :class:`OOD_VAEGMM` out-of-distribution detectors
|
@@ -85,16 +76,62 @@ class OODScoreOutput(OutputMetadata):
|
|
85
76
|
return self.instance_score if ood_type == "instance" or self.feature_score is None else self.feature_score
|
86
77
|
|
87
78
|
|
88
|
-
|
89
|
-
|
90
|
-
|
79
|
+
TGMMData = TypeVar("TGMMData")
|
80
|
+
|
81
|
+
|
82
|
+
class OODGMMMixin(Generic[TGMMData]):
|
83
|
+
_gmm_params: GaussianMixtureModelParams[TGMMData]
|
84
|
+
|
91
85
|
|
92
|
-
|
93
|
-
|
94
|
-
|
86
|
+
TModel = TypeVar("TModel", bound=Callable)
|
87
|
+
TLossFn = TypeVar("TLossFn", bound=Callable)
|
88
|
+
TOptimizer = TypeVar("TOptimizer")
|
89
|
+
|
90
|
+
|
91
|
+
class OODFitMixin(Generic[TLossFn, TOptimizer], ABC):
|
92
|
+
@abstractmethod
|
93
|
+
def fit(
|
94
|
+
self,
|
95
|
+
x_ref: ArrayLike,
|
96
|
+
threshold_perc: float,
|
97
|
+
loss_fn: TLossFn | None,
|
98
|
+
optimizer: TOptimizer | None,
|
99
|
+
epochs: int,
|
100
|
+
batch_size: int,
|
101
|
+
verbose: bool,
|
102
|
+
) -> None:
|
103
|
+
"""
|
104
|
+
Train the model and infer the threshold value.
|
95
105
|
|
96
|
-
|
97
|
-
|
106
|
+
Parameters
|
107
|
+
----------
|
108
|
+
x_ref : ArrayLike
|
109
|
+
Training data.
|
110
|
+
threshold_perc : float, default 100.0
|
111
|
+
Percentage of reference data that is normal.
|
112
|
+
loss_fn : TLossFn
|
113
|
+
Loss function used for training.
|
114
|
+
optimizer : TOptimizer
|
115
|
+
Optimizer used for training.
|
116
|
+
epochs : int, default 20
|
117
|
+
Number of training epochs.
|
118
|
+
batch_size : int, default 64
|
119
|
+
Batch size used for training.
|
120
|
+
verbose : bool, default True
|
121
|
+
Whether to print training progress.
|
122
|
+
"""
|
123
|
+
|
124
|
+
|
125
|
+
class OODBaseMixin(Generic[TModel], ABC):
|
126
|
+
_ref_score: OODScoreOutput
|
127
|
+
_threshold_perc: float
|
128
|
+
_data_info: tuple[tuple, type] | None = None
|
129
|
+
|
130
|
+
def __init__(
|
131
|
+
self,
|
132
|
+
model: TModel,
|
133
|
+
) -> None:
|
134
|
+
self.model = model
|
98
135
|
|
99
136
|
def _get_data_info(self, X: NDArray) -> tuple[tuple, type]:
|
100
137
|
if not isinstance(X, np.ndarray):
|
@@ -107,9 +144,8 @@ class OODBase(ABC):
|
|
107
144
|
raise RuntimeError(f"Expect data of type: {self._data_info[1]} and shape: {self._data_info[0]}. \
|
108
145
|
Provided data is type: {check_data_info[1]} and shape: {check_data_info[0]}.")
|
109
146
|
|
110
|
-
def _validate_state(self, X: NDArray
|
111
|
-
attrs = [
|
112
|
-
attrs = attrs if additional_attrs is None else attrs + additional_attrs
|
147
|
+
def _validate_state(self, X: NDArray) -> None:
|
148
|
+
attrs = [k for c in self.__class__.mro()[:-1][::-1] if hasattr(c, "__annotations__") for k in c.__annotations__]
|
113
149
|
if not all(hasattr(self, attr) for attr in attrs) or any(getattr(self, attr) for attr in attrs) is None:
|
114
150
|
raise RuntimeError("Metric needs to be `fit` before method call.")
|
115
151
|
self._validate(X)
|
@@ -117,7 +153,7 @@ class OODBase(ABC):
|
|
117
153
|
@abstractmethod
|
118
154
|
def _score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput: ...
|
119
155
|
|
120
|
-
@set_metadata
|
156
|
+
@set_metadata
|
121
157
|
def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
|
122
158
|
"""
|
123
159
|
Compute the :term:`out of distribution<Out-of-distribution (OOD)>` scores for a given dataset.
|
@@ -140,53 +176,7 @@ class OODBase(ABC):
|
|
140
176
|
def _threshold_score(self, ood_type: Literal["feature", "instance"] = "instance") -> np.floating:
|
141
177
|
return np.percentile(self._ref_score.get(ood_type), self._threshold_perc)
|
142
178
|
|
143
|
-
|
144
|
-
self,
|
145
|
-
x_ref: ArrayLike,
|
146
|
-
threshold_perc: float,
|
147
|
-
loss_fn: Callable[..., tf.Tensor],
|
148
|
-
optimizer: keras.optimizers.Optimizer,
|
149
|
-
epochs: int,
|
150
|
-
batch_size: int,
|
151
|
-
verbose: bool,
|
152
|
-
) -> None:
|
153
|
-
"""
|
154
|
-
Train the model and infer the threshold value.
|
155
|
-
|
156
|
-
Parameters
|
157
|
-
----------
|
158
|
-
x_ref : ArrayLike
|
159
|
-
Training data.
|
160
|
-
threshold_perc : float, default 100.0
|
161
|
-
Percentage of reference data that is normal.
|
162
|
-
loss_fn : Callable | None, default None
|
163
|
-
Loss function used for training.
|
164
|
-
optimizer : Optimizer, default keras.optimizers.Adam
|
165
|
-
Optimizer used for training.
|
166
|
-
epochs : int, default 20
|
167
|
-
Number of training epochs.
|
168
|
-
batch_size : int, default 64
|
169
|
-
Batch size used for training.
|
170
|
-
verbose : bool, default True
|
171
|
-
Whether to print training progress.
|
172
|
-
"""
|
173
|
-
|
174
|
-
# Train the model
|
175
|
-
trainer(
|
176
|
-
model=self.model,
|
177
|
-
loss_fn=loss_fn,
|
178
|
-
x_train=to_numpy(x_ref),
|
179
|
-
optimizer=optimizer,
|
180
|
-
epochs=epochs,
|
181
|
-
batch_size=batch_size,
|
182
|
-
verbose=verbose,
|
183
|
-
)
|
184
|
-
|
185
|
-
# Infer the threshold values
|
186
|
-
self._ref_score = self.score(x_ref, batch_size)
|
187
|
-
self._threshold_perc = threshold_perc
|
188
|
-
|
189
|
-
@set_metadata()
|
179
|
+
@set_metadata
|
190
180
|
def predict(
|
191
181
|
self,
|
192
182
|
X: ArrayLike,
|
@@ -215,43 +205,3 @@ class OODBase(ABC):
|
|
215
205
|
score = self.score(X, batch_size=batch_size)
|
216
206
|
ood_pred = score.get(ood_type) > self._threshold_score(ood_type)
|
217
207
|
return OODOutput(is_ood=ood_pred, **score.dict())
|
218
|
-
|
219
|
-
|
220
|
-
class OODGMMBase(OODBase):
|
221
|
-
def __init__(self, model: keras.Model) -> None:
|
222
|
-
super().__init__(model)
|
223
|
-
self.gmm_params: GaussianMixtureModelParams
|
224
|
-
|
225
|
-
def _validate_state(self, X: NDArray, additional_attrs: list[str] | None = None) -> None:
|
226
|
-
if additional_attrs is None:
|
227
|
-
additional_attrs = ["gmm_params"]
|
228
|
-
super()._validate_state(X, additional_attrs)
|
229
|
-
|
230
|
-
def fit(
|
231
|
-
self,
|
232
|
-
x_ref: ArrayLike,
|
233
|
-
threshold_perc: float,
|
234
|
-
loss_fn: Callable[..., tf.Tensor],
|
235
|
-
optimizer: keras.optimizers.Optimizer,
|
236
|
-
epochs: int,
|
237
|
-
batch_size: int,
|
238
|
-
verbose: bool,
|
239
|
-
) -> None:
|
240
|
-
# Train the model
|
241
|
-
trainer(
|
242
|
-
model=self.model,
|
243
|
-
loss_fn=loss_fn,
|
244
|
-
x_train=to_numpy(x_ref),
|
245
|
-
optimizer=optimizer,
|
246
|
-
epochs=epochs,
|
247
|
-
batch_size=batch_size,
|
248
|
-
verbose=verbose,
|
249
|
-
)
|
250
|
-
|
251
|
-
# Calculate the GMM parameters
|
252
|
-
_, z, gamma = cast(tuple[tf.Tensor, tf.Tensor, tf.Tensor], self.model(x_ref))
|
253
|
-
self.gmm_params = gmm_params(z, gamma)
|
254
|
-
|
255
|
-
# Infer the threshold values
|
256
|
-
self._ref_score = self.score(x_ref, batch_size)
|
257
|
-
self._threshold_perc = threshold_perc
|