dataeval 0.82.0__py3-none-any.whl → 0.83.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. dataeval/__init__.py +7 -2
  2. dataeval/config.py +78 -11
  3. dataeval/detectors/drift/_mmd.py +9 -9
  4. dataeval/detectors/drift/_torch.py +7 -7
  5. dataeval/detectors/drift/_uncertainty.py +4 -4
  6. dataeval/detectors/linters/duplicates.py +3 -3
  7. dataeval/detectors/linters/outliers.py +3 -3
  8. dataeval/detectors/ood/ae.py +5 -4
  9. dataeval/detectors/ood/base.py +2 -2
  10. dataeval/detectors/ood/mixin.py +1 -1
  11. dataeval/detectors/ood/vae.py +2 -1
  12. dataeval/metadata/__init__.py +2 -2
  13. dataeval/metadata/_distance.py +11 -44
  14. dataeval/metadata/_ood.py +152 -33
  15. dataeval/metrics/bias/_balance.py +9 -5
  16. dataeval/metrics/bias/_diversity.py +3 -0
  17. dataeval/metrics/bias/_parity.py +2 -0
  18. dataeval/metrics/estimators/_ber.py +2 -1
  19. dataeval/metrics/stats/_base.py +20 -21
  20. dataeval/metrics/stats/_boxratiostats.py +1 -1
  21. dataeval/metrics/stats/_dimensionstats.py +2 -2
  22. dataeval/metrics/stats/_hashstats.py +2 -2
  23. dataeval/metrics/stats/_imagestats.py +8 -8
  24. dataeval/metrics/stats/_pixelstats.py +2 -2
  25. dataeval/metrics/stats/_visualstats.py +2 -2
  26. dataeval/outputs/__init__.py +5 -0
  27. dataeval/outputs/_base.py +50 -21
  28. dataeval/outputs/_bias.py +1 -1
  29. dataeval/outputs/_linters.py +4 -2
  30. dataeval/outputs/_metadata.py +61 -0
  31. dataeval/outputs/_stats.py +12 -6
  32. dataeval/typing.py +40 -9
  33. dataeval/utils/_mst.py +1 -2
  34. dataeval/utils/data/_embeddings.py +23 -19
  35. dataeval/utils/data/_metadata.py +16 -7
  36. dataeval/utils/data/_selection.py +22 -15
  37. dataeval/utils/data/_split.py +3 -2
  38. dataeval/utils/data/datasets/_base.py +4 -2
  39. dataeval/utils/data/datasets/_cifar10.py +17 -9
  40. dataeval/utils/data/datasets/_milco.py +18 -12
  41. dataeval/utils/data/datasets/_mnist.py +24 -8
  42. dataeval/utils/data/datasets/_ships.py +18 -8
  43. dataeval/utils/data/datasets/_types.py +1 -5
  44. dataeval/utils/data/datasets/_voc.py +47 -24
  45. dataeval/utils/data/selections/__init__.py +2 -0
  46. dataeval/utils/data/selections/_classfilter.py +5 -3
  47. dataeval/utils/data/selections/_prioritize.py +296 -0
  48. dataeval/utils/data/selections/_shuffle.py +13 -4
  49. dataeval/utils/torch/_gmm.py +3 -2
  50. dataeval/utils/torch/_internal.py +5 -5
  51. dataeval/utils/torch/trainer.py +8 -8
  52. {dataeval-0.82.0.dist-info → dataeval-0.83.0.dist-info}/METADATA +4 -4
  53. dataeval-0.83.0.dist-info/RECORD +105 -0
  54. dataeval/detectors/ood/metadata_ood_mi.py +0 -93
  55. dataeval-0.82.0.dist-info/RECORD +0 -104
  56. {dataeval-0.82.0.dist-info → dataeval-0.83.0.dist-info}/LICENSE.txt +0 -0
  57. {dataeval-0.82.0.dist-info → dataeval-0.83.0.dist-info}/WHEEL +0 -0
dataeval/__init__.py CHANGED
@@ -8,7 +8,7 @@ shifts that impact performance of deployed models.
8
8
  from __future__ import annotations
9
9
 
10
10
  __all__ = ["config", "detectors", "log", "metrics", "typing", "utils", "workflows"]
11
- __version__ = "0.82.0"
11
+ __version__ = "0.83.0"
12
12
 
13
13
  import logging
14
14
 
@@ -34,7 +34,12 @@ def log(level: int = logging.DEBUG, handler: logging.Handler | None = None) -> N
34
34
  logger = logging.getLogger(__name__)
35
35
  if handler is None:
36
36
  handler = logging.StreamHandler() if handler is None else handler
37
- handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s"))
37
+ handler.setFormatter(
38
+ logging.Formatter(
39
+ "%(asctime)s %(levelname)-8s %(name)s.%(filename)s:%(lineno)s - %(funcName)10s() | %(message)s"
40
+ )
41
+ )
38
42
  logger.addHandler(handler)
39
43
  logger.setLevel(level)
44
+ logging.DEBUG
40
45
  logger.debug(f"Added logging handler {handler} to logger: {__name__}")
dataeval/config.py CHANGED
@@ -4,36 +4,71 @@ Global configuration settings for DataEval.
4
4
 
5
5
  from __future__ import annotations
6
6
 
7
- __all__ = ["get_device", "set_device", "get_max_processes", "set_max_processes"]
7
+ __all__ = ["get_device", "set_device", "get_max_processes", "set_max_processes", "DeviceLike"]
8
8
 
9
+ import sys
10
+ from typing import Union
11
+
12
+ if sys.version_info >= (3, 10):
13
+ from typing import TypeAlias
14
+ else:
15
+ from typing_extensions import TypeAlias
16
+
17
+ import numpy as np
9
18
  import torch
10
- from torch import device
11
19
 
12
- _device: device | None = None
20
+ ### GLOBALS ###
21
+
22
+ _device: torch.device | None = None
13
23
  _processes: int | None = None
24
+ _seed: int | None = None
25
+
26
+ ### CONSTS ###
14
27
 
28
+ EPSILON = 1e-10
15
29
 
16
- def set_device(device: str | device | int) -> None:
30
+ ### TYPES ###
31
+
32
+ DeviceLike: TypeAlias = Union[int, str, tuple[str, int], torch.device]
33
+ """
34
+ Type alias for types that are acceptable for specifying a torch.device.
35
+
36
+ See Also
37
+ --------
38
+ `torch.device <https://pytorch.org/docs/stable/tensor_attributes.html#torch.device>`_
39
+ """
40
+
41
+ ### FUNCS ###
42
+
43
+
44
+ def _todevice(device: DeviceLike) -> torch.device:
45
+ return torch.device(*device) if isinstance(device, tuple) else torch.device(device)
46
+
47
+
48
+ def set_device(device: DeviceLike) -> None:
17
49
  """
18
50
  Sets the default device to use when executing against a PyTorch backend.
19
51
 
20
52
  Parameters
21
53
  ----------
22
- device : str or int or `torch.device`
23
- The default device to use. See `torch.device <https://pytorch.org/docs/stable/tensor_attributes.html#torch.device>`_
24
- documentation for more information.
54
+ device : DeviceLike
55
+ The default device to use. See documentation for more information.
56
+
57
+ See Also
58
+ --------
59
+ `torch.device <https://pytorch.org/docs/stable/tensor_attributes.html#torch.device>`_
25
60
  """
26
61
  global _device
27
- _device = torch.device(device)
62
+ _device = _todevice(device)
28
63
 
29
64
 
30
- def get_device(override: str | device | int | None = None) -> torch.device:
65
+ def get_device(override: DeviceLike | None = None) -> torch.device:
31
66
  """
32
67
  Returns the PyTorch device to use.
33
68
 
34
69
  Parameters
35
70
  ----------
36
- override : str or int or `torch.device` or None, default None
71
+ override : DeviceLike or None, default None
37
72
  The user specified override if provided, otherwise returns the default device.
38
73
 
39
74
  Returns
@@ -44,7 +79,7 @@ def get_device(override: str | device | int | None = None) -> torch.device:
44
79
  global _device
45
80
  return torch.get_default_device() if _device is None else _device
46
81
  else:
47
- return torch.device(override)
82
+ return _todevice(override)
48
83
 
49
84
 
50
85
  def set_max_processes(processes: int | None) -> None:
@@ -75,3 +110,35 @@ def get_max_processes() -> int | None:
75
110
  """
76
111
  global _processes
77
112
  return _processes
113
+
114
+
115
+ def set_seed(seed: int | None, all_generators: bool = False) -> None:
116
+ """
117
+ Sets the seed for use by classes that allow for a random state or seed.
118
+
119
+ Parameters
120
+ ----------
121
+ seed : int or None
122
+ The seed to use.
123
+ all_generators : bool, default False
124
+ Whether to set the seed for all generators, including NumPy and PyTorch.
125
+ """
126
+ global _seed
127
+ _seed = seed
128
+
129
+ if all_generators:
130
+ np.random.seed(seed)
131
+ torch.manual_seed(seed)
132
+
133
+
134
+ def get_seed() -> int | None:
135
+ """
136
+ Returns the seed for random state or seed.
137
+
138
+ Returns
139
+ -------
140
+ int or None
141
+ The seed to use.
142
+ """
143
+ global _seed
144
+ return _seed
@@ -14,7 +14,7 @@ from typing import Callable
14
14
 
15
15
  import torch
16
16
 
17
- from dataeval.config import get_device
17
+ from dataeval.config import DeviceLike, get_device
18
18
  from dataeval.detectors.drift._base import BaseDrift, UpdateStrategy, preprocess_x, update_x_ref
19
19
  from dataeval.detectors.drift._torch import GaussianRBF, mmd2_from_kernel_matrix
20
20
  from dataeval.outputs import DriftMMDOutput
@@ -31,7 +31,7 @@ class DriftMMD(BaseDrift):
31
31
  ----------
32
32
  x_ref : ArrayLike
33
33
  Data used as reference distribution.
34
- p_val : float | None, default 0.05
34
+ p_val : float or None, default 0.05
35
35
  :term:`P-value` used for significance of the statistical test for each feature.
36
36
  If the FDR correction method is used, this corresponds to the acceptable
37
37
  q-value.
@@ -39,14 +39,14 @@ class DriftMMD(BaseDrift):
39
39
  Whether the given reference data ``x_ref`` has been preprocessed yet.
40
40
  If ``True``, only the test data ``x`` will be preprocessed at prediction time.
41
41
  If ``False``, the reference data will also be preprocessed.
42
- update_x_ref : UpdateStrategy | None, default None
42
+ update_x_ref : UpdateStrategy or None, default None
43
43
  Reference data can optionally be updated using an UpdateStrategy class. Update
44
44
  using the last n instances seen by the detector with LastSeenUpdateStrategy
45
45
  or via reservoir sampling with ReservoirSamplingUpdateStrategy.
46
- preprocess_fn : Callable | None, default None
46
+ preprocess_fn : Callable or None, default None
47
47
  Function to preprocess the data before computing the data drift metrics.
48
48
  Typically a :term:`dimensionality reduction<Dimensionality Reduction>` technique.
49
- sigma : ArrayLike | None, default None
49
+ sigma : ArrayLike or None, default None
50
50
  Optionally set the internal GaussianRBF kernel bandwidth. Can also pass multiple
51
51
  bandwidth values as an array. The kernel evaluation is then averaged over
52
52
  those bandwidths.
@@ -54,9 +54,9 @@ class DriftMMD(BaseDrift):
54
54
  Whether to already configure the kernel bandwidth from the reference data.
55
55
  n_permutations : int, default 100
56
56
  Number of permutations used in the permutation test.
57
- device : str | None, default None
58
- Device type used. The default None uses the GPU and falls back on CPU.
59
- Can be specified by passing either 'cuda', 'gpu' or 'cpu'.
57
+ device : DeviceLike or None, default None
58
+ The hardware device to use if specified, otherwise uses the DataEval
59
+ default or torch default.
60
60
 
61
61
  Example
62
62
  -------
@@ -84,7 +84,7 @@ class DriftMMD(BaseDrift):
84
84
  sigma: ArrayLike | None = None,
85
85
  configure_kernel_from_x_ref: bool = True,
86
86
  n_permutations: int = 100,
87
- device: str | torch.device | None = None,
87
+ device: DeviceLike | None = None,
88
88
  ) -> None:
89
89
  super().__init__(x_ref, p_val, x_ref_preprocessed, update_x_ref, preprocess_fn)
90
90
 
@@ -17,7 +17,7 @@ import torch
17
17
  import torch.nn as nn
18
18
  from numpy.typing import NDArray
19
19
 
20
- from dataeval.config import get_device
20
+ from dataeval.config import DeviceLike, get_device
21
21
  from dataeval.utils.torch._internal import predict_batch
22
22
 
23
23
 
@@ -59,7 +59,7 @@ def mmd2_from_kernel_matrix(
59
59
  def preprocess_drift(
60
60
  x: NDArray[Any],
61
61
  model: nn.Module,
62
- device: str | torch.device | None = None,
62
+ device: DeviceLike | None = None,
63
63
  preprocess_batch_fn: Callable | None = None,
64
64
  batch_size: int = int(1e10),
65
65
  dtype: type[np.generic] | torch.dtype = np.float32,
@@ -73,15 +73,15 @@ def preprocess_drift(
73
73
  Batch of instances.
74
74
  model : nn.Module
75
75
  Model used for preprocessing.
76
- device : torch.device | None, default None
77
- Device type used. The default None tries to use the GPU and falls back on CPU.
78
- Can be specified by passing either torch.device('cuda') or torch.device('cpu').
79
- preprocess_batch_fn : Callable | None, default None
76
+ device : DeviceLike or None, default None
77
+ The hardware device to use if specified, otherwise uses the DataEval
78
+ default or torch default.
79
+ preprocess_batch_fn : Callable or None, default None
80
80
  Optional batch preprocessing function. For example to convert a list of objects
81
81
  to a batch which can be processed by the PyTorch model.
82
82
  batch_size : int, default 1e10
83
83
  Batch size used during prediction.
84
- dtype : np.dtype | torch.dtype, default np.float32
84
+ dtype : np.dtype or torch.dtype, default np.float32
85
85
  Model output type, either a :term:`NumPy` or torch dtype, e.g. np.float32 or torch.float32.
86
86
 
87
87
  Returns
@@ -85,20 +85,20 @@ class DriftUncertainty:
85
85
  Whether the given reference data ``x_ref`` has been preprocessed yet.
86
86
  If ``True``, only the test data ``x`` will be preprocessed at prediction time.
87
87
  If ``False``, the reference data will also be preprocessed.
88
- update_x_ref : UpdateStrategy | None, default None
88
+ update_x_ref : UpdateStrategy or None, default None
89
89
  Reference data can optionally be updated using an UpdateStrategy class. Update
90
90
  using the last n instances seen by the detector with LastSeenUpdateStrategy
91
91
  or via reservoir sampling with ReservoirSamplingUpdateStrategy.
92
- preds_type : "probs" | "logits", default "probs"
92
+ preds_type : "probs" or "logits", default "probs"
93
93
  Type of prediction output by the model. Options are 'probs' (in [0,1]) or
94
94
  'logits' (in [-inf,inf]).
95
95
  batch_size : int, default 32
96
96
  Batch size used to evaluate model. Only relevant when backend has been
97
97
  specified for batch prediction.
98
- preprocess_batch_fn : Callable | None, default None
98
+ preprocess_batch_fn : Callable or None, default None
99
99
  Optional batch preprocessing function. For example to convert a list of
100
100
  objects to a batch which can be processed by the model.
101
- device : str | None, default None
101
+ device : DeviceLike or None, default None
102
102
  Device type used. The default None tries to use the GPU and falls back on
103
103
  CPU if needed. Can be specified by passing either 'cuda' or 'cpu'.
104
104
 
@@ -88,13 +88,13 @@ class Duplicates:
88
88
  """
89
89
 
90
90
  if isinstance(hashes, HashStatsOutput):
91
- return DuplicatesOutput(**self._get_duplicates(hashes.dict()))
91
+ return DuplicatesOutput(**self._get_duplicates(hashes.data()))
92
92
 
93
93
  if not isinstance(hashes, Sequence):
94
94
  raise TypeError("Invalid stats output type; only use output from hashstats.")
95
95
 
96
96
  combined, dataset_steps = combine_stats(hashes)
97
- duplicates = self._get_duplicates(combined.dict())
97
+ duplicates = self._get_duplicates(combined.data())
98
98
 
99
99
  # split up results from combined dataset into individual dataset buckets
100
100
  for dup_type, dup_list in duplicates.items():
@@ -136,5 +136,5 @@ class Duplicates:
136
136
  """ # noqa: E501
137
137
  images = Images(data) if isinstance(data, Dataset) else data
138
138
  self.stats = hashstats(images)
139
- duplicates = self._get_duplicates(self.stats.dict())
139
+ duplicates = self._get_duplicates(self.stats.data())
140
140
  return DuplicatesOutput(**duplicates)
@@ -169,7 +169,7 @@ class Outliers:
169
169
  {}
170
170
  """ # noqa: E501
171
171
  if isinstance(stats, (ImageStatsOutput, DimensionStatsOutput, PixelStatsOutput, VisualStatsOutput)):
172
- return OutliersOutput(self._get_outliers(stats.dict()))
172
+ return OutliersOutput(self._get_outliers(stats.data()))
173
173
 
174
174
  if not isinstance(stats, Sequence):
175
175
  raise TypeError(
@@ -189,7 +189,7 @@ class Outliers:
189
189
  output_list: list[dict[int, dict[str, float]]] = [{} for _ in stats]
190
190
  for _, indices in stats_map.items():
191
191
  substats, dataset_steps = combine_stats([stats[i] for i in indices])
192
- outliers = self._get_outliers(substats.dict())
192
+ outliers = self._get_outliers(substats.data())
193
193
  for idx, issue in outliers.items():
194
194
  k, v = get_dataset_step_from_idx(idx, dataset_steps)
195
195
  output_list[indices[k]][v] = issue
@@ -225,5 +225,5 @@ class Outliers:
225
225
  """
226
226
  images = Images(data) if isinstance(data, Dataset) else data
227
227
  self.stats = imagestats(images)
228
- outliers = self._get_outliers(self.stats.dict())
228
+ outliers = self._get_outliers(self.stats.data())
229
229
  return OutliersOutput(outliers)
@@ -18,6 +18,7 @@ import numpy as np
18
18
  import torch
19
19
  from numpy.typing import NDArray
20
20
 
21
+ from dataeval.config import DeviceLike
21
22
  from dataeval.detectors.ood.base import OODBase
22
23
  from dataeval.outputs import OODScoreOutput
23
24
  from dataeval.typing import ArrayLike
@@ -33,9 +34,9 @@ class OOD_AE(OODBase):
33
34
  model : torch.nn.Module
34
35
  An autoencoder model to use for encoding and reconstruction of images
35
36
  for detection of out-of-distribution samples.
36
- device : str or torch.Device or None, default None
37
- The device to use for the detector. None will default to the global
38
- configuration selection if set, otherwise "cuda" then "cpu" by availability.
37
+ device : DeviceLike or None, default None
38
+ The hardware device to use if specified, otherwise uses the DataEval
39
+ default or torch default.
39
40
 
40
41
  Example
41
42
  -------
@@ -57,7 +58,7 @@ class OOD_AE(OODBase):
57
58
  array([ True, True, False, True, True, True, True, True])
58
59
  """
59
60
 
60
- def __init__(self, model: torch.nn.Module, device: str | torch.device | None = None) -> None:
61
+ def __init__(self, model: torch.nn.Module, device: DeviceLike | None = None) -> None:
61
62
  super().__init__(model, device)
62
63
 
63
64
  def fit(
@@ -14,7 +14,7 @@ from typing import Callable, cast
14
14
 
15
15
  import torch
16
16
 
17
- from dataeval.config import get_device
17
+ from dataeval.config import DeviceLike, get_device
18
18
  from dataeval.detectors.ood.mixin import OODBaseMixin, OODFitMixin, OODGMMMixin
19
19
  from dataeval.typing import ArrayLike
20
20
  from dataeval.utils._array import to_numpy
@@ -23,7 +23,7 @@ from dataeval.utils.torch._internal import trainer
23
23
 
24
24
 
25
25
  class OODBase(OODBaseMixin[torch.nn.Module], OODFitMixin[Callable[..., torch.nn.Module], torch.optim.Optimizer]):
26
- def __init__(self, model: torch.nn.Module, device: str | torch.device | None = None) -> None:
26
+ def __init__(self, model: torch.nn.Module, device: DeviceLike | None = None) -> None:
27
27
  self.device: torch.device = get_device(device)
28
28
  super().__init__(model)
29
29
 
@@ -157,4 +157,4 @@ class OODBaseMixin(Generic[TModel], ABC):
157
157
  # compute outlier scores
158
158
  score = self.score(X, batch_size=batch_size)
159
159
  ood_pred = score.get(ood_type) > self._threshold_score(ood_type)
160
- return OODOutput(is_ood=ood_pred, **score.dict())
160
+ return OODOutput(is_ood=ood_pred, **score.data())
@@ -17,6 +17,7 @@ from typing import Callable
17
17
  import numpy as np
18
18
  import torch
19
19
 
20
+ from dataeval.config import DeviceLike
20
21
  from dataeval.detectors.ood.base import OODBase
21
22
  from dataeval.outputs import OODScoreOutput
22
23
  from dataeval.typing import ArrayLike
@@ -34,7 +35,7 @@ class OOD_VAE(OODBase):
34
35
  An Autoencoder model.
35
36
  """
36
37
 
37
- def __init__(self, model: torch.nn.Module, device: str | torch.device | None = None) -> None:
38
+ def __init__(self, model: torch.nn.Module, device: DeviceLike | None = None) -> None:
38
39
  super().__init__(model, device)
39
40
 
40
41
  def fit(
@@ -1,6 +1,6 @@
1
1
  """Explanatory functions using metadata and additional features such as ood or drift"""
2
2
 
3
- __all__ = ["most_deviated_factors", "metadata_distance"]
3
+ __all__ = ["find_ood_predictors", "metadata_distance", "find_most_deviated_factors"]
4
4
 
5
5
  from dataeval.metadata._distance import metadata_distance
6
- from dataeval.metadata._ood import most_deviated_factors
6
+ from dataeval.metadata._ood import find_most_deviated_factors, find_ood_predictors
@@ -10,7 +10,8 @@ from scipy.stats import iqr, ks_2samp
10
10
  from scipy.stats import wasserstein_distance as emd
11
11
 
12
12
  from dataeval.metadata._utils import _compare_keys, _validate_factors_and_data
13
- from dataeval.outputs._base import MappingOutput
13
+ from dataeval.outputs import MetadataDistanceOutput, MetadataDistanceValues
14
+ from dataeval.outputs._base import set_metadata
14
15
  from dataeval.typing import ArrayLike
15
16
  from dataeval.utils.data import Metadata
16
17
 
@@ -23,41 +24,6 @@ class KSType(NamedTuple):
23
24
  pvalue: float
24
25
 
25
26
 
26
- class MetadataKSResult(NamedTuple):
27
- """
28
- Attributes
29
- ----------
30
- statistic : float
31
- the KS statistic
32
- location : float
33
- The value at which the KS statistic has its maximum, measured in IQR-normalized units relative
34
- to the median of the reference distribution.
35
- dist : float
36
- The Earth Mover's Distance normalized by the interquartile range (IQR) of the reference
37
- pvalue : float
38
- The p-value from the KS two-sample test
39
- """
40
-
41
- statistic: float
42
- location: float
43
- dist: float
44
- pvalue: float
45
-
46
-
47
- class KSOutput(MappingOutput[str, MetadataKSResult]):
48
- """
49
- Output class for results of ks_2samp featurewise comparisons of new metadata to reference metadata.
50
-
51
-
52
- Attributes
53
- ----------
54
- key: str
55
- Metadata feature names
56
- value: :class:`MetadataKSResult`
57
- Output per feature name containing the statistic, statistic location, distance, and pvalue.
58
- """
59
-
60
-
61
27
  def _calculate_drift(x1: ArrayLike, x2: ArrayLike) -> float:
62
28
  """Calculates the shift magnitude between x1 and x2 scaled by x1"""
63
29
 
@@ -74,7 +40,8 @@ def _calculate_drift(x1: ArrayLike, x2: ArrayLike) -> float:
74
40
  return distance if xmin == xmax else distance / (xmax - xmin)
75
41
 
76
42
 
77
- def metadata_distance(metadata1: Metadata, metadata2: Metadata) -> KSOutput:
43
+ @set_metadata
44
+ def metadata_distance(metadata1: Metadata, metadata2: Metadata) -> MetadataDistanceOutput:
78
45
  """
79
46
  Measures the feature-wise distance between two continuous metadata distributions and
80
47
  computes a p-value to evaluate its significance.
@@ -90,8 +57,8 @@ def metadata_distance(metadata1: Metadata, metadata2: Metadata) -> KSOutput:
90
57
 
91
58
  Returns
92
59
  -------
93
- dict[str, KstestResult]
94
- A dictionary with keys corresponding to metadata feature names, and values that are KstestResult objects, as
60
+ MetadataDistanceOutput
61
+ A mapping with keys corresponding to metadata feature names, and values that are KstestResult objects, as
95
62
  defined by scipy.stats.ks_2samp.
96
63
 
97
64
  See Also
@@ -110,7 +77,7 @@ def metadata_distance(metadata1: Metadata, metadata2: Metadata) -> KSOutput:
110
77
  >>> list(output)
111
78
  ['time', 'altitude']
112
79
  >>> output["time"]
113
- MetadataKSResult(statistic=1.0, location=0.44354838709677413, dist=2.7, pvalue=0.0)
80
+ MetadataDistanceValues(statistic=1.0, location=0.44354838709677413, dist=2.7, pvalue=0.0)
114
81
  """
115
82
 
116
83
  _compare_keys(metadata1.continuous_factor_names, metadata2.continuous_factor_names)
@@ -134,7 +101,7 @@ def metadata_distance(metadata1: Metadata, metadata2: Metadata) -> KSOutput:
134
101
  )
135
102
 
136
103
  # Set default for statistic, location, and magnitude to zero and pvalue to one
137
- results: dict[str, MetadataKSResult] = {}
104
+ results: dict[str, MetadataDistanceValues] = {}
138
105
 
139
106
  # Per factor
140
107
  for i, fname in enumerate(fnames):
@@ -147,7 +114,7 @@ def metadata_distance(metadata1: Metadata, metadata2: Metadata) -> KSOutput:
147
114
 
148
115
  # Default case
149
116
  if xmin == xmax:
150
- results[fname] = MetadataKSResult(statistic=0.0, location=0.0, dist=0.0, pvalue=1.0)
117
+ results[fname] = MetadataDistanceValues(statistic=0.0, location=0.0, dist=0.0, pvalue=1.0)
151
118
  continue
152
119
 
153
120
  ks_result = cast(KSType, ks_2samp(fdata1, fdata2, method="asymp"))
@@ -157,11 +124,11 @@ def metadata_distance(metadata1: Metadata, metadata2: Metadata) -> KSOutput:
157
124
 
158
125
  drift = _calculate_drift(fdata1, fdata2)
159
126
 
160
- results[fname] = MetadataKSResult(
127
+ results[fname] = MetadataDistanceValues(
161
128
  statistic=ks_result.statistic,
162
129
  location=loc,
163
130
  dist=drift,
164
131
  pvalue=ks_result.pvalue,
165
132
  )
166
133
 
167
- return KSOutput(results)
134
+ return MetadataDistanceOutput(results)