dataeval 0.63.0__py3-none-any.whl → 0.65.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +4 -4
- dataeval/_internal/detectors/clusterer.py +47 -34
- dataeval/_internal/detectors/drift/base.py +53 -35
- dataeval/_internal/detectors/drift/cvm.py +5 -4
- dataeval/_internal/detectors/drift/ks.py +7 -6
- dataeval/_internal/detectors/drift/mmd.py +39 -19
- dataeval/_internal/detectors/drift/torch.py +6 -5
- dataeval/_internal/detectors/drift/uncertainty.py +7 -8
- dataeval/_internal/detectors/duplicates.py +57 -30
- dataeval/_internal/detectors/linter.py +40 -24
- dataeval/_internal/detectors/ood/ae.py +2 -1
- dataeval/_internal/detectors/ood/aegmm.py +2 -1
- dataeval/_internal/detectors/ood/base.py +37 -15
- dataeval/_internal/detectors/ood/llr.py +9 -8
- dataeval/_internal/detectors/ood/vae.py +2 -1
- dataeval/_internal/detectors/ood/vaegmm.py +2 -1
- dataeval/_internal/flags.py +42 -21
- dataeval/_internal/interop.py +3 -12
- dataeval/_internal/metrics/balance.py +188 -0
- dataeval/_internal/metrics/ber.py +123 -48
- dataeval/_internal/metrics/coverage.py +90 -74
- dataeval/_internal/metrics/divergence.py +101 -67
- dataeval/_internal/metrics/diversity.py +211 -0
- dataeval/_internal/metrics/parity.py +287 -155
- dataeval/_internal/metrics/stats.py +198 -317
- dataeval/_internal/metrics/uap.py +40 -29
- dataeval/_internal/metrics/utils.py +430 -0
- dataeval/_internal/models/tensorflow/losses.py +3 -3
- dataeval/_internal/models/tensorflow/trainer.py +3 -2
- dataeval/_internal/models/tensorflow/utils.py +4 -3
- dataeval/_internal/output.py +82 -0
- dataeval/_internal/utils.py +64 -0
- dataeval/_internal/workflows/sufficiency.py +96 -107
- dataeval/flags/__init__.py +2 -2
- dataeval/metrics/__init__.py +26 -7
- dataeval/utils/__init__.py +9 -0
- {dataeval-0.63.0.dist-info → dataeval-0.65.0.dist-info}/METADATA +1 -1
- dataeval-0.65.0.dist-info/RECORD +60 -0
- dataeval/_internal/functional/__init__.py +0 -0
- dataeval/_internal/functional/ber.py +0 -63
- dataeval/_internal/functional/coverage.py +0 -75
- dataeval/_internal/functional/divergence.py +0 -16
- dataeval/_internal/functional/hash.py +0 -79
- dataeval/_internal/functional/metadata.py +0 -136
- dataeval/_internal/functional/metadataparity.py +0 -190
- dataeval/_internal/functional/uap.py +0 -6
- dataeval/_internal/functional/utils.py +0 -158
- dataeval/_internal/maite/__init__.py +0 -0
- dataeval/_internal/maite/utils.py +0 -30
- dataeval/_internal/metrics/base.py +0 -92
- dataeval/_internal/metrics/metadata.py +0 -610
- dataeval/_internal/metrics/metadataparity.py +0 -67
- dataeval-0.63.0.dist-info/RECORD +0 -68
- {dataeval-0.63.0.dist-info → dataeval-0.65.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.63.0.dist-info → dataeval-0.65.0.dist-info}/WHEEL +0 -0
@@ -1,8 +1,27 @@
|
|
1
|
-
from
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from typing import Dict, Iterable, List
|
2
3
|
|
3
|
-
from
|
4
|
-
|
5
|
-
from dataeval._internal.metrics.stats import
|
4
|
+
from numpy.typing import ArrayLike
|
5
|
+
|
6
|
+
from dataeval._internal.metrics.stats import StatsOutput
|
7
|
+
from dataeval._internal.output import OutputMetadata, set_metadata
|
8
|
+
from dataeval.flags import ImageStat
|
9
|
+
from dataeval.metrics import imagestats
|
10
|
+
|
11
|
+
|
12
|
+
@dataclass(frozen=True)
|
13
|
+
class DuplicatesOutput(OutputMetadata):
|
14
|
+
"""
|
15
|
+
Attributes
|
16
|
+
----------
|
17
|
+
exact : List[List[int]]
|
18
|
+
Indices of images that are exact matches
|
19
|
+
near: List[List[int]]
|
20
|
+
Indices of images that are near matches
|
21
|
+
"""
|
22
|
+
|
23
|
+
exact: List[List[int]]
|
24
|
+
near: List[List[int]]
|
6
25
|
|
7
26
|
|
8
27
|
class Duplicates:
|
@@ -12,8 +31,8 @@ class Duplicates:
|
|
12
31
|
|
13
32
|
Attributes
|
14
33
|
----------
|
15
|
-
stats :
|
16
|
-
|
34
|
+
stats : StatsOutput
|
35
|
+
Output class of stats
|
17
36
|
|
18
37
|
Example
|
19
38
|
-------
|
@@ -22,25 +41,36 @@ class Duplicates:
|
|
22
41
|
>>> dups = Duplicates()
|
23
42
|
"""
|
24
43
|
|
25
|
-
def __init__(self):
|
26
|
-
self.stats
|
44
|
+
def __init__(self, find_exact: bool = True, find_near: bool = True):
|
45
|
+
self.stats: StatsOutput
|
46
|
+
self.find_exact = find_exact
|
47
|
+
self.find_near = find_near
|
27
48
|
|
28
|
-
def _get_duplicates(self) ->
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
49
|
+
def _get_duplicates(self) -> Dict[str, List[List[int]]]:
|
50
|
+
stats_dict = self.stats.dict()
|
51
|
+
if "xxhash" in stats_dict:
|
52
|
+
exact = {}
|
53
|
+
for i, value in enumerate(stats_dict["xxhash"]):
|
54
|
+
exact.setdefault(value, []).append(i)
|
55
|
+
exact = [v for v in exact.values() if len(v) > 1]
|
56
|
+
else:
|
57
|
+
exact = []
|
58
|
+
|
59
|
+
if "pchash" in stats_dict:
|
60
|
+
near = {}
|
61
|
+
for i, value in enumerate(stats_dict["pchash"]):
|
62
|
+
near.setdefault(value, []).append(i)
|
63
|
+
near = [v for v in near.values() if len(v) > 1 and not any(set(v).issubset(x) for x in exact)]
|
64
|
+
else:
|
65
|
+
near = []
|
37
66
|
|
38
67
|
return {
|
39
68
|
"exact": sorted(exact),
|
40
69
|
"near": sorted(near),
|
41
70
|
}
|
42
71
|
|
43
|
-
|
72
|
+
@set_metadata("dataeval.detectors", ["find_exact", "find_near"])
|
73
|
+
def evaluate(self, images: Iterable[ArrayLike]) -> DuplicatesOutput:
|
44
74
|
"""
|
45
75
|
Returns duplicate image indices for both exact matches and near matches
|
46
76
|
|
@@ -51,22 +81,19 @@ class Duplicates:
|
|
51
81
|
|
52
82
|
Returns
|
53
83
|
-------
|
54
|
-
|
55
|
-
exact
|
56
|
-
List of groups of indices that are exact matches
|
57
|
-
near :
|
58
|
-
List of groups of indices that are near matches
|
84
|
+
DuplicatesOutput
|
85
|
+
List of groups of indices that are exact and near matches
|
59
86
|
|
60
87
|
See Also
|
61
88
|
--------
|
62
|
-
|
89
|
+
imagestats
|
63
90
|
|
64
91
|
Example
|
65
92
|
-------
|
66
93
|
>>> dups.evaluate(images)
|
67
|
-
|
68
|
-
"""
|
69
|
-
self.
|
70
|
-
self.
|
71
|
-
self.
|
72
|
-
return self._get_duplicates()
|
94
|
+
DuplicatesOutput(exact=[[3, 20], [16, 37]], near=[[3, 20, 22], [12, 18], [13, 36], [14, 31], [17, 27], [19, 38, 47]])
|
95
|
+
""" # noqa: E501
|
96
|
+
flag_exact = ImageStat.XXHASH if self.find_exact else ImageStat(0)
|
97
|
+
flag_near = ImageStat.PCHASH if self.find_near else ImageStat(0)
|
98
|
+
self.stats = imagestats(images, flag_exact | flag_near)
|
99
|
+
return DuplicatesOutput(**self._get_duplicates())
|
@@ -1,15 +1,31 @@
|
|
1
|
-
from
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from typing import Dict, Iterable, Literal, Optional
|
2
3
|
|
3
4
|
import numpy as np
|
5
|
+
from numpy.typing import ArrayLike, NDArray
|
4
6
|
|
5
|
-
from dataeval._internal.flags import
|
6
|
-
from dataeval._internal.
|
7
|
-
from dataeval.
|
7
|
+
from dataeval._internal.flags import verify_supported
|
8
|
+
from dataeval._internal.output import OutputMetadata, set_metadata
|
9
|
+
from dataeval.flags import ImageStat
|
10
|
+
from dataeval.metrics import imagestats
|
11
|
+
|
12
|
+
|
13
|
+
@dataclass(frozen=True)
|
14
|
+
class LinterOutput(OutputMetadata):
|
15
|
+
"""
|
16
|
+
Attributes
|
17
|
+
----------
|
18
|
+
issues : Dict[int, Dict[str, float]]
|
19
|
+
Dictionary containing the indices of outliers and a dictionary showing
|
20
|
+
the issues and calculated values for the given index.
|
21
|
+
"""
|
22
|
+
|
23
|
+
issues: Dict[int, Dict[str, float]]
|
8
24
|
|
9
25
|
|
10
26
|
def _get_outlier_mask(
|
11
|
-
values:
|
12
|
-
) ->
|
27
|
+
values: NDArray, method: Literal["zscore", "modzscore", "iqr"], threshold: Optional[float]
|
28
|
+
) -> NDArray:
|
13
29
|
if method == "zscore":
|
14
30
|
threshold = threshold if threshold else 3.0
|
15
31
|
std = np.std(values)
|
@@ -18,7 +34,7 @@ def _get_outlier_mask(
|
|
18
34
|
elif method == "modzscore":
|
19
35
|
threshold = threshold if threshold else 3.5
|
20
36
|
abs_diff = np.abs(values - np.median(values))
|
21
|
-
med_abs_diff = np.median(abs_diff)
|
37
|
+
med_abs_diff = np.median(abs_diff) if np.median(abs_diff) != 0 else np.mean(abs_diff)
|
22
38
|
mod_z_score = 0.6745 * abs_diff / med_abs_diff
|
23
39
|
return mod_z_score > threshold
|
24
40
|
elif method == "iqr":
|
@@ -36,8 +52,9 @@ class Linter:
|
|
36
52
|
|
37
53
|
Parameters
|
38
54
|
----------
|
39
|
-
flags :
|
55
|
+
flags : ImageStat, default ImageStat.ALL_PROPERTIES | ImageStat.ALL_VISUALS
|
40
56
|
Metric(s) to calculate for each image - calculates all metrics if None
|
57
|
+
Only supports ImageStat.ALL_STATS
|
41
58
|
outlier_method : ["modzscore" | "zscore" | "iqr"], optional - default "modzscore"
|
42
59
|
Statistical method used to identify outliers
|
43
60
|
outlier_threshold : float, optional - default None
|
@@ -46,8 +63,8 @@ class Linter:
|
|
46
63
|
|
47
64
|
Attributes
|
48
65
|
----------
|
49
|
-
stats :
|
50
|
-
|
66
|
+
stats : Dict[str, Any]
|
67
|
+
Dictionary to hold the value of each metric for each image
|
51
68
|
|
52
69
|
See Also
|
53
70
|
--------
|
@@ -81,7 +98,7 @@ class Linter:
|
|
81
98
|
|
82
99
|
Specifying specific metrics to analyze:
|
83
100
|
|
84
|
-
>>> lint = Linter(flags=
|
101
|
+
>>> lint = Linter(flags=ImageStat.SIZE | ImageStat.ALL_VISUALS)
|
85
102
|
|
86
103
|
Specifying an outlier method:
|
87
104
|
|
@@ -94,19 +111,19 @@ class Linter:
|
|
94
111
|
|
95
112
|
def __init__(
|
96
113
|
self,
|
97
|
-
flags:
|
114
|
+
flags: ImageStat = ImageStat.ALL_PROPERTIES | ImageStat.ALL_VISUALS,
|
98
115
|
outlier_method: Literal["zscore", "modzscore", "iqr"] = "modzscore",
|
99
116
|
outlier_threshold: Optional[float] = None,
|
100
117
|
):
|
101
|
-
flags
|
102
|
-
self.
|
118
|
+
verify_supported(flags, ImageStat.ALL_STATS)
|
119
|
+
self.flags = flags
|
103
120
|
self.outlier_method: Literal["zscore", "modzscore", "iqr"] = outlier_method
|
104
121
|
self.outlier_threshold = outlier_threshold
|
105
122
|
|
106
123
|
def _get_outliers(self) -> dict:
|
107
124
|
flagged_images = {}
|
108
|
-
|
109
|
-
for stat, values in
|
125
|
+
stats_dict = self.stats.dict()
|
126
|
+
for stat, values in stats_dict.items():
|
110
127
|
if not isinstance(values, np.ndarray):
|
111
128
|
continue
|
112
129
|
|
@@ -118,7 +135,8 @@ class Linter:
|
|
118
135
|
|
119
136
|
return dict(sorted(flagged_images.items()))
|
120
137
|
|
121
|
-
|
138
|
+
@set_metadata("dataeval.detectors", ["flags", "outlier_method", "outlier_threshold"])
|
139
|
+
def evaluate(self, images: Iterable[ArrayLike]) -> LinterOutput:
|
122
140
|
"""
|
123
141
|
Returns indices of outliers with the issues identified for each
|
124
142
|
|
@@ -130,8 +148,8 @@ class Linter:
|
|
130
148
|
|
131
149
|
Returns
|
132
150
|
-------
|
133
|
-
|
134
|
-
|
151
|
+
LinterOutput
|
152
|
+
Output class containing the indices of outliers and a dictionary showing
|
135
153
|
the issues and calculated values for the given index.
|
136
154
|
|
137
155
|
Example
|
@@ -139,9 +157,7 @@ class Linter:
|
|
139
157
|
Evaluate the dataset:
|
140
158
|
|
141
159
|
>>> lint.evaluate(images)
|
142
|
-
{18: {'brightness': 0.78}, 25: {'brightness': 0.98}}
|
160
|
+
LinterOutput(issues={18: {'brightness': 0.78}, 25: {'brightness': 0.98}})
|
143
161
|
"""
|
144
|
-
self.stats.
|
145
|
-
self.
|
146
|
-
self.results = self.stats.compute()
|
147
|
-
return self._get_outliers()
|
162
|
+
self.stats = imagestats(images, self.flags)
|
163
|
+
return LinterOutput(self._get_outliers())
|
@@ -10,9 +10,10 @@ from typing import Callable
|
|
10
10
|
|
11
11
|
import keras
|
12
12
|
import numpy as np
|
13
|
+
from numpy.typing import ArrayLike
|
13
14
|
|
14
15
|
from dataeval._internal.detectors.ood.base import OODBase, OODScore
|
15
|
-
from dataeval._internal.interop import
|
16
|
+
from dataeval._internal.interop import to_numpy
|
16
17
|
from dataeval._internal.models.tensorflow.autoencoder import AE
|
17
18
|
from dataeval._internal.models.tensorflow.utils import predict_batch
|
18
19
|
|
@@ -9,9 +9,10 @@ Licensed under Apache Software License (Apache 2.0)
|
|
9
9
|
from typing import Callable
|
10
10
|
|
11
11
|
import keras
|
12
|
+
from numpy.typing import ArrayLike
|
12
13
|
|
13
14
|
from dataeval._internal.detectors.ood.base import OODGMMBase, OODScore
|
14
|
-
from dataeval._internal.interop import
|
15
|
+
from dataeval._internal.interop import to_numpy
|
15
16
|
from dataeval._internal.models.tensorflow.autoencoder import AEGMM
|
16
17
|
from dataeval._internal.models.tensorflow.gmm import gmm_energy
|
17
18
|
from dataeval._internal.models.tensorflow.losses import LossGMM
|
@@ -7,15 +7,36 @@ Licensed under Apache Software License (Apache 2.0)
|
|
7
7
|
"""
|
8
8
|
|
9
9
|
from abc import ABC, abstractmethod
|
10
|
-
from
|
10
|
+
from dataclasses import dataclass
|
11
|
+
from typing import Callable, List, Literal, NamedTuple, Optional, Tuple, cast
|
11
12
|
|
12
13
|
import keras
|
13
14
|
import numpy as np
|
14
15
|
import tensorflow as tf
|
16
|
+
from numpy.typing import ArrayLike, NDArray
|
15
17
|
|
16
|
-
from dataeval._internal.interop import
|
18
|
+
from dataeval._internal.interop import to_numpy
|
17
19
|
from dataeval._internal.models.tensorflow.gmm import GaussianMixtureModelParams, gmm_params
|
18
20
|
from dataeval._internal.models.tensorflow.trainer import trainer
|
21
|
+
from dataeval._internal.output import OutputMetadata, set_metadata
|
22
|
+
|
23
|
+
|
24
|
+
@dataclass(frozen=True)
|
25
|
+
class OODOutput(OutputMetadata):
|
26
|
+
"""
|
27
|
+
Attributes
|
28
|
+
----------
|
29
|
+
is_ood : NDArray[np.bool_]
|
30
|
+
Array of images that are detected as out of distribution
|
31
|
+
instance_score : NDArray[np.float32]
|
32
|
+
Instance score of the evaluated dataset
|
33
|
+
feature_score : Optional[NDArray[np.float32]]
|
34
|
+
Feature score, if available, of the evaluated dataset
|
35
|
+
"""
|
36
|
+
|
37
|
+
is_ood: NDArray[np.bool_]
|
38
|
+
instance_score: NDArray[np.float32]
|
39
|
+
feature_score: Optional[NDArray[np.float32]]
|
19
40
|
|
20
41
|
|
21
42
|
class OODScore(NamedTuple):
|
@@ -24,16 +45,16 @@ class OODScore(NamedTuple):
|
|
24
45
|
|
25
46
|
Parameters
|
26
47
|
----------
|
27
|
-
instance_score : np.
|
48
|
+
instance_score : NDArray[np.float32]
|
28
49
|
Instance score of the evaluated dataset.
|
29
|
-
feature_score : Optional[np.
|
50
|
+
feature_score : Optional[NDArray[np.float32]], default None
|
30
51
|
Feature score, if available, of the evaluated dataset.
|
31
52
|
"""
|
32
53
|
|
33
|
-
instance_score: np.
|
34
|
-
feature_score: Optional[np.
|
54
|
+
instance_score: NDArray[np.float32]
|
55
|
+
feature_score: Optional[NDArray[np.float32]] = None
|
35
56
|
|
36
|
-
def get(self, ood_type: Literal["instance", "feature"]) ->
|
57
|
+
def get(self, ood_type: Literal["instance", "feature"]) -> NDArray:
|
37
58
|
return self.instance_score if ood_type == "instance" or self.feature_score is None else self.feature_score
|
38
59
|
|
39
60
|
|
@@ -48,18 +69,18 @@ class OODBase(ABC):
|
|
48
69
|
if not isinstance(model, keras.Model):
|
49
70
|
raise TypeError("Model should be of type 'keras.Model'.")
|
50
71
|
|
51
|
-
def _get_data_info(self, X:
|
72
|
+
def _get_data_info(self, X: NDArray) -> Tuple[tuple, type]:
|
52
73
|
if not isinstance(X, np.ndarray):
|
53
|
-
raise TypeError("Dataset should of type: `
|
74
|
+
raise TypeError("Dataset should of type: `NDArray`.")
|
54
75
|
return X.shape[1:], X.dtype.type
|
55
76
|
|
56
|
-
def _validate(self, X:
|
77
|
+
def _validate(self, X: NDArray) -> None:
|
57
78
|
check_data_info = self._get_data_info(X)
|
58
79
|
if self._data_info is not None and check_data_info != self._data_info:
|
59
80
|
raise RuntimeError(f"Expect data of type: {self._data_info[1]} and shape: {self._data_info[0]}. \
|
60
81
|
Provided data is type: {check_data_info[1]} and shape: {check_data_info[0]}.")
|
61
82
|
|
62
|
-
def _validate_state(self, X:
|
83
|
+
def _validate_state(self, X: NDArray, additional_attrs: Optional[List[str]] = None) -> None:
|
63
84
|
attrs = ["_data_info", "_threshold_perc", "_ref_score"]
|
64
85
|
attrs = attrs if additional_attrs is None else attrs + additional_attrs
|
65
86
|
if not all(hasattr(self, attr) for attr in attrs) or any(getattr(self, attr) for attr in attrs) is None:
|
@@ -131,12 +152,13 @@ class OODBase(ABC):
|
|
131
152
|
self._ref_score = self.score(x_ref, batch_size)
|
132
153
|
self._threshold_perc = threshold_perc
|
133
154
|
|
155
|
+
@set_metadata("dataeval.detectors")
|
134
156
|
def predict(
|
135
157
|
self,
|
136
158
|
X: ArrayLike,
|
137
159
|
batch_size: int = int(1e10),
|
138
160
|
ood_type: Literal["feature", "instance"] = "instance",
|
139
|
-
) ->
|
161
|
+
) -> OODOutput:
|
140
162
|
"""
|
141
163
|
Predict whether instances are out-of-distribution or not.
|
142
164
|
|
@@ -156,8 +178,8 @@ class OODBase(ABC):
|
|
156
178
|
self._validate_state(X := to_numpy(X))
|
157
179
|
# compute outlier scores
|
158
180
|
score = self.score(X, batch_size=batch_size)
|
159
|
-
ood_pred =
|
160
|
-
return
|
181
|
+
ood_pred = score.get(ood_type) > self._threshold_score(ood_type)
|
182
|
+
return OODOutput(is_ood=ood_pred, **score._asdict())
|
161
183
|
|
162
184
|
|
163
185
|
class OODGMMBase(OODBase):
|
@@ -165,7 +187,7 @@ class OODGMMBase(OODBase):
|
|
165
187
|
super().__init__(model)
|
166
188
|
self.gmm_params: GaussianMixtureModelParams
|
167
189
|
|
168
|
-
def _validate_state(self, X:
|
190
|
+
def _validate_state(self, X: NDArray, additional_attrs: Optional[List[str]] = None) -> None:
|
169
191
|
if additional_attrs is None:
|
170
192
|
additional_attrs = ["gmm_params"]
|
171
193
|
super()._validate_state(X, additional_attrs)
|
@@ -14,9 +14,10 @@ import numpy as np
|
|
14
14
|
import tensorflow as tf
|
15
15
|
from keras.layers import Input
|
16
16
|
from keras.models import Model
|
17
|
+
from numpy.typing import ArrayLike, NDArray
|
17
18
|
|
18
19
|
from dataeval._internal.detectors.ood.base import OODBase, OODScore
|
19
|
-
from dataeval._internal.interop import
|
20
|
+
from dataeval._internal.interop import to_numpy
|
20
21
|
from dataeval._internal.models.tensorflow.pixelcnn import PixelCNN
|
21
22
|
from dataeval._internal.models.tensorflow.trainer import trainer
|
22
23
|
from dataeval._internal.models.tensorflow.utils import predict_batch
|
@@ -51,7 +52,7 @@ def build_model(
|
|
51
52
|
|
52
53
|
|
53
54
|
def mutate_categorical(
|
54
|
-
X:
|
55
|
+
X: NDArray,
|
55
56
|
rate: float,
|
56
57
|
seed: int = 0,
|
57
58
|
feature_range: tuple = (0, 255),
|
@@ -180,7 +181,7 @@ class OOD_LLR(OODBase):
|
|
180
181
|
|
181
182
|
# create background data
|
182
183
|
mutate_fn = partial(mutate_fn, **mutate_fn_kwargs)
|
183
|
-
X_back = predict_batch(x_ref, mutate_fn, batch_size=mutate_batch_size, dtype=x_ref.dtype)
|
184
|
+
X_back = predict_batch(x_ref, mutate_fn, batch_size=mutate_batch_size, dtype=x_ref.dtype) # type: ignore
|
184
185
|
|
185
186
|
# prepare sequential data
|
186
187
|
if self.sequential and not self.has_log_prob:
|
@@ -220,10 +221,10 @@ class OOD_LLR(OODBase):
|
|
220
221
|
def _logp(
|
221
222
|
self,
|
222
223
|
dist,
|
223
|
-
X:
|
224
|
+
X: NDArray,
|
224
225
|
return_per_feature: bool = False,
|
225
226
|
batch_size: int = int(1e10),
|
226
|
-
) ->
|
227
|
+
) -> NDArray:
|
227
228
|
"""
|
228
229
|
Compute log probability of a batch of instances under the generative model.
|
229
230
|
"""
|
@@ -234,10 +235,10 @@ class OOD_LLR(OODBase):
|
|
234
235
|
def _logp_alt(
|
235
236
|
self,
|
236
237
|
model: keras.Model,
|
237
|
-
X:
|
238
|
+
X: NDArray,
|
238
239
|
return_per_feature: bool = False,
|
239
240
|
batch_size: int = int(1e10),
|
240
|
-
) ->
|
241
|
+
) -> NDArray:
|
241
242
|
"""
|
242
243
|
Compute log probability of a batch of instances with the user defined log_prob function.
|
243
244
|
"""
|
@@ -253,7 +254,7 @@ class OOD_LLR(OODBase):
|
|
253
254
|
axis = tuple(np.arange(len(logp.shape))[1:])
|
254
255
|
return np.mean(logp, axis=axis)
|
255
256
|
|
256
|
-
def _llr(self, X:
|
257
|
+
def _llr(self, X: NDArray, return_per_feature: bool, batch_size: int = int(1e10)) -> NDArray:
|
257
258
|
"""
|
258
259
|
Compute likelihood ratios.
|
259
260
|
|
@@ -10,9 +10,10 @@ from typing import Callable
|
|
10
10
|
|
11
11
|
import keras
|
12
12
|
import numpy as np
|
13
|
+
from numpy.typing import ArrayLike
|
13
14
|
|
14
15
|
from dataeval._internal.detectors.ood.base import OODBase, OODScore
|
15
|
-
from dataeval._internal.interop import
|
16
|
+
from dataeval._internal.interop import to_numpy
|
16
17
|
from dataeval._internal.models.tensorflow.autoencoder import VAE
|
17
18
|
from dataeval._internal.models.tensorflow.losses import Elbo
|
18
19
|
from dataeval._internal.models.tensorflow.utils import predict_batch
|
@@ -10,9 +10,10 @@ from typing import Callable
|
|
10
10
|
|
11
11
|
import keras
|
12
12
|
import numpy as np
|
13
|
+
from numpy.typing import ArrayLike
|
13
14
|
|
14
15
|
from dataeval._internal.detectors.ood.base import OODGMMBase, OODScore
|
15
|
-
from dataeval._internal.interop import
|
16
|
+
from dataeval._internal.interop import to_numpy
|
16
17
|
from dataeval._internal.models.tensorflow.autoencoder import VAEGMM
|
17
18
|
from dataeval._internal.models.tensorflow.gmm import gmm_energy
|
18
19
|
from dataeval._internal.models.tensorflow.losses import Elbo, LossGMM
|
dataeval/_internal/flags.py
CHANGED
@@ -1,37 +1,31 @@
|
|
1
|
-
from enum import
|
2
|
-
from
|
1
|
+
from enum import IntFlag, auto
|
2
|
+
from functools import reduce
|
3
|
+
from typing import Dict, Iterable, TypeVar, Union, cast
|
3
4
|
|
5
|
+
TFlag = TypeVar("TFlag", bound=IntFlag)
|
4
6
|
|
5
|
-
class auto_all:
|
6
|
-
def __get__(self, _, cls):
|
7
|
-
return ~cls(0)
|
8
7
|
|
8
|
+
class ImageStat(IntFlag):
|
9
|
+
"""
|
10
|
+
Flags for calculating image and channel statistics
|
11
|
+
"""
|
9
12
|
|
10
|
-
|
13
|
+
# HASHES
|
11
14
|
XXHASH = auto()
|
12
15
|
PCHASH = auto()
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
class ImageProperty(Flag):
|
16
|
+
# PROPERTIES
|
17
17
|
WIDTH = auto()
|
18
18
|
HEIGHT = auto()
|
19
19
|
SIZE = auto()
|
20
20
|
ASPECT_RATIO = auto()
|
21
21
|
CHANNELS = auto()
|
22
22
|
DEPTH = auto()
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
class ImageVisuals(Flag):
|
23
|
+
# VISUALS
|
27
24
|
BRIGHTNESS = auto()
|
28
25
|
BLURRINESS = auto()
|
29
26
|
MISSING = auto()
|
30
27
|
ZERO = auto()
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
class ImageStatistics(Flag):
|
28
|
+
# PIXEL STATS
|
35
29
|
MEAN = auto()
|
36
30
|
STD = auto()
|
37
31
|
VAR = auto()
|
@@ -40,8 +34,35 @@ class ImageStatistics(Flag):
|
|
40
34
|
ENTROPY = auto()
|
41
35
|
PERCENTILES = auto()
|
42
36
|
HISTOGRAM = auto()
|
43
|
-
|
37
|
+
# JOINT FLAGS
|
38
|
+
ALL_HASHES = XXHASH | PCHASH
|
39
|
+
ALL_PROPERTIES = WIDTH | HEIGHT | SIZE | ASPECT_RATIO | CHANNELS | DEPTH
|
40
|
+
ALL_VISUALS = BRIGHTNESS | BLURRINESS | MISSING | ZERO
|
41
|
+
ALL_PIXELSTATS = MEAN | STD | VAR | SKEW | KURTOSIS | ENTROPY | PERCENTILES | HISTOGRAM
|
42
|
+
ALL_STATS = ALL_PROPERTIES | ALL_VISUALS | ALL_PIXELSTATS
|
43
|
+
ALL = ALL_HASHES | ALL_STATS
|
44
|
+
|
45
|
+
|
46
|
+
def is_distinct(flag: IntFlag) -> bool:
|
47
|
+
return (flag & (flag - 1) == 0) and flag != 0
|
48
|
+
|
49
|
+
|
50
|
+
def to_distinct(flag: TFlag) -> Dict[TFlag, str]:
|
51
|
+
"""
|
52
|
+
Returns a distinct set of all flags set on the input flag and their names
|
53
|
+
|
54
|
+
NOTE: this is supported natively in Python 3.11, but for earlier versions we need
|
55
|
+
to use a combination of list comprehension and bit fiddling to determine distinct
|
56
|
+
flag values from joint aliases.
|
57
|
+
"""
|
58
|
+
if isinstance(flag, Iterable): # >= py311
|
59
|
+
return {f: f.name.lower() for f in flag if f.name}
|
60
|
+
else: # < py311
|
61
|
+
return {f: f.name.lower() for f in list(flag.__class__) if f & flag and is_distinct(f) and f.name}
|
44
62
|
|
45
63
|
|
46
|
-
|
47
|
-
|
64
|
+
def verify_supported(flag: TFlag, flags: Union[TFlag, Iterable[TFlag]]):
|
65
|
+
supported = flags if isinstance(flags, flag.__class__) else cast(TFlag, reduce(lambda a, b: a | b, flags)) # type: ignore
|
66
|
+
unsupported = flag & ~supported
|
67
|
+
if unsupported:
|
68
|
+
raise ValueError(f"Unsupported flags {unsupported} called. Only {supported} flags are supported.")
|
dataeval/_internal/interop.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1
1
|
from importlib import import_module
|
2
|
-
from typing import
|
2
|
+
from typing import Iterable, Optional
|
3
3
|
|
4
4
|
import numpy as np
|
5
|
+
from numpy.typing import ArrayLike, NDArray
|
5
6
|
|
6
7
|
module_cache = {}
|
7
8
|
|
@@ -19,17 +20,7 @@ def try_import(module_name):
|
|
19
20
|
return module
|
20
21
|
|
21
22
|
|
22
|
-
|
23
|
-
from maite.protocols import ArrayLike # type: ignore
|
24
|
-
except ImportError: # pragma: no cover - covered by test_mindeps.py
|
25
|
-
from typing import Protocol
|
26
|
-
|
27
|
-
@runtime_checkable
|
28
|
-
class ArrayLike(Protocol):
|
29
|
-
def __array__(self) -> Any: ...
|
30
|
-
|
31
|
-
|
32
|
-
def to_numpy(array: Optional[ArrayLike]) -> np.ndarray:
|
23
|
+
def to_numpy(array: Optional[ArrayLike]) -> NDArray:
|
33
24
|
if array is None:
|
34
25
|
return np.ndarray([])
|
35
26
|
|