dataeval 0.69.4__py3-none-any.whl → 0.70.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +8 -8
- dataeval/_internal/datasets.py +235 -131
- dataeval/_internal/detectors/clusterer.py +2 -0
- dataeval/_internal/detectors/drift/base.py +7 -8
- dataeval/_internal/detectors/drift/mmd.py +4 -4
- dataeval/_internal/detectors/duplicates.py +64 -45
- dataeval/_internal/detectors/merged_stats.py +23 -54
- dataeval/_internal/detectors/ood/ae.py +8 -6
- dataeval/_internal/detectors/ood/aegmm.py +6 -4
- dataeval/_internal/detectors/ood/base.py +12 -7
- dataeval/_internal/detectors/ood/llr.py +6 -4
- dataeval/_internal/detectors/ood/vae.py +5 -3
- dataeval/_internal/detectors/ood/vaegmm.py +6 -4
- dataeval/_internal/detectors/outliers.py +137 -63
- dataeval/_internal/interop.py +11 -7
- dataeval/_internal/metrics/balance.py +13 -11
- dataeval/_internal/metrics/ber.py +5 -3
- dataeval/_internal/metrics/coverage.py +4 -0
- dataeval/_internal/metrics/divergence.py +9 -5
- dataeval/_internal/metrics/diversity.py +14 -12
- dataeval/_internal/metrics/parity.py +32 -22
- dataeval/_internal/metrics/stats/base.py +231 -0
- dataeval/_internal/metrics/stats/boxratiostats.py +159 -0
- dataeval/_internal/metrics/stats/datasetstats.py +99 -0
- dataeval/_internal/metrics/stats/dimensionstats.py +113 -0
- dataeval/_internal/metrics/stats/hashstats.py +75 -0
- dataeval/_internal/metrics/stats/labelstats.py +125 -0
- dataeval/_internal/metrics/stats/pixelstats.py +119 -0
- dataeval/_internal/metrics/stats/visualstats.py +124 -0
- dataeval/_internal/metrics/uap.py +8 -4
- dataeval/_internal/metrics/utils.py +30 -15
- dataeval/_internal/models/pytorch/autoencoder.py +5 -5
- dataeval/_internal/models/tensorflow/pixelcnn.py +1 -4
- dataeval/_internal/output.py +3 -18
- dataeval/_internal/utils.py +11 -16
- dataeval/_internal/workflows/sufficiency.py +152 -151
- dataeval/detectors/__init__.py +4 -0
- dataeval/detectors/drift/__init__.py +8 -3
- dataeval/detectors/drift/kernels/__init__.py +4 -0
- dataeval/detectors/drift/updates/__init__.py +4 -0
- dataeval/detectors/linters/__init__.py +15 -4
- dataeval/detectors/ood/__init__.py +14 -2
- dataeval/metrics/__init__.py +5 -0
- dataeval/metrics/bias/__init__.py +13 -4
- dataeval/metrics/estimators/__init__.py +8 -8
- dataeval/metrics/stats/__init__.py +25 -3
- dataeval/utils/__init__.py +16 -3
- dataeval/utils/tensorflow/__init__.py +11 -0
- dataeval/utils/torch/__init__.py +12 -0
- dataeval/utils/torch/datasets/__init__.py +7 -0
- dataeval/workflows/__init__.py +6 -2
- {dataeval-0.69.4.dist-info → dataeval-0.70.1.dist-info}/METADATA +12 -4
- dataeval-0.70.1.dist-info/RECORD +80 -0
- {dataeval-0.69.4.dist-info → dataeval-0.70.1.dist-info}/WHEEL +1 -1
- dataeval/_internal/flags.py +0 -77
- dataeval/_internal/metrics/stats.py +0 -397
- dataeval/flags/__init__.py +0 -3
- dataeval/tensorflow/__init__.py +0 -3
- dataeval/torch/__init__.py +0 -3
- dataeval-0.69.4.dist-info/RECORD +0 -74
- /dataeval/{tensorflow → utils/tensorflow}/loss/__init__.py +0 -0
- /dataeval/{tensorflow → utils/tensorflow}/models/__init__.py +0 -0
- /dataeval/{tensorflow → utils/tensorflow}/recon/__init__.py +0 -0
- /dataeval/{torch → utils/torch}/models/__init__.py +0 -0
- /dataeval/{torch → utils/torch}/trainer/__init__.py +0 -0
- {dataeval-0.69.4.dist-info → dataeval-0.70.1.dist-info}/LICENSE.txt +0 -0
@@ -1,13 +1,12 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
from dataclasses import dataclass
|
4
|
-
from typing import Generic, Iterable, Sequence, TypeVar
|
4
|
+
from typing import Generic, Iterable, Sequence, TypeVar
|
5
5
|
|
6
6
|
from numpy.typing import ArrayLike
|
7
7
|
|
8
8
|
from dataeval._internal.detectors.merged_stats import combine_stats, get_dataset_step_from_idx
|
9
|
-
from dataeval._internal.
|
10
|
-
from dataeval._internal.metrics.stats import StatsOutput, imagestats
|
9
|
+
from dataeval._internal.metrics.stats.hashstats import HashStatsOutput, hashstats
|
11
10
|
from dataeval._internal.output import OutputMetadata, set_metadata
|
12
11
|
|
13
12
|
DuplicateGroup = list[int]
|
@@ -18,6 +17,8 @@ TIndexCollection = TypeVar("TIndexCollection", DuplicateGroup, DatasetDuplicateG
|
|
18
17
|
@dataclass(frozen=True)
|
19
18
|
class DuplicatesOutput(Generic[TIndexCollection], OutputMetadata):
|
20
19
|
"""
|
20
|
+
Output class for :class:`Duplicates` lint detector
|
21
|
+
|
21
22
|
Attributes
|
22
23
|
----------
|
23
24
|
exact : list[list[int] | dict[int, list[int]]]
|
@@ -53,26 +54,23 @@ class Duplicates:
|
|
53
54
|
-------
|
54
55
|
Initialize the Duplicates class:
|
55
56
|
|
56
|
-
>>>
|
57
|
+
>>> all_dupes = Duplicates()
|
58
|
+
>>> exact_dupes = Duplicates(only_exact=True)
|
57
59
|
"""
|
58
60
|
|
59
61
|
def __init__(self, only_exact: bool = False):
|
60
|
-
self.stats:
|
62
|
+
self.stats: HashStatsOutput
|
61
63
|
self.only_exact = only_exact
|
62
64
|
|
63
|
-
def _get_duplicates(self) -> dict[str, list[list[int]]]:
|
64
|
-
|
65
|
-
|
66
|
-
exact_dict
|
67
|
-
|
68
|
-
exact_dict.setdefault(value, []).append(i)
|
69
|
-
exact = [sorted(v) for v in exact_dict.values() if len(v) > 1]
|
70
|
-
else:
|
71
|
-
exact = []
|
65
|
+
def _get_duplicates(self, stats: dict) -> dict[str, list[list[int]]]:
|
66
|
+
exact_dict: dict[int, list] = {}
|
67
|
+
for i, value in enumerate(stats["xxhash"]):
|
68
|
+
exact_dict.setdefault(value, []).append(i)
|
69
|
+
exact = [sorted(v) for v in exact_dict.values() if len(v) > 1]
|
72
70
|
|
73
|
-
if
|
71
|
+
if not self.only_exact:
|
74
72
|
near_dict: dict[int, list] = {}
|
75
|
-
for i, value in enumerate(
|
73
|
+
for i, value in enumerate(stats["pchash"]):
|
76
74
|
near_dict.setdefault(value, []).append(i)
|
77
75
|
near = [sorted(v) for v in near_dict.values() if len(v) > 1 and not any(set(v).issubset(x) for x in exact)]
|
78
76
|
else:
|
@@ -84,14 +82,14 @@ class Duplicates:
|
|
84
82
|
}
|
85
83
|
|
86
84
|
@set_metadata("dataeval.detectors", ["only_exact"])
|
87
|
-
def
|
85
|
+
def from_stats(self, hashes: HashStatsOutput | Sequence[HashStatsOutput]) -> DuplicatesOutput:
|
88
86
|
"""
|
89
87
|
Returns duplicate image indices for both exact matches and near matches
|
90
88
|
|
91
89
|
Parameters
|
92
90
|
----------
|
93
|
-
data :
|
94
|
-
|
91
|
+
data : HashStatsOutput | Sequence[HashStatsOutput]
|
92
|
+
The output(s) from a hashstats analysis
|
95
93
|
|
96
94
|
Returns
|
97
95
|
-------
|
@@ -100,39 +98,60 @@ class Duplicates:
|
|
100
98
|
|
101
99
|
See Also
|
102
100
|
--------
|
103
|
-
|
101
|
+
hashstats
|
104
102
|
|
105
103
|
Example
|
106
104
|
-------
|
107
|
-
>>>
|
108
|
-
DuplicatesOutput(exact=[[3, 20], [16
|
109
|
-
"""
|
105
|
+
>>> exact_dupes.from_stats([hashes1, hashes2])
|
106
|
+
DuplicatesOutput(exact=[{0: [3, 20]}, {0: [16], 1: [12]}], near=[])
|
107
|
+
"""
|
110
108
|
|
111
|
-
|
109
|
+
if isinstance(hashes, HashStatsOutput):
|
110
|
+
return DuplicatesOutput(**self._get_duplicates(hashes.dict()))
|
112
111
|
|
113
|
-
if isinstance(
|
114
|
-
|
115
|
-
raise ValueError("StatsOutput must include xxhash information of the images.")
|
116
|
-
if not self.only_exact and not stats.pchash:
|
117
|
-
raise ValueError("StatsOutput must include pchash information of the images for near matches.")
|
118
|
-
self.stats = stats
|
119
|
-
else:
|
120
|
-
flags = ImageStat.XXHASH | (ImageStat(0) if self.only_exact else ImageStat.PCHASH)
|
121
|
-
self.stats = imagestats(cast(Iterable[ArrayLike], data), flags)
|
112
|
+
if not isinstance(hashes, Sequence):
|
113
|
+
raise TypeError("Invalid stats output type; only use output from hashstats.")
|
122
114
|
|
123
|
-
|
115
|
+
combined, dataset_steps = combine_stats(hashes)
|
116
|
+
duplicates = self._get_duplicates(combined.dict())
|
124
117
|
|
125
118
|
# split up results from combined dataset into individual dataset buckets
|
126
|
-
|
127
|
-
|
128
|
-
for
|
129
|
-
|
130
|
-
for
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
dup_list_dict.append(dup_dict)
|
136
|
-
duplicates[dup_type] = dup_list_dict
|
119
|
+
for dup_type, dup_list in duplicates.items():
|
120
|
+
dup_list_dict = []
|
121
|
+
for idxs in dup_list:
|
122
|
+
dup_dict = {}
|
123
|
+
for idx in idxs:
|
124
|
+
k, v = get_dataset_step_from_idx(idx, dataset_steps)
|
125
|
+
dup_dict.setdefault(k, []).append(v)
|
126
|
+
dup_list_dict.append(dup_dict)
|
127
|
+
duplicates[dup_type] = dup_list_dict
|
137
128
|
|
138
129
|
return DuplicatesOutput(**duplicates)
|
130
|
+
|
131
|
+
@set_metadata("dataeval.detectors", ["only_exact"])
|
132
|
+
def evaluate(self, data: Iterable[ArrayLike]) -> DuplicatesOutput:
|
133
|
+
"""
|
134
|
+
Returns duplicate image indices for both exact matches and near matches
|
135
|
+
|
136
|
+
Parameters
|
137
|
+
----------
|
138
|
+
data : Iterable[ArrayLike], shape - (N, C, H, W) | StatsOutput | Sequence[StatsOutput]
|
139
|
+
A dataset of images in an ArrayLike format or the output(s) from a hashstats analysis
|
140
|
+
|
141
|
+
Returns
|
142
|
+
-------
|
143
|
+
DuplicatesOutput
|
144
|
+
List of groups of indices that are exact and near matches
|
145
|
+
|
146
|
+
See Also
|
147
|
+
--------
|
148
|
+
hashstats
|
149
|
+
|
150
|
+
Example
|
151
|
+
-------
|
152
|
+
>>> all_dupes.evaluate(images)
|
153
|
+
DuplicatesOutput(exact=[[3, 20], [16, 37]], near=[[3, 20, 22], [12, 18], [13, 36], [14, 31], [17, 27], [19, 38, 47]])
|
154
|
+
""" # noqa: E501
|
155
|
+
self.stats = hashstats(data)
|
156
|
+
duplicates = self._get_duplicates(self.stats.dict())
|
157
|
+
return DuplicatesOutput(**duplicates)
|
@@ -1,71 +1,40 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from
|
4
|
-
from
|
3
|
+
from copy import deepcopy
|
4
|
+
from typing import Sequence, TypeVar
|
5
5
|
|
6
6
|
import numpy as np
|
7
7
|
|
8
|
-
from dataeval._internal.metrics.stats import
|
9
|
-
from dataeval._internal.output import populate_defaults
|
8
|
+
from dataeval._internal.metrics.stats.base import BaseStatsOutput
|
10
9
|
|
10
|
+
TStatsOutput = TypeVar("TStatsOutput", bound=BaseStatsOutput)
|
11
11
|
|
12
|
-
def add_stats(a: StatsOutput, b: StatsOutput) -> StatsOutput:
|
13
|
-
if not isinstance(a, StatsOutput) or not isinstance(b, StatsOutput):
|
14
|
-
raise TypeError(f"Cannot add object of type {type(a)} and type {type(b)}.")
|
15
12
|
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
b_keys = set(b_dict)
|
13
|
+
def add_stats(a: TStatsOutput, b: TStatsOutput) -> TStatsOutput:
|
14
|
+
if type(a) is not type(b):
|
15
|
+
raise TypeError(f"Types {type(a)} and {type(b)} cannot be added.")
|
20
16
|
|
21
|
-
|
22
|
-
if missing_keys:
|
23
|
-
raise ValueError(f"Required keys are missing: {missing_keys}.")
|
17
|
+
sum_dict = deepcopy(a.dict())
|
24
18
|
|
25
|
-
|
26
|
-
|
27
|
-
|
19
|
+
for k in sum_dict:
|
20
|
+
if isinstance(sum_dict[k], list):
|
21
|
+
sum_dict[k].extend(b.dict()[k])
|
22
|
+
else:
|
23
|
+
sum_dict[k] = np.concatenate((sum_dict[k], b.dict()[k]))
|
28
24
|
|
29
|
-
|
30
|
-
if "ch_idx_map" in a_dict:
|
31
|
-
for k, v in a_dict.items():
|
32
|
-
if k == "ch_idx_map":
|
33
|
-
offset = sum([len(idxs) for idxs in v.values()])
|
34
|
-
for ch_k, ch_v in b_dict[k].items():
|
35
|
-
if ch_k not in v:
|
36
|
-
v[ch_k] = []
|
37
|
-
a_dict[k][ch_k].extend([idx + offset for idx in ch_v])
|
38
|
-
else:
|
39
|
-
for ch_k in b_dict[k]:
|
40
|
-
if ch_k not in v:
|
41
|
-
v[ch_k] = b_dict[k][ch_k]
|
42
|
-
else:
|
43
|
-
v[ch_k] = np.concatenate((v[ch_k], b_dict[k][ch_k]), axis=1)
|
44
|
-
else:
|
45
|
-
for k in a_dict:
|
46
|
-
if isinstance(a_dict[k], list):
|
47
|
-
a_dict[k].extend(b_dict[k])
|
48
|
-
else:
|
49
|
-
a_dict[k] = np.concatenate((a_dict[k], b_dict[k]))
|
25
|
+
return type(a)(**sum_dict)
|
50
26
|
|
51
|
-
return StatsOutput(**populate_defaults(a_dict, StatsOutput))
|
52
|
-
|
53
|
-
|
54
|
-
def combine_stats(stats) -> tuple[StatsOutput | None, list[int]]:
|
55
|
-
dataset_steps = []
|
56
|
-
|
57
|
-
if isinstance(stats, StatsOutput):
|
58
|
-
return stats, dataset_steps
|
59
27
|
|
28
|
+
def combine_stats(stats: Sequence[TStatsOutput]) -> tuple[TStatsOutput, list[int]]:
|
60
29
|
output = None
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
30
|
+
dataset_steps = []
|
31
|
+
cur_len = 0
|
32
|
+
for s in stats:
|
33
|
+
output = s if output is None else add_stats(output, s)
|
34
|
+
cur_len += len(s)
|
35
|
+
dataset_steps.append(cur_len)
|
36
|
+
if output is None:
|
37
|
+
raise TypeError("Cannot combine empty sequence of stats.")
|
69
38
|
return output, dataset_steps
|
70
39
|
|
71
40
|
|
@@ -15,10 +15,11 @@ import numpy as np
|
|
15
15
|
import tensorflow as tf
|
16
16
|
from numpy.typing import ArrayLike
|
17
17
|
|
18
|
-
from dataeval._internal.detectors.ood.base import OODBase,
|
19
|
-
from dataeval._internal.interop import
|
18
|
+
from dataeval._internal.detectors.ood.base import OODBase, OODScoreOutput
|
19
|
+
from dataeval._internal.interop import as_numpy
|
20
20
|
from dataeval._internal.models.tensorflow.autoencoder import AE
|
21
21
|
from dataeval._internal.models.tensorflow.utils import predict_batch
|
22
|
+
from dataeval._internal.output import set_metadata
|
22
23
|
|
23
24
|
|
24
25
|
class OOD_AE(OODBase):
|
@@ -46,10 +47,11 @@ class OOD_AE(OODBase):
|
|
46
47
|
) -> None:
|
47
48
|
if loss_fn is None:
|
48
49
|
loss_fn = keras.losses.MeanSquaredError()
|
49
|
-
super().fit(
|
50
|
+
super().fit(as_numpy(x_ref), threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
|
50
51
|
|
51
|
-
|
52
|
-
|
52
|
+
@set_metadata("dataeval.detectors")
|
53
|
+
def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
|
54
|
+
self._validate(X := as_numpy(X))
|
53
55
|
|
54
56
|
# reconstruct instances
|
55
57
|
X_recon = predict_batch(X, self.model, batch_size=batch_size)
|
@@ -62,4 +64,4 @@ class OOD_AE(OODBase):
|
|
62
64
|
sorted_fscore_perc = sorted_fscore[:, -n_score_features:]
|
63
65
|
iscore = np.mean(sorted_fscore_perc, axis=1)
|
64
66
|
|
65
|
-
return
|
67
|
+
return OODScoreOutput(iscore, fscore)
|
@@ -14,12 +14,13 @@ import keras
|
|
14
14
|
import tensorflow as tf
|
15
15
|
from numpy.typing import ArrayLike
|
16
16
|
|
17
|
-
from dataeval._internal.detectors.ood.base import OODGMMBase,
|
17
|
+
from dataeval._internal.detectors.ood.base import OODGMMBase, OODScoreOutput
|
18
18
|
from dataeval._internal.interop import to_numpy
|
19
19
|
from dataeval._internal.models.tensorflow.autoencoder import AEGMM
|
20
20
|
from dataeval._internal.models.tensorflow.gmm import gmm_energy
|
21
21
|
from dataeval._internal.models.tensorflow.losses import LossGMM
|
22
22
|
from dataeval._internal.models.tensorflow.utils import predict_batch
|
23
|
+
from dataeval._internal.output import set_metadata
|
23
24
|
|
24
25
|
|
25
26
|
class OOD_AEGMM(OODGMMBase):
|
@@ -49,7 +50,8 @@ class OOD_AEGMM(OODGMMBase):
|
|
49
50
|
loss_fn = LossGMM()
|
50
51
|
super().fit(x_ref, threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
|
51
52
|
|
52
|
-
|
53
|
+
@set_metadata("dataeval.detectors")
|
54
|
+
def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
|
53
55
|
"""
|
54
56
|
Compute the out-of-distribution (OOD) score for a given dataset.
|
55
57
|
|
@@ -63,7 +65,7 @@ class OOD_AEGMM(OODGMMBase):
|
|
63
65
|
|
64
66
|
Returns
|
65
67
|
-------
|
66
|
-
|
68
|
+
OODScoreOutput
|
67
69
|
An object containing the instance-level OOD score.
|
68
70
|
|
69
71
|
Note
|
@@ -73,4 +75,4 @@ class OOD_AEGMM(OODGMMBase):
|
|
73
75
|
self._validate(X := to_numpy(X))
|
74
76
|
_, z, _ = predict_batch(X, self.model, batch_size=batch_size)
|
75
77
|
energy, _ = gmm_energy(z, self.gmm_params, return_mean=False)
|
76
|
-
return
|
78
|
+
return OODScoreOutput(energy.numpy()) # type: ignore
|
@@ -10,7 +10,7 @@ from __future__ import annotations
|
|
10
10
|
|
11
11
|
from abc import ABC, abstractmethod
|
12
12
|
from dataclasses import dataclass
|
13
|
-
from typing import Callable, Literal,
|
13
|
+
from typing import Callable, Literal, cast
|
14
14
|
|
15
15
|
import keras
|
16
16
|
import numpy as np
|
@@ -26,6 +26,9 @@ from dataeval._internal.output import OutputMetadata, set_metadata
|
|
26
26
|
@dataclass(frozen=True)
|
27
27
|
class OODOutput(OutputMetadata):
|
28
28
|
"""
|
29
|
+
Output class for predictions from :class:`OOD_AE`, :class:`OOD_AEGMM`, :class:`OOD_LLR`,
|
30
|
+
:class:`OOD_VAE`, and :class:`OOD_VAEGMM` out-of-distribution detectors
|
31
|
+
|
29
32
|
Attributes
|
30
33
|
----------
|
31
34
|
is_ood : NDArray
|
@@ -41,9 +44,11 @@ class OODOutput(OutputMetadata):
|
|
41
44
|
feature_score: NDArray[np.float32] | None
|
42
45
|
|
43
46
|
|
44
|
-
|
47
|
+
@dataclass(frozen=True)
|
48
|
+
class OODScoreOutput(OutputMetadata):
|
45
49
|
"""
|
46
|
-
|
50
|
+
Output class for instance and feature scores from :class:`OOD_AE`, :class:`OOD_AEGMM`,
|
51
|
+
:class:`OOD_LLR`, :class:`OOD_VAE`, and :class:`OOD_VAEGMM` out-of-distribution detectors
|
47
52
|
|
48
53
|
Parameters
|
49
54
|
----------
|
@@ -76,7 +81,7 @@ class OODBase(ABC):
|
|
76
81
|
def __init__(self, model: keras.Model) -> None:
|
77
82
|
self.model = model
|
78
83
|
|
79
|
-
self._ref_score:
|
84
|
+
self._ref_score: OODScoreOutput
|
80
85
|
self._threshold_perc: float
|
81
86
|
self._data_info: tuple[tuple, type] | None = None
|
82
87
|
|
@@ -102,7 +107,7 @@ class OODBase(ABC):
|
|
102
107
|
self._validate(X)
|
103
108
|
|
104
109
|
@abstractmethod
|
105
|
-
def score(self, X: ArrayLike, batch_size: int = int(1e10)) ->
|
110
|
+
def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
|
106
111
|
"""
|
107
112
|
Compute the out-of-distribution (OOD) scores for a given dataset.
|
108
113
|
|
@@ -116,7 +121,7 @@ class OODBase(ABC):
|
|
116
121
|
|
117
122
|
Returns
|
118
123
|
-------
|
119
|
-
|
124
|
+
OODScoreOutput
|
120
125
|
An object containing the instance-level and feature-level OOD scores.
|
121
126
|
"""
|
122
127
|
|
@@ -197,7 +202,7 @@ class OODBase(ABC):
|
|
197
202
|
# compute outlier scores
|
198
203
|
score = self.score(X, batch_size=batch_size)
|
199
204
|
ood_pred = score.get(ood_type) > self._threshold_score(ood_type)
|
200
|
-
return OODOutput(is_ood=ood_pred, **score.
|
205
|
+
return OODOutput(is_ood=ood_pred, **score.dict())
|
201
206
|
|
202
207
|
|
203
208
|
class OODGMMBase(OODBase):
|
@@ -18,11 +18,12 @@ from keras.layers import Input
|
|
18
18
|
from keras.models import Model
|
19
19
|
from numpy.typing import ArrayLike, NDArray
|
20
20
|
|
21
|
-
from dataeval._internal.detectors.ood.base import OODBase,
|
21
|
+
from dataeval._internal.detectors.ood.base import OODBase, OODScoreOutput
|
22
22
|
from dataeval._internal.interop import to_numpy
|
23
23
|
from dataeval._internal.models.tensorflow.pixelcnn import PixelCNN
|
24
24
|
from dataeval._internal.models.tensorflow.trainer import trainer
|
25
25
|
from dataeval._internal.models.tensorflow.utils import predict_batch
|
26
|
+
from dataeval._internal.output import set_metadata
|
26
27
|
|
27
28
|
|
28
29
|
def build_model(
|
@@ -124,7 +125,7 @@ class OOD_LLR(OODBase):
|
|
124
125
|
self.sequential = sequential
|
125
126
|
self.log_prob = log_prob
|
126
127
|
|
127
|
-
self._ref_score:
|
128
|
+
self._ref_score: OODScoreOutput
|
128
129
|
self._threshold_perc: float
|
129
130
|
self._data_info: tuple[tuple, type] | None = None
|
130
131
|
|
@@ -279,12 +280,13 @@ class OOD_LLR(OODBase):
|
|
279
280
|
logp_b = logp_fn(self.dist_b, X, return_per_feature=return_per_feature, batch_size=batch_size)
|
280
281
|
return logp_s - logp_b
|
281
282
|
|
283
|
+
@set_metadata("dataeval.detectors")
|
282
284
|
def score(
|
283
285
|
self,
|
284
286
|
X: ArrayLike,
|
285
287
|
batch_size: int = int(1e10),
|
286
|
-
) ->
|
288
|
+
) -> OODScoreOutput:
|
287
289
|
self._validate(X := to_numpy(X))
|
288
290
|
fscore = -self._llr(X, True, batch_size=batch_size)
|
289
291
|
iscore = -self._llr(X, False, batch_size=batch_size)
|
290
|
-
return
|
292
|
+
return OODScoreOutput(iscore, fscore)
|
@@ -15,11 +15,12 @@ import numpy as np
|
|
15
15
|
import tensorflow as tf
|
16
16
|
from numpy.typing import ArrayLike
|
17
17
|
|
18
|
-
from dataeval._internal.detectors.ood.base import OODBase,
|
18
|
+
from dataeval._internal.detectors.ood.base import OODBase, OODScoreOutput
|
19
19
|
from dataeval._internal.interop import to_numpy
|
20
20
|
from dataeval._internal.models.tensorflow.autoencoder import VAE
|
21
21
|
from dataeval._internal.models.tensorflow.losses import Elbo
|
22
22
|
from dataeval._internal.models.tensorflow.utils import predict_batch
|
23
|
+
from dataeval._internal.output import set_metadata
|
23
24
|
|
24
25
|
|
25
26
|
class OOD_VAE(OODBase):
|
@@ -67,7 +68,8 @@ class OOD_VAE(OODBase):
|
|
67
68
|
loss_fn = Elbo(0.05)
|
68
69
|
super().fit(x_ref, threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
|
69
70
|
|
70
|
-
|
71
|
+
@set_metadata("dataeval.detectors")
|
72
|
+
def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
|
71
73
|
self._validate(X := to_numpy(X))
|
72
74
|
|
73
75
|
# sample reconstructed instances
|
@@ -86,4 +88,4 @@ class OOD_VAE(OODBase):
|
|
86
88
|
sorted_fscore_perc = sorted_fscore[:, -n_score_features:]
|
87
89
|
iscore = np.mean(sorted_fscore_perc, axis=1)
|
88
90
|
|
89
|
-
return
|
91
|
+
return OODScoreOutput(iscore, fscore)
|
@@ -15,12 +15,13 @@ import numpy as np
|
|
15
15
|
import tensorflow as tf
|
16
16
|
from numpy.typing import ArrayLike
|
17
17
|
|
18
|
-
from dataeval._internal.detectors.ood.base import OODGMMBase,
|
18
|
+
from dataeval._internal.detectors.ood.base import OODGMMBase, OODScoreOutput
|
19
19
|
from dataeval._internal.interop import to_numpy
|
20
20
|
from dataeval._internal.models.tensorflow.autoencoder import VAEGMM
|
21
21
|
from dataeval._internal.models.tensorflow.gmm import gmm_energy
|
22
22
|
from dataeval._internal.models.tensorflow.losses import Elbo, LossGMM
|
23
23
|
from dataeval._internal.models.tensorflow.utils import predict_batch
|
24
|
+
from dataeval._internal.output import set_metadata
|
24
25
|
|
25
26
|
|
26
27
|
class OOD_VAEGMM(OODGMMBase):
|
@@ -53,7 +54,8 @@ class OOD_VAEGMM(OODGMMBase):
|
|
53
54
|
loss_fn = LossGMM(elbo=Elbo(0.05))
|
54
55
|
super().fit(x_ref, threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
|
55
56
|
|
56
|
-
|
57
|
+
@set_metadata("dataeval.detectors")
|
58
|
+
def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
|
57
59
|
"""
|
58
60
|
Compute the out-of-distribution (OOD) score for a given dataset.
|
59
61
|
|
@@ -67,7 +69,7 @@ class OOD_VAEGMM(OODGMMBase):
|
|
67
69
|
|
68
70
|
Returns
|
69
71
|
-------
|
70
|
-
|
72
|
+
OODScoreOutput
|
71
73
|
An object containing the instance-level OOD score.
|
72
74
|
|
73
75
|
Note
|
@@ -84,4 +86,4 @@ class OOD_VAEGMM(OODGMMBase):
|
|
84
86
|
energy, _ = gmm_energy(z, self.gmm_params, return_mean=False)
|
85
87
|
energy_samples = energy.numpy().reshape((-1, self.samples)) # type: ignore
|
86
88
|
iscore = np.mean(energy_samples, axis=-1)
|
87
|
-
return
|
89
|
+
return OODScoreOutput(iscore)
|