dataeval 0.81.0__py3-none-any.whl → 0.82.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +1 -1
- dataeval/config.py +68 -11
- dataeval/detectors/drift/__init__.py +2 -2
- dataeval/detectors/drift/_base.py +8 -64
- dataeval/detectors/drift/_mmd.py +12 -38
- dataeval/detectors/drift/_torch.py +7 -7
- dataeval/detectors/drift/_uncertainty.py +6 -5
- dataeval/detectors/drift/updates.py +20 -3
- dataeval/detectors/linters/__init__.py +3 -2
- dataeval/detectors/linters/duplicates.py +14 -46
- dataeval/detectors/linters/outliers.py +25 -159
- dataeval/detectors/ood/__init__.py +1 -1
- dataeval/detectors/ood/ae.py +6 -5
- dataeval/detectors/ood/base.py +2 -2
- dataeval/detectors/ood/metadata_ood_mi.py +4 -6
- dataeval/detectors/ood/mixin.py +3 -4
- dataeval/detectors/ood/vae.py +3 -2
- dataeval/metadata/__init__.py +2 -1
- dataeval/metadata/_distance.py +134 -0
- dataeval/metadata/_ood.py +30 -49
- dataeval/metadata/_utils.py +44 -0
- dataeval/metrics/bias/__init__.py +5 -4
- dataeval/metrics/bias/_balance.py +17 -149
- dataeval/metrics/bias/_coverage.py +4 -106
- dataeval/metrics/bias/_diversity.py +12 -107
- dataeval/metrics/bias/_parity.py +7 -71
- dataeval/metrics/estimators/__init__.py +5 -4
- dataeval/metrics/estimators/_ber.py +2 -20
- dataeval/metrics/estimators/_clusterer.py +1 -61
- dataeval/metrics/estimators/_divergence.py +2 -19
- dataeval/metrics/estimators/_uap.py +2 -16
- dataeval/metrics/stats/__init__.py +15 -12
- dataeval/metrics/stats/_base.py +41 -128
- dataeval/metrics/stats/_boxratiostats.py +13 -13
- dataeval/metrics/stats/_dimensionstats.py +17 -58
- dataeval/metrics/stats/_hashstats.py +19 -35
- dataeval/metrics/stats/_imagestats.py +94 -0
- dataeval/metrics/stats/_labelstats.py +42 -121
- dataeval/metrics/stats/_pixelstats.py +19 -51
- dataeval/metrics/stats/_visualstats.py +19 -51
- dataeval/outputs/__init__.py +57 -0
- dataeval/outputs/_base.py +182 -0
- dataeval/outputs/_bias.py +381 -0
- dataeval/outputs/_drift.py +83 -0
- dataeval/outputs/_estimators.py +114 -0
- dataeval/outputs/_linters.py +186 -0
- dataeval/outputs/_metadata.py +54 -0
- dataeval/{detectors/ood/output.py → outputs/_ood.py} +22 -22
- dataeval/outputs/_stats.py +393 -0
- dataeval/outputs/_utils.py +44 -0
- dataeval/outputs/_workflows.py +364 -0
- dataeval/typing.py +187 -7
- dataeval/utils/_method.py +1 -5
- dataeval/utils/_plot.py +2 -2
- dataeval/utils/data/__init__.py +5 -1
- dataeval/utils/data/_dataset.py +217 -0
- dataeval/utils/data/_embeddings.py +12 -14
- dataeval/utils/data/_images.py +30 -27
- dataeval/utils/data/_metadata.py +28 -11
- dataeval/utils/data/_selection.py +25 -22
- dataeval/utils/data/_split.py +5 -29
- dataeval/utils/data/_targets.py +14 -2
- dataeval/utils/data/datasets/_base.py +5 -5
- dataeval/utils/data/datasets/_cifar10.py +1 -1
- dataeval/utils/data/datasets/_milco.py +1 -1
- dataeval/utils/data/datasets/_mnist.py +1 -1
- dataeval/utils/data/datasets/_ships.py +1 -1
- dataeval/utils/data/{_types.py → datasets/_types.py} +10 -16
- dataeval/utils/data/datasets/_voc.py +1 -1
- dataeval/utils/data/selections/_classfilter.py +4 -5
- dataeval/utils/data/selections/_indices.py +2 -2
- dataeval/utils/data/selections/_limit.py +2 -2
- dataeval/utils/data/selections/_reverse.py +2 -2
- dataeval/utils/data/selections/_shuffle.py +2 -2
- dataeval/utils/torch/_internal.py +5 -5
- dataeval/utils/torch/trainer.py +8 -8
- dataeval/workflows/__init__.py +2 -1
- dataeval/workflows/sufficiency.py +6 -342
- {dataeval-0.81.0.dist-info → dataeval-0.82.1.dist-info}/METADATA +2 -2
- dataeval-0.82.1.dist-info/RECORD +105 -0
- dataeval/_output.py +0 -137
- dataeval/detectors/ood/metadata_ks_compare.py +0 -129
- dataeval/metrics/stats/_datasetstats.py +0 -198
- dataeval-0.81.0.dist-info/RECORD +0 -94
- {dataeval-0.81.0.dist-info → dataeval-0.82.1.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.81.0.dist-info → dataeval-0.82.1.dist-info}/WHEEL +0 -0
@@ -2,142 +2,19 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
-
import
|
6
|
-
from dataclasses import dataclass
|
7
|
-
from typing import Any, Generic, Iterable, Literal, Sequence, TypeVar, Union, overload
|
5
|
+
from typing import Any, Literal, Sequence, overload
|
8
6
|
|
9
7
|
import numpy as np
|
10
8
|
from numpy.typing import NDArray
|
11
|
-
from torch.utils.data import Dataset
|
12
|
-
|
13
|
-
from dataeval._output import Output, set_metadata
|
14
|
-
from dataeval.metrics.stats._base import BOX_COUNT, SOURCE_INDEX, combine_stats, get_dataset_step_from_idx
|
15
|
-
from dataeval.metrics.stats._datasetstats import DatasetStatsOutput, datasetstats
|
16
|
-
from dataeval.metrics.stats._dimensionstats import DimensionStatsOutput
|
17
|
-
from dataeval.metrics.stats._labelstats import LabelStatsOutput
|
18
|
-
from dataeval.metrics.stats._pixelstats import PixelStatsOutput
|
19
|
-
from dataeval.metrics.stats._visualstats import VisualStatsOutput
|
20
|
-
from dataeval.typing import ArrayLike
|
21
|
-
|
22
|
-
with contextlib.suppress(ImportError):
|
23
|
-
import pandas as pd
|
24
|
-
|
25
|
-
|
26
|
-
IndexIssueMap = dict[int, dict[str, float]]
|
27
|
-
OutlierStatsOutput = Union[DimensionStatsOutput, PixelStatsOutput, VisualStatsOutput]
|
28
|
-
TIndexIssueMap = TypeVar("TIndexIssueMap", IndexIssueMap, list[IndexIssueMap])
|
29
|
-
|
30
|
-
|
31
|
-
def _reorganize_by_class_and_metric(result, lstats):
|
32
|
-
"""Flip result from grouping by image to grouping by class and metric"""
|
33
|
-
metrics = {}
|
34
|
-
class_wise = {label: {} for label in lstats.image_indices_per_label}
|
35
|
-
|
36
|
-
# Group metrics and calculate class-wise counts
|
37
|
-
for img, group in result.items():
|
38
|
-
for extreme in group:
|
39
|
-
metrics.setdefault(extreme, []).append(img)
|
40
|
-
for label, images in lstats.image_indices_per_label.items():
|
41
|
-
if img in images:
|
42
|
-
class_wise[label][extreme] = class_wise[label].get(extreme, 0) + 1
|
43
|
-
|
44
|
-
return metrics, class_wise
|
45
|
-
|
46
|
-
|
47
|
-
def _create_table(metrics, class_wise):
|
48
|
-
"""Create table for displaying the results"""
|
49
|
-
max_class_length = max(len(str(label)) for label in class_wise) + 2
|
50
|
-
max_total = max(len(metrics[group]) for group in metrics) + 2
|
51
|
-
|
52
|
-
table_header = " | ".join(
|
53
|
-
[f"{'Class':>{max_class_length}}"]
|
54
|
-
+ [f"{group:^{max(5, len(str(group))) + 2}}" for group in sorted(metrics.keys())]
|
55
|
-
+ [f"{'Total':<{max_total}}"]
|
56
|
-
)
|
57
|
-
table_rows = []
|
58
|
-
|
59
|
-
for class_cat, results in class_wise.items():
|
60
|
-
table_value = [f"{class_cat:>{max_class_length}}"]
|
61
|
-
total = 0
|
62
|
-
for group in sorted(metrics.keys()):
|
63
|
-
count = results.get(group, 0)
|
64
|
-
table_value.append(f"{count:^{max(5, len(str(group))) + 2}}")
|
65
|
-
total += count
|
66
|
-
table_value.append(f"{total:^{max_total}}")
|
67
|
-
table_rows.append(" | ".join(table_value))
|
68
|
-
|
69
|
-
table = [table_header] + table_rows
|
70
|
-
return table
|
71
|
-
|
72
|
-
|
73
|
-
def _create_pandas_dataframe(class_wise):
|
74
|
-
"""Create data for pandas dataframe"""
|
75
|
-
data = []
|
76
|
-
for label, metrics_dict in class_wise.items():
|
77
|
-
row = {"Class": label}
|
78
|
-
total = sum(metrics_dict.values())
|
79
|
-
row.update(metrics_dict) # Add metric counts
|
80
|
-
row["Total"] = total
|
81
|
-
data.append(row)
|
82
|
-
return data
|
83
|
-
|
84
|
-
|
85
|
-
@dataclass(frozen=True)
|
86
|
-
class OutliersOutput(Generic[TIndexIssueMap], Output):
|
87
|
-
"""
|
88
|
-
Output class for :class:`.Outliers` lint detector.
|
89
|
-
|
90
|
-
Attributes
|
91
|
-
----------
|
92
|
-
issues : dict[int, dict[str, float]] | list[dict[int, dict[str, float]]]
|
93
|
-
Indices of image Outliers with their associated issue type and calculated values.
|
94
|
-
|
95
|
-
- For a single dataset, a dictionary containing the indices of outliers and
|
96
|
-
a dictionary showing the issues and calculated values for the given index.
|
97
|
-
- For multiple stats outputs, a list of dictionaries containing the indices of
|
98
|
-
outliers and their associated issues and calculated values.
|
99
|
-
"""
|
100
9
|
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
def to_table(self, labelstats: LabelStatsOutput) -> str:
|
110
|
-
if isinstance(self.issues, dict):
|
111
|
-
metrics, classwise = _reorganize_by_class_and_metric(self.issues, labelstats)
|
112
|
-
listed_table = _create_table(metrics, classwise)
|
113
|
-
table = "\n".join(listed_table)
|
114
|
-
else:
|
115
|
-
outertable = []
|
116
|
-
for d in self.issues:
|
117
|
-
metrics, classwise = _reorganize_by_class_and_metric(d, labelstats)
|
118
|
-
listed_table = _create_table(metrics, classwise)
|
119
|
-
str_table = "\n".join(listed_table)
|
120
|
-
outertable.append(str_table)
|
121
|
-
table = "\n\n".join(outertable)
|
122
|
-
return table
|
123
|
-
|
124
|
-
def to_dataframe(self, labelstats: LabelStatsOutput) -> pd.DataFrame:
|
125
|
-
import pandas as pd
|
126
|
-
|
127
|
-
if isinstance(self.issues, dict):
|
128
|
-
_, classwise = _reorganize_by_class_and_metric(self.issues, labelstats)
|
129
|
-
data = _create_pandas_dataframe(classwise)
|
130
|
-
df = pd.DataFrame(data)
|
131
|
-
else:
|
132
|
-
df_list = []
|
133
|
-
for i, d in enumerate(self.issues):
|
134
|
-
_, classwise = _reorganize_by_class_and_metric(d, labelstats)
|
135
|
-
data = _create_pandas_dataframe(classwise)
|
136
|
-
single_df = pd.DataFrame(data)
|
137
|
-
single_df["Dataset"] = i
|
138
|
-
df_list.append(single_df)
|
139
|
-
df = pd.concat(df_list)
|
140
|
-
return df
|
10
|
+
from dataeval.metrics.stats._base import combine_stats, get_dataset_step_from_idx
|
11
|
+
from dataeval.metrics.stats._imagestats import imagestats
|
12
|
+
from dataeval.outputs import DimensionStatsOutput, ImageStatsOutput, OutliersOutput, PixelStatsOutput, VisualStatsOutput
|
13
|
+
from dataeval.outputs._base import set_metadata
|
14
|
+
from dataeval.outputs._linters import IndexIssueMap, OutlierStatsOutput
|
15
|
+
from dataeval.outputs._stats import BOX_COUNT, SOURCE_INDEX
|
16
|
+
from dataeval.typing import Array, Dataset
|
17
|
+
from dataeval.utils.data._images import Images
|
141
18
|
|
142
19
|
|
143
20
|
def _get_outlier_mask(
|
@@ -227,7 +104,7 @@ class Outliers:
|
|
227
104
|
outlier_method: Literal["zscore", "modzscore", "iqr"] = "modzscore",
|
228
105
|
outlier_threshold: float | None = None,
|
229
106
|
):
|
230
|
-
self.stats:
|
107
|
+
self.stats: ImageStatsOutput
|
231
108
|
self.use_dimension = use_dimension
|
232
109
|
self.use_pixel = use_pixel
|
233
110
|
self.use_visual = use_visual
|
@@ -248,23 +125,23 @@ class Outliers:
|
|
248
125
|
return dict(sorted(flagged_images.items()))
|
249
126
|
|
250
127
|
@overload
|
251
|
-
def from_stats(self, stats: OutlierStatsOutput |
|
128
|
+
def from_stats(self, stats: OutlierStatsOutput | ImageStatsOutput) -> OutliersOutput[IndexIssueMap]: ...
|
252
129
|
|
253
130
|
@overload
|
254
131
|
def from_stats(self, stats: Sequence[OutlierStatsOutput]) -> OutliersOutput[list[IndexIssueMap]]: ...
|
255
132
|
|
256
133
|
@set_metadata(state=["outlier_method", "outlier_threshold"])
|
257
134
|
def from_stats(
|
258
|
-
self, stats: OutlierStatsOutput |
|
135
|
+
self, stats: OutlierStatsOutput | ImageStatsOutput | Sequence[OutlierStatsOutput]
|
259
136
|
) -> OutliersOutput[IndexIssueMap] | OutliersOutput[list[IndexIssueMap]]:
|
260
137
|
"""
|
261
138
|
Returns indices of Outliers with the issues identified for each.
|
262
139
|
|
263
140
|
Parameters
|
264
141
|
----------
|
265
|
-
stats : OutlierStatsOutput |
|
142
|
+
stats : OutlierStatsOutput | ImageStatsOutput | Sequence[OutlierStatsOutput]
|
266
143
|
The output(s) from a dimensionstats, pixelstats, or visualstats metric
|
267
|
-
analysis or an aggregate
|
144
|
+
analysis or an aggregate ImageStatsOutput
|
268
145
|
|
269
146
|
Returns
|
270
147
|
-------
|
@@ -291,12 +168,8 @@ class Outliers:
|
|
291
168
|
>>> results.issues[1]
|
292
169
|
{}
|
293
170
|
""" # noqa: E501
|
294
|
-
if isinstance(stats,
|
295
|
-
|
296
|
-
return OutliersOutput(outliers)
|
297
|
-
|
298
|
-
if isinstance(stats, (DimensionStatsOutput, PixelStatsOutput, VisualStatsOutput)):
|
299
|
-
return OutliersOutput(self._get_outliers(stats.dict()))
|
171
|
+
if isinstance(stats, (ImageStatsOutput, DimensionStatsOutput, PixelStatsOutput, VisualStatsOutput)):
|
172
|
+
return OutliersOutput(self._get_outliers(stats.data()))
|
300
173
|
|
301
174
|
if not isinstance(stats, Sequence):
|
302
175
|
raise TypeError(
|
@@ -306,7 +179,7 @@ class Outliers:
|
|
306
179
|
stats_map: dict[type, list[int]] = {}
|
307
180
|
for i, stats_output in enumerate(stats):
|
308
181
|
if not isinstance(
|
309
|
-
stats_output, (
|
182
|
+
stats_output, (ImageStatsOutput, DimensionStatsOutput, PixelStatsOutput, VisualStatsOutput)
|
310
183
|
):
|
311
184
|
raise TypeError(
|
312
185
|
"Invalid stats output type; only use output from dimensionstats, pixelstats or visualstats."
|
@@ -316,29 +189,22 @@ class Outliers:
|
|
316
189
|
output_list: list[dict[int, dict[str, float]]] = [{} for _ in stats]
|
317
190
|
for _, indices in stats_map.items():
|
318
191
|
substats, dataset_steps = combine_stats([stats[i] for i in indices])
|
319
|
-
outliers = self._get_outliers(substats.
|
192
|
+
outliers = self._get_outliers(substats.data())
|
320
193
|
for idx, issue in outliers.items():
|
321
194
|
k, v = get_dataset_step_from_idx(idx, dataset_steps)
|
322
195
|
output_list[indices[k]][v] = issue
|
323
196
|
|
324
197
|
return OutliersOutput(output_list)
|
325
198
|
|
326
|
-
@overload
|
327
|
-
def evaluate(self, data: Iterable[ArrayLike]) -> OutliersOutput[IndexIssueMap]: ...
|
328
|
-
@overload
|
329
|
-
def evaluate(self, data: Dataset[tuple[ArrayLike, Any, dict[str, Any]]]) -> OutliersOutput[IndexIssueMap]: ...
|
330
|
-
|
331
199
|
@set_metadata(state=["use_dimension", "use_pixel", "use_visual", "outlier_method", "outlier_threshold"])
|
332
|
-
def evaluate(
|
333
|
-
self, data: Iterable[ArrayLike] | Dataset[tuple[ArrayLike, Any, dict[str, Any]]]
|
334
|
-
) -> OutliersOutput[IndexIssueMap]:
|
200
|
+
def evaluate(self, data: Dataset[Array] | Dataset[tuple[Array, Any, Any]]) -> OutliersOutput[IndexIssueMap]:
|
335
201
|
"""
|
336
202
|
Returns indices of Outliers with the issues identified for each
|
337
203
|
|
338
204
|
Parameters
|
339
205
|
----------
|
340
|
-
data : Iterable[
|
341
|
-
A dataset of images in an
|
206
|
+
data : Iterable[Array], shape - (C, H, W)
|
207
|
+
A dataset of images in an Array format
|
342
208
|
|
343
209
|
Returns
|
344
210
|
-------
|
@@ -355,9 +221,9 @@ class Outliers:
|
|
355
221
|
>>> list(results.issues)
|
356
222
|
[10, 12]
|
357
223
|
>>> results.issues[10]
|
358
|
-
{'
|
224
|
+
{'contrast': 1.25, 'zeros': 0.05493, 'skew': -3.906, 'kurtosis': 13.266, 'entropy': 0.2128}
|
359
225
|
"""
|
360
|
-
images = (
|
361
|
-
self.stats =
|
362
|
-
outliers = self._get_outliers(self.stats.
|
226
|
+
images = Images(data) if isinstance(data, Dataset) else data
|
227
|
+
self.stats = imagestats(images)
|
228
|
+
outliers = self._get_outliers(self.stats.data())
|
363
229
|
return OutliersOutput(outliers)
|
@@ -5,4 +5,4 @@ Out-of-distribution (OOD) detectors identify data that is different from the dat
|
|
5
5
|
__all__ = ["OODOutput", "OODScoreOutput", "OOD_AE"]
|
6
6
|
|
7
7
|
from dataeval.detectors.ood.ae import OOD_AE
|
8
|
-
from dataeval.
|
8
|
+
from dataeval.outputs._ood import OODOutput, OODScoreOutput
|
dataeval/detectors/ood/ae.py
CHANGED
@@ -18,8 +18,9 @@ import numpy as np
|
|
18
18
|
import torch
|
19
19
|
from numpy.typing import NDArray
|
20
20
|
|
21
|
+
from dataeval.config import DeviceLike
|
21
22
|
from dataeval.detectors.ood.base import OODBase
|
22
|
-
from dataeval.
|
23
|
+
from dataeval.outputs import OODScoreOutput
|
23
24
|
from dataeval.typing import ArrayLike
|
24
25
|
from dataeval.utils.torch._internal import predict_batch
|
25
26
|
|
@@ -33,9 +34,9 @@ class OOD_AE(OODBase):
|
|
33
34
|
model : torch.nn.Module
|
34
35
|
An autoencoder model to use for encoding and reconstruction of images
|
35
36
|
for detection of out-of-distribution samples.
|
36
|
-
device :
|
37
|
-
The device to use
|
38
|
-
|
37
|
+
device : DeviceLike or None, default None
|
38
|
+
The hardware device to use if specified, otherwise uses the DataEval
|
39
|
+
default or torch default.
|
39
40
|
|
40
41
|
Example
|
41
42
|
-------
|
@@ -57,7 +58,7 @@ class OOD_AE(OODBase):
|
|
57
58
|
array([ True, True, False, True, True, True, True, True])
|
58
59
|
"""
|
59
60
|
|
60
|
-
def __init__(self, model: torch.nn.Module, device:
|
61
|
+
def __init__(self, model: torch.nn.Module, device: DeviceLike | None = None) -> None:
|
61
62
|
super().__init__(model, device)
|
62
63
|
|
63
64
|
def fit(
|
dataeval/detectors/ood/base.py
CHANGED
@@ -14,7 +14,7 @@ from typing import Callable, cast
|
|
14
14
|
|
15
15
|
import torch
|
16
16
|
|
17
|
-
from dataeval.config import get_device
|
17
|
+
from dataeval.config import DeviceLike, get_device
|
18
18
|
from dataeval.detectors.ood.mixin import OODBaseMixin, OODFitMixin, OODGMMMixin
|
19
19
|
from dataeval.typing import ArrayLike
|
20
20
|
from dataeval.utils._array import to_numpy
|
@@ -23,7 +23,7 @@ from dataeval.utils.torch._internal import trainer
|
|
23
23
|
|
24
24
|
|
25
25
|
class OODBase(OODBaseMixin[torch.nn.Module], OODFitMixin[Callable[..., torch.nn.Module], torch.optim.Optimizer]):
|
26
|
-
def __init__(self, model: torch.nn.Module, device:
|
26
|
+
def __init__(self, model: torch.nn.Module, device: DeviceLike | None = None) -> None:
|
27
27
|
self.device: torch.device = get_device(device)
|
28
28
|
super().__init__(model)
|
29
29
|
|
@@ -10,6 +10,8 @@ import numpy as np
|
|
10
10
|
from numpy.typing import NDArray
|
11
11
|
from sklearn.feature_selection import mutual_info_classif
|
12
12
|
|
13
|
+
from dataeval.config import get_seed
|
14
|
+
|
13
15
|
# NATS2BITS is the reciprocal of natural log of 2. If you have an information/entropy-type quantity measured in nats,
|
14
16
|
# which is what many library functions return, multiply it by NATS2BITS to get it in bits.
|
15
17
|
NATS2BITS = 1.442695
|
@@ -19,7 +21,6 @@ def get_metadata_ood_mi(
|
|
19
21
|
metadata: dict[str, list[Any] | NDArray[Any]],
|
20
22
|
is_ood: NDArray[np.bool_],
|
21
23
|
discrete_features: str | bool | NDArray[np.bool_] = False,
|
22
|
-
random_state: int | None = None,
|
23
24
|
) -> dict[str, float]:
|
24
25
|
"""Computes mutual information between a set of metadata features and an out-of-distribution flag.
|
25
26
|
|
@@ -39,9 +40,6 @@ def get_metadata_ood_mi(
|
|
39
40
|
A boolean array, with one value per example, that indicates which examples are OOD.
|
40
41
|
discrete_features : str | bool | NDArray[np.bool_]
|
41
42
|
Either a boolean array or a single boolean value, indicate which features take on discrete values.
|
42
|
-
random_state : int, optional - default None
|
43
|
-
Determines random number generation for small noise added to continuous variables. Set to a value for
|
44
|
-
reproducible results.
|
45
43
|
|
46
44
|
Returns
|
47
45
|
-------
|
@@ -55,7 +53,7 @@ def get_metadata_ood_mi(
|
|
55
53
|
|
56
54
|
>>> metadata = {"time": np.linspace(0, 10, 100), "altitude": np.linspace(0, 16, 100) ** 2}
|
57
55
|
>>> is_ood = metadata["altitude"] > 100
|
58
|
-
>>> get_metadata_ood_mi(metadata, is_ood, discrete_features=False
|
56
|
+
>>> get_metadata_ood_mi(metadata, is_ood, discrete_features=False)
|
59
57
|
{'time': 0.9359596758173668, 'altitude': 0.9407686591507002}
|
60
58
|
"""
|
61
59
|
numerical_keys = [k for k, v in metadata.items() if all(isinstance(vi, numbers.Number) for vi in v)]
|
@@ -84,7 +82,7 @@ def get_metadata_ood_mi(
|
|
84
82
|
Xscl,
|
85
83
|
is_ood,
|
86
84
|
discrete_features=discrete_features, # type: ignore
|
87
|
-
random_state=
|
85
|
+
random_state=get_seed(),
|
88
86
|
)
|
89
87
|
* NATS2BITS
|
90
88
|
)
|
dataeval/detectors/ood/mixin.py
CHANGED
@@ -1,7 +1,5 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from dataeval.detectors.ood.output import OODOutput, OODScoreOutput
|
4
|
-
|
5
3
|
__all__ = []
|
6
4
|
|
7
5
|
from abc import ABC, abstractmethod
|
@@ -10,7 +8,8 @@ from typing import Callable, Generic, Literal, TypeVar
|
|
10
8
|
import numpy as np
|
11
9
|
from numpy.typing import NDArray
|
12
10
|
|
13
|
-
from dataeval.
|
11
|
+
from dataeval.outputs import OODOutput, OODScoreOutput
|
12
|
+
from dataeval.outputs._base import set_metadata
|
14
13
|
from dataeval.typing import ArrayLike
|
15
14
|
from dataeval.utils._array import as_numpy, to_numpy
|
16
15
|
|
@@ -158,4 +157,4 @@ class OODBaseMixin(Generic[TModel], ABC):
|
|
158
157
|
# compute outlier scores
|
159
158
|
score = self.score(X, batch_size=batch_size)
|
160
159
|
ood_pred = score.get(ood_type) > self._threshold_score(ood_type)
|
161
|
-
return OODOutput(is_ood=ood_pred, **score.
|
160
|
+
return OODOutput(is_ood=ood_pred, **score.data())
|
dataeval/detectors/ood/vae.py
CHANGED
@@ -17,8 +17,9 @@ from typing import Callable
|
|
17
17
|
import numpy as np
|
18
18
|
import torch
|
19
19
|
|
20
|
+
from dataeval.config import DeviceLike
|
20
21
|
from dataeval.detectors.ood.base import OODBase
|
21
|
-
from dataeval.
|
22
|
+
from dataeval.outputs import OODScoreOutput
|
22
23
|
from dataeval.typing import ArrayLike
|
23
24
|
from dataeval.utils._array import as_numpy
|
24
25
|
from dataeval.utils.torch._internal import predict_batch
|
@@ -34,7 +35,7 @@ class OOD_VAE(OODBase):
|
|
34
35
|
An Autoencoder model.
|
35
36
|
"""
|
36
37
|
|
37
|
-
def __init__(self, model: torch.nn.Module, device:
|
38
|
+
def __init__(self, model: torch.nn.Module, device: DeviceLike | None = None) -> None:
|
38
39
|
super().__init__(model, device)
|
39
40
|
|
40
41
|
def fit(
|
dataeval/metadata/__init__.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
"""Explanatory functions using metadata and additional features such as ood or drift"""
|
2
2
|
|
3
|
-
__all__ = ["most_deviated_factors"]
|
3
|
+
__all__ = ["most_deviated_factors", "metadata_distance"]
|
4
4
|
|
5
|
+
from dataeval.metadata._distance import metadata_distance
|
5
6
|
from dataeval.metadata._ood import most_deviated_factors
|
@@ -0,0 +1,134 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
import warnings
|
6
|
+
from typing import NamedTuple, cast
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
from scipy.stats import iqr, ks_2samp
|
10
|
+
from scipy.stats import wasserstein_distance as emd
|
11
|
+
|
12
|
+
from dataeval.metadata._utils import _compare_keys, _validate_factors_and_data
|
13
|
+
from dataeval.outputs import MetadataDistanceOutput, MetadataDistanceValues
|
14
|
+
from dataeval.outputs._base import set_metadata
|
15
|
+
from dataeval.typing import ArrayLike
|
16
|
+
from dataeval.utils.data import Metadata
|
17
|
+
|
18
|
+
|
19
|
+
class KSType(NamedTuple):
|
20
|
+
"""Used to typehint scipy's internal hidden ks_2samp output"""
|
21
|
+
|
22
|
+
statistic: float
|
23
|
+
statistic_location: float
|
24
|
+
pvalue: float
|
25
|
+
|
26
|
+
|
27
|
+
def _calculate_drift(x1: ArrayLike, x2: ArrayLike) -> float:
|
28
|
+
"""Calculates the shift magnitude between x1 and x2 scaled by x1"""
|
29
|
+
|
30
|
+
distance = emd(x1, x2)
|
31
|
+
|
32
|
+
X = iqr(x1)
|
33
|
+
|
34
|
+
# Preferred scaling of x1
|
35
|
+
if X:
|
36
|
+
return distance / X
|
37
|
+
|
38
|
+
# Return if single-valued, else scale
|
39
|
+
xmin, xmax = np.min(x1), np.max(x1)
|
40
|
+
return distance if xmin == xmax else distance / (xmax - xmin)
|
41
|
+
|
42
|
+
|
43
|
+
@set_metadata
|
44
|
+
def metadata_distance(metadata1: Metadata, metadata2: Metadata) -> MetadataDistanceOutput:
|
45
|
+
"""
|
46
|
+
Measures the feature-wise distance between two continuous metadata distributions and
|
47
|
+
computes a p-value to evaluate its significance.
|
48
|
+
|
49
|
+
Uses the Earth Mover's Distance and the Kolmogorov-Smirnov two-sample test, featurewise.
|
50
|
+
|
51
|
+
Parameters
|
52
|
+
----------
|
53
|
+
metadata1 : Metadata
|
54
|
+
Class containing continuous factor names and values to be used as reference
|
55
|
+
metadata2 : Metadata
|
56
|
+
Class containing continuous factor names and values to be compare with the reference
|
57
|
+
|
58
|
+
Returns
|
59
|
+
-------
|
60
|
+
MetadataDistanceOutput
|
61
|
+
A mapping with keys corresponding to metadata feature names, and values that are KstestResult objects, as
|
62
|
+
defined by scipy.stats.ks_2samp.
|
63
|
+
|
64
|
+
See Also
|
65
|
+
--------
|
66
|
+
Earth mover's distance
|
67
|
+
|
68
|
+
Kolmogorov-Smirnov two-sample test
|
69
|
+
|
70
|
+
Note
|
71
|
+
----
|
72
|
+
This function only applies to the continuous data
|
73
|
+
|
74
|
+
Examples
|
75
|
+
--------
|
76
|
+
>>> output = metadata_distance(metadata1, metadata2)
|
77
|
+
>>> list(output)
|
78
|
+
['time', 'altitude']
|
79
|
+
>>> output["time"]
|
80
|
+
MetadataDistanceValues(statistic=1.0, location=0.44354838709677413, dist=2.7, pvalue=0.0)
|
81
|
+
"""
|
82
|
+
|
83
|
+
_compare_keys(metadata1.continuous_factor_names, metadata2.continuous_factor_names)
|
84
|
+
fnames = metadata1.continuous_factor_names
|
85
|
+
|
86
|
+
cont1 = np.atleast_2d(metadata1.continuous_data) # (S, F)
|
87
|
+
cont2 = np.atleast_2d(metadata2.continuous_data) # (S, F)
|
88
|
+
|
89
|
+
_validate_factors_and_data(fnames, cont1)
|
90
|
+
_validate_factors_and_data(fnames, cont2)
|
91
|
+
|
92
|
+
N = len(cont1)
|
93
|
+
M = len(cont2)
|
94
|
+
|
95
|
+
# This is a simplified version of sqrt(N*M / N+M) < 4
|
96
|
+
if (N - 16) * (M - 16) < 256:
|
97
|
+
warnings.warn(
|
98
|
+
f"Sample sizes of {N}, {M} will yield unreliable p-values from the KS test. "
|
99
|
+
f"Recommended 32 samples per factor or at least 16 if one set has many more.",
|
100
|
+
UserWarning,
|
101
|
+
)
|
102
|
+
|
103
|
+
# Set default for statistic, location, and magnitude to zero and pvalue to one
|
104
|
+
results: dict[str, MetadataDistanceValues] = {}
|
105
|
+
|
106
|
+
# Per factor
|
107
|
+
for i, fname in enumerate(fnames):
|
108
|
+
fdata1 = cont1[:, i] # (S, 1)
|
109
|
+
fdata2 = cont2[:, i] # (S, 1)
|
110
|
+
|
111
|
+
# Min and max over both distributions
|
112
|
+
xmin = min(np.min(fdata1), np.min(fdata2))
|
113
|
+
xmax = max(np.max(fdata1), np.max(fdata2))
|
114
|
+
|
115
|
+
# Default case
|
116
|
+
if xmin == xmax:
|
117
|
+
results[fname] = MetadataDistanceValues(statistic=0.0, location=0.0, dist=0.0, pvalue=1.0)
|
118
|
+
continue
|
119
|
+
|
120
|
+
ks_result = cast(KSType, ks_2samp(fdata1, fdata2, method="asymp"))
|
121
|
+
|
122
|
+
# Normalized location
|
123
|
+
loc = float((ks_result.statistic_location - xmin) / (xmax - xmin))
|
124
|
+
|
125
|
+
drift = _calculate_drift(fdata1, fdata2)
|
126
|
+
|
127
|
+
results[fname] = MetadataDistanceValues(
|
128
|
+
statistic=ks_result.statistic,
|
129
|
+
location=loc,
|
130
|
+
dist=drift,
|
131
|
+
pvalue=ks_result.pvalue,
|
132
|
+
)
|
133
|
+
|
134
|
+
return MetadataDistanceOutput(results)
|