dataeval 0.74.0__py3-none-any.whl → 0.74.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. dataeval/__init__.py +23 -10
  2. dataeval/detectors/__init__.py +2 -10
  3. dataeval/detectors/drift/base.py +3 -3
  4. dataeval/detectors/drift/mmd.py +1 -1
  5. dataeval/detectors/linters/clusterer.py +3 -3
  6. dataeval/detectors/linters/duplicates.py +4 -4
  7. dataeval/detectors/linters/outliers.py +4 -4
  8. dataeval/detectors/ood/__init__.py +5 -12
  9. dataeval/detectors/ood/base.py +5 -5
  10. dataeval/detectors/ood/metadata_ks_compare.py +12 -13
  11. dataeval/interop.py +15 -3
  12. dataeval/logging.py +16 -0
  13. dataeval/metrics/bias/balance.py +3 -3
  14. dataeval/metrics/bias/coverage.py +3 -3
  15. dataeval/metrics/bias/diversity.py +3 -3
  16. dataeval/metrics/bias/metadata_preprocessing.py +3 -3
  17. dataeval/metrics/bias/parity.py +4 -4
  18. dataeval/metrics/estimators/ber.py +3 -3
  19. dataeval/metrics/estimators/divergence.py +3 -3
  20. dataeval/metrics/estimators/uap.py +3 -3
  21. dataeval/metrics/stats/base.py +2 -2
  22. dataeval/metrics/stats/boxratiostats.py +1 -1
  23. dataeval/metrics/stats/datasetstats.py +6 -6
  24. dataeval/metrics/stats/dimensionstats.py +1 -1
  25. dataeval/metrics/stats/hashstats.py +1 -1
  26. dataeval/metrics/stats/labelstats.py +3 -3
  27. dataeval/metrics/stats/pixelstats.py +1 -1
  28. dataeval/metrics/stats/visualstats.py +1 -1
  29. dataeval/output.py +81 -57
  30. dataeval/utils/__init__.py +1 -7
  31. dataeval/utils/split_dataset.py +306 -279
  32. dataeval/workflows/sufficiency.py +4 -4
  33. {dataeval-0.74.0.dist-info → dataeval-0.74.2.dist-info}/METADATA +3 -8
  34. dataeval-0.74.2.dist-info/RECORD +66 -0
  35. dataeval/detectors/ood/ae.py +0 -76
  36. dataeval/detectors/ood/aegmm.py +0 -67
  37. dataeval/detectors/ood/base_tf.py +0 -109
  38. dataeval/detectors/ood/llr.py +0 -302
  39. dataeval/detectors/ood/vae.py +0 -98
  40. dataeval/detectors/ood/vaegmm.py +0 -76
  41. dataeval/utils/lazy.py +0 -26
  42. dataeval/utils/tensorflow/__init__.py +0 -19
  43. dataeval/utils/tensorflow/_internal/gmm.py +0 -103
  44. dataeval/utils/tensorflow/_internal/loss.py +0 -121
  45. dataeval/utils/tensorflow/_internal/models.py +0 -1394
  46. dataeval/utils/tensorflow/_internal/trainer.py +0 -114
  47. dataeval/utils/tensorflow/_internal/utils.py +0 -256
  48. dataeval/utils/tensorflow/loss/__init__.py +0 -11
  49. dataeval-0.74.0.dist-info/RECORD +0 -79
  50. {dataeval-0.74.0.dist-info → dataeval-0.74.2.dist-info}/LICENSE.txt +0 -0
  51. {dataeval-0.74.0.dist-info → dataeval-0.74.2.dist-info}/WHEEL +0 -0
dataeval/__init__.py CHANGED
@@ -1,23 +1,36 @@
1
- __version__ = "0.74.0"
1
+ __version__ = "0.74.2"
2
2
 
3
+ import logging
3
4
  from importlib.util import find_spec
4
5
 
6
+ logging.getLogger(__name__).addHandler(logging.NullHandler())
7
+
8
+
9
+ def log_stderr(level: int = logging.DEBUG) -> None:
10
+ """
11
+ Helper for quickly adding a StreamHandler to the logger. Useful for
12
+ debugging.
13
+ """
14
+ import logging
15
+
16
+ logger = logging.getLogger(__name__)
17
+ handler = logging.StreamHandler()
18
+ handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s"))
19
+ logger.addHandler(handler)
20
+ logger.setLevel(level)
21
+ logger.debug("Added a stderr logging handler to logger: %s", __name__)
22
+
23
+
5
24
  _IS_TORCH_AVAILABLE = find_spec("torch") is not None
6
25
  _IS_TORCHVISION_AVAILABLE = find_spec("torchvision") is not None
7
- _IS_TENSORFLOW_AVAILABLE = find_spec("tensorflow") is not None and find_spec("tensorflow_probability") is not None
8
26
 
9
27
  del find_spec
10
28
 
11
29
  from dataeval import detectors, metrics # noqa: E402
12
30
 
13
- __all__ = ["detectors", "metrics"]
31
+ __all__ = ["log_stderr", "detectors", "metrics"]
14
32
 
15
33
  if _IS_TORCH_AVAILABLE:
16
- from dataeval import workflows
17
-
18
- __all__ += ["workflows"]
19
-
20
- if _IS_TENSORFLOW_AVAILABLE or _IS_TORCH_AVAILABLE:
21
- from dataeval import utils
34
+ from dataeval import utils, workflows
22
35
 
23
- __all__ += ["utils"]
36
+ __all__ += ["utils", "workflows"]
@@ -2,14 +2,6 @@
2
2
  Detectors can determine if a dataset or individual images in a dataset are indicative of a specific issue.
3
3
  """
4
4
 
5
- from dataeval import _IS_TENSORFLOW_AVAILABLE
6
- from dataeval.detectors import drift, linters
5
+ from dataeval.detectors import drift, linters, ood
7
6
 
8
- __all__ = ["drift", "linters"]
9
-
10
- if _IS_TENSORFLOW_AVAILABLE:
11
- from dataeval.detectors import ood
12
-
13
- __all__ += ["ood"]
14
-
15
- del _IS_TENSORFLOW_AVAILABLE
7
+ __all__ = ["drift", "linters", "ood"]
@@ -19,7 +19,7 @@ import numpy as np
19
19
  from numpy.typing import ArrayLike, NDArray
20
20
 
21
21
  from dataeval.interop import as_numpy
22
- from dataeval.output import OutputMetadata, set_metadata
22
+ from dataeval.output import Output, set_metadata
23
23
 
24
24
  R = TypeVar("R")
25
25
 
@@ -43,7 +43,7 @@ class UpdateStrategy(ABC):
43
43
 
44
44
 
45
45
  @dataclass(frozen=True)
46
- class DriftBaseOutput(OutputMetadata):
46
+ class DriftBaseOutput(Output):
47
47
  """
48
48
  Base output class for Drift detector classes
49
49
 
@@ -387,7 +387,7 @@ class BaseDriftUnivariate(BaseDrift):
387
387
  else:
388
388
  raise ValueError("`correction` needs to be either `bonferroni` or `fdr`.")
389
389
 
390
- @set_metadata()
390
+ @set_metadata
391
391
  @preprocess_x
392
392
  @update_x_ref
393
393
  def predict(
@@ -161,7 +161,7 @@ class DriftMMD(BaseDrift):
161
161
  distance_threshold = torch.sort(mmd2_permuted, descending=True).values[idx_threshold]
162
162
  return p_val.numpy().item(), mmd2.numpy().item(), distance_threshold.numpy().item()
163
163
 
164
- @set_metadata()
164
+ @set_metadata
165
165
  @preprocess_x
166
166
  @update_x_ref
167
167
  def predict(self, x: ArrayLike) -> DriftMMDOutput:
@@ -11,12 +11,12 @@ from scipy.cluster.hierarchy import linkage
11
11
  from scipy.spatial.distance import pdist, squareform
12
12
 
13
13
  from dataeval.interop import to_numpy
14
- from dataeval.output import OutputMetadata, set_metadata
14
+ from dataeval.output import Output, set_metadata
15
15
  from dataeval.utils.shared import flatten
16
16
 
17
17
 
18
18
  @dataclass(frozen=True)
19
- class ClustererOutput(OutputMetadata):
19
+ class ClustererOutput(Output):
20
20
  """
21
21
  Output class for :class:`Clusterer` lint detector
22
22
 
@@ -495,7 +495,7 @@ class Clusterer:
495
495
  return exact_dupes, near_dupes
496
496
 
497
497
  # TODO: Move data input to evaluate from class
498
- @set_metadata(["data"])
498
+ @set_metadata(state=["data"])
499
499
  def evaluate(self) -> ClustererOutput:
500
500
  """Finds and flags indices of the data for Outliers and :term:`duplicates<Duplicates>`
501
501
 
@@ -9,7 +9,7 @@ from numpy.typing import ArrayLike
9
9
 
10
10
  from dataeval.detectors.linters.merged_stats import combine_stats, get_dataset_step_from_idx
11
11
  from dataeval.metrics.stats.hashstats import HashStatsOutput, hashstats
12
- from dataeval.output import OutputMetadata, set_metadata
12
+ from dataeval.output import Output, set_metadata
13
13
 
14
14
  DuplicateGroup = list[int]
15
15
  DatasetDuplicateGroupMap = dict[int, DuplicateGroup]
@@ -17,7 +17,7 @@ TIndexCollection = TypeVar("TIndexCollection", DuplicateGroup, DatasetDuplicateG
17
17
 
18
18
 
19
19
  @dataclass(frozen=True)
20
- class DuplicatesOutput(Generic[TIndexCollection], OutputMetadata):
20
+ class DuplicatesOutput(Generic[TIndexCollection], Output):
21
21
  """
22
22
  Output class for :class:`Duplicates` lint detector
23
23
 
@@ -89,7 +89,7 @@ class Duplicates:
89
89
  @overload
90
90
  def from_stats(self, hashes: Sequence[HashStatsOutput]) -> DuplicatesOutput[DatasetDuplicateGroupMap]: ...
91
91
 
92
- @set_metadata(["only_exact"])
92
+ @set_metadata(state=["only_exact"])
93
93
  def from_stats(
94
94
  self, hashes: HashStatsOutput | Sequence[HashStatsOutput]
95
95
  ) -> DuplicatesOutput[DuplicateGroup] | DuplicatesOutput[DatasetDuplicateGroupMap]:
@@ -138,7 +138,7 @@ class Duplicates:
138
138
 
139
139
  return DuplicatesOutput(**duplicates)
140
140
 
141
- @set_metadata(["only_exact"])
141
+ @set_metadata(state=["only_exact"])
142
142
  def evaluate(self, data: Iterable[ArrayLike]) -> DuplicatesOutput[DuplicateGroup]:
143
143
  """
144
144
  Returns duplicate image indices for both exact matches and near matches
@@ -14,7 +14,7 @@ from dataeval.metrics.stats.datasetstats import DatasetStatsOutput, datasetstats
14
14
  from dataeval.metrics.stats.dimensionstats import DimensionStatsOutput
15
15
  from dataeval.metrics.stats.pixelstats import PixelStatsOutput
16
16
  from dataeval.metrics.stats.visualstats import VisualStatsOutput
17
- from dataeval.output import OutputMetadata, set_metadata
17
+ from dataeval.output import Output, set_metadata
18
18
 
19
19
  IndexIssueMap = dict[int, dict[str, float]]
20
20
  OutlierStatsOutput = Union[DimensionStatsOutput, PixelStatsOutput, VisualStatsOutput]
@@ -22,7 +22,7 @@ TIndexIssueMap = TypeVar("TIndexIssueMap", IndexIssueMap, list[IndexIssueMap])
22
22
 
23
23
 
24
24
  @dataclass(frozen=True)
25
- class OutliersOutput(Generic[TIndexIssueMap], OutputMetadata):
25
+ class OutliersOutput(Generic[TIndexIssueMap], Output):
26
26
  """
27
27
  Output class for :class:`Outliers` lint detector
28
28
 
@@ -159,7 +159,7 @@ class Outliers:
159
159
  @overload
160
160
  def from_stats(self, stats: Sequence[OutlierStatsOutput]) -> OutliersOutput[list[IndexIssueMap]]: ...
161
161
 
162
- @set_metadata(["outlier_method", "outlier_threshold"])
162
+ @set_metadata(state=["outlier_method", "outlier_threshold"])
163
163
  def from_stats(
164
164
  self, stats: OutlierStatsOutput | DatasetStatsOutput | Sequence[OutlierStatsOutput]
165
165
  ) -> OutliersOutput[IndexIssueMap] | OutliersOutput[list[IndexIssueMap]]:
@@ -228,7 +228,7 @@ class Outliers:
228
228
 
229
229
  return OutliersOutput(output_list)
230
230
 
231
- @set_metadata(["use_dimension", "use_pixel", "use_visual", "outlier_method", "outlier_threshold"])
231
+ @set_metadata(state=["use_dimension", "use_pixel", "use_visual", "outlier_method", "outlier_threshold"])
232
232
  def evaluate(self, data: Iterable[ArrayLike]) -> OutliersOutput[IndexIssueMap]:
233
233
  """
234
234
  Returns indices of Outliers with the issues identified for each
@@ -2,21 +2,14 @@
2
2
  Out-of-distribution (OOD)` detectors identify data that is different from the data used to train a particular model.
3
3
  """
4
4
 
5
- from dataeval import _IS_TENSORFLOW_AVAILABLE, _IS_TORCH_AVAILABLE
5
+ from dataeval import _IS_TORCH_AVAILABLE
6
6
  from dataeval.detectors.ood.base import OODOutput, OODScoreOutput
7
7
 
8
8
  __all__ = ["OODOutput", "OODScoreOutput"]
9
9
 
10
- if _IS_TENSORFLOW_AVAILABLE:
11
- from dataeval.detectors.ood.ae import OOD_AE
12
- from dataeval.detectors.ood.aegmm import OOD_AEGMM
13
- from dataeval.detectors.ood.llr import OOD_LLR
14
- from dataeval.detectors.ood.vae import OOD_VAE
15
- from dataeval.detectors.ood.vaegmm import OOD_VAEGMM
16
-
17
- __all__ += ["OOD_AE", "OOD_AEGMM", "OOD_LLR", "OOD_VAE", "OOD_VAEGMM"]
18
-
19
- elif _IS_TORCH_AVAILABLE:
10
+ if _IS_TORCH_AVAILABLE:
20
11
  from dataeval.detectors.ood.ae_torch import OOD_AE
21
12
 
22
- __all__ += ["OOD_AE", "OODOutput"]
13
+ __all__ += ["OOD_AE"]
14
+
15
+ del _IS_TORCH_AVAILABLE
@@ -18,12 +18,12 @@ import numpy as np
18
18
  from numpy.typing import ArrayLike, NDArray
19
19
 
20
20
  from dataeval.interop import to_numpy
21
- from dataeval.output import OutputMetadata, set_metadata
21
+ from dataeval.output import Output, set_metadata
22
22
  from dataeval.utils.gmm import GaussianMixtureModelParams
23
23
 
24
24
 
25
25
  @dataclass(frozen=True)
26
- class OODOutput(OutputMetadata):
26
+ class OODOutput(Output):
27
27
  """
28
28
  Output class for predictions from :class:`OOD_AE`, :class:`OOD_AEGMM`, :class:`OOD_LLR`,
29
29
  :class:`OOD_VAE`, and :class:`OOD_VAEGMM` out-of-distribution detectors
@@ -44,7 +44,7 @@ class OODOutput(OutputMetadata):
44
44
 
45
45
 
46
46
  @dataclass(frozen=True)
47
- class OODScoreOutput(OutputMetadata):
47
+ class OODScoreOutput(Output):
48
48
  """
49
49
  Output class for instance and feature scores from :class:`OOD_AE`, :class:`OOD_AEGMM`,
50
50
  :class:`OOD_LLR`, :class:`OOD_VAE`, and :class:`OOD_VAEGMM` out-of-distribution detectors
@@ -153,7 +153,7 @@ class OODBaseMixin(Generic[TModel], ABC):
153
153
  @abstractmethod
154
154
  def _score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput: ...
155
155
 
156
- @set_metadata()
156
+ @set_metadata
157
157
  def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
158
158
  """
159
159
  Compute the :term:`out of distribution<Out-of-distribution (OOD)>` scores for a given dataset.
@@ -176,7 +176,7 @@ class OODBaseMixin(Generic[TModel], ABC):
176
176
  def _threshold_score(self, ood_type: Literal["feature", "instance"] = "instance") -> np.floating:
177
177
  return np.percentile(self._ref_score.get(ood_type), self._threshold_perc)
178
178
 
179
- @set_metadata()
179
+ @set_metadata
180
180
  def predict(
181
181
  self,
182
182
  X: ArrayLike,
@@ -2,7 +2,6 @@ from __future__ import annotations
2
2
 
3
3
  import numbers
4
4
  import warnings
5
- from dataclasses import dataclass
6
5
  from typing import Any, Mapping, NamedTuple
7
6
 
8
7
  import numpy as np
@@ -10,7 +9,7 @@ from numpy.typing import NDArray
10
9
  from scipy.stats import iqr, ks_2samp
11
10
  from scipy.stats import wasserstein_distance as emd
12
11
 
13
- from dataeval.output import OutputMetadata, set_metadata
12
+ from dataeval.output import MappingOutput, set_metadata
14
13
 
15
14
 
16
15
  class MetadataKSResult(NamedTuple):
@@ -20,24 +19,24 @@ class MetadataKSResult(NamedTuple):
20
19
  pvalue: float
21
20
 
22
21
 
23
- @dataclass(frozen=True)
24
- class KSOutput(OutputMetadata):
22
+ class KSOutput(MappingOutput[str, MetadataKSResult]):
25
23
  """
26
- Output class for results of ks_2samp featurewise comparisons of new metadata to reference metadata.
24
+ Output dictionary class for results of ks_2samp featurewise comparisons of new metadata to reference metadata.
27
25
 
28
26
  Attributes
29
27
  ----------
30
- mdc : dict[str, dict[str, float]]
31
- dict keyed by metadata feature names. Each value contains four floats, which are the KS statistic itself, its
32
- location within the range of the reference metadata, the shift of new metadata relative to reference, the
33
- p-value from the KS two-sample test.
34
-
28
+ key: str
29
+ Metadata feature names
30
+ value: NamedTuple[float, float, float, float]
31
+ Each value contains four floats, which are:
32
+ - statistic: the KS statistic itself
33
+ - statistic_location: its location within the range of the reference metadata
34
+ - shift_magnitude: the shift of new metadata relative to reference
35
+ - pvalue: the p-value from the KS two-sample test
35
36
  """
36
37
 
37
- mdc: dict[str, MetadataKSResult]
38
-
39
38
 
40
- @set_metadata()
39
+ @set_metadata
41
40
  def meta_distribution_compare(
42
41
  md0: Mapping[str, list[Any] | NDArray[Any]], md1: Mapping[str, list[Any] | NDArray[Any]]
43
42
  ) -> KSOutput:
dataeval/interop.py CHANGED
@@ -1,23 +1,31 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from types import ModuleType
4
+
5
+ from dataeval.logging import LogMessage
6
+
3
7
  __all__ = ["as_numpy", "to_numpy", "to_numpy_iter"]
4
8
 
9
+ import logging
5
10
  from importlib import import_module
6
11
  from typing import Any, Iterable, Iterator
7
12
 
8
13
  import numpy as np
9
14
  from numpy.typing import ArrayLike, NDArray
10
15
 
16
+ _logger = logging.getLogger(__name__)
17
+
11
18
  _MODULE_CACHE = {}
12
19
 
13
20
 
14
- def _try_import(module_name):
21
+ def _try_import(module_name) -> ModuleType | None:
15
22
  if module_name in _MODULE_CACHE:
16
23
  return _MODULE_CACHE[module_name]
17
24
 
18
25
  try:
19
26
  module = import_module(module_name)
20
27
  except ImportError: # pragma: no cover - covered by test_mindeps.py
28
+ _logger.log(logging.INFO, f"Unable to import {module_name}.")
21
29
  module = None
22
30
 
23
31
  _MODULE_CACHE[module_name] = module
@@ -40,14 +48,18 @@ def to_numpy(array: ArrayLike | None, copy: bool = True) -> NDArray[Any]:
40
48
  if array.__class__.__module__.startswith("tensorflow"):
41
49
  tf = _try_import("tensorflow")
42
50
  if tf and tf.is_tensor(array):
51
+ _logger.log(logging.INFO, "Converting Tensorflow array to NumPy array.")
43
52
  return array.numpy().copy() if copy else array.numpy() # type: ignore
44
53
 
45
54
  if array.__class__.__module__.startswith("torch"):
46
55
  torch = _try_import("torch")
47
56
  if torch and isinstance(array, torch.Tensor):
48
- return array.detach().cpu().numpy().copy() if copy else array.detach().cpu().numpy() # type: ignore
57
+ _logger.log(logging.INFO, "Converting PyTorch array to NumPy array.")
58
+ numpy = array.detach().cpu().numpy().copy() if copy else array.detach().cpu().numpy() # type: ignore
59
+ _logger.log(logging.DEBUG, LogMessage(lambda: f"{str(array)} -> {str(numpy)}"))
60
+ return numpy
49
61
 
50
- return np.array(array, copy=copy)
62
+ return np.array(array) if copy else np.asarray(array)
51
63
 
52
64
 
53
65
  def to_numpy_iter(iterable: Iterable[ArrayLike]) -> Iterator[NDArray[Any]]:
dataeval/logging.py ADDED
@@ -0,0 +1,16 @@
1
+ from typing import Callable
2
+
3
+
4
+ class LogMessage:
5
+ """
6
+ Deferred message callback for logging expensive messages.
7
+ """
8
+
9
+ def __init__(self, fn: Callable[..., str]):
10
+ self._fn = fn
11
+ self._str = None
12
+
13
+ def __str__(self) -> str:
14
+ if self._str is None:
15
+ self._str = self._fn()
16
+ return self._str
@@ -14,14 +14,14 @@ from sklearn.feature_selection import mutual_info_classif, mutual_info_regressio
14
14
 
15
15
  from dataeval.metrics.bias.metadata_preprocessing import MetadataOutput
16
16
  from dataeval.metrics.bias.metadata_utils import get_counts, heatmap
17
- from dataeval.output import OutputMetadata, set_metadata
17
+ from dataeval.output import Output, set_metadata
18
18
 
19
19
  with contextlib.suppress(ImportError):
20
20
  from matplotlib.figure import Figure
21
21
 
22
22
 
23
23
  @dataclass(frozen=True)
24
- class BalanceOutput(OutputMetadata):
24
+ class BalanceOutput(Output):
25
25
  """
26
26
  Output class for :func:`balance` bias metric
27
27
 
@@ -117,7 +117,7 @@ def _validate_num_neighbors(num_neighbors: int) -> int:
117
117
  return num_neighbors
118
118
 
119
119
 
120
- @set_metadata("dataeval.metrics")
120
+ @set_metadata
121
121
  def balance(
122
122
  metadata: MetadataOutput,
123
123
  num_neighbors: int = 5,
@@ -13,7 +13,7 @@ from scipy.spatial.distance import pdist, squareform
13
13
 
14
14
  from dataeval.interop import to_numpy
15
15
  from dataeval.metrics.bias.metadata_utils import coverage_plot
16
- from dataeval.output import OutputMetadata, set_metadata
16
+ from dataeval.output import Output, set_metadata
17
17
  from dataeval.utils.shared import flatten
18
18
 
19
19
  with contextlib.suppress(ImportError):
@@ -21,7 +21,7 @@ with contextlib.suppress(ImportError):
21
21
 
22
22
 
23
23
  @dataclass(frozen=True)
24
- class CoverageOutput(OutputMetadata):
24
+ class CoverageOutput(Output):
25
25
  """
26
26
  Output class for :func:`coverage` :term:`bias<Bias>` metric
27
27
 
@@ -67,7 +67,7 @@ class CoverageOutput(OutputMetadata):
67
67
  return fig
68
68
 
69
69
 
70
- @set_metadata()
70
+ @set_metadata
71
71
  def coverage(
72
72
  embeddings: ArrayLike,
73
73
  radius_type: Literal["adaptive", "naive"] = "adaptive",
@@ -12,7 +12,7 @@ from numpy.typing import ArrayLike, NDArray
12
12
 
13
13
  from dataeval.metrics.bias.metadata_preprocessing import MetadataOutput
14
14
  from dataeval.metrics.bias.metadata_utils import diversity_bar_plot, get_counts, heatmap
15
- from dataeval.output import OutputMetadata, set_metadata
15
+ from dataeval.output import Output, set_metadata
16
16
  from dataeval.utils.shared import get_method
17
17
 
18
18
  with contextlib.suppress(ImportError):
@@ -20,7 +20,7 @@ with contextlib.suppress(ImportError):
20
20
 
21
21
 
22
22
  @dataclass(frozen=True)
23
- class DiversityOutput(OutputMetadata):
23
+ class DiversityOutput(Output):
24
24
  """
25
25
  Output class for :func:`diversity` :term:`bias<Bias>` metric
26
26
 
@@ -163,7 +163,7 @@ def diversity_simpson(
163
163
  return ev_index
164
164
 
165
165
 
166
- @set_metadata()
166
+ @set_metadata
167
167
  def diversity(
168
168
  metadata: MetadataOutput,
169
169
  method: Literal["simpson", "shannon"] = "simpson",
@@ -11,7 +11,7 @@ from numpy.typing import ArrayLike, NDArray
11
11
  from scipy.stats import wasserstein_distance as wd
12
12
 
13
13
  from dataeval.interop import as_numpy, to_numpy
14
- from dataeval.output import OutputMetadata, set_metadata
14
+ from dataeval.output import Output, set_metadata
15
15
  from dataeval.utils.metadata import merge_metadata
16
16
 
17
17
  TNum = TypeVar("TNum", int, float)
@@ -20,7 +20,7 @@ CONTINUOUS_MIN_SAMPLE_SIZE = 20
20
20
 
21
21
 
22
22
  @dataclass(frozen=True)
23
- class MetadataOutput(OutputMetadata):
23
+ class MetadataOutput(Output):
24
24
  """
25
25
  Output class for :func:`metadata_binning` function
26
26
 
@@ -51,7 +51,7 @@ class MetadataOutput(OutputMetadata):
51
51
  total_num_factors: int
52
52
 
53
53
 
54
- @set_metadata()
54
+ @set_metadata
55
55
  def metadata_preprocessing(
56
56
  raw_metadata: Iterable[Mapping[str, Any]],
57
57
  class_labels: ArrayLike | str,
@@ -13,13 +13,13 @@ from scipy.stats.contingency import chi2_contingency, crosstab
13
13
 
14
14
  from dataeval.interop import as_numpy, to_numpy
15
15
  from dataeval.metrics.bias.metadata_preprocessing import MetadataOutput
16
- from dataeval.output import OutputMetadata, set_metadata
16
+ from dataeval.output import Output, set_metadata
17
17
 
18
18
  TData = TypeVar("TData", np.float64, NDArray[np.float64])
19
19
 
20
20
 
21
21
  @dataclass(frozen=True)
22
- class ParityOutput(Generic[TData], OutputMetadata):
22
+ class ParityOutput(Generic[TData], Output):
23
23
  """
24
24
  Output class for :func:`parity` and :func:`label_parity` :term:`bias<Bias>` metrics
25
25
 
@@ -116,7 +116,7 @@ def validate_dist(label_dist: NDArray[Any], label_name: str) -> None:
116
116
  )
117
117
 
118
118
 
119
- @set_metadata()
119
+ @set_metadata
120
120
  def label_parity(
121
121
  expected_labels: ArrayLike,
122
122
  observed_labels: ArrayLike,
@@ -204,7 +204,7 @@ def label_parity(
204
204
  return ParityOutput(cs, p, None)
205
205
 
206
206
 
207
- @set_metadata()
207
+ @set_metadata
208
208
  def parity(metadata: MetadataOutput) -> ParityOutput[NDArray[np.float64]]:
209
209
  """
210
210
  Calculate chi-square statistics to assess the linear relationship between multiple factors
@@ -20,12 +20,12 @@ from scipy.sparse import coo_matrix
20
20
  from scipy.stats import mode
21
21
 
22
22
  from dataeval.interop import as_numpy
23
- from dataeval.output import OutputMetadata, set_metadata
23
+ from dataeval.output import Output, set_metadata
24
24
  from dataeval.utils.shared import compute_neighbors, get_classes_counts, get_method, minimum_spanning_tree
25
25
 
26
26
 
27
27
  @dataclass(frozen=True)
28
- class BEROutput(OutputMetadata):
28
+ class BEROutput(Output):
29
29
  """
30
30
  Output class for :func:`ber` estimator metric
31
31
 
@@ -114,7 +114,7 @@ def knn_lowerbound(value: float, classes: int, k: int) -> float:
114
114
  return ((classes - 1) / classes) * (1 - np.sqrt(max(0, 1 - ((classes / (classes - 1)) * value))))
115
115
 
116
116
 
117
- @set_metadata()
117
+ @set_metadata
118
118
  def ber(images: ArrayLike, labels: ArrayLike, k: int = 1, method: Literal["KNN", "MST"] = "KNN") -> BEROutput:
119
119
  """
120
120
  An estimator for Multi-class :term:`Bayes error rate<Bayes Error Rate (BER)>` using FR or KNN test statistic basis
@@ -14,12 +14,12 @@ import numpy as np
14
14
  from numpy.typing import ArrayLike, NDArray
15
15
 
16
16
  from dataeval.interop import as_numpy
17
- from dataeval.output import OutputMetadata, set_metadata
17
+ from dataeval.output import Output, set_metadata
18
18
  from dataeval.utils.shared import compute_neighbors, get_method, minimum_spanning_tree
19
19
 
20
20
 
21
21
  @dataclass(frozen=True)
22
- class DivergenceOutput(OutputMetadata):
22
+ class DivergenceOutput(Output):
23
23
  """
24
24
  Output class for :func:`divergence` estimator metric
25
25
 
@@ -78,7 +78,7 @@ def divergence_fnn(data: NDArray[np.float64], labels: NDArray[np.int_]) -> int:
78
78
  return errors
79
79
 
80
80
 
81
- @set_metadata()
81
+ @set_metadata
82
82
  def divergence(data_a: ArrayLike, data_b: ArrayLike, method: Literal["FNN", "MST"] = "FNN") -> DivergenceOutput:
83
83
  """
84
84
  Calculates the :term`divergence` and any errors between the datasets
@@ -14,11 +14,11 @@ from numpy.typing import ArrayLike
14
14
  from sklearn.metrics import average_precision_score
15
15
 
16
16
  from dataeval.interop import as_numpy
17
- from dataeval.output import OutputMetadata, set_metadata
17
+ from dataeval.output import Output, set_metadata
18
18
 
19
19
 
20
20
  @dataclass(frozen=True)
21
- class UAPOutput(OutputMetadata):
21
+ class UAPOutput(Output):
22
22
  """
23
23
  Output class for :func:`uap` estimator metric
24
24
 
@@ -31,7 +31,7 @@ class UAPOutput(OutputMetadata):
31
31
  uap: float
32
32
 
33
33
 
34
- @set_metadata()
34
+ @set_metadata
35
35
  def uap(labels: ArrayLike, scores: ArrayLike) -> UAPOutput:
36
36
  """
37
37
  FR Test Statistic based estimate of the empirical mean precision for
@@ -15,7 +15,7 @@ import tqdm
15
15
  from numpy.typing import ArrayLike, NDArray
16
16
 
17
17
  from dataeval.interop import to_numpy_iter
18
- from dataeval.output import OutputMetadata
18
+ from dataeval.output import Output
19
19
  from dataeval.utils.image import normalize_image_shape, rescale
20
20
 
21
21
  DTYPE_REGEX = re.compile(r"NDArray\[np\.(.*?)\]")
@@ -65,7 +65,7 @@ class SourceIndex(NamedTuple):
65
65
 
66
66
 
67
67
  @dataclass(frozen=True)
68
- class BaseStatsOutput(OutputMetadata):
68
+ class BaseStatsOutput(Output):
69
69
  """
70
70
  Attributes
71
71
  ----------
@@ -96,7 +96,7 @@ def calculate_ratios(key: str, box_stats: BaseStatsOutput, img_stats: BaseStatsO
96
96
  return out_stats
97
97
 
98
98
 
99
- @set_metadata()
99
+ @set_metadata
100
100
  def boxratiostats(
101
101
  boxstats: TStatOutput,
102
102
  imgstats: TStatOutput,