dataeval 0.64.0__py3-none-any.whl → 0.66.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +13 -9
- dataeval/_internal/detectors/clusterer.py +63 -49
- dataeval/_internal/detectors/drift/base.py +248 -51
- dataeval/_internal/detectors/drift/cvm.py +28 -26
- dataeval/_internal/detectors/drift/ks.py +31 -28
- dataeval/_internal/detectors/drift/mmd.py +62 -42
- dataeval/_internal/detectors/drift/torch.py +69 -60
- dataeval/_internal/detectors/drift/uncertainty.py +32 -32
- dataeval/_internal/detectors/duplicates.py +67 -31
- dataeval/_internal/detectors/ood/ae.py +15 -29
- dataeval/_internal/detectors/ood/aegmm.py +33 -27
- dataeval/_internal/detectors/ood/base.py +86 -47
- dataeval/_internal/detectors/ood/llr.py +34 -31
- dataeval/_internal/detectors/ood/vae.py +32 -31
- dataeval/_internal/detectors/ood/vaegmm.py +34 -28
- dataeval/_internal/detectors/{linter.py → outliers.py} +60 -38
- dataeval/_internal/flags.py +44 -21
- dataeval/_internal/interop.py +5 -3
- dataeval/_internal/metrics/balance.py +42 -5
- dataeval/_internal/metrics/ber.py +11 -8
- dataeval/_internal/metrics/coverage.py +15 -8
- dataeval/_internal/metrics/divergence.py +41 -7
- dataeval/_internal/metrics/diversity.py +57 -19
- dataeval/_internal/metrics/parity.py +141 -66
- dataeval/_internal/metrics/stats.py +330 -313
- dataeval/_internal/metrics/uap.py +33 -4
- dataeval/_internal/metrics/utils.py +79 -40
- dataeval/_internal/models/pytorch/autoencoder.py +127 -22
- dataeval/_internal/models/tensorflow/autoencoder.py +33 -30
- dataeval/_internal/models/tensorflow/gmm.py +4 -2
- dataeval/_internal/models/tensorflow/losses.py +17 -13
- dataeval/_internal/models/tensorflow/pixelcnn.py +19 -18
- dataeval/_internal/models/tensorflow/trainer.py +10 -7
- dataeval/_internal/models/tensorflow/utils.py +23 -20
- dataeval/_internal/output.py +85 -0
- dataeval/_internal/utils.py +5 -3
- dataeval/_internal/workflows/sufficiency.py +122 -121
- dataeval/detectors/__init__.py +6 -25
- dataeval/detectors/drift/__init__.py +16 -0
- dataeval/detectors/drift/kernels/__init__.py +6 -0
- dataeval/detectors/drift/updates/__init__.py +3 -0
- dataeval/detectors/linters/__init__.py +5 -0
- dataeval/detectors/ood/__init__.py +11 -0
- dataeval/flags/__init__.py +2 -2
- dataeval/metrics/__init__.py +2 -26
- dataeval/metrics/bias/__init__.py +14 -0
- dataeval/metrics/estimators/__init__.py +9 -0
- dataeval/metrics/stats/__init__.py +6 -0
- dataeval/tensorflow/__init__.py +3 -0
- dataeval/tensorflow/loss/__init__.py +3 -0
- dataeval/tensorflow/models/__init__.py +5 -0
- dataeval/tensorflow/recon/__init__.py +3 -0
- dataeval/torch/__init__.py +3 -0
- dataeval/{models/torch → torch/models}/__init__.py +1 -2
- dataeval/torch/trainer/__init__.py +3 -0
- dataeval/utils/__init__.py +3 -6
- dataeval/workflows/__init__.py +2 -4
- {dataeval-0.64.0.dist-info → dataeval-0.66.0.dist-info}/METADATA +1 -1
- dataeval-0.66.0.dist-info/RECORD +72 -0
- dataeval/_internal/metrics/base.py +0 -10
- dataeval/models/__init__.py +0 -15
- dataeval/models/tensorflow/__init__.py +0 -6
- dataeval-0.64.0.dist-info/RECORD +0 -60
- {dataeval-0.64.0.dist-info → dataeval-0.66.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.64.0.dist-info → dataeval-0.66.0.dist-info}/WHEEL +0 -0
@@ -1,9 +1,28 @@
|
|
1
|
-
from
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from typing import Iterable
|
2
5
|
|
3
6
|
from numpy.typing import ArrayLike
|
4
7
|
|
5
|
-
from dataeval._internal.flags import
|
6
|
-
from dataeval._internal.metrics.stats import
|
8
|
+
from dataeval._internal.flags import ImageStat
|
9
|
+
from dataeval._internal.metrics.stats import StatsOutput, imagestats
|
10
|
+
from dataeval._internal.output import OutputMetadata, set_metadata
|
11
|
+
|
12
|
+
|
13
|
+
@dataclass(frozen=True)
|
14
|
+
class DuplicatesOutput(OutputMetadata):
|
15
|
+
"""
|
16
|
+
Attributes
|
17
|
+
----------
|
18
|
+
exact : List[List[int]]
|
19
|
+
Indices of images that are exact matches
|
20
|
+
near: List[List[int]]
|
21
|
+
Indices of images that are near matches
|
22
|
+
"""
|
23
|
+
|
24
|
+
exact: list[list[int]]
|
25
|
+
near: list[list[int]]
|
7
26
|
|
8
27
|
|
9
28
|
class Duplicates:
|
@@ -13,8 +32,13 @@ class Duplicates:
|
|
13
32
|
|
14
33
|
Attributes
|
15
34
|
----------
|
16
|
-
stats :
|
17
|
-
|
35
|
+
stats : StatsOutput
|
36
|
+
Output class of stats
|
37
|
+
|
38
|
+
Parameters
|
39
|
+
----------
|
40
|
+
only_exact : bool, default False
|
41
|
+
Only inspect the dataset for exact image matches
|
18
42
|
|
19
43
|
Example
|
20
44
|
-------
|
@@ -23,51 +47,63 @@ class Duplicates:
|
|
23
47
|
>>> dups = Duplicates()
|
24
48
|
"""
|
25
49
|
|
26
|
-
def __init__(self):
|
27
|
-
self.stats
|
50
|
+
def __init__(self, only_exact: bool = False):
|
51
|
+
self.stats: StatsOutput
|
52
|
+
self.only_exact = only_exact
|
53
|
+
|
54
|
+
def _get_duplicates(self) -> dict[str, list[list[int]]]:
|
55
|
+
stats_dict = self.stats.dict()
|
56
|
+
if "xxhash" in stats_dict:
|
57
|
+
exact = {}
|
58
|
+
for i, value in enumerate(stats_dict["xxhash"]):
|
59
|
+
exact.setdefault(value, []).append(i)
|
60
|
+
exact = [v for v in exact.values() if len(v) > 1]
|
61
|
+
else:
|
62
|
+
exact = []
|
28
63
|
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
near
|
36
|
-
exact = [v for v in exact.values() if len(v) > 1]
|
37
|
-
near = [v for v in near.values() if len(v) > 1 and not any(set(v).issubset(x) for x in exact)]
|
64
|
+
if "pchash" in stats_dict and not self.only_exact:
|
65
|
+
near = {}
|
66
|
+
for i, value in enumerate(stats_dict["pchash"]):
|
67
|
+
near.setdefault(value, []).append(i)
|
68
|
+
near = [v for v in near.values() if len(v) > 1 and not any(set(v).issubset(x) for x in exact)]
|
69
|
+
else:
|
70
|
+
near = []
|
38
71
|
|
39
72
|
return {
|
40
73
|
"exact": sorted(exact),
|
41
74
|
"near": sorted(near),
|
42
75
|
}
|
43
76
|
|
44
|
-
|
77
|
+
@set_metadata("dataeval.detectors", ["only_exact"])
|
78
|
+
def evaluate(self, data: Iterable[ArrayLike] | StatsOutput) -> DuplicatesOutput:
|
45
79
|
"""
|
46
80
|
Returns duplicate image indices for both exact matches and near matches
|
47
81
|
|
48
82
|
Parameters
|
49
83
|
----------
|
50
|
-
|
51
|
-
A
|
84
|
+
data : Iterable[ArrayLike], shape - (N, C, H, W) | StatsOutput
|
85
|
+
A dataset of images in an ArrayLike format or the output from an imagestats metric analysis
|
52
86
|
|
53
87
|
Returns
|
54
88
|
-------
|
55
|
-
|
56
|
-
exact
|
57
|
-
List of groups of indices that are exact matches
|
58
|
-
near :
|
59
|
-
List of groups of indices that are near matches
|
89
|
+
DuplicatesOutput
|
90
|
+
List of groups of indices that are exact and near matches
|
60
91
|
|
61
92
|
See Also
|
62
93
|
--------
|
63
|
-
|
94
|
+
imagestats
|
64
95
|
|
65
96
|
Example
|
66
97
|
-------
|
67
98
|
>>> dups.evaluate(images)
|
68
|
-
|
69
|
-
"""
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
99
|
+
DuplicatesOutput(exact=[[3, 20], [16, 37]], near=[[3, 20, 22], [12, 18], [13, 36], [14, 31], [17, 27], [19, 38, 47]])
|
100
|
+
""" # noqa: E501
|
101
|
+
if isinstance(data, StatsOutput):
|
102
|
+
if not data.xxhash:
|
103
|
+
raise ValueError("StatsOutput must include xxhash information of the images.")
|
104
|
+
if not self.only_exact and not data.pchash:
|
105
|
+
raise ValueError("StatsOutput must include pchash information of the images for near matches.")
|
106
|
+
self.stats = data
|
107
|
+
else:
|
108
|
+
self.stats = imagestats(data, ImageStat.XXHASH | (ImageStat(0) if self.only_exact else ImageStat.PCHASH))
|
109
|
+
return DuplicatesOutput(**self._get_duplicates())
|
@@ -6,10 +6,13 @@ Original code Copyright (c) 2023 Seldon Technologies Ltd
|
|
6
6
|
Licensed under Apache Software License (Apache 2.0)
|
7
7
|
"""
|
8
8
|
|
9
|
+
from __future__ import annotations
|
10
|
+
|
9
11
|
from typing import Callable
|
10
12
|
|
11
13
|
import keras
|
12
14
|
import numpy as np
|
15
|
+
import tensorflow as tf
|
13
16
|
from numpy.typing import ArrayLike
|
14
17
|
|
15
18
|
from dataeval._internal.detectors.ood.base import OODBase, OODScore
|
@@ -19,47 +22,30 @@ from dataeval._internal.models.tensorflow.utils import predict_batch
|
|
19
22
|
|
20
23
|
|
21
24
|
class OOD_AE(OODBase):
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
+
"""
|
26
|
+
Autoencoder based out-of-distribution detector.
|
27
|
+
|
28
|
+
Parameters
|
29
|
+
----------
|
30
|
+
model : AE
|
31
|
+
An Autoencoder model.
|
32
|
+
"""
|
25
33
|
|
26
|
-
|
27
|
-
----------
|
28
|
-
model : AE
|
29
|
-
An Autoencoder model.
|
30
|
-
"""
|
34
|
+
def __init__(self, model: AE) -> None:
|
31
35
|
super().__init__(model)
|
32
36
|
|
33
37
|
def fit(
|
34
38
|
self,
|
35
39
|
x_ref: ArrayLike,
|
36
40
|
threshold_perc: float = 100.0,
|
37
|
-
loss_fn: Callable =
|
41
|
+
loss_fn: Callable[..., tf.Tensor] | None = None,
|
38
42
|
optimizer: keras.optimizers.Optimizer = keras.optimizers.Adam,
|
39
43
|
epochs: int = 20,
|
40
44
|
batch_size: int = 64,
|
41
45
|
verbose: bool = True,
|
42
46
|
) -> None:
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
Parameters
|
47
|
-
----------
|
48
|
-
x_ref : ArrayLike
|
49
|
-
Training batch.
|
50
|
-
threshold_perc : float, default 100.0
|
51
|
-
Percentage of reference data that is normal.
|
52
|
-
loss_fn : Callable, default keras.losses.MeanSquaredError()
|
53
|
-
Loss function used for training.
|
54
|
-
optimizer : keras.optimizers.Optimizer, default keras.optimizers.Adam
|
55
|
-
Optimizer used for training.
|
56
|
-
epochs : int, default 20
|
57
|
-
Number of training epochs.
|
58
|
-
batch_size : int, default 64
|
59
|
-
Batch size used for training.
|
60
|
-
verbose : bool, default True
|
61
|
-
Whether to print training progress.
|
62
|
-
"""
|
47
|
+
if loss_fn is None:
|
48
|
+
loss_fn = keras.losses.MeanSquaredError()
|
63
49
|
super().fit(to_numpy(x_ref), threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
|
64
50
|
|
65
51
|
def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScore:
|
@@ -6,9 +6,12 @@ Original code Copyright (c) 2023 Seldon Technologies Ltd
|
|
6
6
|
Licensed under Apache Software License (Apache 2.0)
|
7
7
|
"""
|
8
8
|
|
9
|
+
from __future__ import annotations
|
10
|
+
|
9
11
|
from typing import Callable
|
10
12
|
|
11
13
|
import keras
|
14
|
+
import tensorflow as tf
|
12
15
|
from numpy.typing import ArrayLike
|
13
16
|
|
14
17
|
from dataeval._internal.detectors.ood.base import OODGMMBase, OODScore
|
@@ -20,50 +23,53 @@ from dataeval._internal.models.tensorflow.utils import predict_batch
|
|
20
23
|
|
21
24
|
|
22
25
|
class OOD_AEGMM(OODGMMBase):
|
23
|
-
|
24
|
-
|
25
|
-
AE with Gaussian Mixture Model based outlier detector.
|
26
|
+
"""
|
27
|
+
AE with Gaussian Mixture Model based outlier detector.
|
26
28
|
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
29
|
+
Parameters
|
30
|
+
----------
|
31
|
+
model : AEGMM
|
32
|
+
An AEGMM model.
|
33
|
+
"""
|
34
|
+
|
35
|
+
def __init__(self, model: AEGMM) -> None:
|
32
36
|
super().__init__(model)
|
33
37
|
|
34
38
|
def fit(
|
35
39
|
self,
|
36
40
|
x_ref: ArrayLike,
|
37
41
|
threshold_perc: float = 100.0,
|
38
|
-
loss_fn: Callable =
|
42
|
+
loss_fn: Callable[..., tf.Tensor] | None = None,
|
39
43
|
optimizer: keras.optimizers.Optimizer = keras.optimizers.Adam,
|
40
44
|
epochs: int = 20,
|
41
45
|
batch_size: int = 64,
|
42
46
|
verbose: bool = True,
|
43
47
|
) -> None:
|
48
|
+
if loss_fn is None:
|
49
|
+
loss_fn = LossGMM()
|
50
|
+
super().fit(x_ref, threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
|
51
|
+
|
52
|
+
def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScore:
|
44
53
|
"""
|
45
|
-
|
54
|
+
Compute the out-of-distribution (OOD) score for a given dataset.
|
46
55
|
|
47
56
|
Parameters
|
48
57
|
----------
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
Loss function used for training.
|
55
|
-
optimizer : keras.optimizers.Optimizer, default keras.optimizers.Adam
|
56
|
-
Optimizer used for training.
|
57
|
-
epochs : int, default 20
|
58
|
-
Number of training epochs.
|
59
|
-
batch_size : int, default 64
|
60
|
-
Batch size used for training.
|
61
|
-
verbose : bool, default True
|
62
|
-
Whether to print training progress.
|
63
|
-
"""
|
64
|
-
super().fit(x_ref, threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
|
58
|
+
X : ArrayLike
|
59
|
+
Input data to score.
|
60
|
+
batch_size : int, default 1e10
|
61
|
+
Number of instances to process in each batch.
|
62
|
+
Use a smaller batch size if your dataset is large or if you encounter memory issues.
|
65
63
|
|
66
|
-
|
64
|
+
Returns
|
65
|
+
-------
|
66
|
+
OODScore
|
67
|
+
An object containing the instance-level OOD score.
|
68
|
+
|
69
|
+
Note
|
70
|
+
----
|
71
|
+
This model does not produce a feature level score like the OOD_AE or OOD_VAE models.
|
72
|
+
"""
|
67
73
|
self._validate(X := to_numpy(X))
|
68
74
|
_, z, _ = predict_batch(X, self.model, batch_size=batch_size)
|
69
75
|
energy, _ = gmm_energy(z, self.gmm_params, return_mean=False)
|
@@ -6,17 +6,39 @@ Original code Copyright (c) 2023 Seldon Technologies Ltd
|
|
6
6
|
Licensed under Apache Software License (Apache 2.0)
|
7
7
|
"""
|
8
8
|
|
9
|
+
from __future__ import annotations
|
10
|
+
|
9
11
|
from abc import ABC, abstractmethod
|
10
|
-
from
|
12
|
+
from dataclasses import dataclass
|
13
|
+
from typing import Callable, Literal, NamedTuple, cast
|
11
14
|
|
12
15
|
import keras
|
13
16
|
import numpy as np
|
14
17
|
import tensorflow as tf
|
15
|
-
from numpy.typing import ArrayLike
|
18
|
+
from numpy.typing import ArrayLike, NDArray
|
16
19
|
|
17
20
|
from dataeval._internal.interop import to_numpy
|
18
21
|
from dataeval._internal.models.tensorflow.gmm import GaussianMixtureModelParams, gmm_params
|
19
22
|
from dataeval._internal.models.tensorflow.trainer import trainer
|
23
|
+
from dataeval._internal.output import OutputMetadata, set_metadata
|
24
|
+
|
25
|
+
|
26
|
+
@dataclass(frozen=True)
|
27
|
+
class OODOutput(OutputMetadata):
|
28
|
+
"""
|
29
|
+
Attributes
|
30
|
+
----------
|
31
|
+
is_ood : NDArray
|
32
|
+
Array of images that are detected as out of distribution
|
33
|
+
instance_score : NDArray
|
34
|
+
Instance score of the evaluated dataset
|
35
|
+
feature_score : NDArray | None
|
36
|
+
Feature score, if available, of the evaluated dataset
|
37
|
+
"""
|
38
|
+
|
39
|
+
is_ood: NDArray[np.bool_]
|
40
|
+
instance_score: NDArray[np.float32]
|
41
|
+
feature_score: NDArray[np.float32] | None
|
20
42
|
|
21
43
|
|
22
44
|
class OODScore(NamedTuple):
|
@@ -25,16 +47,28 @@ class OODScore(NamedTuple):
|
|
25
47
|
|
26
48
|
Parameters
|
27
49
|
----------
|
28
|
-
instance_score :
|
50
|
+
instance_score : NDArray
|
29
51
|
Instance score of the evaluated dataset.
|
30
|
-
feature_score :
|
52
|
+
feature_score : NDArray | None, default None
|
31
53
|
Feature score, if available, of the evaluated dataset.
|
32
54
|
"""
|
33
55
|
|
34
|
-
instance_score: np.
|
35
|
-
feature_score:
|
56
|
+
instance_score: NDArray[np.float32]
|
57
|
+
feature_score: NDArray[np.float32] | None = None
|
58
|
+
|
59
|
+
def get(self, ood_type: Literal["instance", "feature"]) -> NDArray:
|
60
|
+
"""
|
61
|
+
Returns either the instance or feature score
|
62
|
+
|
63
|
+
Parameters
|
64
|
+
----------
|
65
|
+
ood_type : "instance" | "feature"
|
36
66
|
|
37
|
-
|
67
|
+
Returns
|
68
|
+
-------
|
69
|
+
NDArray
|
70
|
+
Either the instance or feature score based on input selection
|
71
|
+
"""
|
38
72
|
return self.instance_score if ood_type == "instance" or self.feature_score is None else self.feature_score
|
39
73
|
|
40
74
|
|
@@ -44,23 +78,23 @@ class OODBase(ABC):
|
|
44
78
|
|
45
79
|
self._ref_score: OODScore
|
46
80
|
self._threshold_perc: float
|
47
|
-
self._data_info:
|
81
|
+
self._data_info: tuple[tuple, type] | None = None
|
48
82
|
|
49
83
|
if not isinstance(model, keras.Model):
|
50
84
|
raise TypeError("Model should be of type 'keras.Model'.")
|
51
85
|
|
52
|
-
def _get_data_info(self, X:
|
86
|
+
def _get_data_info(self, X: NDArray) -> tuple[tuple, type]:
|
53
87
|
if not isinstance(X, np.ndarray):
|
54
|
-
raise TypeError("Dataset should of type: `
|
88
|
+
raise TypeError("Dataset should of type: `NDArray`.")
|
55
89
|
return X.shape[1:], X.dtype.type
|
56
90
|
|
57
|
-
def _validate(self, X:
|
91
|
+
def _validate(self, X: NDArray) -> None:
|
58
92
|
check_data_info = self._get_data_info(X)
|
59
93
|
if self._data_info is not None and check_data_info != self._data_info:
|
60
94
|
raise RuntimeError(f"Expect data of type: {self._data_info[1]} and shape: {self._data_info[0]}. \
|
61
95
|
Provided data is type: {check_data_info[1]} and shape: {check_data_info[0]}.")
|
62
96
|
|
63
|
-
def _validate_state(self, X:
|
97
|
+
def _validate_state(self, X: NDArray, additional_attrs: list[str] | None = None) -> None:
|
64
98
|
attrs = ["_data_info", "_threshold_perc", "_ref_score"]
|
65
99
|
attrs = attrs if additional_attrs is None else attrs + additional_attrs
|
66
100
|
if not all(hasattr(self, attr) for attr in attrs) or any(getattr(self, attr) for attr in attrs) is None:
|
@@ -70,18 +104,20 @@ class OODBase(ABC):
|
|
70
104
|
@abstractmethod
|
71
105
|
def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScore:
|
72
106
|
"""
|
73
|
-
Compute
|
107
|
+
Compute the out-of-distribution (OOD) scores for a given dataset.
|
74
108
|
|
75
109
|
Parameters
|
76
110
|
----------
|
77
111
|
X : ArrayLike
|
78
|
-
|
79
|
-
batch_size : int, default
|
80
|
-
|
112
|
+
Input data to score.
|
113
|
+
batch_size : int, default 1e10
|
114
|
+
Number of instances to process in each batch.
|
115
|
+
Use a smaller batch size if your dataset is large or if you encounter memory issues.
|
81
116
|
|
82
117
|
Returns
|
83
118
|
-------
|
84
|
-
|
119
|
+
OODScore
|
120
|
+
An object containing the instance-level and feature-level OOD scores.
|
85
121
|
"""
|
86
122
|
|
87
123
|
def _threshold_score(self, ood_type: Literal["feature", "instance"] = "instance") -> np.floating:
|
@@ -90,33 +126,34 @@ class OODBase(ABC):
|
|
90
126
|
def fit(
|
91
127
|
self,
|
92
128
|
x_ref: ArrayLike,
|
93
|
-
threshold_perc: float,
|
94
|
-
loss_fn: Callable,
|
95
|
-
optimizer: keras.optimizers.Optimizer,
|
96
|
-
epochs: int,
|
97
|
-
batch_size: int,
|
98
|
-
verbose: bool,
|
129
|
+
threshold_perc: float = 100.0,
|
130
|
+
loss_fn: Callable[..., tf.Tensor] | None = None,
|
131
|
+
optimizer: keras.optimizers.Optimizer = keras.optimizers.Adam,
|
132
|
+
epochs: int = 20,
|
133
|
+
batch_size: int = 64,
|
134
|
+
verbose: bool = True,
|
99
135
|
) -> None:
|
100
136
|
"""
|
101
137
|
Train the model and infer the threshold value.
|
102
138
|
|
103
139
|
Parameters
|
104
140
|
----------
|
105
|
-
x_ref
|
106
|
-
Training
|
107
|
-
threshold_perc : float
|
141
|
+
x_ref : ArrayLike
|
142
|
+
Training data.
|
143
|
+
threshold_perc : float, default 100.0
|
108
144
|
Percentage of reference data that is normal.
|
109
|
-
loss_fn : Callable
|
145
|
+
loss_fn : Callable | None, default None
|
110
146
|
Loss function used for training.
|
111
|
-
optimizer : keras.optimizers.
|
147
|
+
optimizer : Optimizer, default keras.optimizers.Adam
|
112
148
|
Optimizer used for training.
|
113
|
-
epochs : int
|
149
|
+
epochs : int, default 20
|
114
150
|
Number of training epochs.
|
115
|
-
batch_size : int
|
151
|
+
batch_size : int, default 64
|
116
152
|
Batch size used for training.
|
117
|
-
verbose : bool
|
153
|
+
verbose : bool, default True
|
118
154
|
Whether to print training progress.
|
119
155
|
"""
|
156
|
+
|
120
157
|
# Train the model
|
121
158
|
trainer(
|
122
159
|
model=self.model,
|
@@ -132,33 +169,35 @@ class OODBase(ABC):
|
|
132
169
|
self._ref_score = self.score(x_ref, batch_size)
|
133
170
|
self._threshold_perc = threshold_perc
|
134
171
|
|
172
|
+
@set_metadata("dataeval.detectors")
|
135
173
|
def predict(
|
136
174
|
self,
|
137
175
|
X: ArrayLike,
|
138
176
|
batch_size: int = int(1e10),
|
139
177
|
ood_type: Literal["feature", "instance"] = "instance",
|
140
|
-
) ->
|
178
|
+
) -> OODOutput:
|
141
179
|
"""
|
142
180
|
Predict whether instances are out-of-distribution or not.
|
143
181
|
|
144
182
|
Parameters
|
145
183
|
----------
|
146
184
|
X : ArrayLike
|
147
|
-
|
148
|
-
batch_size : int, default
|
149
|
-
|
150
|
-
ood_type :
|
185
|
+
Input data for out-of-distribution prediction.
|
186
|
+
batch_size : int, default 1e10
|
187
|
+
Number of instances to process in each batch.
|
188
|
+
ood_type : "feature" | "instance", default "instance"
|
151
189
|
Predict out-of-distribution at the 'feature' or 'instance' level.
|
152
190
|
|
153
191
|
Returns
|
154
192
|
-------
|
155
|
-
Dictionary containing the outlier predictions
|
193
|
+
Dictionary containing the outlier predictions for the selected level,
|
194
|
+
and the OOD scores for the data including both 'instance' and 'feature' (if present) level scores.
|
156
195
|
"""
|
157
196
|
self._validate_state(X := to_numpy(X))
|
158
197
|
# compute outlier scores
|
159
198
|
score = self.score(X, batch_size=batch_size)
|
160
|
-
ood_pred =
|
161
|
-
return
|
199
|
+
ood_pred = score.get(ood_type) > self._threshold_score(ood_type)
|
200
|
+
return OODOutput(is_ood=ood_pred, **score._asdict())
|
162
201
|
|
163
202
|
|
164
203
|
class OODGMMBase(OODBase):
|
@@ -166,7 +205,7 @@ class OODGMMBase(OODBase):
|
|
166
205
|
super().__init__(model)
|
167
206
|
self.gmm_params: GaussianMixtureModelParams
|
168
207
|
|
169
|
-
def _validate_state(self, X:
|
208
|
+
def _validate_state(self, X: NDArray, additional_attrs: list[str] | None = None) -> None:
|
170
209
|
if additional_attrs is None:
|
171
210
|
additional_attrs = ["gmm_params"]
|
172
211
|
super()._validate_state(X, additional_attrs)
|
@@ -174,12 +213,12 @@ class OODGMMBase(OODBase):
|
|
174
213
|
def fit(
|
175
214
|
self,
|
176
215
|
x_ref: ArrayLike,
|
177
|
-
threshold_perc: float,
|
178
|
-
loss_fn: Callable[
|
179
|
-
optimizer: keras.optimizers.Optimizer,
|
180
|
-
epochs: int,
|
181
|
-
batch_size: int,
|
182
|
-
verbose: bool,
|
216
|
+
threshold_perc: float = 100.0,
|
217
|
+
loss_fn: Callable[..., tf.Tensor] | None = None,
|
218
|
+
optimizer: keras.optimizers.Optimizer = keras.optimizers.Adam,
|
219
|
+
epochs: int = 20,
|
220
|
+
batch_size: int = 64,
|
221
|
+
verbose: bool = True,
|
183
222
|
) -> None:
|
184
223
|
# Train the model
|
185
224
|
trainer(
|
@@ -193,7 +232,7 @@ class OODGMMBase(OODBase):
|
|
193
232
|
)
|
194
233
|
|
195
234
|
# Calculate the GMM parameters
|
196
|
-
_, z, gamma = cast(
|
235
|
+
_, z, gamma = cast(tuple[tf.Tensor, tf.Tensor, tf.Tensor], self.model(x_ref))
|
197
236
|
self.gmm_params = gmm_params(z, gamma)
|
198
237
|
|
199
238
|
# Infer the threshold values
|