dataeval 0.82.0__py3-none-any.whl → 0.82.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. dataeval/__init__.py +1 -1
  2. dataeval/config.py +68 -11
  3. dataeval/detectors/drift/_mmd.py +9 -9
  4. dataeval/detectors/drift/_torch.py +7 -7
  5. dataeval/detectors/drift/_uncertainty.py +4 -4
  6. dataeval/detectors/linters/duplicates.py +3 -3
  7. dataeval/detectors/linters/outliers.py +3 -3
  8. dataeval/detectors/ood/ae.py +5 -4
  9. dataeval/detectors/ood/base.py +2 -2
  10. dataeval/detectors/ood/metadata_ood_mi.py +4 -6
  11. dataeval/detectors/ood/mixin.py +1 -1
  12. dataeval/detectors/ood/vae.py +2 -1
  13. dataeval/metadata/_distance.py +11 -44
  14. dataeval/metadata/_ood.py +9 -7
  15. dataeval/metrics/bias/_balance.py +7 -3
  16. dataeval/metrics/bias/_diversity.py +3 -0
  17. dataeval/metrics/bias/_parity.py +2 -0
  18. dataeval/metrics/stats/_base.py +3 -3
  19. dataeval/metrics/stats/_boxratiostats.py +1 -1
  20. dataeval/metrics/stats/_imagestats.py +4 -4
  21. dataeval/outputs/__init__.py +4 -0
  22. dataeval/outputs/_base.py +50 -21
  23. dataeval/outputs/_bias.py +1 -1
  24. dataeval/outputs/_linters.py +4 -2
  25. dataeval/outputs/_metadata.py +54 -0
  26. dataeval/outputs/_stats.py +12 -6
  27. dataeval/utils/data/_embeddings.py +8 -9
  28. dataeval/utils/data/_metadata.py +16 -7
  29. dataeval/utils/data/_selection.py +4 -8
  30. dataeval/utils/data/_split.py +3 -2
  31. dataeval/utils/data/selections/_classfilter.py +5 -3
  32. dataeval/utils/torch/_internal.py +5 -5
  33. dataeval/utils/torch/trainer.py +8 -8
  34. {dataeval-0.82.0.dist-info → dataeval-0.82.1.dist-info}/METADATA +1 -1
  35. {dataeval-0.82.0.dist-info → dataeval-0.82.1.dist-info}/RECORD +37 -36
  36. {dataeval-0.82.0.dist-info → dataeval-0.82.1.dist-info}/LICENSE.txt +0 -0
  37. {dataeval-0.82.0.dist-info → dataeval-0.82.1.dist-info}/WHEEL +0 -0
dataeval/__init__.py CHANGED
@@ -8,7 +8,7 @@ shifts that impact performance of deployed models.
8
8
  from __future__ import annotations
9
9
 
10
10
  __all__ = ["config", "detectors", "log", "metrics", "typing", "utils", "workflows"]
11
- __version__ = "0.82.0"
11
+ __version__ = "0.82.1"
12
12
 
13
13
  import logging
14
14
 
dataeval/config.py CHANGED
@@ -4,36 +4,61 @@ Global configuration settings for DataEval.
4
4
 
5
5
  from __future__ import annotations
6
6
 
7
- __all__ = ["get_device", "set_device", "get_max_processes", "set_max_processes"]
7
+ __all__ = ["get_device", "set_device", "get_max_processes", "set_max_processes", "DeviceLike"]
8
8
 
9
+ import sys
10
+ from typing import Union
11
+
12
+ if sys.version_info >= (3, 10):
13
+ from typing import TypeAlias
14
+ else:
15
+ from typing_extensions import TypeAlias
16
+
17
+ import numpy as np
9
18
  import torch
10
- from torch import device
11
19
 
12
- _device: device | None = None
20
+ _device: torch.device | None = None
13
21
  _processes: int | None = None
22
+ _seed: int | None = None
23
+
24
+ DeviceLike: TypeAlias = Union[int, str, tuple[str, int], torch.device]
25
+ """
26
+ Type alias for types that are acceptable for specifying a torch.device.
27
+
28
+ See Also
29
+ --------
30
+ `torch.device <https://pytorch.org/docs/stable/tensor_attributes.html#torch.device>`_
31
+ """
32
+
14
33
 
34
+ def _todevice(device: DeviceLike) -> torch.device:
35
+ return torch.device(*device) if isinstance(device, tuple) else torch.device(device)
15
36
 
16
- def set_device(device: str | device | int) -> None:
37
+
38
+ def set_device(device: DeviceLike) -> None:
17
39
  """
18
40
  Sets the default device to use when executing against a PyTorch backend.
19
41
 
20
42
  Parameters
21
43
  ----------
22
- device : str or int or `torch.device`
23
- The default device to use. See `torch.device <https://pytorch.org/docs/stable/tensor_attributes.html#torch.device>`_
24
- documentation for more information.
44
+ device : DeviceLike
45
+ The default device to use. See documentation for more information.
46
+
47
+ See Also
48
+ --------
49
+ `torch.device <https://pytorch.org/docs/stable/tensor_attributes.html#torch.device>`_
25
50
  """
26
51
  global _device
27
- _device = torch.device(device)
52
+ _device = _todevice(device)
28
53
 
29
54
 
30
- def get_device(override: str | device | int | None = None) -> torch.device:
55
+ def get_device(override: DeviceLike | None = None) -> torch.device:
31
56
  """
32
57
  Returns the PyTorch device to use.
33
58
 
34
59
  Parameters
35
60
  ----------
36
- override : str or int or `torch.device` or None, default None
61
+ override : DeviceLike or None, default None
37
62
  The user specified override if provided, otherwise returns the default device.
38
63
 
39
64
  Returns
@@ -44,7 +69,7 @@ def get_device(override: str | device | int | None = None) -> torch.device:
44
69
  global _device
45
70
  return torch.get_default_device() if _device is None else _device
46
71
  else:
47
- return torch.device(override)
72
+ return _todevice(override)
48
73
 
49
74
 
50
75
  def set_max_processes(processes: int | None) -> None:
@@ -75,3 +100,35 @@ def get_max_processes() -> int | None:
75
100
  """
76
101
  global _processes
77
102
  return _processes
103
+
104
+
105
+ def set_seed(seed: int | None, all_generators: bool = False) -> None:
106
+ """
107
+ Sets the seed for use by classes that allow for a random state or seed.
108
+
109
+ Parameters
110
+ ----------
111
+ seed : int or None
112
+ The seed to use.
113
+ all_generators : bool, default False
114
+ Whether to set the seed for all generators, including NumPy and PyTorch.
115
+ """
116
+ global _seed
117
+ _seed = seed
118
+
119
+ if all_generators:
120
+ np.random.seed(seed)
121
+ torch.manual_seed(seed)
122
+
123
+
124
+ def get_seed() -> int | None:
125
+ """
126
+ Returns the seed for random state or seed.
127
+
128
+ Returns
129
+ -------
130
+ int or None
131
+ The seed to use.
132
+ """
133
+ global _seed
134
+ return _seed
@@ -14,7 +14,7 @@ from typing import Callable
14
14
 
15
15
  import torch
16
16
 
17
- from dataeval.config import get_device
17
+ from dataeval.config import DeviceLike, get_device
18
18
  from dataeval.detectors.drift._base import BaseDrift, UpdateStrategy, preprocess_x, update_x_ref
19
19
  from dataeval.detectors.drift._torch import GaussianRBF, mmd2_from_kernel_matrix
20
20
  from dataeval.outputs import DriftMMDOutput
@@ -31,7 +31,7 @@ class DriftMMD(BaseDrift):
31
31
  ----------
32
32
  x_ref : ArrayLike
33
33
  Data used as reference distribution.
34
- p_val : float | None, default 0.05
34
+ p_val : float or None, default 0.05
35
35
  :term:`P-value` used for significance of the statistical test for each feature.
36
36
  If the FDR correction method is used, this corresponds to the acceptable
37
37
  q-value.
@@ -39,14 +39,14 @@ class DriftMMD(BaseDrift):
39
39
  Whether the given reference data ``x_ref`` has been preprocessed yet.
40
40
  If ``True``, only the test data ``x`` will be preprocessed at prediction time.
41
41
  If ``False``, the reference data will also be preprocessed.
42
- update_x_ref : UpdateStrategy | None, default None
42
+ update_x_ref : UpdateStrategy or None, default None
43
43
  Reference data can optionally be updated using an UpdateStrategy class. Update
44
44
  using the last n instances seen by the detector with LastSeenUpdateStrategy
45
45
  or via reservoir sampling with ReservoirSamplingUpdateStrategy.
46
- preprocess_fn : Callable | None, default None
46
+ preprocess_fn : Callable or None, default None
47
47
  Function to preprocess the data before computing the data drift metrics.
48
48
  Typically a :term:`dimensionality reduction<Dimensionality Reduction>` technique.
49
- sigma : ArrayLike | None, default None
49
+ sigma : ArrayLike or None, default None
50
50
  Optionally set the internal GaussianRBF kernel bandwidth. Can also pass multiple
51
51
  bandwidth values as an array. The kernel evaluation is then averaged over
52
52
  those bandwidths.
@@ -54,9 +54,9 @@ class DriftMMD(BaseDrift):
54
54
  Whether to already configure the kernel bandwidth from the reference data.
55
55
  n_permutations : int, default 100
56
56
  Number of permutations used in the permutation test.
57
- device : str | None, default None
58
- Device type used. The default None uses the GPU and falls back on CPU.
59
- Can be specified by passing either 'cuda', 'gpu' or 'cpu'.
57
+ device : DeviceLike or None, default None
58
+ The hardware device to use if specified, otherwise uses the DataEval
59
+ default or torch default.
60
60
 
61
61
  Example
62
62
  -------
@@ -84,7 +84,7 @@ class DriftMMD(BaseDrift):
84
84
  sigma: ArrayLike | None = None,
85
85
  configure_kernel_from_x_ref: bool = True,
86
86
  n_permutations: int = 100,
87
- device: str | torch.device | None = None,
87
+ device: DeviceLike | None = None,
88
88
  ) -> None:
89
89
  super().__init__(x_ref, p_val, x_ref_preprocessed, update_x_ref, preprocess_fn)
90
90
 
@@ -17,7 +17,7 @@ import torch
17
17
  import torch.nn as nn
18
18
  from numpy.typing import NDArray
19
19
 
20
- from dataeval.config import get_device
20
+ from dataeval.config import DeviceLike, get_device
21
21
  from dataeval.utils.torch._internal import predict_batch
22
22
 
23
23
 
@@ -59,7 +59,7 @@ def mmd2_from_kernel_matrix(
59
59
  def preprocess_drift(
60
60
  x: NDArray[Any],
61
61
  model: nn.Module,
62
- device: str | torch.device | None = None,
62
+ device: DeviceLike | None = None,
63
63
  preprocess_batch_fn: Callable | None = None,
64
64
  batch_size: int = int(1e10),
65
65
  dtype: type[np.generic] | torch.dtype = np.float32,
@@ -73,15 +73,15 @@ def preprocess_drift(
73
73
  Batch of instances.
74
74
  model : nn.Module
75
75
  Model used for preprocessing.
76
- device : torch.device | None, default None
77
- Device type used. The default None tries to use the GPU and falls back on CPU.
78
- Can be specified by passing either torch.device('cuda') or torch.device('cpu').
79
- preprocess_batch_fn : Callable | None, default None
76
+ device : DeviceLike or None, default None
77
+ The hardware device to use if specified, otherwise uses the DataEval
78
+ default or torch default.
79
+ preprocess_batch_fn : Callable or None, default None
80
80
  Optional batch preprocessing function. For example to convert a list of objects
81
81
  to a batch which can be processed by the PyTorch model.
82
82
  batch_size : int, default 1e10
83
83
  Batch size used during prediction.
84
- dtype : np.dtype | torch.dtype, default np.float32
84
+ dtype : np.dtype or torch.dtype, default np.float32
85
85
  Model output type, either a :term:`NumPy` or torch dtype, e.g. np.float32 or torch.float32.
86
86
 
87
87
  Returns
@@ -85,20 +85,20 @@ class DriftUncertainty:
85
85
  Whether the given reference data ``x_ref`` has been preprocessed yet.
86
86
  If ``True``, only the test data ``x`` will be preprocessed at prediction time.
87
87
  If ``False``, the reference data will also be preprocessed.
88
- update_x_ref : UpdateStrategy | None, default None
88
+ update_x_ref : UpdateStrategy or None, default None
89
89
  Reference data can optionally be updated using an UpdateStrategy class. Update
90
90
  using the last n instances seen by the detector with LastSeenUpdateStrategy
91
91
  or via reservoir sampling with ReservoirSamplingUpdateStrategy.
92
- preds_type : "probs" | "logits", default "probs"
92
+ preds_type : "probs" or "logits", default "probs"
93
93
  Type of prediction output by the model. Options are 'probs' (in [0,1]) or
94
94
  'logits' (in [-inf,inf]).
95
95
  batch_size : int, default 32
96
96
  Batch size used to evaluate model. Only relevant when backend has been
97
97
  specified for batch prediction.
98
- preprocess_batch_fn : Callable | None, default None
98
+ preprocess_batch_fn : Callable or None, default None
99
99
  Optional batch preprocessing function. For example to convert a list of
100
100
  objects to a batch which can be processed by the model.
101
- device : str | None, default None
101
+ device : DeviceLike or None, default None
102
102
  Device type used. The default None tries to use the GPU and falls back on
103
103
  CPU if needed. Can be specified by passing either 'cuda' or 'cpu'.
104
104
 
@@ -88,13 +88,13 @@ class Duplicates:
88
88
  """
89
89
 
90
90
  if isinstance(hashes, HashStatsOutput):
91
- return DuplicatesOutput(**self._get_duplicates(hashes.dict()))
91
+ return DuplicatesOutput(**self._get_duplicates(hashes.data()))
92
92
 
93
93
  if not isinstance(hashes, Sequence):
94
94
  raise TypeError("Invalid stats output type; only use output from hashstats.")
95
95
 
96
96
  combined, dataset_steps = combine_stats(hashes)
97
- duplicates = self._get_duplicates(combined.dict())
97
+ duplicates = self._get_duplicates(combined.data())
98
98
 
99
99
  # split up results from combined dataset into individual dataset buckets
100
100
  for dup_type, dup_list in duplicates.items():
@@ -136,5 +136,5 @@ class Duplicates:
136
136
  """ # noqa: E501
137
137
  images = Images(data) if isinstance(data, Dataset) else data
138
138
  self.stats = hashstats(images)
139
- duplicates = self._get_duplicates(self.stats.dict())
139
+ duplicates = self._get_duplicates(self.stats.data())
140
140
  return DuplicatesOutput(**duplicates)
@@ -169,7 +169,7 @@ class Outliers:
169
169
  {}
170
170
  """ # noqa: E501
171
171
  if isinstance(stats, (ImageStatsOutput, DimensionStatsOutput, PixelStatsOutput, VisualStatsOutput)):
172
- return OutliersOutput(self._get_outliers(stats.dict()))
172
+ return OutliersOutput(self._get_outliers(stats.data()))
173
173
 
174
174
  if not isinstance(stats, Sequence):
175
175
  raise TypeError(
@@ -189,7 +189,7 @@ class Outliers:
189
189
  output_list: list[dict[int, dict[str, float]]] = [{} for _ in stats]
190
190
  for _, indices in stats_map.items():
191
191
  substats, dataset_steps = combine_stats([stats[i] for i in indices])
192
- outliers = self._get_outliers(substats.dict())
192
+ outliers = self._get_outliers(substats.data())
193
193
  for idx, issue in outliers.items():
194
194
  k, v = get_dataset_step_from_idx(idx, dataset_steps)
195
195
  output_list[indices[k]][v] = issue
@@ -225,5 +225,5 @@ class Outliers:
225
225
  """
226
226
  images = Images(data) if isinstance(data, Dataset) else data
227
227
  self.stats = imagestats(images)
228
- outliers = self._get_outliers(self.stats.dict())
228
+ outliers = self._get_outliers(self.stats.data())
229
229
  return OutliersOutput(outliers)
@@ -18,6 +18,7 @@ import numpy as np
18
18
  import torch
19
19
  from numpy.typing import NDArray
20
20
 
21
+ from dataeval.config import DeviceLike
21
22
  from dataeval.detectors.ood.base import OODBase
22
23
  from dataeval.outputs import OODScoreOutput
23
24
  from dataeval.typing import ArrayLike
@@ -33,9 +34,9 @@ class OOD_AE(OODBase):
33
34
  model : torch.nn.Module
34
35
  An autoencoder model to use for encoding and reconstruction of images
35
36
  for detection of out-of-distribution samples.
36
- device : str or torch.Device or None, default None
37
- The device to use for the detector. None will default to the global
38
- configuration selection if set, otherwise "cuda" then "cpu" by availability.
37
+ device : DeviceLike or None, default None
38
+ The hardware device to use if specified, otherwise uses the DataEval
39
+ default or torch default.
39
40
 
40
41
  Example
41
42
  -------
@@ -57,7 +58,7 @@ class OOD_AE(OODBase):
57
58
  array([ True, True, False, True, True, True, True, True])
58
59
  """
59
60
 
60
- def __init__(self, model: torch.nn.Module, device: str | torch.device | None = None) -> None:
61
+ def __init__(self, model: torch.nn.Module, device: DeviceLike | None = None) -> None:
61
62
  super().__init__(model, device)
62
63
 
63
64
  def fit(
@@ -14,7 +14,7 @@ from typing import Callable, cast
14
14
 
15
15
  import torch
16
16
 
17
- from dataeval.config import get_device
17
+ from dataeval.config import DeviceLike, get_device
18
18
  from dataeval.detectors.ood.mixin import OODBaseMixin, OODFitMixin, OODGMMMixin
19
19
  from dataeval.typing import ArrayLike
20
20
  from dataeval.utils._array import to_numpy
@@ -23,7 +23,7 @@ from dataeval.utils.torch._internal import trainer
23
23
 
24
24
 
25
25
  class OODBase(OODBaseMixin[torch.nn.Module], OODFitMixin[Callable[..., torch.nn.Module], torch.optim.Optimizer]):
26
- def __init__(self, model: torch.nn.Module, device: str | torch.device | None = None) -> None:
26
+ def __init__(self, model: torch.nn.Module, device: DeviceLike | None = None) -> None:
27
27
  self.device: torch.device = get_device(device)
28
28
  super().__init__(model)
29
29
 
@@ -10,6 +10,8 @@ import numpy as np
10
10
  from numpy.typing import NDArray
11
11
  from sklearn.feature_selection import mutual_info_classif
12
12
 
13
+ from dataeval.config import get_seed
14
+
13
15
  # NATS2BITS is the reciprocal of natural log of 2. If you have an information/entropy-type quantity measured in nats,
14
16
  # which is what many library functions return, multiply it by NATS2BITS to get it in bits.
15
17
  NATS2BITS = 1.442695
@@ -19,7 +21,6 @@ def get_metadata_ood_mi(
19
21
  metadata: dict[str, list[Any] | NDArray[Any]],
20
22
  is_ood: NDArray[np.bool_],
21
23
  discrete_features: str | bool | NDArray[np.bool_] = False,
22
- random_state: int | None = None,
23
24
  ) -> dict[str, float]:
24
25
  """Computes mutual information between a set of metadata features and an out-of-distribution flag.
25
26
 
@@ -39,9 +40,6 @@ def get_metadata_ood_mi(
39
40
  A boolean array, with one value per example, that indicates which examples are OOD.
40
41
  discrete_features : str | bool | NDArray[np.bool_]
41
42
  Either a boolean array or a single boolean value, indicate which features take on discrete values.
42
- random_state : int, optional - default None
43
- Determines random number generation for small noise added to continuous variables. Set to a value for
44
- reproducible results.
45
43
 
46
44
  Returns
47
45
  -------
@@ -55,7 +53,7 @@ def get_metadata_ood_mi(
55
53
 
56
54
  >>> metadata = {"time": np.linspace(0, 10, 100), "altitude": np.linspace(0, 16, 100) ** 2}
57
55
  >>> is_ood = metadata["altitude"] > 100
58
- >>> get_metadata_ood_mi(metadata, is_ood, discrete_features=False, random_state=0)
56
+ >>> get_metadata_ood_mi(metadata, is_ood, discrete_features=False)
59
57
  {'time': 0.9359596758173668, 'altitude': 0.9407686591507002}
60
58
  """
61
59
  numerical_keys = [k for k, v in metadata.items() if all(isinstance(vi, numbers.Number) for vi in v)]
@@ -84,7 +82,7 @@ def get_metadata_ood_mi(
84
82
  Xscl,
85
83
  is_ood,
86
84
  discrete_features=discrete_features, # type: ignore
87
- random_state=random_state,
85
+ random_state=get_seed(),
88
86
  )
89
87
  * NATS2BITS
90
88
  )
@@ -157,4 +157,4 @@ class OODBaseMixin(Generic[TModel], ABC):
157
157
  # compute outlier scores
158
158
  score = self.score(X, batch_size=batch_size)
159
159
  ood_pred = score.get(ood_type) > self._threshold_score(ood_type)
160
- return OODOutput(is_ood=ood_pred, **score.dict())
160
+ return OODOutput(is_ood=ood_pred, **score.data())
@@ -17,6 +17,7 @@ from typing import Callable
17
17
  import numpy as np
18
18
  import torch
19
19
 
20
+ from dataeval.config import DeviceLike
20
21
  from dataeval.detectors.ood.base import OODBase
21
22
  from dataeval.outputs import OODScoreOutput
22
23
  from dataeval.typing import ArrayLike
@@ -34,7 +35,7 @@ class OOD_VAE(OODBase):
34
35
  An Autoencoder model.
35
36
  """
36
37
 
37
- def __init__(self, model: torch.nn.Module, device: str | torch.device | None = None) -> None:
38
+ def __init__(self, model: torch.nn.Module, device: DeviceLike | None = None) -> None:
38
39
  super().__init__(model, device)
39
40
 
40
41
  def fit(
@@ -10,7 +10,8 @@ from scipy.stats import iqr, ks_2samp
10
10
  from scipy.stats import wasserstein_distance as emd
11
11
 
12
12
  from dataeval.metadata._utils import _compare_keys, _validate_factors_and_data
13
- from dataeval.outputs._base import MappingOutput
13
+ from dataeval.outputs import MetadataDistanceOutput, MetadataDistanceValues
14
+ from dataeval.outputs._base import set_metadata
14
15
  from dataeval.typing import ArrayLike
15
16
  from dataeval.utils.data import Metadata
16
17
 
@@ -23,41 +24,6 @@ class KSType(NamedTuple):
23
24
  pvalue: float
24
25
 
25
26
 
26
- class MetadataKSResult(NamedTuple):
27
- """
28
- Attributes
29
- ----------
30
- statistic : float
31
- the KS statistic
32
- location : float
33
- The value at which the KS statistic has its maximum, measured in IQR-normalized units relative
34
- to the median of the reference distribution.
35
- dist : float
36
- The Earth Mover's Distance normalized by the interquartile range (IQR) of the reference
37
- pvalue : float
38
- The p-value from the KS two-sample test
39
- """
40
-
41
- statistic: float
42
- location: float
43
- dist: float
44
- pvalue: float
45
-
46
-
47
- class KSOutput(MappingOutput[str, MetadataKSResult]):
48
- """
49
- Output class for results of ks_2samp featurewise comparisons of new metadata to reference metadata.
50
-
51
-
52
- Attributes
53
- ----------
54
- key: str
55
- Metadata feature names
56
- value: :class:`MetadataKSResult`
57
- Output per feature name containing the statistic, statistic location, distance, and pvalue.
58
- """
59
-
60
-
61
27
  def _calculate_drift(x1: ArrayLike, x2: ArrayLike) -> float:
62
28
  """Calculates the shift magnitude between x1 and x2 scaled by x1"""
63
29
 
@@ -74,7 +40,8 @@ def _calculate_drift(x1: ArrayLike, x2: ArrayLike) -> float:
74
40
  return distance if xmin == xmax else distance / (xmax - xmin)
75
41
 
76
42
 
77
- def metadata_distance(metadata1: Metadata, metadata2: Metadata) -> KSOutput:
43
+ @set_metadata
44
+ def metadata_distance(metadata1: Metadata, metadata2: Metadata) -> MetadataDistanceOutput:
78
45
  """
79
46
  Measures the feature-wise distance between two continuous metadata distributions and
80
47
  computes a p-value to evaluate its significance.
@@ -90,8 +57,8 @@ def metadata_distance(metadata1: Metadata, metadata2: Metadata) -> KSOutput:
90
57
 
91
58
  Returns
92
59
  -------
93
- dict[str, KstestResult]
94
- A dictionary with keys corresponding to metadata feature names, and values that are KstestResult objects, as
60
+ MetadataDistanceOutput
61
+ A mapping with keys corresponding to metadata feature names, and values that are KstestResult objects, as
95
62
  defined by scipy.stats.ks_2samp.
96
63
 
97
64
  See Also
@@ -110,7 +77,7 @@ def metadata_distance(metadata1: Metadata, metadata2: Metadata) -> KSOutput:
110
77
  >>> list(output)
111
78
  ['time', 'altitude']
112
79
  >>> output["time"]
113
- MetadataKSResult(statistic=1.0, location=0.44354838709677413, dist=2.7, pvalue=0.0)
80
+ MetadataDistanceValues(statistic=1.0, location=0.44354838709677413, dist=2.7, pvalue=0.0)
114
81
  """
115
82
 
116
83
  _compare_keys(metadata1.continuous_factor_names, metadata2.continuous_factor_names)
@@ -134,7 +101,7 @@ def metadata_distance(metadata1: Metadata, metadata2: Metadata) -> KSOutput:
134
101
  )
135
102
 
136
103
  # Set default for statistic, location, and magnitude to zero and pvalue to one
137
- results: dict[str, MetadataKSResult] = {}
104
+ results: dict[str, MetadataDistanceValues] = {}
138
105
 
139
106
  # Per factor
140
107
  for i, fname in enumerate(fnames):
@@ -147,7 +114,7 @@ def metadata_distance(metadata1: Metadata, metadata2: Metadata) -> KSOutput:
147
114
 
148
115
  # Default case
149
116
  if xmin == xmax:
150
- results[fname] = MetadataKSResult(statistic=0.0, location=0.0, dist=0.0, pvalue=1.0)
117
+ results[fname] = MetadataDistanceValues(statistic=0.0, location=0.0, dist=0.0, pvalue=1.0)
151
118
  continue
152
119
 
153
120
  ks_result = cast(KSType, ks_2samp(fdata1, fdata2, method="asymp"))
@@ -157,11 +124,11 @@ def metadata_distance(metadata1: Metadata, metadata2: Metadata) -> KSOutput:
157
124
 
158
125
  drift = _calculate_drift(fdata1, fdata2)
159
126
 
160
- results[fname] = MetadataKSResult(
127
+ results[fname] = MetadataDistanceValues(
161
128
  statistic=ks_result.statistic,
162
129
  location=loc,
163
130
  dist=drift,
164
131
  pvalue=ks_result.pvalue,
165
132
  )
166
133
 
167
- return KSOutput(results)
134
+ return MetadataDistanceOutput(results)
dataeval/metadata/_ood.py CHANGED
@@ -8,7 +8,8 @@ import numpy as np
8
8
  from numpy.typing import NDArray
9
9
 
10
10
  from dataeval.metadata._utils import _compare_keys, _validate_factors_and_data
11
- from dataeval.outputs import OODOutput
11
+ from dataeval.outputs import MostDeviatedFactorsOutput, OODOutput
12
+ from dataeval.outputs._base import set_metadata
12
13
  from dataeval.utils.data import Metadata
13
14
 
14
15
 
@@ -119,11 +120,12 @@ def _calc_median_deviations(reference: NDArray, test: NDArray) -> NDArray:
119
120
  return np.abs(np.where(test_dev >= 0, test_dev / pscale, test_dev / nscale)) # (S_t, F)
120
121
 
121
122
 
123
+ @set_metadata
122
124
  def most_deviated_factors(
123
125
  metadata_1: Metadata,
124
126
  metadata_2: Metadata,
125
127
  ood: OODOutput,
126
- ) -> list[tuple[str, float]]:
128
+ ) -> MostDeviatedFactorsOutput:
127
129
  """
128
130
  Determines greatest deviation in metadata features per out of distribution sample in metadata_2.
129
131
 
@@ -159,20 +161,20 @@ def most_deviated_factors(
159
161
 
160
162
  >>> is_ood = OODOutput(np.array([True, True, True]), np.array([]), np.array([]))
161
163
  >>> most_deviated_factors(metadata1, metadata2, is_ood)
162
- [('time', 2.0), ('time', 2.592), ('time', 3.51)]
164
+ MostDeviatedFactorsOutput([('time', 2.0), ('time', 2.592), ('time', 3.51)])
163
165
 
164
166
  If there are no out-of-distribution samples, a list is returned
165
167
 
166
168
  >>> is_ood = OODOutput(np.array([False, False, False]), np.array([]), np.array([]))
167
169
  >>> most_deviated_factors(metadata1, metadata2, is_ood)
168
- []
170
+ MostDeviatedFactorsOutput([])
169
171
  """
170
172
 
171
173
  ood_mask: NDArray[np.bool] = ood.is_ood
172
174
 
173
175
  # No metadata correlated with out of distribution data
174
176
  if not any(ood_mask):
175
- return []
177
+ return MostDeviatedFactorsOutput([])
176
178
 
177
179
  # Combines reference and test factor names and data if exists and match exactly
178
180
  # shape -> (samples, factors)
@@ -190,7 +192,7 @@ def most_deviated_factors(
190
192
  f"At least 3 reference metadata samples are needed, got {len(metadata_ref)}",
191
193
  UserWarning,
192
194
  )
193
- return []
195
+ return MostDeviatedFactorsOutput([])
194
196
 
195
197
  if len(metadata_tst) != len(ood_mask):
196
198
  raise ValueError(
@@ -214,4 +216,4 @@ def most_deviated_factors(
214
216
 
215
217
  # List of tuples matching the factor name with its deviation
216
218
 
217
- return [(factor, dev) for factor, dev in zip(most_ood_factors, deviation)]
219
+ return MostDeviatedFactorsOutput([(factor, dev) for factor, dev in zip(most_ood_factors, deviation)])
@@ -8,6 +8,7 @@ import numpy as np
8
8
  import scipy as sp
9
9
  from sklearn.feature_selection import mutual_info_classif, mutual_info_regression
10
10
 
11
+ from dataeval.config import get_seed
11
12
  from dataeval.outputs import BalanceOutput
12
13
  from dataeval.outputs._base import set_metadata
13
14
  from dataeval.utils._bin import get_counts
@@ -91,6 +92,9 @@ def balance(
91
92
  sklearn.feature_selection.mutual_info_regression
92
93
  sklearn.metrics.mutual_info_score
93
94
  """
95
+ if not metadata.discrete_factor_names and not metadata.continuous_factor_names:
96
+ raise ValueError("No factors found in provided metadata.")
97
+
94
98
  num_neighbors = _validate_num_neighbors(num_neighbors)
95
99
 
96
100
  num_factors = metadata.total_num_factors
@@ -110,7 +114,7 @@ def balance(
110
114
  data[:, idx],
111
115
  discrete_features=is_discrete, # type: ignore
112
116
  n_neighbors=num_neighbors,
113
- random_state=0,
117
+ random_state=get_seed(),
114
118
  )
115
119
  else:
116
120
  mi[idx, :] = mutual_info_classif(
@@ -118,7 +122,7 @@ def balance(
118
122
  data[:, idx],
119
123
  discrete_features=is_discrete, # type: ignore
120
124
  n_neighbors=num_neighbors,
121
- random_state=0,
125
+ random_state=get_seed(),
122
126
  )
123
127
 
124
128
  # Normalization via entropy
@@ -147,7 +151,7 @@ def balance(
147
151
  tgt_bin[:, idx],
148
152
  discrete_features=is_discrete, # type: ignore
149
153
  n_neighbors=num_neighbors,
150
- random_state=0,
154
+ random_state=get_seed(),
151
155
  )
152
156
 
153
157
  # Classwise normalization via entropy
@@ -158,6 +158,9 @@ def diversity(
158
158
  --------
159
159
  scipy.stats.entropy
160
160
  """
161
+ if not metadata.discrete_factor_names and not metadata.continuous_factor_names:
162
+ raise ValueError("No factors found in provided metadata.")
163
+
161
164
  diversity_fn = get_method(_DIVERSITY_FN_MAP, method)
162
165
  discretized_data = np.hstack((metadata.class_labels[:, np.newaxis], metadata.discrete_data))
163
166
  cnts = get_counts(discretized_data)
@@ -241,6 +241,8 @@ def parity(metadata: Metadata) -> ParityOutput:
241
241
  >>> parity(metadata)
242
242
  ParityOutput(score=array([7.357, 5.467, 0.515]), p_value=array([0.289, 0.243, 0.773]), factor_names=['age', 'income', 'gender'], insufficient_data={'age': {3: {'artist': 4}, 4: {'artist': 4, 'teacher': 3}}, 'income': {1: {'artist': 3}}})
243
243
  """ # noqa: E501
244
+ if not metadata.discrete_factor_names and not metadata.continuous_factor_names:
245
+ raise ValueError("No factors found in provided metadata.")
244
246
 
245
247
  chi_scores = np.zeros(metadata.discrete_data.shape[1])
246
248
  p_values = np.zeros_like(chi_scores)