dataeval 0.76.1__py3-none-any.whl → 0.82.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. dataeval/__init__.py +3 -3
  2. dataeval/config.py +77 -0
  3. dataeval/detectors/__init__.py +1 -1
  4. dataeval/detectors/drift/__init__.py +6 -6
  5. dataeval/detectors/drift/{base.py → _base.py} +40 -85
  6. dataeval/detectors/drift/{cvm.py → _cvm.py} +21 -28
  7. dataeval/detectors/drift/{ks.py → _ks.py} +20 -26
  8. dataeval/detectors/drift/{mmd.py → _mmd.py} +31 -43
  9. dataeval/detectors/drift/{torch.py → _torch.py} +2 -1
  10. dataeval/detectors/drift/{uncertainty.py → _uncertainty.py} +24 -7
  11. dataeval/detectors/drift/updates.py +20 -3
  12. dataeval/detectors/linters/__init__.py +3 -5
  13. dataeval/detectors/linters/duplicates.py +13 -36
  14. dataeval/detectors/linters/outliers.py +23 -148
  15. dataeval/detectors/ood/__init__.py +1 -1
  16. dataeval/detectors/ood/ae.py +30 -9
  17. dataeval/detectors/ood/base.py +5 -4
  18. dataeval/detectors/ood/mixin.py +21 -7
  19. dataeval/detectors/ood/vae.py +73 -0
  20. dataeval/metadata/__init__.py +6 -0
  21. dataeval/metadata/_distance.py +167 -0
  22. dataeval/metadata/_ood.py +217 -0
  23. dataeval/metadata/_utils.py +44 -0
  24. dataeval/metrics/__init__.py +1 -1
  25. dataeval/metrics/bias/__init__.py +6 -4
  26. dataeval/metrics/bias/{balance.py → _balance.py} +15 -101
  27. dataeval/metrics/bias/_coverage.py +98 -0
  28. dataeval/metrics/bias/{diversity.py → _diversity.py} +18 -111
  29. dataeval/metrics/bias/{parity.py → _parity.py} +39 -77
  30. dataeval/metrics/estimators/__init__.py +15 -4
  31. dataeval/metrics/estimators/{ber.py → _ber.py} +42 -29
  32. dataeval/metrics/estimators/_clusterer.py +44 -0
  33. dataeval/metrics/estimators/{divergence.py → _divergence.py} +18 -30
  34. dataeval/metrics/estimators/{uap.py → _uap.py} +4 -18
  35. dataeval/metrics/stats/__init__.py +16 -13
  36. dataeval/metrics/stats/{base.py → _base.py} +82 -133
  37. dataeval/metrics/stats/{boxratiostats.py → _boxratiostats.py} +15 -18
  38. dataeval/metrics/stats/_dimensionstats.py +75 -0
  39. dataeval/metrics/stats/{hashstats.py → _hashstats.py} +21 -37
  40. dataeval/metrics/stats/_imagestats.py +94 -0
  41. dataeval/metrics/stats/_labelstats.py +131 -0
  42. dataeval/metrics/stats/{pixelstats.py → _pixelstats.py} +19 -50
  43. dataeval/metrics/stats/{visualstats.py → _visualstats.py} +23 -54
  44. dataeval/outputs/__init__.py +53 -0
  45. dataeval/{output.py → outputs/_base.py} +55 -25
  46. dataeval/outputs/_bias.py +381 -0
  47. dataeval/outputs/_drift.py +83 -0
  48. dataeval/outputs/_estimators.py +114 -0
  49. dataeval/outputs/_linters.py +184 -0
  50. dataeval/{detectors/ood/output.py → outputs/_ood.py} +22 -22
  51. dataeval/outputs/_stats.py +387 -0
  52. dataeval/outputs/_utils.py +44 -0
  53. dataeval/outputs/_workflows.py +364 -0
  54. dataeval/typing.py +234 -0
  55. dataeval/utils/__init__.py +2 -2
  56. dataeval/utils/_array.py +169 -0
  57. dataeval/utils/_bin.py +199 -0
  58. dataeval/utils/_clusterer.py +144 -0
  59. dataeval/utils/_fast_mst.py +189 -0
  60. dataeval/utils/{image.py → _image.py} +6 -4
  61. dataeval/utils/_method.py +14 -0
  62. dataeval/utils/{shared.py → _mst.py} +3 -65
  63. dataeval/utils/{plot.py → _plot.py} +6 -6
  64. dataeval/utils/data/__init__.py +26 -0
  65. dataeval/utils/data/_dataset.py +217 -0
  66. dataeval/utils/data/_embeddings.py +104 -0
  67. dataeval/utils/data/_images.py +68 -0
  68. dataeval/utils/data/_metadata.py +360 -0
  69. dataeval/utils/data/_selection.py +126 -0
  70. dataeval/utils/{dataset/split.py → data/_split.py} +12 -38
  71. dataeval/utils/data/_targets.py +85 -0
  72. dataeval/utils/data/collate.py +103 -0
  73. dataeval/utils/data/datasets/__init__.py +17 -0
  74. dataeval/utils/data/datasets/_base.py +254 -0
  75. dataeval/utils/data/datasets/_cifar10.py +134 -0
  76. dataeval/utils/data/datasets/_fileio.py +168 -0
  77. dataeval/utils/data/datasets/_milco.py +153 -0
  78. dataeval/utils/data/datasets/_mixin.py +56 -0
  79. dataeval/utils/data/datasets/_mnist.py +183 -0
  80. dataeval/utils/data/datasets/_ships.py +123 -0
  81. dataeval/utils/data/datasets/_types.py +52 -0
  82. dataeval/utils/data/datasets/_voc.py +352 -0
  83. dataeval/utils/data/selections/__init__.py +15 -0
  84. dataeval/utils/data/selections/_classfilter.py +57 -0
  85. dataeval/utils/data/selections/_indices.py +26 -0
  86. dataeval/utils/data/selections/_limit.py +26 -0
  87. dataeval/utils/data/selections/_reverse.py +18 -0
  88. dataeval/utils/data/selections/_shuffle.py +29 -0
  89. dataeval/utils/metadata.py +51 -376
  90. dataeval/utils/torch/{gmm.py → _gmm.py} +4 -2
  91. dataeval/utils/torch/{internal.py → _internal.py} +21 -51
  92. dataeval/utils/torch/models.py +43 -2
  93. dataeval/workflows/__init__.py +2 -1
  94. dataeval/workflows/sufficiency.py +11 -346
  95. {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/METADATA +5 -2
  96. dataeval-0.82.0.dist-info/RECORD +104 -0
  97. dataeval/detectors/linters/clusterer.py +0 -512
  98. dataeval/detectors/linters/merged_stats.py +0 -49
  99. dataeval/detectors/ood/metadata_ks_compare.py +0 -129
  100. dataeval/detectors/ood/metadata_least_likely.py +0 -119
  101. dataeval/interop.py +0 -69
  102. dataeval/metrics/bias/coverage.py +0 -194
  103. dataeval/metrics/stats/datasetstats.py +0 -202
  104. dataeval/metrics/stats/dimensionstats.py +0 -115
  105. dataeval/metrics/stats/labelstats.py +0 -210
  106. dataeval/utils/dataset/__init__.py +0 -7
  107. dataeval/utils/dataset/datasets.py +0 -412
  108. dataeval/utils/dataset/read.py +0 -63
  109. dataeval-0.76.1.dist-info/RECORD +0 -67
  110. /dataeval/{log.py → _log.py} +0 -0
  111. /dataeval/utils/torch/{blocks.py → _blocks.py} +0 -0
  112. {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/LICENSE.txt +0 -0
  113. {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/WHEEL +0 -0
@@ -3,14 +3,12 @@ Linters help identify potential issues in training and test data and are an impo
3
3
  """
4
4
 
5
5
  __all__ = [
6
- "Clusterer",
7
- "ClustererOutput",
8
6
  "Duplicates",
9
7
  "DuplicatesOutput",
10
8
  "Outliers",
11
9
  "OutliersOutput",
12
10
  ]
13
11
 
14
- from dataeval.detectors.linters.clusterer import Clusterer, ClustererOutput
15
- from dataeval.detectors.linters.duplicates import Duplicates, DuplicatesOutput
16
- from dataeval.detectors.linters.outliers import Outliers, OutliersOutput
12
+ from dataeval.detectors.linters.duplicates import Duplicates
13
+ from dataeval.detectors.linters.outliers import Outliers
14
+ from dataeval.outputs._linters import DuplicatesOutput, OutliersOutput
@@ -2,39 +2,15 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- from dataclasses import dataclass
6
- from typing import Generic, Iterable, Sequence, TypeVar, overload
5
+ from typing import Any, Sequence, overload
7
6
 
8
- from numpy.typing import ArrayLike
9
-
10
- from dataeval.detectors.linters.merged_stats import combine_stats, get_dataset_step_from_idx
11
- from dataeval.metrics.stats.hashstats import HashStatsOutput, hashstats
12
- from dataeval.output import Output, set_metadata
13
-
14
- DuplicateGroup = list[int]
15
- DatasetDuplicateGroupMap = dict[int, DuplicateGroup]
16
- TIndexCollection = TypeVar("TIndexCollection", DuplicateGroup, DatasetDuplicateGroupMap)
17
-
18
-
19
- @dataclass(frozen=True)
20
- class DuplicatesOutput(Generic[TIndexCollection], Output):
21
- """
22
- Output class for :class:`Duplicates` lint detector.
23
-
24
- Attributes
25
- ----------
26
- exact : list[list[int] | dict[int, list[int]]]
27
- Indices of images that are exact matches
28
- near: list[list[int] | dict[int, list[int]]]
29
- Indices of images that are near matches
30
-
31
- - For a single dataset, indices are returned as a list of index groups.
32
- - For multiple datasets, indices are returned as dictionaries where the key is the
33
- index of the dataset, and the value is the list index groups from that dataset.
34
- """
35
-
36
- exact: list[TIndexCollection]
37
- near: list[TIndexCollection]
7
+ from dataeval.metrics.stats import hashstats
8
+ from dataeval.metrics.stats._base import combine_stats, get_dataset_step_from_idx
9
+ from dataeval.outputs import DuplicatesOutput, HashStatsOutput
10
+ from dataeval.outputs._base import set_metadata
11
+ from dataeval.outputs._linters import DatasetDuplicateGroupMap, DuplicateGroup
12
+ from dataeval.typing import Array, Dataset
13
+ from dataeval.utils.data._images import Images
38
14
 
39
15
 
40
16
  class Duplicates:
@@ -134,14 +110,14 @@ class Duplicates:
134
110
  return DuplicatesOutput(**duplicates)
135
111
 
136
112
  @set_metadata(state=["only_exact"])
137
- def evaluate(self, data: Iterable[ArrayLike]) -> DuplicatesOutput[DuplicateGroup]:
113
+ def evaluate(self, data: Dataset[Array] | Dataset[tuple[Array, Any, Any]]) -> DuplicatesOutput[DuplicateGroup]:
138
114
  """
139
115
  Returns duplicate image indices for both exact matches and near matches
140
116
 
141
117
  Parameters
142
118
  ----------
143
- data : Iterable[ArrayLike], shape - (N, C, H, W) | StatsOutput | Sequence[StatsOutput]
144
- A dataset of images in an ArrayLike format or the output(s) from a hashstats analysis
119
+ data : Iterable[Array], shape - (N, C, H, W) | Dataset[tuple[Array, Any, Any]]
120
+ A dataset of images in an Array format or the output(s) from a hashstats analysis
145
121
 
146
122
  Returns
147
123
  -------
@@ -158,6 +134,7 @@ class Duplicates:
158
134
  >>> all_dupes.evaluate(duplicate_images)
159
135
  DuplicatesOutput(exact=[[3, 20], [16, 37]], near=[[3, 20, 22], [12, 18], [13, 36], [14, 31], [17, 27], [19, 38, 47]])
160
136
  """ # noqa: E501
161
- self.stats = hashstats(data)
137
+ images = Images(data) if isinstance(data, Dataset) else data
138
+ self.stats = hashstats(images)
162
139
  duplicates = self._get_duplicates(self.stats.dict())
163
140
  return DuplicatesOutput(**duplicates)
@@ -2,141 +2,19 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- import contextlib
6
- from dataclasses import dataclass
7
- from typing import Generic, Iterable, Literal, Sequence, TypeVar, Union, overload
5
+ from typing import Any, Literal, Sequence, overload
8
6
 
9
7
  import numpy as np
10
- from numpy.typing import ArrayLike, NDArray
11
-
12
- from dataeval.detectors.linters.merged_stats import combine_stats, get_dataset_step_from_idx
13
- from dataeval.metrics.stats.base import BOX_COUNT, SOURCE_INDEX
14
- from dataeval.metrics.stats.datasetstats import DatasetStatsOutput, datasetstats
15
- from dataeval.metrics.stats.dimensionstats import DimensionStatsOutput
16
- from dataeval.metrics.stats.labelstats import LabelStatsOutput
17
- from dataeval.metrics.stats.pixelstats import PixelStatsOutput
18
- from dataeval.metrics.stats.visualstats import VisualStatsOutput
19
- from dataeval.output import Output, set_metadata
20
-
21
- with contextlib.suppress(ImportError):
22
- import pandas as pd
23
-
24
-
25
- IndexIssueMap = dict[int, dict[str, float]]
26
- OutlierStatsOutput = Union[DimensionStatsOutput, PixelStatsOutput, VisualStatsOutput]
27
- TIndexIssueMap = TypeVar("TIndexIssueMap", IndexIssueMap, list[IndexIssueMap])
28
-
29
-
30
- def _reorganize_by_class_and_metric(result, lstats):
31
- """Flip result from grouping by image to grouping by class and metric"""
32
- metrics = {}
33
- class_wise = {label: {} for label in lstats.image_indices_per_label}
34
-
35
- # Group metrics and calculate class-wise counts
36
- for img, group in result.items():
37
- for extreme in group:
38
- metrics.setdefault(extreme, []).append(img)
39
- for label, images in lstats.image_indices_per_label.items():
40
- if img in images:
41
- class_wise[label][extreme] = class_wise[label].get(extreme, 0) + 1
42
-
43
- return metrics, class_wise
44
-
45
-
46
- def _create_table(metrics, class_wise):
47
- """Create table for displaying the results"""
48
- max_class_length = max(len(str(label)) for label in class_wise) + 2
49
- max_total = max(len(metrics[group]) for group in metrics) + 2
50
-
51
- table_header = " | ".join(
52
- [f"{'Class':>{max_class_length}}"]
53
- + [f"{group:^{max(5, len(str(group))) + 2}}" for group in sorted(metrics.keys())]
54
- + [f"{'Total':<{max_total}}"]
55
- )
56
- table_rows = []
57
-
58
- for class_cat, results in class_wise.items():
59
- table_value = [f"{class_cat:>{max_class_length}}"]
60
- total = 0
61
- for group in sorted(metrics.keys()):
62
- count = results.get(group, 0)
63
- table_value.append(f"{count:^{max(5, len(str(group))) + 2}}")
64
- total += count
65
- table_value.append(f"{total:^{max_total}}")
66
- table_rows.append(" | ".join(table_value))
67
-
68
- table = [table_header] + table_rows
69
- return table
70
-
71
-
72
- def _create_pandas_dataframe(class_wise):
73
- """Create data for pandas dataframe"""
74
- data = []
75
- for label, metrics_dict in class_wise.items():
76
- row = {"Class": label}
77
- total = sum(metrics_dict.values())
78
- row.update(metrics_dict) # Add metric counts
79
- row["Total"] = total
80
- data.append(row)
81
- return data
82
-
83
-
84
- @dataclass(frozen=True)
85
- class OutliersOutput(Generic[TIndexIssueMap], Output):
86
- """
87
- Output class for :class:`Outliers` lint detector.
8
+ from numpy.typing import NDArray
88
9
 
89
- Attributes
90
- ----------
91
- issues : dict[int, dict[str, float]] | list[dict[int, dict[str, float]]]
92
- Indices of image Outliers with their associated issue type and calculated values.
93
-
94
- - For a single dataset, a dictionary containing the indices of outliers and
95
- a dictionary showing the issues and calculated values for the given index.
96
- - For multiple stats outputs, a list of dictionaries containing the indices of
97
- outliers and their associated issues and calculated values.
98
- """
99
-
100
- issues: TIndexIssueMap
101
-
102
- def __len__(self) -> int:
103
- if isinstance(self.issues, dict):
104
- return len(self.issues)
105
- else:
106
- return sum(len(d) for d in self.issues)
107
-
108
- def to_table(self, labelstats: LabelStatsOutput) -> str:
109
- if isinstance(self.issues, dict):
110
- metrics, classwise = _reorganize_by_class_and_metric(self.issues, labelstats)
111
- listed_table = _create_table(metrics, classwise)
112
- table = "\n".join(listed_table)
113
- else:
114
- outertable = []
115
- for d in self.issues:
116
- metrics, classwise = _reorganize_by_class_and_metric(d, labelstats)
117
- listed_table = _create_table(metrics, classwise)
118
- str_table = "\n".join(listed_table)
119
- outertable.append(str_table)
120
- table = "\n\n".join(outertable)
121
- return table
122
-
123
- def to_dataframe(self, labelstats: LabelStatsOutput) -> pd.DataFrame:
124
- import pandas as pd
125
-
126
- if isinstance(self.issues, dict):
127
- _, classwise = _reorganize_by_class_and_metric(self.issues, labelstats)
128
- data = _create_pandas_dataframe(classwise)
129
- df = pd.DataFrame(data)
130
- else:
131
- df_list = []
132
- for i, d in enumerate(self.issues):
133
- _, classwise = _reorganize_by_class_and_metric(d, labelstats)
134
- data = _create_pandas_dataframe(classwise)
135
- single_df = pd.DataFrame(data)
136
- single_df["Dataset"] = i
137
- df_list.append(single_df)
138
- df = pd.concat(df_list)
139
- return df
10
+ from dataeval.metrics.stats._base import combine_stats, get_dataset_step_from_idx
11
+ from dataeval.metrics.stats._imagestats import imagestats
12
+ from dataeval.outputs import DimensionStatsOutput, ImageStatsOutput, OutliersOutput, PixelStatsOutput, VisualStatsOutput
13
+ from dataeval.outputs._base import set_metadata
14
+ from dataeval.outputs._linters import IndexIssueMap, OutlierStatsOutput
15
+ from dataeval.outputs._stats import BOX_COUNT, SOURCE_INDEX
16
+ from dataeval.typing import Array, Dataset
17
+ from dataeval.utils.data._images import Images
140
18
 
141
19
 
142
20
  def _get_outlier_mask(
@@ -226,7 +104,7 @@ class Outliers:
226
104
  outlier_method: Literal["zscore", "modzscore", "iqr"] = "modzscore",
227
105
  outlier_threshold: float | None = None,
228
106
  ):
229
- self.stats: DatasetStatsOutput
107
+ self.stats: ImageStatsOutput
230
108
  self.use_dimension = use_dimension
231
109
  self.use_pixel = use_pixel
232
110
  self.use_visual = use_visual
@@ -247,23 +125,23 @@ class Outliers:
247
125
  return dict(sorted(flagged_images.items()))
248
126
 
249
127
  @overload
250
- def from_stats(self, stats: OutlierStatsOutput | DatasetStatsOutput) -> OutliersOutput[IndexIssueMap]: ...
128
+ def from_stats(self, stats: OutlierStatsOutput | ImageStatsOutput) -> OutliersOutput[IndexIssueMap]: ...
251
129
 
252
130
  @overload
253
131
  def from_stats(self, stats: Sequence[OutlierStatsOutput]) -> OutliersOutput[list[IndexIssueMap]]: ...
254
132
 
255
133
  @set_metadata(state=["outlier_method", "outlier_threshold"])
256
134
  def from_stats(
257
- self, stats: OutlierStatsOutput | DatasetStatsOutput | Sequence[OutlierStatsOutput]
135
+ self, stats: OutlierStatsOutput | ImageStatsOutput | Sequence[OutlierStatsOutput]
258
136
  ) -> OutliersOutput[IndexIssueMap] | OutliersOutput[list[IndexIssueMap]]:
259
137
  """
260
138
  Returns indices of Outliers with the issues identified for each.
261
139
 
262
140
  Parameters
263
141
  ----------
264
- stats : OutlierStatsOutput | DatasetStatsOutput | Sequence[OutlierStatsOutput]
142
+ stats : OutlierStatsOutput | ImageStatsOutput | Sequence[OutlierStatsOutput]
265
143
  The output(s) from a dimensionstats, pixelstats, or visualstats metric
266
- analysis or an aggregate DatasetStatsOutput
144
+ analysis or an aggregate ImageStatsOutput
267
145
 
268
146
  Returns
269
147
  -------
@@ -290,11 +168,7 @@ class Outliers:
290
168
  >>> results.issues[1]
291
169
  {}
292
170
  """ # noqa: E501
293
- if isinstance(stats, DatasetStatsOutput):
294
- outliers = self._get_outliers({k: v for o in stats._outputs() for k, v in o.dict().items()})
295
- return OutliersOutput(outliers)
296
-
297
- if isinstance(stats, (DimensionStatsOutput, PixelStatsOutput, VisualStatsOutput)):
171
+ if isinstance(stats, (ImageStatsOutput, DimensionStatsOutput, PixelStatsOutput, VisualStatsOutput)):
298
172
  return OutliersOutput(self._get_outliers(stats.dict()))
299
173
 
300
174
  if not isinstance(stats, Sequence):
@@ -305,7 +179,7 @@ class Outliers:
305
179
  stats_map: dict[type, list[int]] = {}
306
180
  for i, stats_output in enumerate(stats):
307
181
  if not isinstance(
308
- stats_output, (DatasetStatsOutput, DimensionStatsOutput, PixelStatsOutput, VisualStatsOutput)
182
+ stats_output, (ImageStatsOutput, DimensionStatsOutput, PixelStatsOutput, VisualStatsOutput)
309
183
  ):
310
184
  raise TypeError(
311
185
  "Invalid stats output type; only use output from dimensionstats, pixelstats or visualstats."
@@ -323,14 +197,14 @@ class Outliers:
323
197
  return OutliersOutput(output_list)
324
198
 
325
199
  @set_metadata(state=["use_dimension", "use_pixel", "use_visual", "outlier_method", "outlier_threshold"])
326
- def evaluate(self, data: Iterable[ArrayLike]) -> OutliersOutput[IndexIssueMap]:
200
+ def evaluate(self, data: Dataset[Array] | Dataset[tuple[Array, Any, Any]]) -> OutliersOutput[IndexIssueMap]:
327
201
  """
328
202
  Returns indices of Outliers with the issues identified for each
329
203
 
330
204
  Parameters
331
205
  ----------
332
- data : Iterable[ArrayLike], shape - (C, H, W)
333
- A dataset of images in an ArrayLike format
206
+ data : Iterable[Array], shape - (C, H, W)
207
+ A dataset of images in an Array format
334
208
 
335
209
  Returns
336
210
  -------
@@ -347,8 +221,9 @@ class Outliers:
347
221
  >>> list(results.issues)
348
222
  [10, 12]
349
223
  >>> results.issues[10]
350
- {'skew': -3.906, 'kurtosis': 13.266, 'entropy': 0.2128, 'contrast': 1.25, 'zeros': 0.05493}
224
+ {'contrast': 1.25, 'zeros': 0.05493, 'skew': -3.906, 'kurtosis': 13.266, 'entropy': 0.2128}
351
225
  """
352
- self.stats = datasetstats(images=data)
226
+ images = Images(data) if isinstance(data, Dataset) else data
227
+ self.stats = imagestats(images)
353
228
  outliers = self._get_outliers(self.stats.dict())
354
229
  return OutliersOutput(outliers)
@@ -5,4 +5,4 @@ Out-of-distribution (OOD) detectors identify data that is different from the dat
5
5
  __all__ = ["OODOutput", "OODScoreOutput", "OOD_AE"]
6
6
 
7
7
  from dataeval.detectors.ood.ae import OOD_AE
8
- from dataeval.detectors.ood.output import OODOutput, OODScoreOutput
8
+ from dataeval.outputs._ood import OODOutput, OODScoreOutput
@@ -16,12 +16,12 @@ from typing import Callable
16
16
 
17
17
  import numpy as np
18
18
  import torch
19
- from numpy.typing import ArrayLike
19
+ from numpy.typing import NDArray
20
20
 
21
21
  from dataeval.detectors.ood.base import OODBase
22
- from dataeval.detectors.ood.output import OODScoreOutput
23
- from dataeval.interop import as_numpy
24
- from dataeval.utils.torch.internal import predict_batch
22
+ from dataeval.outputs import OODScoreOutput
23
+ from dataeval.typing import ArrayLike
24
+ from dataeval.utils.torch._internal import predict_batch
25
25
 
26
26
 
27
27
  class OOD_AE(OODBase):
@@ -30,8 +30,31 @@ class OOD_AE(OODBase):
30
30
 
31
31
  Parameters
32
32
  ----------
33
- model : Autoencoder
34
- An Autoencoder model.
33
+ model : torch.nn.Module
34
+ An autoencoder model to use for encoding and reconstruction of images
35
+ for detection of out-of-distribution samples.
36
+ device : str or torch.Device or None, default None
37
+ The device to use for the detector. None will default to the global
38
+ configuration selection if set, otherwise "cuda" then "cpu" by availability.
39
+
40
+ Example
41
+ -------
42
+ Perform out-of-distribution detection on test data.
43
+
44
+ >>> from dataeval.utils.torch.models import AE
45
+
46
+ >>> input_shape = train_images[0].shape
47
+ >>> ood = OOD_AE(AE(input_shape))
48
+
49
+ Train the autoencoder using the training data.
50
+
51
+ >>> ood.fit(train_images, threshold_perc=99, epochs=20)
52
+
53
+ Test for out-of-distribution samples on the test data.
54
+
55
+ >>> output = ood.predict(test_images)
56
+ >>> output.is_ood
57
+ array([ True, True, False, True, True, True, True, True])
35
58
  """
36
59
 
37
60
  def __init__(self, model: torch.nn.Module, device: str | torch.device | None = None) -> None:
@@ -55,9 +78,7 @@ class OOD_AE(OODBase):
55
78
 
56
79
  super().fit(x_ref, threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
57
80
 
58
- def _score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
59
- self._validate(X := as_numpy(X))
60
-
81
+ def _score(self, X: NDArray[np.float32], batch_size: int = int(1e10)) -> OODScoreOutput:
61
82
  # reconstruct instances
62
83
  X_recon = predict_batch(X, self.model, batch_size=batch_size)
63
84
 
@@ -13,12 +13,13 @@ __all__ = []
13
13
  from typing import Callable, cast
14
14
 
15
15
  import torch
16
- from numpy.typing import ArrayLike
17
16
 
17
+ from dataeval.config import get_device
18
18
  from dataeval.detectors.ood.mixin import OODBaseMixin, OODFitMixin, OODGMMMixin
19
- from dataeval.interop import to_numpy
20
- from dataeval.utils.torch.gmm import GaussianMixtureModelParams, gmm_params
21
- from dataeval.utils.torch.internal import get_device, trainer
19
+ from dataeval.typing import ArrayLike
20
+ from dataeval.utils._array import to_numpy
21
+ from dataeval.utils.torch._gmm import GaussianMixtureModelParams, gmm_params
22
+ from dataeval.utils.torch._internal import trainer
22
23
 
23
24
 
24
25
  class OODBase(OODBaseMixin[torch.nn.Module], OODFitMixin[Callable[..., torch.nn.Module], torch.optim.Optimizer]):
@@ -1,17 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
- from dataeval.detectors.ood.output import OODOutput, OODScoreOutput
4
-
5
3
  __all__ = []
6
4
 
7
5
  from abc import ABC, abstractmethod
8
6
  from typing import Callable, Generic, Literal, TypeVar
9
7
 
10
8
  import numpy as np
11
- from numpy.typing import ArrayLike, NDArray
9
+ from numpy.typing import NDArray
12
10
 
13
- from dataeval.interop import to_numpy
14
- from dataeval.output import set_metadata
11
+ from dataeval.outputs import OODOutput, OODScoreOutput
12
+ from dataeval.outputs._base import set_metadata
13
+ from dataeval.typing import ArrayLike
14
+ from dataeval.utils._array import as_numpy, to_numpy
15
15
 
16
16
  TGMMParams = TypeVar("TGMMParams")
17
17
 
@@ -73,6 +73,9 @@ class OODBaseMixin(Generic[TModel], ABC):
73
73
  def _get_data_info(self, X: NDArray) -> tuple[tuple, type]:
74
74
  if not isinstance(X, np.ndarray):
75
75
  raise TypeError("Dataset should of type: `NDArray`.")
76
+ if np.min(X) < 0 or np.max(X) > 1:
77
+ raise ValueError("Embeddings must be on the unit interval [0-1].")
78
+
76
79
  return X.shape[1:], X.dtype.type
77
80
 
78
81
  def _validate(self, X: NDArray) -> None:
@@ -90,7 +93,7 @@ class OODBaseMixin(Generic[TModel], ABC):
90
93
  self._validate(X)
91
94
 
92
95
  @abstractmethod
93
- def _score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput: ...
96
+ def _score(self, X: NDArray[np.float32], batch_size: int = int(1e10)) -> OODScoreOutput: ...
94
97
 
95
98
  @set_metadata
96
99
  def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
@@ -105,11 +108,17 @@ class OODBaseMixin(Generic[TModel], ABC):
105
108
  Number of instances to process in each batch.
106
109
  Use a smaller batch size if your dataset is large or if you encounter memory issues.
107
110
 
111
+ Raises
112
+ ------
113
+ ValueError
114
+ X input data must be unit interval [0-1].
115
+
108
116
  Returns
109
117
  -------
110
118
  OODScoreOutput
111
119
  An object containing the instance-level and feature-level OOD scores.
112
120
  """
121
+ self._validate(X := as_numpy(X).astype(np.float32))
113
122
  return self._score(X, batch_size)
114
123
 
115
124
  def _threshold_score(self, ood_type: Literal["feature", "instance"] = "instance") -> np.floating:
@@ -134,12 +143,17 @@ class OODBaseMixin(Generic[TModel], ABC):
134
143
  ood_type : "feature" | "instance", default "instance"
135
144
  Predict out-of-distribution at the 'feature' or 'instance' level.
136
145
 
146
+ Raises
147
+ ------
148
+ ValueError
149
+ X input data must be unit interval [0-1].
150
+
137
151
  Returns
138
152
  -------
139
153
  Dictionary containing the outlier predictions for the selected level,
140
154
  and the OOD scores for the data including both 'instance' and 'feature' (if present) level scores.
141
155
  """
142
- self._validate_state(X := to_numpy(X))
156
+ self._validate_state(X := to_numpy(X).astype(np.float32))
143
157
  # compute outlier scores
144
158
  score = self.score(X, batch_size=batch_size)
145
159
  ood_pred = score.get(ood_type) > self._threshold_score(ood_type)
@@ -0,0 +1,73 @@
1
+ """
2
+ Adapted for Pytorch from
3
+
4
+ Source code derived from Alibi-Detect 0.11.4
5
+ https://github.com/SeldonIO/alibi-detect/tree/v0.11.4
6
+
7
+ Original code Copyright (c) 2023 Seldon Technologies Ltd
8
+ Licensed under Apache Software License (Apache 2.0)
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ __all__ = []
14
+
15
+ from typing import Callable
16
+
17
+ import numpy as np
18
+ import torch
19
+
20
+ from dataeval.detectors.ood.base import OODBase
21
+ from dataeval.outputs import OODScoreOutput
22
+ from dataeval.typing import ArrayLike
23
+ from dataeval.utils._array import as_numpy
24
+ from dataeval.utils.torch._internal import predict_batch
25
+
26
+
27
+ class OOD_VAE(OODBase):
28
+ """
29
+ Autoencoder based out-of-distribution detector.
30
+
31
+ Parameters
32
+ ----------
33
+ model : Autoencoder
34
+ An Autoencoder model.
35
+ """
36
+
37
+ def __init__(self, model: torch.nn.Module, device: str | torch.device | None = None) -> None:
38
+ super().__init__(model, device)
39
+
40
+ def fit(
41
+ self,
42
+ x_ref: ArrayLike,
43
+ threshold_perc: float,
44
+ loss_fn: Callable[..., torch.nn.Module] | None = None,
45
+ optimizer: torch.optim.Optimizer | None = None,
46
+ epochs: int = 20,
47
+ batch_size: int = 64,
48
+ verbose: bool = False,
49
+ ) -> None:
50
+ if loss_fn is None:
51
+ loss_fn = torch.nn.MSELoss()
52
+
53
+ if optimizer is None:
54
+ optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4)
55
+
56
+ super().fit(x_ref, threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
57
+
58
+ def _score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
59
+ self._validate(X := as_numpy(X))
60
+
61
+ # reconstruct instances
62
+ X_recon = predict_batch(X, self.model, batch_size=batch_size)[0] # don't need mu or logvar from model
63
+
64
+ # compute feature and instance level scores
65
+ fscore = np.power(X.reshape((len(X), -1)) - X_recon, 2)
66
+ # fscore_flat = fscore.reshape(fscore.shape[0], -1).copy()
67
+ # n_score_features = int(np.ceil(fscore_flat.shape[1]))
68
+ # sorted_fscore = np.sort(fscore_flat, axis=1)
69
+ # sorted_fscore_perc = sorted_fscore[:, -n_score_features:]
70
+ # iscore = np.mean(sorted_fscore_perc, axis=1)
71
+ iscore = np.sum(fscore, axis=1)
72
+
73
+ return OODScoreOutput(iscore, fscore)
@@ -0,0 +1,6 @@
1
+ """Explanatory functions using metadata and additional features such as ood or drift"""
2
+
3
+ __all__ = ["most_deviated_factors", "metadata_distance"]
4
+
5
+ from dataeval.metadata._distance import metadata_distance
6
+ from dataeval.metadata._ood import most_deviated_factors