dataeval 0.76.1__py3-none-any.whl → 0.82.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +3 -3
- dataeval/config.py +77 -0
- dataeval/detectors/__init__.py +1 -1
- dataeval/detectors/drift/__init__.py +6 -6
- dataeval/detectors/drift/{base.py → _base.py} +40 -85
- dataeval/detectors/drift/{cvm.py → _cvm.py} +21 -28
- dataeval/detectors/drift/{ks.py → _ks.py} +20 -26
- dataeval/detectors/drift/{mmd.py → _mmd.py} +31 -43
- dataeval/detectors/drift/{torch.py → _torch.py} +2 -1
- dataeval/detectors/drift/{uncertainty.py → _uncertainty.py} +24 -7
- dataeval/detectors/drift/updates.py +20 -3
- dataeval/detectors/linters/__init__.py +3 -5
- dataeval/detectors/linters/duplicates.py +13 -36
- dataeval/detectors/linters/outliers.py +23 -148
- dataeval/detectors/ood/__init__.py +1 -1
- dataeval/detectors/ood/ae.py +30 -9
- dataeval/detectors/ood/base.py +5 -4
- dataeval/detectors/ood/mixin.py +21 -7
- dataeval/detectors/ood/vae.py +73 -0
- dataeval/metadata/__init__.py +6 -0
- dataeval/metadata/_distance.py +167 -0
- dataeval/metadata/_ood.py +217 -0
- dataeval/metadata/_utils.py +44 -0
- dataeval/metrics/__init__.py +1 -1
- dataeval/metrics/bias/__init__.py +6 -4
- dataeval/metrics/bias/{balance.py → _balance.py} +15 -101
- dataeval/metrics/bias/_coverage.py +98 -0
- dataeval/metrics/bias/{diversity.py → _diversity.py} +18 -111
- dataeval/metrics/bias/{parity.py → _parity.py} +39 -77
- dataeval/metrics/estimators/__init__.py +15 -4
- dataeval/metrics/estimators/{ber.py → _ber.py} +42 -29
- dataeval/metrics/estimators/_clusterer.py +44 -0
- dataeval/metrics/estimators/{divergence.py → _divergence.py} +18 -30
- dataeval/metrics/estimators/{uap.py → _uap.py} +4 -18
- dataeval/metrics/stats/__init__.py +16 -13
- dataeval/metrics/stats/{base.py → _base.py} +82 -133
- dataeval/metrics/stats/{boxratiostats.py → _boxratiostats.py} +15 -18
- dataeval/metrics/stats/_dimensionstats.py +75 -0
- dataeval/metrics/stats/{hashstats.py → _hashstats.py} +21 -37
- dataeval/metrics/stats/_imagestats.py +94 -0
- dataeval/metrics/stats/_labelstats.py +131 -0
- dataeval/metrics/stats/{pixelstats.py → _pixelstats.py} +19 -50
- dataeval/metrics/stats/{visualstats.py → _visualstats.py} +23 -54
- dataeval/outputs/__init__.py +53 -0
- dataeval/{output.py → outputs/_base.py} +55 -25
- dataeval/outputs/_bias.py +381 -0
- dataeval/outputs/_drift.py +83 -0
- dataeval/outputs/_estimators.py +114 -0
- dataeval/outputs/_linters.py +184 -0
- dataeval/{detectors/ood/output.py → outputs/_ood.py} +22 -22
- dataeval/outputs/_stats.py +387 -0
- dataeval/outputs/_utils.py +44 -0
- dataeval/outputs/_workflows.py +364 -0
- dataeval/typing.py +234 -0
- dataeval/utils/__init__.py +2 -2
- dataeval/utils/_array.py +169 -0
- dataeval/utils/_bin.py +199 -0
- dataeval/utils/_clusterer.py +144 -0
- dataeval/utils/_fast_mst.py +189 -0
- dataeval/utils/{image.py → _image.py} +6 -4
- dataeval/utils/_method.py +14 -0
- dataeval/utils/{shared.py → _mst.py} +3 -65
- dataeval/utils/{plot.py → _plot.py} +6 -6
- dataeval/utils/data/__init__.py +26 -0
- dataeval/utils/data/_dataset.py +217 -0
- dataeval/utils/data/_embeddings.py +104 -0
- dataeval/utils/data/_images.py +68 -0
- dataeval/utils/data/_metadata.py +360 -0
- dataeval/utils/data/_selection.py +126 -0
- dataeval/utils/{dataset/split.py → data/_split.py} +12 -38
- dataeval/utils/data/_targets.py +85 -0
- dataeval/utils/data/collate.py +103 -0
- dataeval/utils/data/datasets/__init__.py +17 -0
- dataeval/utils/data/datasets/_base.py +254 -0
- dataeval/utils/data/datasets/_cifar10.py +134 -0
- dataeval/utils/data/datasets/_fileio.py +168 -0
- dataeval/utils/data/datasets/_milco.py +153 -0
- dataeval/utils/data/datasets/_mixin.py +56 -0
- dataeval/utils/data/datasets/_mnist.py +183 -0
- dataeval/utils/data/datasets/_ships.py +123 -0
- dataeval/utils/data/datasets/_types.py +52 -0
- dataeval/utils/data/datasets/_voc.py +352 -0
- dataeval/utils/data/selections/__init__.py +15 -0
- dataeval/utils/data/selections/_classfilter.py +57 -0
- dataeval/utils/data/selections/_indices.py +26 -0
- dataeval/utils/data/selections/_limit.py +26 -0
- dataeval/utils/data/selections/_reverse.py +18 -0
- dataeval/utils/data/selections/_shuffle.py +29 -0
- dataeval/utils/metadata.py +51 -376
- dataeval/utils/torch/{gmm.py → _gmm.py} +4 -2
- dataeval/utils/torch/{internal.py → _internal.py} +21 -51
- dataeval/utils/torch/models.py +43 -2
- dataeval/workflows/__init__.py +2 -1
- dataeval/workflows/sufficiency.py +11 -346
- {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/METADATA +5 -2
- dataeval-0.82.0.dist-info/RECORD +104 -0
- dataeval/detectors/linters/clusterer.py +0 -512
- dataeval/detectors/linters/merged_stats.py +0 -49
- dataeval/detectors/ood/metadata_ks_compare.py +0 -129
- dataeval/detectors/ood/metadata_least_likely.py +0 -119
- dataeval/interop.py +0 -69
- dataeval/metrics/bias/coverage.py +0 -194
- dataeval/metrics/stats/datasetstats.py +0 -202
- dataeval/metrics/stats/dimensionstats.py +0 -115
- dataeval/metrics/stats/labelstats.py +0 -210
- dataeval/utils/dataset/__init__.py +0 -7
- dataeval/utils/dataset/datasets.py +0 -412
- dataeval/utils/dataset/read.py +0 -63
- dataeval-0.76.1.dist-info/RECORD +0 -67
- /dataeval/{log.py → _log.py} +0 -0
- /dataeval/utils/torch/{blocks.py → _blocks.py} +0 -0
- {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/WHEEL +0 -0
@@ -3,14 +3,12 @@ Linters help identify potential issues in training and test data and are an impo
|
|
3
3
|
"""
|
4
4
|
|
5
5
|
__all__ = [
|
6
|
-
"Clusterer",
|
7
|
-
"ClustererOutput",
|
8
6
|
"Duplicates",
|
9
7
|
"DuplicatesOutput",
|
10
8
|
"Outliers",
|
11
9
|
"OutliersOutput",
|
12
10
|
]
|
13
11
|
|
14
|
-
from dataeval.detectors.linters.
|
15
|
-
from dataeval.detectors.linters.
|
16
|
-
from dataeval.
|
12
|
+
from dataeval.detectors.linters.duplicates import Duplicates
|
13
|
+
from dataeval.detectors.linters.outliers import Outliers
|
14
|
+
from dataeval.outputs._linters import DuplicatesOutput, OutliersOutput
|
@@ -2,39 +2,15 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
-
from
|
6
|
-
from typing import Generic, Iterable, Sequence, TypeVar, overload
|
5
|
+
from typing import Any, Sequence, overload
|
7
6
|
|
8
|
-
from
|
9
|
-
|
10
|
-
from dataeval.
|
11
|
-
from dataeval.
|
12
|
-
from dataeval.
|
13
|
-
|
14
|
-
|
15
|
-
DatasetDuplicateGroupMap = dict[int, DuplicateGroup]
|
16
|
-
TIndexCollection = TypeVar("TIndexCollection", DuplicateGroup, DatasetDuplicateGroupMap)
|
17
|
-
|
18
|
-
|
19
|
-
@dataclass(frozen=True)
|
20
|
-
class DuplicatesOutput(Generic[TIndexCollection], Output):
|
21
|
-
"""
|
22
|
-
Output class for :class:`Duplicates` lint detector.
|
23
|
-
|
24
|
-
Attributes
|
25
|
-
----------
|
26
|
-
exact : list[list[int] | dict[int, list[int]]]
|
27
|
-
Indices of images that are exact matches
|
28
|
-
near: list[list[int] | dict[int, list[int]]]
|
29
|
-
Indices of images that are near matches
|
30
|
-
|
31
|
-
- For a single dataset, indices are returned as a list of index groups.
|
32
|
-
- For multiple datasets, indices are returned as dictionaries where the key is the
|
33
|
-
index of the dataset, and the value is the list index groups from that dataset.
|
34
|
-
"""
|
35
|
-
|
36
|
-
exact: list[TIndexCollection]
|
37
|
-
near: list[TIndexCollection]
|
7
|
+
from dataeval.metrics.stats import hashstats
|
8
|
+
from dataeval.metrics.stats._base import combine_stats, get_dataset_step_from_idx
|
9
|
+
from dataeval.outputs import DuplicatesOutput, HashStatsOutput
|
10
|
+
from dataeval.outputs._base import set_metadata
|
11
|
+
from dataeval.outputs._linters import DatasetDuplicateGroupMap, DuplicateGroup
|
12
|
+
from dataeval.typing import Array, Dataset
|
13
|
+
from dataeval.utils.data._images import Images
|
38
14
|
|
39
15
|
|
40
16
|
class Duplicates:
|
@@ -134,14 +110,14 @@ class Duplicates:
|
|
134
110
|
return DuplicatesOutput(**duplicates)
|
135
111
|
|
136
112
|
@set_metadata(state=["only_exact"])
|
137
|
-
def evaluate(self, data:
|
113
|
+
def evaluate(self, data: Dataset[Array] | Dataset[tuple[Array, Any, Any]]) -> DuplicatesOutput[DuplicateGroup]:
|
138
114
|
"""
|
139
115
|
Returns duplicate image indices for both exact matches and near matches
|
140
116
|
|
141
117
|
Parameters
|
142
118
|
----------
|
143
|
-
data : Iterable[
|
144
|
-
A dataset of images in an
|
119
|
+
data : Iterable[Array], shape - (N, C, H, W) | Dataset[tuple[Array, Any, Any]]
|
120
|
+
A dataset of images in an Array format or the output(s) from a hashstats analysis
|
145
121
|
|
146
122
|
Returns
|
147
123
|
-------
|
@@ -158,6 +134,7 @@ class Duplicates:
|
|
158
134
|
>>> all_dupes.evaluate(duplicate_images)
|
159
135
|
DuplicatesOutput(exact=[[3, 20], [16, 37]], near=[[3, 20, 22], [12, 18], [13, 36], [14, 31], [17, 27], [19, 38, 47]])
|
160
136
|
""" # noqa: E501
|
161
|
-
|
137
|
+
images = Images(data) if isinstance(data, Dataset) else data
|
138
|
+
self.stats = hashstats(images)
|
162
139
|
duplicates = self._get_duplicates(self.stats.dict())
|
163
140
|
return DuplicatesOutput(**duplicates)
|
@@ -2,141 +2,19 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
-
import
|
6
|
-
from dataclasses import dataclass
|
7
|
-
from typing import Generic, Iterable, Literal, Sequence, TypeVar, Union, overload
|
5
|
+
from typing import Any, Literal, Sequence, overload
|
8
6
|
|
9
7
|
import numpy as np
|
10
|
-
from numpy.typing import
|
11
|
-
|
12
|
-
from dataeval.detectors.linters.merged_stats import combine_stats, get_dataset_step_from_idx
|
13
|
-
from dataeval.metrics.stats.base import BOX_COUNT, SOURCE_INDEX
|
14
|
-
from dataeval.metrics.stats.datasetstats import DatasetStatsOutput, datasetstats
|
15
|
-
from dataeval.metrics.stats.dimensionstats import DimensionStatsOutput
|
16
|
-
from dataeval.metrics.stats.labelstats import LabelStatsOutput
|
17
|
-
from dataeval.metrics.stats.pixelstats import PixelStatsOutput
|
18
|
-
from dataeval.metrics.stats.visualstats import VisualStatsOutput
|
19
|
-
from dataeval.output import Output, set_metadata
|
20
|
-
|
21
|
-
with contextlib.suppress(ImportError):
|
22
|
-
import pandas as pd
|
23
|
-
|
24
|
-
|
25
|
-
IndexIssueMap = dict[int, dict[str, float]]
|
26
|
-
OutlierStatsOutput = Union[DimensionStatsOutput, PixelStatsOutput, VisualStatsOutput]
|
27
|
-
TIndexIssueMap = TypeVar("TIndexIssueMap", IndexIssueMap, list[IndexIssueMap])
|
28
|
-
|
29
|
-
|
30
|
-
def _reorganize_by_class_and_metric(result, lstats):
|
31
|
-
"""Flip result from grouping by image to grouping by class and metric"""
|
32
|
-
metrics = {}
|
33
|
-
class_wise = {label: {} for label in lstats.image_indices_per_label}
|
34
|
-
|
35
|
-
# Group metrics and calculate class-wise counts
|
36
|
-
for img, group in result.items():
|
37
|
-
for extreme in group:
|
38
|
-
metrics.setdefault(extreme, []).append(img)
|
39
|
-
for label, images in lstats.image_indices_per_label.items():
|
40
|
-
if img in images:
|
41
|
-
class_wise[label][extreme] = class_wise[label].get(extreme, 0) + 1
|
42
|
-
|
43
|
-
return metrics, class_wise
|
44
|
-
|
45
|
-
|
46
|
-
def _create_table(metrics, class_wise):
|
47
|
-
"""Create table for displaying the results"""
|
48
|
-
max_class_length = max(len(str(label)) for label in class_wise) + 2
|
49
|
-
max_total = max(len(metrics[group]) for group in metrics) + 2
|
50
|
-
|
51
|
-
table_header = " | ".join(
|
52
|
-
[f"{'Class':>{max_class_length}}"]
|
53
|
-
+ [f"{group:^{max(5, len(str(group))) + 2}}" for group in sorted(metrics.keys())]
|
54
|
-
+ [f"{'Total':<{max_total}}"]
|
55
|
-
)
|
56
|
-
table_rows = []
|
57
|
-
|
58
|
-
for class_cat, results in class_wise.items():
|
59
|
-
table_value = [f"{class_cat:>{max_class_length}}"]
|
60
|
-
total = 0
|
61
|
-
for group in sorted(metrics.keys()):
|
62
|
-
count = results.get(group, 0)
|
63
|
-
table_value.append(f"{count:^{max(5, len(str(group))) + 2}}")
|
64
|
-
total += count
|
65
|
-
table_value.append(f"{total:^{max_total}}")
|
66
|
-
table_rows.append(" | ".join(table_value))
|
67
|
-
|
68
|
-
table = [table_header] + table_rows
|
69
|
-
return table
|
70
|
-
|
71
|
-
|
72
|
-
def _create_pandas_dataframe(class_wise):
|
73
|
-
"""Create data for pandas dataframe"""
|
74
|
-
data = []
|
75
|
-
for label, metrics_dict in class_wise.items():
|
76
|
-
row = {"Class": label}
|
77
|
-
total = sum(metrics_dict.values())
|
78
|
-
row.update(metrics_dict) # Add metric counts
|
79
|
-
row["Total"] = total
|
80
|
-
data.append(row)
|
81
|
-
return data
|
82
|
-
|
83
|
-
|
84
|
-
@dataclass(frozen=True)
|
85
|
-
class OutliersOutput(Generic[TIndexIssueMap], Output):
|
86
|
-
"""
|
87
|
-
Output class for :class:`Outliers` lint detector.
|
8
|
+
from numpy.typing import NDArray
|
88
9
|
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
outliers and their associated issues and calculated values.
|
98
|
-
"""
|
99
|
-
|
100
|
-
issues: TIndexIssueMap
|
101
|
-
|
102
|
-
def __len__(self) -> int:
|
103
|
-
if isinstance(self.issues, dict):
|
104
|
-
return len(self.issues)
|
105
|
-
else:
|
106
|
-
return sum(len(d) for d in self.issues)
|
107
|
-
|
108
|
-
def to_table(self, labelstats: LabelStatsOutput) -> str:
|
109
|
-
if isinstance(self.issues, dict):
|
110
|
-
metrics, classwise = _reorganize_by_class_and_metric(self.issues, labelstats)
|
111
|
-
listed_table = _create_table(metrics, classwise)
|
112
|
-
table = "\n".join(listed_table)
|
113
|
-
else:
|
114
|
-
outertable = []
|
115
|
-
for d in self.issues:
|
116
|
-
metrics, classwise = _reorganize_by_class_and_metric(d, labelstats)
|
117
|
-
listed_table = _create_table(metrics, classwise)
|
118
|
-
str_table = "\n".join(listed_table)
|
119
|
-
outertable.append(str_table)
|
120
|
-
table = "\n\n".join(outertable)
|
121
|
-
return table
|
122
|
-
|
123
|
-
def to_dataframe(self, labelstats: LabelStatsOutput) -> pd.DataFrame:
|
124
|
-
import pandas as pd
|
125
|
-
|
126
|
-
if isinstance(self.issues, dict):
|
127
|
-
_, classwise = _reorganize_by_class_and_metric(self.issues, labelstats)
|
128
|
-
data = _create_pandas_dataframe(classwise)
|
129
|
-
df = pd.DataFrame(data)
|
130
|
-
else:
|
131
|
-
df_list = []
|
132
|
-
for i, d in enumerate(self.issues):
|
133
|
-
_, classwise = _reorganize_by_class_and_metric(d, labelstats)
|
134
|
-
data = _create_pandas_dataframe(classwise)
|
135
|
-
single_df = pd.DataFrame(data)
|
136
|
-
single_df["Dataset"] = i
|
137
|
-
df_list.append(single_df)
|
138
|
-
df = pd.concat(df_list)
|
139
|
-
return df
|
10
|
+
from dataeval.metrics.stats._base import combine_stats, get_dataset_step_from_idx
|
11
|
+
from dataeval.metrics.stats._imagestats import imagestats
|
12
|
+
from dataeval.outputs import DimensionStatsOutput, ImageStatsOutput, OutliersOutput, PixelStatsOutput, VisualStatsOutput
|
13
|
+
from dataeval.outputs._base import set_metadata
|
14
|
+
from dataeval.outputs._linters import IndexIssueMap, OutlierStatsOutput
|
15
|
+
from dataeval.outputs._stats import BOX_COUNT, SOURCE_INDEX
|
16
|
+
from dataeval.typing import Array, Dataset
|
17
|
+
from dataeval.utils.data._images import Images
|
140
18
|
|
141
19
|
|
142
20
|
def _get_outlier_mask(
|
@@ -226,7 +104,7 @@ class Outliers:
|
|
226
104
|
outlier_method: Literal["zscore", "modzscore", "iqr"] = "modzscore",
|
227
105
|
outlier_threshold: float | None = None,
|
228
106
|
):
|
229
|
-
self.stats:
|
107
|
+
self.stats: ImageStatsOutput
|
230
108
|
self.use_dimension = use_dimension
|
231
109
|
self.use_pixel = use_pixel
|
232
110
|
self.use_visual = use_visual
|
@@ -247,23 +125,23 @@ class Outliers:
|
|
247
125
|
return dict(sorted(flagged_images.items()))
|
248
126
|
|
249
127
|
@overload
|
250
|
-
def from_stats(self, stats: OutlierStatsOutput |
|
128
|
+
def from_stats(self, stats: OutlierStatsOutput | ImageStatsOutput) -> OutliersOutput[IndexIssueMap]: ...
|
251
129
|
|
252
130
|
@overload
|
253
131
|
def from_stats(self, stats: Sequence[OutlierStatsOutput]) -> OutliersOutput[list[IndexIssueMap]]: ...
|
254
132
|
|
255
133
|
@set_metadata(state=["outlier_method", "outlier_threshold"])
|
256
134
|
def from_stats(
|
257
|
-
self, stats: OutlierStatsOutput |
|
135
|
+
self, stats: OutlierStatsOutput | ImageStatsOutput | Sequence[OutlierStatsOutput]
|
258
136
|
) -> OutliersOutput[IndexIssueMap] | OutliersOutput[list[IndexIssueMap]]:
|
259
137
|
"""
|
260
138
|
Returns indices of Outliers with the issues identified for each.
|
261
139
|
|
262
140
|
Parameters
|
263
141
|
----------
|
264
|
-
stats : OutlierStatsOutput |
|
142
|
+
stats : OutlierStatsOutput | ImageStatsOutput | Sequence[OutlierStatsOutput]
|
265
143
|
The output(s) from a dimensionstats, pixelstats, or visualstats metric
|
266
|
-
analysis or an aggregate
|
144
|
+
analysis or an aggregate ImageStatsOutput
|
267
145
|
|
268
146
|
Returns
|
269
147
|
-------
|
@@ -290,11 +168,7 @@ class Outliers:
|
|
290
168
|
>>> results.issues[1]
|
291
169
|
{}
|
292
170
|
""" # noqa: E501
|
293
|
-
if isinstance(stats,
|
294
|
-
outliers = self._get_outliers({k: v for o in stats._outputs() for k, v in o.dict().items()})
|
295
|
-
return OutliersOutput(outliers)
|
296
|
-
|
297
|
-
if isinstance(stats, (DimensionStatsOutput, PixelStatsOutput, VisualStatsOutput)):
|
171
|
+
if isinstance(stats, (ImageStatsOutput, DimensionStatsOutput, PixelStatsOutput, VisualStatsOutput)):
|
298
172
|
return OutliersOutput(self._get_outliers(stats.dict()))
|
299
173
|
|
300
174
|
if not isinstance(stats, Sequence):
|
@@ -305,7 +179,7 @@ class Outliers:
|
|
305
179
|
stats_map: dict[type, list[int]] = {}
|
306
180
|
for i, stats_output in enumerate(stats):
|
307
181
|
if not isinstance(
|
308
|
-
stats_output, (
|
182
|
+
stats_output, (ImageStatsOutput, DimensionStatsOutput, PixelStatsOutput, VisualStatsOutput)
|
309
183
|
):
|
310
184
|
raise TypeError(
|
311
185
|
"Invalid stats output type; only use output from dimensionstats, pixelstats or visualstats."
|
@@ -323,14 +197,14 @@ class Outliers:
|
|
323
197
|
return OutliersOutput(output_list)
|
324
198
|
|
325
199
|
@set_metadata(state=["use_dimension", "use_pixel", "use_visual", "outlier_method", "outlier_threshold"])
|
326
|
-
def evaluate(self, data:
|
200
|
+
def evaluate(self, data: Dataset[Array] | Dataset[tuple[Array, Any, Any]]) -> OutliersOutput[IndexIssueMap]:
|
327
201
|
"""
|
328
202
|
Returns indices of Outliers with the issues identified for each
|
329
203
|
|
330
204
|
Parameters
|
331
205
|
----------
|
332
|
-
data : Iterable[
|
333
|
-
A dataset of images in an
|
206
|
+
data : Iterable[Array], shape - (C, H, W)
|
207
|
+
A dataset of images in an Array format
|
334
208
|
|
335
209
|
Returns
|
336
210
|
-------
|
@@ -347,8 +221,9 @@ class Outliers:
|
|
347
221
|
>>> list(results.issues)
|
348
222
|
[10, 12]
|
349
223
|
>>> results.issues[10]
|
350
|
-
{'
|
224
|
+
{'contrast': 1.25, 'zeros': 0.05493, 'skew': -3.906, 'kurtosis': 13.266, 'entropy': 0.2128}
|
351
225
|
"""
|
352
|
-
|
226
|
+
images = Images(data) if isinstance(data, Dataset) else data
|
227
|
+
self.stats = imagestats(images)
|
353
228
|
outliers = self._get_outliers(self.stats.dict())
|
354
229
|
return OutliersOutput(outliers)
|
@@ -5,4 +5,4 @@ Out-of-distribution (OOD) detectors identify data that is different from the dat
|
|
5
5
|
__all__ = ["OODOutput", "OODScoreOutput", "OOD_AE"]
|
6
6
|
|
7
7
|
from dataeval.detectors.ood.ae import OOD_AE
|
8
|
-
from dataeval.
|
8
|
+
from dataeval.outputs._ood import OODOutput, OODScoreOutput
|
dataeval/detectors/ood/ae.py
CHANGED
@@ -16,12 +16,12 @@ from typing import Callable
|
|
16
16
|
|
17
17
|
import numpy as np
|
18
18
|
import torch
|
19
|
-
from numpy.typing import
|
19
|
+
from numpy.typing import NDArray
|
20
20
|
|
21
21
|
from dataeval.detectors.ood.base import OODBase
|
22
|
-
from dataeval.
|
23
|
-
from dataeval.
|
24
|
-
from dataeval.utils.torch.
|
22
|
+
from dataeval.outputs import OODScoreOutput
|
23
|
+
from dataeval.typing import ArrayLike
|
24
|
+
from dataeval.utils.torch._internal import predict_batch
|
25
25
|
|
26
26
|
|
27
27
|
class OOD_AE(OODBase):
|
@@ -30,8 +30,31 @@ class OOD_AE(OODBase):
|
|
30
30
|
|
31
31
|
Parameters
|
32
32
|
----------
|
33
|
-
model :
|
34
|
-
An
|
33
|
+
model : torch.nn.Module
|
34
|
+
An autoencoder model to use for encoding and reconstruction of images
|
35
|
+
for detection of out-of-distribution samples.
|
36
|
+
device : str or torch.Device or None, default None
|
37
|
+
The device to use for the detector. None will default to the global
|
38
|
+
configuration selection if set, otherwise "cuda" then "cpu" by availability.
|
39
|
+
|
40
|
+
Example
|
41
|
+
-------
|
42
|
+
Perform out-of-distribution detection on test data.
|
43
|
+
|
44
|
+
>>> from dataeval.utils.torch.models import AE
|
45
|
+
|
46
|
+
>>> input_shape = train_images[0].shape
|
47
|
+
>>> ood = OOD_AE(AE(input_shape))
|
48
|
+
|
49
|
+
Train the autoencoder using the training data.
|
50
|
+
|
51
|
+
>>> ood.fit(train_images, threshold_perc=99, epochs=20)
|
52
|
+
|
53
|
+
Test for out-of-distribution samples on the test data.
|
54
|
+
|
55
|
+
>>> output = ood.predict(test_images)
|
56
|
+
>>> output.is_ood
|
57
|
+
array([ True, True, False, True, True, True, True, True])
|
35
58
|
"""
|
36
59
|
|
37
60
|
def __init__(self, model: torch.nn.Module, device: str | torch.device | None = None) -> None:
|
@@ -55,9 +78,7 @@ class OOD_AE(OODBase):
|
|
55
78
|
|
56
79
|
super().fit(x_ref, threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
|
57
80
|
|
58
|
-
def _score(self, X:
|
59
|
-
self._validate(X := as_numpy(X))
|
60
|
-
|
81
|
+
def _score(self, X: NDArray[np.float32], batch_size: int = int(1e10)) -> OODScoreOutput:
|
61
82
|
# reconstruct instances
|
62
83
|
X_recon = predict_batch(X, self.model, batch_size=batch_size)
|
63
84
|
|
dataeval/detectors/ood/base.py
CHANGED
@@ -13,12 +13,13 @@ __all__ = []
|
|
13
13
|
from typing import Callable, cast
|
14
14
|
|
15
15
|
import torch
|
16
|
-
from numpy.typing import ArrayLike
|
17
16
|
|
17
|
+
from dataeval.config import get_device
|
18
18
|
from dataeval.detectors.ood.mixin import OODBaseMixin, OODFitMixin, OODGMMMixin
|
19
|
-
from dataeval.
|
20
|
-
from dataeval.utils.
|
21
|
-
from dataeval.utils.torch.
|
19
|
+
from dataeval.typing import ArrayLike
|
20
|
+
from dataeval.utils._array import to_numpy
|
21
|
+
from dataeval.utils.torch._gmm import GaussianMixtureModelParams, gmm_params
|
22
|
+
from dataeval.utils.torch._internal import trainer
|
22
23
|
|
23
24
|
|
24
25
|
class OODBase(OODBaseMixin[torch.nn.Module], OODFitMixin[Callable[..., torch.nn.Module], torch.optim.Optimizer]):
|
dataeval/detectors/ood/mixin.py
CHANGED
@@ -1,17 +1,17 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from dataeval.detectors.ood.output import OODOutput, OODScoreOutput
|
4
|
-
|
5
3
|
__all__ = []
|
6
4
|
|
7
5
|
from abc import ABC, abstractmethod
|
8
6
|
from typing import Callable, Generic, Literal, TypeVar
|
9
7
|
|
10
8
|
import numpy as np
|
11
|
-
from numpy.typing import
|
9
|
+
from numpy.typing import NDArray
|
12
10
|
|
13
|
-
from dataeval.
|
14
|
-
from dataeval.
|
11
|
+
from dataeval.outputs import OODOutput, OODScoreOutput
|
12
|
+
from dataeval.outputs._base import set_metadata
|
13
|
+
from dataeval.typing import ArrayLike
|
14
|
+
from dataeval.utils._array import as_numpy, to_numpy
|
15
15
|
|
16
16
|
TGMMParams = TypeVar("TGMMParams")
|
17
17
|
|
@@ -73,6 +73,9 @@ class OODBaseMixin(Generic[TModel], ABC):
|
|
73
73
|
def _get_data_info(self, X: NDArray) -> tuple[tuple, type]:
|
74
74
|
if not isinstance(X, np.ndarray):
|
75
75
|
raise TypeError("Dataset should of type: `NDArray`.")
|
76
|
+
if np.min(X) < 0 or np.max(X) > 1:
|
77
|
+
raise ValueError("Embeddings must be on the unit interval [0-1].")
|
78
|
+
|
76
79
|
return X.shape[1:], X.dtype.type
|
77
80
|
|
78
81
|
def _validate(self, X: NDArray) -> None:
|
@@ -90,7 +93,7 @@ class OODBaseMixin(Generic[TModel], ABC):
|
|
90
93
|
self._validate(X)
|
91
94
|
|
92
95
|
@abstractmethod
|
93
|
-
def _score(self, X:
|
96
|
+
def _score(self, X: NDArray[np.float32], batch_size: int = int(1e10)) -> OODScoreOutput: ...
|
94
97
|
|
95
98
|
@set_metadata
|
96
99
|
def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
|
@@ -105,11 +108,17 @@ class OODBaseMixin(Generic[TModel], ABC):
|
|
105
108
|
Number of instances to process in each batch.
|
106
109
|
Use a smaller batch size if your dataset is large or if you encounter memory issues.
|
107
110
|
|
111
|
+
Raises
|
112
|
+
------
|
113
|
+
ValueError
|
114
|
+
X input data must be unit interval [0-1].
|
115
|
+
|
108
116
|
Returns
|
109
117
|
-------
|
110
118
|
OODScoreOutput
|
111
119
|
An object containing the instance-level and feature-level OOD scores.
|
112
120
|
"""
|
121
|
+
self._validate(X := as_numpy(X).astype(np.float32))
|
113
122
|
return self._score(X, batch_size)
|
114
123
|
|
115
124
|
def _threshold_score(self, ood_type: Literal["feature", "instance"] = "instance") -> np.floating:
|
@@ -134,12 +143,17 @@ class OODBaseMixin(Generic[TModel], ABC):
|
|
134
143
|
ood_type : "feature" | "instance", default "instance"
|
135
144
|
Predict out-of-distribution at the 'feature' or 'instance' level.
|
136
145
|
|
146
|
+
Raises
|
147
|
+
------
|
148
|
+
ValueError
|
149
|
+
X input data must be unit interval [0-1].
|
150
|
+
|
137
151
|
Returns
|
138
152
|
-------
|
139
153
|
Dictionary containing the outlier predictions for the selected level,
|
140
154
|
and the OOD scores for the data including both 'instance' and 'feature' (if present) level scores.
|
141
155
|
"""
|
142
|
-
self._validate_state(X := to_numpy(X))
|
156
|
+
self._validate_state(X := to_numpy(X).astype(np.float32))
|
143
157
|
# compute outlier scores
|
144
158
|
score = self.score(X, batch_size=batch_size)
|
145
159
|
ood_pred = score.get(ood_type) > self._threshold_score(ood_type)
|
@@ -0,0 +1,73 @@
|
|
1
|
+
"""
|
2
|
+
Adapted for Pytorch from
|
3
|
+
|
4
|
+
Source code derived from Alibi-Detect 0.11.4
|
5
|
+
https://github.com/SeldonIO/alibi-detect/tree/v0.11.4
|
6
|
+
|
7
|
+
Original code Copyright (c) 2023 Seldon Technologies Ltd
|
8
|
+
Licensed under Apache Software License (Apache 2.0)
|
9
|
+
"""
|
10
|
+
|
11
|
+
from __future__ import annotations
|
12
|
+
|
13
|
+
__all__ = []
|
14
|
+
|
15
|
+
from typing import Callable
|
16
|
+
|
17
|
+
import numpy as np
|
18
|
+
import torch
|
19
|
+
|
20
|
+
from dataeval.detectors.ood.base import OODBase
|
21
|
+
from dataeval.outputs import OODScoreOutput
|
22
|
+
from dataeval.typing import ArrayLike
|
23
|
+
from dataeval.utils._array import as_numpy
|
24
|
+
from dataeval.utils.torch._internal import predict_batch
|
25
|
+
|
26
|
+
|
27
|
+
class OOD_VAE(OODBase):
|
28
|
+
"""
|
29
|
+
Autoencoder based out-of-distribution detector.
|
30
|
+
|
31
|
+
Parameters
|
32
|
+
----------
|
33
|
+
model : Autoencoder
|
34
|
+
An Autoencoder model.
|
35
|
+
"""
|
36
|
+
|
37
|
+
def __init__(self, model: torch.nn.Module, device: str | torch.device | None = None) -> None:
|
38
|
+
super().__init__(model, device)
|
39
|
+
|
40
|
+
def fit(
|
41
|
+
self,
|
42
|
+
x_ref: ArrayLike,
|
43
|
+
threshold_perc: float,
|
44
|
+
loss_fn: Callable[..., torch.nn.Module] | None = None,
|
45
|
+
optimizer: torch.optim.Optimizer | None = None,
|
46
|
+
epochs: int = 20,
|
47
|
+
batch_size: int = 64,
|
48
|
+
verbose: bool = False,
|
49
|
+
) -> None:
|
50
|
+
if loss_fn is None:
|
51
|
+
loss_fn = torch.nn.MSELoss()
|
52
|
+
|
53
|
+
if optimizer is None:
|
54
|
+
optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4)
|
55
|
+
|
56
|
+
super().fit(x_ref, threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
|
57
|
+
|
58
|
+
def _score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
|
59
|
+
self._validate(X := as_numpy(X))
|
60
|
+
|
61
|
+
# reconstruct instances
|
62
|
+
X_recon = predict_batch(X, self.model, batch_size=batch_size)[0] # don't need mu or logvar from model
|
63
|
+
|
64
|
+
# compute feature and instance level scores
|
65
|
+
fscore = np.power(X.reshape((len(X), -1)) - X_recon, 2)
|
66
|
+
# fscore_flat = fscore.reshape(fscore.shape[0], -1).copy()
|
67
|
+
# n_score_features = int(np.ceil(fscore_flat.shape[1]))
|
68
|
+
# sorted_fscore = np.sort(fscore_flat, axis=1)
|
69
|
+
# sorted_fscore_perc = sorted_fscore[:, -n_score_features:]
|
70
|
+
# iscore = np.mean(sorted_fscore_perc, axis=1)
|
71
|
+
iscore = np.sum(fscore, axis=1)
|
72
|
+
|
73
|
+
return OODScoreOutput(iscore, fscore)
|