dataeval 0.76.1__py3-none-any.whl → 0.81.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. dataeval/__init__.py +3 -3
  2. dataeval/{output.py → _output.py} +14 -0
  3. dataeval/config.py +77 -0
  4. dataeval/detectors/__init__.py +1 -1
  5. dataeval/detectors/drift/__init__.py +6 -6
  6. dataeval/detectors/drift/{base.py → _base.py} +41 -30
  7. dataeval/detectors/drift/{cvm.py → _cvm.py} +21 -28
  8. dataeval/detectors/drift/{ks.py → _ks.py} +20 -26
  9. dataeval/detectors/drift/{mmd.py → _mmd.py} +33 -19
  10. dataeval/detectors/drift/{torch.py → _torch.py} +2 -1
  11. dataeval/detectors/drift/{uncertainty.py → _uncertainty.py} +23 -7
  12. dataeval/detectors/drift/updates.py +1 -1
  13. dataeval/detectors/linters/__init__.py +0 -3
  14. dataeval/detectors/linters/duplicates.py +17 -8
  15. dataeval/detectors/linters/outliers.py +23 -14
  16. dataeval/detectors/ood/ae.py +29 -8
  17. dataeval/detectors/ood/base.py +5 -4
  18. dataeval/detectors/ood/metadata_ks_compare.py +1 -1
  19. dataeval/detectors/ood/mixin.py +20 -5
  20. dataeval/detectors/ood/output.py +1 -1
  21. dataeval/detectors/ood/vae.py +73 -0
  22. dataeval/metadata/__init__.py +5 -0
  23. dataeval/metadata/_ood.py +238 -0
  24. dataeval/metrics/__init__.py +1 -1
  25. dataeval/metrics/bias/__init__.py +5 -4
  26. dataeval/metrics/bias/{balance.py → _balance.py} +67 -17
  27. dataeval/metrics/bias/{coverage.py → _coverage.py} +41 -35
  28. dataeval/metrics/bias/{diversity.py → _diversity.py} +17 -12
  29. dataeval/metrics/bias/{parity.py → _parity.py} +89 -61
  30. dataeval/metrics/estimators/__init__.py +14 -4
  31. dataeval/metrics/estimators/{ber.py → _ber.py} +42 -11
  32. dataeval/metrics/estimators/_clusterer.py +104 -0
  33. dataeval/metrics/estimators/{divergence.py → _divergence.py} +18 -13
  34. dataeval/metrics/estimators/{uap.py → _uap.py} +4 -4
  35. dataeval/metrics/stats/__init__.py +7 -7
  36. dataeval/metrics/stats/{base.py → _base.py} +52 -16
  37. dataeval/metrics/stats/{boxratiostats.py → _boxratiostats.py} +6 -9
  38. dataeval/metrics/stats/{datasetstats.py → _datasetstats.py} +10 -14
  39. dataeval/metrics/stats/{dimensionstats.py → _dimensionstats.py} +6 -5
  40. dataeval/metrics/stats/{hashstats.py → _hashstats.py} +6 -6
  41. dataeval/metrics/stats/{labelstats.py → _labelstats.py} +4 -4
  42. dataeval/metrics/stats/{pixelstats.py → _pixelstats.py} +5 -4
  43. dataeval/metrics/stats/{visualstats.py → _visualstats.py} +9 -8
  44. dataeval/typing.py +54 -0
  45. dataeval/utils/__init__.py +2 -2
  46. dataeval/utils/_array.py +169 -0
  47. dataeval/utils/_bin.py +199 -0
  48. dataeval/utils/_clusterer.py +144 -0
  49. dataeval/utils/_fast_mst.py +189 -0
  50. dataeval/utils/{image.py → _image.py} +6 -4
  51. dataeval/utils/_method.py +18 -0
  52. dataeval/utils/{shared.py → _mst.py} +3 -65
  53. dataeval/utils/{plot.py → _plot.py} +4 -4
  54. dataeval/utils/data/__init__.py +22 -0
  55. dataeval/utils/data/_embeddings.py +105 -0
  56. dataeval/utils/data/_images.py +65 -0
  57. dataeval/utils/data/_metadata.py +352 -0
  58. dataeval/utils/data/_selection.py +119 -0
  59. dataeval/utils/{dataset/split.py → data/_split.py} +13 -14
  60. dataeval/utils/data/_targets.py +73 -0
  61. dataeval/utils/data/_types.py +58 -0
  62. dataeval/utils/data/collate.py +103 -0
  63. dataeval/utils/data/datasets/__init__.py +17 -0
  64. dataeval/utils/data/datasets/_base.py +254 -0
  65. dataeval/utils/data/datasets/_cifar10.py +134 -0
  66. dataeval/utils/data/datasets/_fileio.py +168 -0
  67. dataeval/utils/data/datasets/_milco.py +153 -0
  68. dataeval/utils/data/datasets/_mixin.py +56 -0
  69. dataeval/utils/data/datasets/_mnist.py +183 -0
  70. dataeval/utils/data/datasets/_ships.py +123 -0
  71. dataeval/utils/data/datasets/_voc.py +352 -0
  72. dataeval/utils/data/selections/__init__.py +15 -0
  73. dataeval/utils/data/selections/_classfilter.py +60 -0
  74. dataeval/utils/data/selections/_indices.py +26 -0
  75. dataeval/utils/data/selections/_limit.py +26 -0
  76. dataeval/utils/data/selections/_reverse.py +18 -0
  77. dataeval/utils/data/selections/_shuffle.py +29 -0
  78. dataeval/utils/metadata.py +51 -376
  79. dataeval/utils/torch/{gmm.py → _gmm.py} +4 -2
  80. dataeval/utils/torch/{internal.py → _internal.py} +21 -51
  81. dataeval/utils/torch/models.py +43 -2
  82. dataeval/workflows/sufficiency.py +10 -9
  83. {dataeval-0.76.1.dist-info → dataeval-0.81.0.dist-info}/METADATA +4 -1
  84. dataeval-0.81.0.dist-info/RECORD +94 -0
  85. dataeval/detectors/linters/clusterer.py +0 -512
  86. dataeval/detectors/linters/merged_stats.py +0 -49
  87. dataeval/detectors/ood/metadata_least_likely.py +0 -119
  88. dataeval/interop.py +0 -69
  89. dataeval/utils/dataset/__init__.py +0 -7
  90. dataeval/utils/dataset/datasets.py +0 -412
  91. dataeval/utils/dataset/read.py +0 -63
  92. dataeval-0.76.1.dist-info/RECORD +0 -67
  93. /dataeval/{log.py → _log.py} +0 -0
  94. /dataeval/utils/torch/{blocks.py → _blocks.py} +0 -0
  95. {dataeval-0.76.1.dist-info → dataeval-0.81.0.dist-info}/LICENSE.txt +0 -0
  96. {dataeval-0.76.1.dist-info → dataeval-0.81.0.dist-info}/WHEEL +0 -0
@@ -3,13 +3,14 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  from dataclasses import dataclass
6
- from typing import Generic, Iterable, Sequence, TypeVar, overload
6
+ from typing import Any, Generic, Iterable, Sequence, TypeVar, overload
7
7
 
8
- from numpy.typing import ArrayLike
8
+ from torch.utils.data import Dataset
9
9
 
10
- from dataeval.detectors.linters.merged_stats import combine_stats, get_dataset_step_from_idx
11
- from dataeval.metrics.stats.hashstats import HashStatsOutput, hashstats
12
- from dataeval.output import Output, set_metadata
10
+ from dataeval._output import Output, set_metadata
11
+ from dataeval.metrics.stats._base import combine_stats, get_dataset_step_from_idx
12
+ from dataeval.metrics.stats._hashstats import HashStatsOutput, hashstats
13
+ from dataeval.typing import ArrayLike
13
14
 
14
15
  DuplicateGroup = list[int]
15
16
  DatasetDuplicateGroupMap = dict[int, DuplicateGroup]
@@ -19,7 +20,7 @@ TIndexCollection = TypeVar("TIndexCollection", DuplicateGroup, DatasetDuplicateG
19
20
  @dataclass(frozen=True)
20
21
  class DuplicatesOutput(Generic[TIndexCollection], Output):
21
22
  """
22
- Output class for :class:`Duplicates` lint detector.
23
+ Output class for :class:`.Duplicates` lint detector.
23
24
 
24
25
  Attributes
25
26
  ----------
@@ -133,8 +134,15 @@ class Duplicates:
133
134
 
134
135
  return DuplicatesOutput(**duplicates)
135
136
 
137
+ @overload
138
+ def evaluate(self, data: Iterable[ArrayLike]) -> DuplicatesOutput[DuplicateGroup]: ...
139
+ @overload
140
+ def evaluate(self, data: Dataset[tuple[ArrayLike, Any, dict[str, Any]]]) -> DuplicatesOutput[DuplicateGroup]: ...
141
+
136
142
  @set_metadata(state=["only_exact"])
137
- def evaluate(self, data: Iterable[ArrayLike]) -> DuplicatesOutput[DuplicateGroup]:
143
+ def evaluate(
144
+ self, data: Iterable[ArrayLike] | Dataset[tuple[ArrayLike, Any, dict[str, Any]]]
145
+ ) -> DuplicatesOutput[DuplicateGroup]:
138
146
  """
139
147
  Returns duplicate image indices for both exact matches and near matches
140
148
 
@@ -158,6 +166,7 @@ class Duplicates:
158
166
  >>> all_dupes.evaluate(duplicate_images)
159
167
  DuplicatesOutput(exact=[[3, 20], [16, 37]], near=[[3, 20, 22], [12, 18], [13, 36], [14, 31], [17, 27], [19, 38, 47]])
160
168
  """ # noqa: E501
161
- self.stats = hashstats(data)
169
+ images = (d[0] for d in data) if isinstance(data, Dataset) else data
170
+ self.stats = hashstats(images)
162
171
  duplicates = self._get_duplicates(self.stats.dict())
163
172
  return DuplicatesOutput(**duplicates)
@@ -4,19 +4,20 @@ __all__ = []
4
4
 
5
5
  import contextlib
6
6
  from dataclasses import dataclass
7
- from typing import Generic, Iterable, Literal, Sequence, TypeVar, Union, overload
7
+ from typing import Any, Generic, Iterable, Literal, Sequence, TypeVar, Union, overload
8
8
 
9
9
  import numpy as np
10
- from numpy.typing import ArrayLike, NDArray
11
-
12
- from dataeval.detectors.linters.merged_stats import combine_stats, get_dataset_step_from_idx
13
- from dataeval.metrics.stats.base import BOX_COUNT, SOURCE_INDEX
14
- from dataeval.metrics.stats.datasetstats import DatasetStatsOutput, datasetstats
15
- from dataeval.metrics.stats.dimensionstats import DimensionStatsOutput
16
- from dataeval.metrics.stats.labelstats import LabelStatsOutput
17
- from dataeval.metrics.stats.pixelstats import PixelStatsOutput
18
- from dataeval.metrics.stats.visualstats import VisualStatsOutput
19
- from dataeval.output import Output, set_metadata
10
+ from numpy.typing import NDArray
11
+ from torch.utils.data import Dataset
12
+
13
+ from dataeval._output import Output, set_metadata
14
+ from dataeval.metrics.stats._base import BOX_COUNT, SOURCE_INDEX, combine_stats, get_dataset_step_from_idx
15
+ from dataeval.metrics.stats._datasetstats import DatasetStatsOutput, datasetstats
16
+ from dataeval.metrics.stats._dimensionstats import DimensionStatsOutput
17
+ from dataeval.metrics.stats._labelstats import LabelStatsOutput
18
+ from dataeval.metrics.stats._pixelstats import PixelStatsOutput
19
+ from dataeval.metrics.stats._visualstats import VisualStatsOutput
20
+ from dataeval.typing import ArrayLike
20
21
 
21
22
  with contextlib.suppress(ImportError):
22
23
  import pandas as pd
@@ -84,7 +85,7 @@ def _create_pandas_dataframe(class_wise):
84
85
  @dataclass(frozen=True)
85
86
  class OutliersOutput(Generic[TIndexIssueMap], Output):
86
87
  """
87
- Output class for :class:`Outliers` lint detector.
88
+ Output class for :class:`.Outliers` lint detector.
88
89
 
89
90
  Attributes
90
91
  ----------
@@ -322,8 +323,15 @@ class Outliers:
322
323
 
323
324
  return OutliersOutput(output_list)
324
325
 
326
+ @overload
327
+ def evaluate(self, data: Iterable[ArrayLike]) -> OutliersOutput[IndexIssueMap]: ...
328
+ @overload
329
+ def evaluate(self, data: Dataset[tuple[ArrayLike, Any, dict[str, Any]]]) -> OutliersOutput[IndexIssueMap]: ...
330
+
325
331
  @set_metadata(state=["use_dimension", "use_pixel", "use_visual", "outlier_method", "outlier_threshold"])
326
- def evaluate(self, data: Iterable[ArrayLike]) -> OutliersOutput[IndexIssueMap]:
332
+ def evaluate(
333
+ self, data: Iterable[ArrayLike] | Dataset[tuple[ArrayLike, Any, dict[str, Any]]]
334
+ ) -> OutliersOutput[IndexIssueMap]:
327
335
  """
328
336
  Returns indices of Outliers with the issues identified for each
329
337
 
@@ -349,6 +357,7 @@ class Outliers:
349
357
  >>> results.issues[10]
350
358
  {'skew': -3.906, 'kurtosis': 13.266, 'entropy': 0.2128, 'contrast': 1.25, 'zeros': 0.05493}
351
359
  """
352
- self.stats = datasetstats(images=data)
360
+ images = (d[0] for d in data) if isinstance(data, Dataset) else data
361
+ self.stats = datasetstats(images=images)
353
362
  outliers = self._get_outliers(self.stats.dict())
354
363
  return OutliersOutput(outliers)
@@ -16,12 +16,12 @@ from typing import Callable
16
16
 
17
17
  import numpy as np
18
18
  import torch
19
- from numpy.typing import ArrayLike
19
+ from numpy.typing import NDArray
20
20
 
21
21
  from dataeval.detectors.ood.base import OODBase
22
22
  from dataeval.detectors.ood.output import OODScoreOutput
23
- from dataeval.interop import as_numpy
24
- from dataeval.utils.torch.internal import predict_batch
23
+ from dataeval.typing import ArrayLike
24
+ from dataeval.utils.torch._internal import predict_batch
25
25
 
26
26
 
27
27
  class OOD_AE(OODBase):
@@ -30,8 +30,31 @@ class OOD_AE(OODBase):
30
30
 
31
31
  Parameters
32
32
  ----------
33
- model : Autoencoder
34
- An Autoencoder model.
33
+ model : torch.nn.Module
34
+ An autoencoder model to use for encoding and reconstruction of images
35
+ for detection of out-of-distribution samples.
36
+ device : str or torch.Device or None, default None
37
+ The device to use for the detector. None will default to the global
38
+ configuration selection if set, otherwise "cuda" then "cpu" by availability.
39
+
40
+ Example
41
+ -------
42
+ Perform out-of-distribution detection on test data.
43
+
44
+ >>> from dataeval.utils.torch.models import AE
45
+
46
+ >>> input_shape = train_images[0].shape
47
+ >>> ood = OOD_AE(AE(input_shape))
48
+
49
+ Train the autoencoder using the training data.
50
+
51
+ >>> ood.fit(train_images, threshold_perc=99, epochs=20)
52
+
53
+ Test for out-of-distribution samples on the test data.
54
+
55
+ >>> output = ood.predict(test_images)
56
+ >>> output.is_ood
57
+ array([ True, True, False, True, True, True, True, True])
35
58
  """
36
59
 
37
60
  def __init__(self, model: torch.nn.Module, device: str | torch.device | None = None) -> None:
@@ -55,9 +78,7 @@ class OOD_AE(OODBase):
55
78
 
56
79
  super().fit(x_ref, threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
57
80
 
58
- def _score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
59
- self._validate(X := as_numpy(X))
60
-
81
+ def _score(self, X: NDArray[np.float32], batch_size: int = int(1e10)) -> OODScoreOutput:
61
82
  # reconstruct instances
62
83
  X_recon = predict_batch(X, self.model, batch_size=batch_size)
63
84
 
@@ -13,12 +13,13 @@ __all__ = []
13
13
  from typing import Callable, cast
14
14
 
15
15
  import torch
16
- from numpy.typing import ArrayLike
17
16
 
17
+ from dataeval.config import get_device
18
18
  from dataeval.detectors.ood.mixin import OODBaseMixin, OODFitMixin, OODGMMMixin
19
- from dataeval.interop import to_numpy
20
- from dataeval.utils.torch.gmm import GaussianMixtureModelParams, gmm_params
21
- from dataeval.utils.torch.internal import get_device, trainer
19
+ from dataeval.typing import ArrayLike
20
+ from dataeval.utils._array import to_numpy
21
+ from dataeval.utils.torch._gmm import GaussianMixtureModelParams, gmm_params
22
+ from dataeval.utils.torch._internal import trainer
22
23
 
23
24
 
24
25
  class OODBase(OODBaseMixin[torch.nn.Module], OODFitMixin[Callable[..., torch.nn.Module], torch.optim.Optimizer]):
@@ -11,7 +11,7 @@ from numpy.typing import NDArray
11
11
  from scipy.stats import iqr, ks_2samp
12
12
  from scipy.stats import wasserstein_distance as emd
13
13
 
14
- from dataeval.output import MappingOutput, set_metadata
14
+ from dataeval._output import MappingOutput, set_metadata
15
15
 
16
16
 
17
17
  class MetadataKSResult(NamedTuple):
@@ -8,10 +8,11 @@ from abc import ABC, abstractmethod
8
8
  from typing import Callable, Generic, Literal, TypeVar
9
9
 
10
10
  import numpy as np
11
- from numpy.typing import ArrayLike, NDArray
11
+ from numpy.typing import NDArray
12
12
 
13
- from dataeval.interop import to_numpy
14
- from dataeval.output import set_metadata
13
+ from dataeval._output import set_metadata
14
+ from dataeval.typing import ArrayLike
15
+ from dataeval.utils._array import as_numpy, to_numpy
15
16
 
16
17
  TGMMParams = TypeVar("TGMMParams")
17
18
 
@@ -73,6 +74,9 @@ class OODBaseMixin(Generic[TModel], ABC):
73
74
  def _get_data_info(self, X: NDArray) -> tuple[tuple, type]:
74
75
  if not isinstance(X, np.ndarray):
75
76
  raise TypeError("Dataset should of type: `NDArray`.")
77
+ if np.min(X) < 0 or np.max(X) > 1:
78
+ raise ValueError("Embeddings must be on the unit interval [0-1].")
79
+
76
80
  return X.shape[1:], X.dtype.type
77
81
 
78
82
  def _validate(self, X: NDArray) -> None:
@@ -90,7 +94,7 @@ class OODBaseMixin(Generic[TModel], ABC):
90
94
  self._validate(X)
91
95
 
92
96
  @abstractmethod
93
- def _score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput: ...
97
+ def _score(self, X: NDArray[np.float32], batch_size: int = int(1e10)) -> OODScoreOutput: ...
94
98
 
95
99
  @set_metadata
96
100
  def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
@@ -105,11 +109,17 @@ class OODBaseMixin(Generic[TModel], ABC):
105
109
  Number of instances to process in each batch.
106
110
  Use a smaller batch size if your dataset is large or if you encounter memory issues.
107
111
 
112
+ Raises
113
+ ------
114
+ ValueError
115
+ X input data must be unit interval [0-1].
116
+
108
117
  Returns
109
118
  -------
110
119
  OODScoreOutput
111
120
  An object containing the instance-level and feature-level OOD scores.
112
121
  """
122
+ self._validate(X := as_numpy(X).astype(np.float32))
113
123
  return self._score(X, batch_size)
114
124
 
115
125
  def _threshold_score(self, ood_type: Literal["feature", "instance"] = "instance") -> np.floating:
@@ -134,12 +144,17 @@ class OODBaseMixin(Generic[TModel], ABC):
134
144
  ood_type : "feature" | "instance", default "instance"
135
145
  Predict out-of-distribution at the 'feature' or 'instance' level.
136
146
 
147
+ Raises
148
+ ------
149
+ ValueError
150
+ X input data must be unit interval [0-1].
151
+
137
152
  Returns
138
153
  -------
139
154
  Dictionary containing the outlier predictions for the selected level,
140
155
  and the OOD scores for the data including both 'instance' and 'feature' (if present) level scores.
141
156
  """
142
- self._validate_state(X := to_numpy(X))
157
+ self._validate_state(X := to_numpy(X).astype(np.float32))
143
158
  # compute outlier scores
144
159
  score = self.score(X, batch_size=batch_size)
145
160
  ood_pred = score.get(ood_type) > self._threshold_score(ood_type)
@@ -8,7 +8,7 @@ from typing import Literal
8
8
  import numpy as np
9
9
  from numpy.typing import NDArray
10
10
 
11
- from dataeval.output import Output
11
+ from dataeval._output import Output
12
12
 
13
13
 
14
14
  @dataclass(frozen=True)
@@ -0,0 +1,73 @@
1
+ """
2
+ Adapted for Pytorch from
3
+
4
+ Source code derived from Alibi-Detect 0.11.4
5
+ https://github.com/SeldonIO/alibi-detect/tree/v0.11.4
6
+
7
+ Original code Copyright (c) 2023 Seldon Technologies Ltd
8
+ Licensed under Apache Software License (Apache 2.0)
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ __all__ = []
14
+
15
+ from typing import Callable
16
+
17
+ import numpy as np
18
+ import torch
19
+
20
+ from dataeval.detectors.ood.base import OODBase
21
+ from dataeval.detectors.ood.output import OODScoreOutput
22
+ from dataeval.typing import ArrayLike
23
+ from dataeval.utils._array import as_numpy
24
+ from dataeval.utils.torch._internal import predict_batch
25
+
26
+
27
+ class OOD_VAE(OODBase):
28
+ """
29
+ Autoencoder based out-of-distribution detector.
30
+
31
+ Parameters
32
+ ----------
33
+ model : Autoencoder
34
+ An Autoencoder model.
35
+ """
36
+
37
+ def __init__(self, model: torch.nn.Module, device: str | torch.device | None = None) -> None:
38
+ super().__init__(model, device)
39
+
40
+ def fit(
41
+ self,
42
+ x_ref: ArrayLike,
43
+ threshold_perc: float,
44
+ loss_fn: Callable[..., torch.nn.Module] | None = None,
45
+ optimizer: torch.optim.Optimizer | None = None,
46
+ epochs: int = 20,
47
+ batch_size: int = 64,
48
+ verbose: bool = False,
49
+ ) -> None:
50
+ if loss_fn is None:
51
+ loss_fn = torch.nn.MSELoss()
52
+
53
+ if optimizer is None:
54
+ optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4)
55
+
56
+ super().fit(x_ref, threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
57
+
58
+ def _score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
59
+ self._validate(X := as_numpy(X))
60
+
61
+ # reconstruct instances
62
+ X_recon = predict_batch(X, self.model, batch_size=batch_size)[0] # don't need mu or logvar from model
63
+
64
+ # compute feature and instance level scores
65
+ fscore = np.power(X.reshape((len(X), -1)) - X_recon, 2)
66
+ # fscore_flat = fscore.reshape(fscore.shape[0], -1).copy()
67
+ # n_score_features = int(np.ceil(fscore_flat.shape[1]))
68
+ # sorted_fscore = np.sort(fscore_flat, axis=1)
69
+ # sorted_fscore_perc = sorted_fscore[:, -n_score_features:]
70
+ # iscore = np.mean(sorted_fscore_perc, axis=1)
71
+ iscore = np.sum(fscore, axis=1)
72
+
73
+ return OODScoreOutput(iscore, fscore)
@@ -0,0 +1,5 @@
1
+ """Explanatory functions using metadata and additional features such as ood or drift"""
2
+
3
+ __all__ = ["most_deviated_factors"]
4
+
5
+ from dataeval.metadata._ood import most_deviated_factors
@@ -0,0 +1,238 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ import warnings
6
+
7
+ import numpy as np
8
+ from numpy.typing import NDArray
9
+
10
+ from dataeval.detectors.ood import OODOutput
11
+ from dataeval.utils.data import Metadata
12
+
13
+
14
+ def _validate_keys(keys1: list[str], keys2: list[str]) -> None:
15
+ """
16
+ Raises error when two lists are not equivalent including ordering
17
+
18
+ Parameters
19
+ ----------
20
+ keys1 : list of strings
21
+ List of strings to compare
22
+ keys2 : list of strings
23
+ List of strings to compare
24
+
25
+ Raises
26
+ ------
27
+ ValueError
28
+ If lists do not have the same values, value counts, or ordering
29
+ """
30
+
31
+ if keys1 != keys2:
32
+ raise ValueError(f"Metadata keys must be identical, got {keys1} and {keys2}")
33
+
34
+
35
+ def _validate_factors_and_data(factors: list[str], data: NDArray) -> None:
36
+ """
37
+ Raises error when the number of factors and number of rows do not match
38
+
39
+ Parameters
40
+ ----------
41
+ factors : list of strings
42
+ List of factor names of size N
43
+ data : NDArray
44
+ Array of values with shape (M, N)
45
+
46
+ Raises
47
+ ------
48
+ ValueError
49
+ If the length of factors does not equal the length of the transposed data
50
+ """
51
+ if len(factors) != len(data.T):
52
+ raise ValueError(f"Factors and data have mismatched lengths. Got {len(factors)} and {len(data.T)}")
53
+
54
+
55
+ def _combine_metadata(metadata_1: Metadata, metadata_2: Metadata) -> tuple[list[str], list[NDArray], list[NDArray]]:
56
+ """
57
+ Combines the factor names and data arrays of metadata_1 and metadata_2 when the names
58
+ match exactly and data has the same number of columns (factors).
59
+
60
+ Parameters
61
+ ----------
62
+ metadata_1 : Metadata
63
+ The set of factor names used as reference to determine the correct factor names and length of data
64
+ metadata_2 : Metadata
65
+ The compared set of factor names and data that must match metadata_1
66
+
67
+ Returns
68
+ -------
69
+ list[str]
70
+ The combined discrete and continuous factor names in that order.
71
+ list[NDArray]
72
+ Combined discrete and continuous data of metadata_1
73
+ list[NDArray]
74
+ Combined discrete and continuous data of metadata_2
75
+
76
+ Raises
77
+ ------
78
+ ValueError
79
+ If keys do not match in metadata_1 and metadata_2
80
+ ValueError
81
+ If the length of keys do not match the length of the data
82
+ """
83
+ factor_names: list[str] = []
84
+ m1_data: list[NDArray] = []
85
+ m2_data: list[NDArray] = []
86
+
87
+ # Both metadata must have the same number of factors (cols), but not necessarily samples (row)
88
+ if metadata_1.total_num_factors != metadata_2.total_num_factors:
89
+ raise ValueError(
90
+ f"Number of factors differs between metadata_1 ({metadata_1.total_num_factors}) "
91
+ f"and metadata_2 ({metadata_2.total_num_factors})"
92
+ )
93
+
94
+ # Validate and attach discrete data
95
+ if metadata_1.discrete_factor_names:
96
+ _validate_keys(metadata_1.discrete_factor_names, metadata_2.discrete_factor_names)
97
+ _validate_factors_and_data(metadata_1.discrete_factor_names, metadata_1.discrete_data)
98
+
99
+ factor_names.extend(metadata_1.discrete_factor_names)
100
+ m1_data.append(metadata_1.discrete_data)
101
+ m2_data.append(metadata_2.discrete_data)
102
+
103
+ # Validate and attach continuous data
104
+ if metadata_1.continuous_factor_names:
105
+ _validate_keys(metadata_1.continuous_factor_names, metadata_2.continuous_factor_names)
106
+ _validate_factors_and_data(metadata_1.continuous_factor_names, metadata_1.continuous_data)
107
+
108
+ factor_names.extend(metadata_1.continuous_factor_names)
109
+ m1_data.append(metadata_1.continuous_data)
110
+ m2_data.append(metadata_2.continuous_data)
111
+
112
+ # Turns list of discrete and continuous into one array
113
+ return factor_names, m1_data, m2_data
114
+
115
+
116
+ def _calc_median_deviations(reference: NDArray, test: NDArray) -> NDArray:
117
+ """
118
+ Calculates deviations of the test data from the median of the reference data
119
+
120
+ Parameters
121
+ ----------
122
+ reference : NDArray
123
+ Reference values of shape (samples, factors)
124
+ test : NDArray
125
+ Incoming values where each sample's factors will be compared to the median of
126
+ the reference set corresponding factors
127
+
128
+ Returns
129
+ -------
130
+ NDArray
131
+ Scaled positive and negative deviations of the test data from the reference.
132
+
133
+ Note
134
+ ----
135
+ All return values are in the range [0, pos_inf]
136
+ """
137
+
138
+ # Take median over samples (rows)
139
+ ref_median = np.median(reference, axis=0) # (F, )
140
+
141
+ # Shift reference and test distributions by reference
142
+ ref_dev = reference - ref_median # (S, F) - F
143
+ test_dev = test - ref_median # (S_t, F) - F
144
+
145
+ # Separate positive and negative distributions
146
+ # Fills with nans to keep shape in both 1-D and N-D matrices
147
+ pdev = np.where(ref_dev > 0, ref_dev, np.nan) # (S, F)
148
+ ndev = np.where(ref_dev < 0, ref_dev, np.nan) # (S, F)
149
+
150
+ # Calculate middle of positive and negative distributions per feature
151
+ pscale = np.nanmedian(pdev, axis=0) # (F, )
152
+ nscale = np.abs(np.nanmedian(ndev, axis=0)) # (F, )
153
+
154
+ # Replace 0's for division. Negatives should not happen
155
+ pscale = np.where(pscale > 0, pscale, 1.0) # (F, )
156
+ nscale = np.where(nscale > 0, nscale, 1.0) # (F, )
157
+
158
+ # Scales positive values by positive scale and negative values by negative
159
+ return np.abs(np.where(test_dev >= 0, test_dev / pscale, test_dev / nscale)) # (S_t, F)
160
+
161
+
162
+ def most_deviated_factors(
163
+ metadata_1: Metadata,
164
+ metadata_2: Metadata,
165
+ ood: OODOutput,
166
+ ) -> list[tuple[str, float]]:
167
+ """
168
+ Determines greatest deviation in metadata features per out of distribution sample in metadata_2.
169
+
170
+ Parameters
171
+ ----------
172
+ metadata_1 : Metadata
173
+ A reference set of Metadata containing factor names and samples
174
+ with discrete and/or continuous values per factor
175
+ metadata_2 : Metadata
176
+ The set of Metadata that is tested against the reference metadata.
177
+ This set must have the same number of features but does not require the same number of samples.
178
+ ood : OODOutput
179
+ A class output by the DataEval's OOD functions that contains which examples are OOD.
180
+
181
+ Returns
182
+ -------
183
+ list[tuple[str, float]]
184
+ An array of the factor name and deviation of the highest metadata deviation for each OOD example in metadata_2.
185
+
186
+ Notes
187
+ -----
188
+ 1. Both :class:`.Metadata` inputs must have discrete and continuous data in the shape (samples, factors)
189
+ and have equivalent factor names and lengths
190
+ 2. The flag at index `i` in :attr:`.OODOutput.is_ood` must correspond
191
+ directly to sample `i` of `metadata_2` being out-of-distribution from `metadata_1`
192
+ """
193
+
194
+ ood_mask: NDArray[np.bool] = ood.is_ood
195
+
196
+ # No metadata correlated with out of distribution data
197
+ if not any(ood_mask):
198
+ return []
199
+
200
+ # Combines reference and test factor names and data if exists and match exactly
201
+ # shape -> (samples, factors)
202
+ factor_names, md_1, md_2 = _combine_metadata(
203
+ metadata_1=metadata_1,
204
+ metadata_2=metadata_2,
205
+ )
206
+
207
+ metadata_ref = np.hstack(md_1) if md_1 else np.array([])
208
+ metadata_tst = np.hstack(md_2) if md_2 else np.array([])
209
+
210
+ if len(metadata_ref) < 3:
211
+ warnings.warn(
212
+ f"At least 3 reference metadata samples are needed, got {len(metadata_ref)}",
213
+ UserWarning,
214
+ )
215
+ return []
216
+
217
+ if len(metadata_tst) != len(ood_mask):
218
+ raise ValueError(
219
+ f"ood and test metadata must have the same length, "
220
+ f"got {len(ood_mask)} and {len(metadata_tst)} respectively."
221
+ )
222
+
223
+ # Calculates deviations of all samples in m2_data
224
+ # from the median values of the corresponding index in m1_data
225
+ # Guaranteed for inputs to not be empty
226
+ deviations = _calc_median_deviations(metadata_ref, metadata_tst)
227
+
228
+ # Get most impactful factor deviation of each sample for ood samples only
229
+ deviation = np.max(deviations, axis=1)[ood_mask]
230
+
231
+ # Get indices of most impactful factors for ood samples only
232
+ max_factors = np.argmax(deviations, axis=1)[ood_mask]
233
+
234
+ # Get names of most impactful factors TODO: Find better way than np.dtype(<U4)
235
+ most_ood_factors = np.array(factor_names)[max_factors].tolist()
236
+
237
+ # List of tuples matching the factor name with its deviation
238
+ return [(factor, dev.item()) for factor, dev in zip(most_ood_factors, deviation)]
@@ -5,4 +5,4 @@ can then be analyzed in the context of a given problem.
5
5
 
6
6
  __all__ = ["bias", "estimators", "stats"]
7
7
 
8
- from dataeval.metrics import bias, estimators, stats
8
+ from . import bias, estimators, stats
@@ -7,6 +7,7 @@ __all__ = [
7
7
  "BalanceOutput",
8
8
  "CoverageOutput",
9
9
  "DiversityOutput",
10
+ "LabelParityOutput",
10
11
  "ParityOutput",
11
12
  "balance",
12
13
  "coverage",
@@ -15,7 +16,7 @@ __all__ = [
15
16
  "parity",
16
17
  ]
17
18
 
18
- from dataeval.metrics.bias.balance import BalanceOutput, balance
19
- from dataeval.metrics.bias.coverage import CoverageOutput, coverage
20
- from dataeval.metrics.bias.diversity import DiversityOutput, diversity
21
- from dataeval.metrics.bias.parity import ParityOutput, label_parity, parity
19
+ from dataeval.metrics.bias._balance import BalanceOutput, balance
20
+ from dataeval.metrics.bias._coverage import CoverageOutput, coverage
21
+ from dataeval.metrics.bias._diversity import DiversityOutput, diversity
22
+ from dataeval.metrics.bias._parity import LabelParityOutput, ParityOutput, label_parity, parity