dataeval 0.81.0__py3-none-any.whl → 0.82.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. dataeval/__init__.py +1 -1
  2. dataeval/config.py +68 -11
  3. dataeval/detectors/drift/__init__.py +2 -2
  4. dataeval/detectors/drift/_base.py +8 -64
  5. dataeval/detectors/drift/_mmd.py +12 -38
  6. dataeval/detectors/drift/_torch.py +7 -7
  7. dataeval/detectors/drift/_uncertainty.py +6 -5
  8. dataeval/detectors/drift/updates.py +20 -3
  9. dataeval/detectors/linters/__init__.py +3 -2
  10. dataeval/detectors/linters/duplicates.py +14 -46
  11. dataeval/detectors/linters/outliers.py +25 -159
  12. dataeval/detectors/ood/__init__.py +1 -1
  13. dataeval/detectors/ood/ae.py +6 -5
  14. dataeval/detectors/ood/base.py +2 -2
  15. dataeval/detectors/ood/metadata_ood_mi.py +4 -6
  16. dataeval/detectors/ood/mixin.py +3 -4
  17. dataeval/detectors/ood/vae.py +3 -2
  18. dataeval/metadata/__init__.py +2 -1
  19. dataeval/metadata/_distance.py +134 -0
  20. dataeval/metadata/_ood.py +30 -49
  21. dataeval/metadata/_utils.py +44 -0
  22. dataeval/metrics/bias/__init__.py +5 -4
  23. dataeval/metrics/bias/_balance.py +17 -149
  24. dataeval/metrics/bias/_coverage.py +4 -106
  25. dataeval/metrics/bias/_diversity.py +12 -107
  26. dataeval/metrics/bias/_parity.py +7 -71
  27. dataeval/metrics/estimators/__init__.py +5 -4
  28. dataeval/metrics/estimators/_ber.py +2 -20
  29. dataeval/metrics/estimators/_clusterer.py +1 -61
  30. dataeval/metrics/estimators/_divergence.py +2 -19
  31. dataeval/metrics/estimators/_uap.py +2 -16
  32. dataeval/metrics/stats/__init__.py +15 -12
  33. dataeval/metrics/stats/_base.py +41 -128
  34. dataeval/metrics/stats/_boxratiostats.py +13 -13
  35. dataeval/metrics/stats/_dimensionstats.py +17 -58
  36. dataeval/metrics/stats/_hashstats.py +19 -35
  37. dataeval/metrics/stats/_imagestats.py +94 -0
  38. dataeval/metrics/stats/_labelstats.py +42 -121
  39. dataeval/metrics/stats/_pixelstats.py +19 -51
  40. dataeval/metrics/stats/_visualstats.py +19 -51
  41. dataeval/outputs/__init__.py +57 -0
  42. dataeval/outputs/_base.py +182 -0
  43. dataeval/outputs/_bias.py +381 -0
  44. dataeval/outputs/_drift.py +83 -0
  45. dataeval/outputs/_estimators.py +114 -0
  46. dataeval/outputs/_linters.py +186 -0
  47. dataeval/outputs/_metadata.py +54 -0
  48. dataeval/{detectors/ood/output.py → outputs/_ood.py} +22 -22
  49. dataeval/outputs/_stats.py +393 -0
  50. dataeval/outputs/_utils.py +44 -0
  51. dataeval/outputs/_workflows.py +364 -0
  52. dataeval/typing.py +187 -7
  53. dataeval/utils/_method.py +1 -5
  54. dataeval/utils/_plot.py +2 -2
  55. dataeval/utils/data/__init__.py +5 -1
  56. dataeval/utils/data/_dataset.py +217 -0
  57. dataeval/utils/data/_embeddings.py +12 -14
  58. dataeval/utils/data/_images.py +30 -27
  59. dataeval/utils/data/_metadata.py +28 -11
  60. dataeval/utils/data/_selection.py +25 -22
  61. dataeval/utils/data/_split.py +5 -29
  62. dataeval/utils/data/_targets.py +14 -2
  63. dataeval/utils/data/datasets/_base.py +5 -5
  64. dataeval/utils/data/datasets/_cifar10.py +1 -1
  65. dataeval/utils/data/datasets/_milco.py +1 -1
  66. dataeval/utils/data/datasets/_mnist.py +1 -1
  67. dataeval/utils/data/datasets/_ships.py +1 -1
  68. dataeval/utils/data/{_types.py → datasets/_types.py} +10 -16
  69. dataeval/utils/data/datasets/_voc.py +1 -1
  70. dataeval/utils/data/selections/_classfilter.py +4 -5
  71. dataeval/utils/data/selections/_indices.py +2 -2
  72. dataeval/utils/data/selections/_limit.py +2 -2
  73. dataeval/utils/data/selections/_reverse.py +2 -2
  74. dataeval/utils/data/selections/_shuffle.py +2 -2
  75. dataeval/utils/torch/_internal.py +5 -5
  76. dataeval/utils/torch/trainer.py +8 -8
  77. dataeval/workflows/__init__.py +2 -1
  78. dataeval/workflows/sufficiency.py +6 -342
  79. {dataeval-0.81.0.dist-info → dataeval-0.82.1.dist-info}/METADATA +2 -2
  80. dataeval-0.82.1.dist-info/RECORD +105 -0
  81. dataeval/_output.py +0 -137
  82. dataeval/detectors/ood/metadata_ks_compare.py +0 -129
  83. dataeval/metrics/stats/_datasetstats.py +0 -198
  84. dataeval-0.81.0.dist-info/RECORD +0 -94
  85. {dataeval-0.81.0.dist-info → dataeval-0.82.1.dist-info}/LICENSE.txt +0 -0
  86. {dataeval-0.81.0.dist-info → dataeval-0.82.1.dist-info}/WHEEL +0 -0
@@ -2,142 +2,19 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- import contextlib
6
- from dataclasses import dataclass
7
- from typing import Any, Generic, Iterable, Literal, Sequence, TypeVar, Union, overload
5
+ from typing import Any, Literal, Sequence, overload
8
6
 
9
7
  import numpy as np
10
8
  from numpy.typing import NDArray
11
- from torch.utils.data import Dataset
12
-
13
- from dataeval._output import Output, set_metadata
14
- from dataeval.metrics.stats._base import BOX_COUNT, SOURCE_INDEX, combine_stats, get_dataset_step_from_idx
15
- from dataeval.metrics.stats._datasetstats import DatasetStatsOutput, datasetstats
16
- from dataeval.metrics.stats._dimensionstats import DimensionStatsOutput
17
- from dataeval.metrics.stats._labelstats import LabelStatsOutput
18
- from dataeval.metrics.stats._pixelstats import PixelStatsOutput
19
- from dataeval.metrics.stats._visualstats import VisualStatsOutput
20
- from dataeval.typing import ArrayLike
21
-
22
- with contextlib.suppress(ImportError):
23
- import pandas as pd
24
-
25
-
26
- IndexIssueMap = dict[int, dict[str, float]]
27
- OutlierStatsOutput = Union[DimensionStatsOutput, PixelStatsOutput, VisualStatsOutput]
28
- TIndexIssueMap = TypeVar("TIndexIssueMap", IndexIssueMap, list[IndexIssueMap])
29
-
30
-
31
- def _reorganize_by_class_and_metric(result, lstats):
32
- """Flip result from grouping by image to grouping by class and metric"""
33
- metrics = {}
34
- class_wise = {label: {} for label in lstats.image_indices_per_label}
35
-
36
- # Group metrics and calculate class-wise counts
37
- for img, group in result.items():
38
- for extreme in group:
39
- metrics.setdefault(extreme, []).append(img)
40
- for label, images in lstats.image_indices_per_label.items():
41
- if img in images:
42
- class_wise[label][extreme] = class_wise[label].get(extreme, 0) + 1
43
-
44
- return metrics, class_wise
45
-
46
-
47
- def _create_table(metrics, class_wise):
48
- """Create table for displaying the results"""
49
- max_class_length = max(len(str(label)) for label in class_wise) + 2
50
- max_total = max(len(metrics[group]) for group in metrics) + 2
51
-
52
- table_header = " | ".join(
53
- [f"{'Class':>{max_class_length}}"]
54
- + [f"{group:^{max(5, len(str(group))) + 2}}" for group in sorted(metrics.keys())]
55
- + [f"{'Total':<{max_total}}"]
56
- )
57
- table_rows = []
58
-
59
- for class_cat, results in class_wise.items():
60
- table_value = [f"{class_cat:>{max_class_length}}"]
61
- total = 0
62
- for group in sorted(metrics.keys()):
63
- count = results.get(group, 0)
64
- table_value.append(f"{count:^{max(5, len(str(group))) + 2}}")
65
- total += count
66
- table_value.append(f"{total:^{max_total}}")
67
- table_rows.append(" | ".join(table_value))
68
-
69
- table = [table_header] + table_rows
70
- return table
71
-
72
-
73
- def _create_pandas_dataframe(class_wise):
74
- """Create data for pandas dataframe"""
75
- data = []
76
- for label, metrics_dict in class_wise.items():
77
- row = {"Class": label}
78
- total = sum(metrics_dict.values())
79
- row.update(metrics_dict) # Add metric counts
80
- row["Total"] = total
81
- data.append(row)
82
- return data
83
-
84
-
85
- @dataclass(frozen=True)
86
- class OutliersOutput(Generic[TIndexIssueMap], Output):
87
- """
88
- Output class for :class:`.Outliers` lint detector.
89
-
90
- Attributes
91
- ----------
92
- issues : dict[int, dict[str, float]] | list[dict[int, dict[str, float]]]
93
- Indices of image Outliers with their associated issue type and calculated values.
94
-
95
- - For a single dataset, a dictionary containing the indices of outliers and
96
- a dictionary showing the issues and calculated values for the given index.
97
- - For multiple stats outputs, a list of dictionaries containing the indices of
98
- outliers and their associated issues and calculated values.
99
- """
100
9
 
101
- issues: TIndexIssueMap
102
-
103
- def __len__(self) -> int:
104
- if isinstance(self.issues, dict):
105
- return len(self.issues)
106
- else:
107
- return sum(len(d) for d in self.issues)
108
-
109
- def to_table(self, labelstats: LabelStatsOutput) -> str:
110
- if isinstance(self.issues, dict):
111
- metrics, classwise = _reorganize_by_class_and_metric(self.issues, labelstats)
112
- listed_table = _create_table(metrics, classwise)
113
- table = "\n".join(listed_table)
114
- else:
115
- outertable = []
116
- for d in self.issues:
117
- metrics, classwise = _reorganize_by_class_and_metric(d, labelstats)
118
- listed_table = _create_table(metrics, classwise)
119
- str_table = "\n".join(listed_table)
120
- outertable.append(str_table)
121
- table = "\n\n".join(outertable)
122
- return table
123
-
124
- def to_dataframe(self, labelstats: LabelStatsOutput) -> pd.DataFrame:
125
- import pandas as pd
126
-
127
- if isinstance(self.issues, dict):
128
- _, classwise = _reorganize_by_class_and_metric(self.issues, labelstats)
129
- data = _create_pandas_dataframe(classwise)
130
- df = pd.DataFrame(data)
131
- else:
132
- df_list = []
133
- for i, d in enumerate(self.issues):
134
- _, classwise = _reorganize_by_class_and_metric(d, labelstats)
135
- data = _create_pandas_dataframe(classwise)
136
- single_df = pd.DataFrame(data)
137
- single_df["Dataset"] = i
138
- df_list.append(single_df)
139
- df = pd.concat(df_list)
140
- return df
10
+ from dataeval.metrics.stats._base import combine_stats, get_dataset_step_from_idx
11
+ from dataeval.metrics.stats._imagestats import imagestats
12
+ from dataeval.outputs import DimensionStatsOutput, ImageStatsOutput, OutliersOutput, PixelStatsOutput, VisualStatsOutput
13
+ from dataeval.outputs._base import set_metadata
14
+ from dataeval.outputs._linters import IndexIssueMap, OutlierStatsOutput
15
+ from dataeval.outputs._stats import BOX_COUNT, SOURCE_INDEX
16
+ from dataeval.typing import Array, Dataset
17
+ from dataeval.utils.data._images import Images
141
18
 
142
19
 
143
20
  def _get_outlier_mask(
@@ -227,7 +104,7 @@ class Outliers:
227
104
  outlier_method: Literal["zscore", "modzscore", "iqr"] = "modzscore",
228
105
  outlier_threshold: float | None = None,
229
106
  ):
230
- self.stats: DatasetStatsOutput
107
+ self.stats: ImageStatsOutput
231
108
  self.use_dimension = use_dimension
232
109
  self.use_pixel = use_pixel
233
110
  self.use_visual = use_visual
@@ -248,23 +125,23 @@ class Outliers:
248
125
  return dict(sorted(flagged_images.items()))
249
126
 
250
127
  @overload
251
- def from_stats(self, stats: OutlierStatsOutput | DatasetStatsOutput) -> OutliersOutput[IndexIssueMap]: ...
128
+ def from_stats(self, stats: OutlierStatsOutput | ImageStatsOutput) -> OutliersOutput[IndexIssueMap]: ...
252
129
 
253
130
  @overload
254
131
  def from_stats(self, stats: Sequence[OutlierStatsOutput]) -> OutliersOutput[list[IndexIssueMap]]: ...
255
132
 
256
133
  @set_metadata(state=["outlier_method", "outlier_threshold"])
257
134
  def from_stats(
258
- self, stats: OutlierStatsOutput | DatasetStatsOutput | Sequence[OutlierStatsOutput]
135
+ self, stats: OutlierStatsOutput | ImageStatsOutput | Sequence[OutlierStatsOutput]
259
136
  ) -> OutliersOutput[IndexIssueMap] | OutliersOutput[list[IndexIssueMap]]:
260
137
  """
261
138
  Returns indices of Outliers with the issues identified for each.
262
139
 
263
140
  Parameters
264
141
  ----------
265
- stats : OutlierStatsOutput | DatasetStatsOutput | Sequence[OutlierStatsOutput]
142
+ stats : OutlierStatsOutput | ImageStatsOutput | Sequence[OutlierStatsOutput]
266
143
  The output(s) from a dimensionstats, pixelstats, or visualstats metric
267
- analysis or an aggregate DatasetStatsOutput
144
+ analysis or an aggregate ImageStatsOutput
268
145
 
269
146
  Returns
270
147
  -------
@@ -291,12 +168,8 @@ class Outliers:
291
168
  >>> results.issues[1]
292
169
  {}
293
170
  """ # noqa: E501
294
- if isinstance(stats, DatasetStatsOutput):
295
- outliers = self._get_outliers({k: v for o in stats._outputs() for k, v in o.dict().items()})
296
- return OutliersOutput(outliers)
297
-
298
- if isinstance(stats, (DimensionStatsOutput, PixelStatsOutput, VisualStatsOutput)):
299
- return OutliersOutput(self._get_outliers(stats.dict()))
171
+ if isinstance(stats, (ImageStatsOutput, DimensionStatsOutput, PixelStatsOutput, VisualStatsOutput)):
172
+ return OutliersOutput(self._get_outliers(stats.data()))
300
173
 
301
174
  if not isinstance(stats, Sequence):
302
175
  raise TypeError(
@@ -306,7 +179,7 @@ class Outliers:
306
179
  stats_map: dict[type, list[int]] = {}
307
180
  for i, stats_output in enumerate(stats):
308
181
  if not isinstance(
309
- stats_output, (DatasetStatsOutput, DimensionStatsOutput, PixelStatsOutput, VisualStatsOutput)
182
+ stats_output, (ImageStatsOutput, DimensionStatsOutput, PixelStatsOutput, VisualStatsOutput)
310
183
  ):
311
184
  raise TypeError(
312
185
  "Invalid stats output type; only use output from dimensionstats, pixelstats or visualstats."
@@ -316,29 +189,22 @@ class Outliers:
316
189
  output_list: list[dict[int, dict[str, float]]] = [{} for _ in stats]
317
190
  for _, indices in stats_map.items():
318
191
  substats, dataset_steps = combine_stats([stats[i] for i in indices])
319
- outliers = self._get_outliers(substats.dict())
192
+ outliers = self._get_outliers(substats.data())
320
193
  for idx, issue in outliers.items():
321
194
  k, v = get_dataset_step_from_idx(idx, dataset_steps)
322
195
  output_list[indices[k]][v] = issue
323
196
 
324
197
  return OutliersOutput(output_list)
325
198
 
326
- @overload
327
- def evaluate(self, data: Iterable[ArrayLike]) -> OutliersOutput[IndexIssueMap]: ...
328
- @overload
329
- def evaluate(self, data: Dataset[tuple[ArrayLike, Any, dict[str, Any]]]) -> OutliersOutput[IndexIssueMap]: ...
330
-
331
199
  @set_metadata(state=["use_dimension", "use_pixel", "use_visual", "outlier_method", "outlier_threshold"])
332
- def evaluate(
333
- self, data: Iterable[ArrayLike] | Dataset[tuple[ArrayLike, Any, dict[str, Any]]]
334
- ) -> OutliersOutput[IndexIssueMap]:
200
+ def evaluate(self, data: Dataset[Array] | Dataset[tuple[Array, Any, Any]]) -> OutliersOutput[IndexIssueMap]:
335
201
  """
336
202
  Returns indices of Outliers with the issues identified for each
337
203
 
338
204
  Parameters
339
205
  ----------
340
- data : Iterable[ArrayLike], shape - (C, H, W)
341
- A dataset of images in an ArrayLike format
206
+ data : Iterable[Array], shape - (C, H, W)
207
+ A dataset of images in an Array format
342
208
 
343
209
  Returns
344
210
  -------
@@ -355,9 +221,9 @@ class Outliers:
355
221
  >>> list(results.issues)
356
222
  [10, 12]
357
223
  >>> results.issues[10]
358
- {'skew': -3.906, 'kurtosis': 13.266, 'entropy': 0.2128, 'contrast': 1.25, 'zeros': 0.05493}
224
+ {'contrast': 1.25, 'zeros': 0.05493, 'skew': -3.906, 'kurtosis': 13.266, 'entropy': 0.2128}
359
225
  """
360
- images = (d[0] for d in data) if isinstance(data, Dataset) else data
361
- self.stats = datasetstats(images=images)
362
- outliers = self._get_outliers(self.stats.dict())
226
+ images = Images(data) if isinstance(data, Dataset) else data
227
+ self.stats = imagestats(images)
228
+ outliers = self._get_outliers(self.stats.data())
363
229
  return OutliersOutput(outliers)
@@ -5,4 +5,4 @@ Out-of-distribution (OOD) detectors identify data that is different from the dat
5
5
  __all__ = ["OODOutput", "OODScoreOutput", "OOD_AE"]
6
6
 
7
7
  from dataeval.detectors.ood.ae import OOD_AE
8
- from dataeval.detectors.ood.output import OODOutput, OODScoreOutput
8
+ from dataeval.outputs._ood import OODOutput, OODScoreOutput
@@ -18,8 +18,9 @@ import numpy as np
18
18
  import torch
19
19
  from numpy.typing import NDArray
20
20
 
21
+ from dataeval.config import DeviceLike
21
22
  from dataeval.detectors.ood.base import OODBase
22
- from dataeval.detectors.ood.output import OODScoreOutput
23
+ from dataeval.outputs import OODScoreOutput
23
24
  from dataeval.typing import ArrayLike
24
25
  from dataeval.utils.torch._internal import predict_batch
25
26
 
@@ -33,9 +34,9 @@ class OOD_AE(OODBase):
33
34
  model : torch.nn.Module
34
35
  An autoencoder model to use for encoding and reconstruction of images
35
36
  for detection of out-of-distribution samples.
36
- device : str or torch.Device or None, default None
37
- The device to use for the detector. None will default to the global
38
- configuration selection if set, otherwise "cuda" then "cpu" by availability.
37
+ device : DeviceLike or None, default None
38
+ The hardware device to use if specified, otherwise uses the DataEval
39
+ default or torch default.
39
40
 
40
41
  Example
41
42
  -------
@@ -57,7 +58,7 @@ class OOD_AE(OODBase):
57
58
  array([ True, True, False, True, True, True, True, True])
58
59
  """
59
60
 
60
- def __init__(self, model: torch.nn.Module, device: str | torch.device | None = None) -> None:
61
+ def __init__(self, model: torch.nn.Module, device: DeviceLike | None = None) -> None:
61
62
  super().__init__(model, device)
62
63
 
63
64
  def fit(
@@ -14,7 +14,7 @@ from typing import Callable, cast
14
14
 
15
15
  import torch
16
16
 
17
- from dataeval.config import get_device
17
+ from dataeval.config import DeviceLike, get_device
18
18
  from dataeval.detectors.ood.mixin import OODBaseMixin, OODFitMixin, OODGMMMixin
19
19
  from dataeval.typing import ArrayLike
20
20
  from dataeval.utils._array import to_numpy
@@ -23,7 +23,7 @@ from dataeval.utils.torch._internal import trainer
23
23
 
24
24
 
25
25
  class OODBase(OODBaseMixin[torch.nn.Module], OODFitMixin[Callable[..., torch.nn.Module], torch.optim.Optimizer]):
26
- def __init__(self, model: torch.nn.Module, device: str | torch.device | None = None) -> None:
26
+ def __init__(self, model: torch.nn.Module, device: DeviceLike | None = None) -> None:
27
27
  self.device: torch.device = get_device(device)
28
28
  super().__init__(model)
29
29
 
@@ -10,6 +10,8 @@ import numpy as np
10
10
  from numpy.typing import NDArray
11
11
  from sklearn.feature_selection import mutual_info_classif
12
12
 
13
+ from dataeval.config import get_seed
14
+
13
15
  # NATS2BITS is the reciprocal of natural log of 2. If you have an information/entropy-type quantity measured in nats,
14
16
  # which is what many library functions return, multiply it by NATS2BITS to get it in bits.
15
17
  NATS2BITS = 1.442695
@@ -19,7 +21,6 @@ def get_metadata_ood_mi(
19
21
  metadata: dict[str, list[Any] | NDArray[Any]],
20
22
  is_ood: NDArray[np.bool_],
21
23
  discrete_features: str | bool | NDArray[np.bool_] = False,
22
- random_state: int | None = None,
23
24
  ) -> dict[str, float]:
24
25
  """Computes mutual information between a set of metadata features and an out-of-distribution flag.
25
26
 
@@ -39,9 +40,6 @@ def get_metadata_ood_mi(
39
40
  A boolean array, with one value per example, that indicates which examples are OOD.
40
41
  discrete_features : str | bool | NDArray[np.bool_]
41
42
  Either a boolean array or a single boolean value, indicate which features take on discrete values.
42
- random_state : int, optional - default None
43
- Determines random number generation for small noise added to continuous variables. Set to a value for
44
- reproducible results.
45
43
 
46
44
  Returns
47
45
  -------
@@ -55,7 +53,7 @@ def get_metadata_ood_mi(
55
53
 
56
54
  >>> metadata = {"time": np.linspace(0, 10, 100), "altitude": np.linspace(0, 16, 100) ** 2}
57
55
  >>> is_ood = metadata["altitude"] > 100
58
- >>> get_metadata_ood_mi(metadata, is_ood, discrete_features=False, random_state=0)
56
+ >>> get_metadata_ood_mi(metadata, is_ood, discrete_features=False)
59
57
  {'time': 0.9359596758173668, 'altitude': 0.9407686591507002}
60
58
  """
61
59
  numerical_keys = [k for k, v in metadata.items() if all(isinstance(vi, numbers.Number) for vi in v)]
@@ -84,7 +82,7 @@ def get_metadata_ood_mi(
84
82
  Xscl,
85
83
  is_ood,
86
84
  discrete_features=discrete_features, # type: ignore
87
- random_state=random_state,
85
+ random_state=get_seed(),
88
86
  )
89
87
  * NATS2BITS
90
88
  )
@@ -1,7 +1,5 @@
1
1
  from __future__ import annotations
2
2
 
3
- from dataeval.detectors.ood.output import OODOutput, OODScoreOutput
4
-
5
3
  __all__ = []
6
4
 
7
5
  from abc import ABC, abstractmethod
@@ -10,7 +8,8 @@ from typing import Callable, Generic, Literal, TypeVar
10
8
  import numpy as np
11
9
  from numpy.typing import NDArray
12
10
 
13
- from dataeval._output import set_metadata
11
+ from dataeval.outputs import OODOutput, OODScoreOutput
12
+ from dataeval.outputs._base import set_metadata
14
13
  from dataeval.typing import ArrayLike
15
14
  from dataeval.utils._array import as_numpy, to_numpy
16
15
 
@@ -158,4 +157,4 @@ class OODBaseMixin(Generic[TModel], ABC):
158
157
  # compute outlier scores
159
158
  score = self.score(X, batch_size=batch_size)
160
159
  ood_pred = score.get(ood_type) > self._threshold_score(ood_type)
161
- return OODOutput(is_ood=ood_pred, **score.dict())
160
+ return OODOutput(is_ood=ood_pred, **score.data())
@@ -17,8 +17,9 @@ from typing import Callable
17
17
  import numpy as np
18
18
  import torch
19
19
 
20
+ from dataeval.config import DeviceLike
20
21
  from dataeval.detectors.ood.base import OODBase
21
- from dataeval.detectors.ood.output import OODScoreOutput
22
+ from dataeval.outputs import OODScoreOutput
22
23
  from dataeval.typing import ArrayLike
23
24
  from dataeval.utils._array import as_numpy
24
25
  from dataeval.utils.torch._internal import predict_batch
@@ -34,7 +35,7 @@ class OOD_VAE(OODBase):
34
35
  An Autoencoder model.
35
36
  """
36
37
 
37
- def __init__(self, model: torch.nn.Module, device: str | torch.device | None = None) -> None:
38
+ def __init__(self, model: torch.nn.Module, device: DeviceLike | None = None) -> None:
38
39
  super().__init__(model, device)
39
40
 
40
41
  def fit(
@@ -1,5 +1,6 @@
1
1
  """Explanatory functions using metadata and additional features such as ood or drift"""
2
2
 
3
- __all__ = ["most_deviated_factors"]
3
+ __all__ = ["most_deviated_factors", "metadata_distance"]
4
4
 
5
+ from dataeval.metadata._distance import metadata_distance
5
6
  from dataeval.metadata._ood import most_deviated_factors
@@ -0,0 +1,134 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ import warnings
6
+ from typing import NamedTuple, cast
7
+
8
+ import numpy as np
9
+ from scipy.stats import iqr, ks_2samp
10
+ from scipy.stats import wasserstein_distance as emd
11
+
12
+ from dataeval.metadata._utils import _compare_keys, _validate_factors_and_data
13
+ from dataeval.outputs import MetadataDistanceOutput, MetadataDistanceValues
14
+ from dataeval.outputs._base import set_metadata
15
+ from dataeval.typing import ArrayLike
16
+ from dataeval.utils.data import Metadata
17
+
18
+
19
+ class KSType(NamedTuple):
20
+ """Used to typehint scipy's internal hidden ks_2samp output"""
21
+
22
+ statistic: float
23
+ statistic_location: float
24
+ pvalue: float
25
+
26
+
27
+ def _calculate_drift(x1: ArrayLike, x2: ArrayLike) -> float:
28
+ """Calculates the shift magnitude between x1 and x2 scaled by x1"""
29
+
30
+ distance = emd(x1, x2)
31
+
32
+ X = iqr(x1)
33
+
34
+ # Preferred scaling of x1
35
+ if X:
36
+ return distance / X
37
+
38
+ # Return if single-valued, else scale
39
+ xmin, xmax = np.min(x1), np.max(x1)
40
+ return distance if xmin == xmax else distance / (xmax - xmin)
41
+
42
+
43
+ @set_metadata
44
+ def metadata_distance(metadata1: Metadata, metadata2: Metadata) -> MetadataDistanceOutput:
45
+ """
46
+ Measures the feature-wise distance between two continuous metadata distributions and
47
+ computes a p-value to evaluate its significance.
48
+
49
+ Uses the Earth Mover's Distance and the Kolmogorov-Smirnov two-sample test, featurewise.
50
+
51
+ Parameters
52
+ ----------
53
+ metadata1 : Metadata
54
+ Class containing continuous factor names and values to be used as reference
55
+ metadata2 : Metadata
56
+ Class containing continuous factor names and values to be compare with the reference
57
+
58
+ Returns
59
+ -------
60
+ MetadataDistanceOutput
61
+ A mapping with keys corresponding to metadata feature names, and values that are KstestResult objects, as
62
+ defined by scipy.stats.ks_2samp.
63
+
64
+ See Also
65
+ --------
66
+ Earth mover's distance
67
+
68
+ Kolmogorov-Smirnov two-sample test
69
+
70
+ Note
71
+ ----
72
+ This function only applies to the continuous data
73
+
74
+ Examples
75
+ --------
76
+ >>> output = metadata_distance(metadata1, metadata2)
77
+ >>> list(output)
78
+ ['time', 'altitude']
79
+ >>> output["time"]
80
+ MetadataDistanceValues(statistic=1.0, location=0.44354838709677413, dist=2.7, pvalue=0.0)
81
+ """
82
+
83
+ _compare_keys(metadata1.continuous_factor_names, metadata2.continuous_factor_names)
84
+ fnames = metadata1.continuous_factor_names
85
+
86
+ cont1 = np.atleast_2d(metadata1.continuous_data) # (S, F)
87
+ cont2 = np.atleast_2d(metadata2.continuous_data) # (S, F)
88
+
89
+ _validate_factors_and_data(fnames, cont1)
90
+ _validate_factors_and_data(fnames, cont2)
91
+
92
+ N = len(cont1)
93
+ M = len(cont2)
94
+
95
+ # This is a simplified version of sqrt(N*M / N+M) < 4
96
+ if (N - 16) * (M - 16) < 256:
97
+ warnings.warn(
98
+ f"Sample sizes of {N}, {M} will yield unreliable p-values from the KS test. "
99
+ f"Recommended 32 samples per factor or at least 16 if one set has many more.",
100
+ UserWarning,
101
+ )
102
+
103
+ # Set default for statistic, location, and magnitude to zero and pvalue to one
104
+ results: dict[str, MetadataDistanceValues] = {}
105
+
106
+ # Per factor
107
+ for i, fname in enumerate(fnames):
108
+ fdata1 = cont1[:, i] # (S, 1)
109
+ fdata2 = cont2[:, i] # (S, 1)
110
+
111
+ # Min and max over both distributions
112
+ xmin = min(np.min(fdata1), np.min(fdata2))
113
+ xmax = max(np.max(fdata1), np.max(fdata2))
114
+
115
+ # Default case
116
+ if xmin == xmax:
117
+ results[fname] = MetadataDistanceValues(statistic=0.0, location=0.0, dist=0.0, pvalue=1.0)
118
+ continue
119
+
120
+ ks_result = cast(KSType, ks_2samp(fdata1, fdata2, method="asymp"))
121
+
122
+ # Normalized location
123
+ loc = float((ks_result.statistic_location - xmin) / (xmax - xmin))
124
+
125
+ drift = _calculate_drift(fdata1, fdata2)
126
+
127
+ results[fname] = MetadataDistanceValues(
128
+ statistic=ks_result.statistic,
129
+ location=loc,
130
+ dist=drift,
131
+ pvalue=ks_result.pvalue,
132
+ )
133
+
134
+ return MetadataDistanceOutput(results)