dataeval 0.76.1__py3-none-any.whl → 0.82.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. dataeval/__init__.py +3 -3
  2. dataeval/config.py +77 -0
  3. dataeval/detectors/__init__.py +1 -1
  4. dataeval/detectors/drift/__init__.py +6 -6
  5. dataeval/detectors/drift/{base.py → _base.py} +40 -85
  6. dataeval/detectors/drift/{cvm.py → _cvm.py} +21 -28
  7. dataeval/detectors/drift/{ks.py → _ks.py} +20 -26
  8. dataeval/detectors/drift/{mmd.py → _mmd.py} +31 -43
  9. dataeval/detectors/drift/{torch.py → _torch.py} +2 -1
  10. dataeval/detectors/drift/{uncertainty.py → _uncertainty.py} +24 -7
  11. dataeval/detectors/drift/updates.py +20 -3
  12. dataeval/detectors/linters/__init__.py +3 -5
  13. dataeval/detectors/linters/duplicates.py +13 -36
  14. dataeval/detectors/linters/outliers.py +23 -148
  15. dataeval/detectors/ood/__init__.py +1 -1
  16. dataeval/detectors/ood/ae.py +30 -9
  17. dataeval/detectors/ood/base.py +5 -4
  18. dataeval/detectors/ood/mixin.py +21 -7
  19. dataeval/detectors/ood/vae.py +73 -0
  20. dataeval/metadata/__init__.py +6 -0
  21. dataeval/metadata/_distance.py +167 -0
  22. dataeval/metadata/_ood.py +217 -0
  23. dataeval/metadata/_utils.py +44 -0
  24. dataeval/metrics/__init__.py +1 -1
  25. dataeval/metrics/bias/__init__.py +6 -4
  26. dataeval/metrics/bias/{balance.py → _balance.py} +15 -101
  27. dataeval/metrics/bias/_coverage.py +98 -0
  28. dataeval/metrics/bias/{diversity.py → _diversity.py} +18 -111
  29. dataeval/metrics/bias/{parity.py → _parity.py} +39 -77
  30. dataeval/metrics/estimators/__init__.py +15 -4
  31. dataeval/metrics/estimators/{ber.py → _ber.py} +42 -29
  32. dataeval/metrics/estimators/_clusterer.py +44 -0
  33. dataeval/metrics/estimators/{divergence.py → _divergence.py} +18 -30
  34. dataeval/metrics/estimators/{uap.py → _uap.py} +4 -18
  35. dataeval/metrics/stats/__init__.py +16 -13
  36. dataeval/metrics/stats/{base.py → _base.py} +82 -133
  37. dataeval/metrics/stats/{boxratiostats.py → _boxratiostats.py} +15 -18
  38. dataeval/metrics/stats/_dimensionstats.py +75 -0
  39. dataeval/metrics/stats/{hashstats.py → _hashstats.py} +21 -37
  40. dataeval/metrics/stats/_imagestats.py +94 -0
  41. dataeval/metrics/stats/_labelstats.py +131 -0
  42. dataeval/metrics/stats/{pixelstats.py → _pixelstats.py} +19 -50
  43. dataeval/metrics/stats/{visualstats.py → _visualstats.py} +23 -54
  44. dataeval/outputs/__init__.py +53 -0
  45. dataeval/{output.py → outputs/_base.py} +55 -25
  46. dataeval/outputs/_bias.py +381 -0
  47. dataeval/outputs/_drift.py +83 -0
  48. dataeval/outputs/_estimators.py +114 -0
  49. dataeval/outputs/_linters.py +184 -0
  50. dataeval/{detectors/ood/output.py → outputs/_ood.py} +22 -22
  51. dataeval/outputs/_stats.py +387 -0
  52. dataeval/outputs/_utils.py +44 -0
  53. dataeval/outputs/_workflows.py +364 -0
  54. dataeval/typing.py +234 -0
  55. dataeval/utils/__init__.py +2 -2
  56. dataeval/utils/_array.py +169 -0
  57. dataeval/utils/_bin.py +199 -0
  58. dataeval/utils/_clusterer.py +144 -0
  59. dataeval/utils/_fast_mst.py +189 -0
  60. dataeval/utils/{image.py → _image.py} +6 -4
  61. dataeval/utils/_method.py +14 -0
  62. dataeval/utils/{shared.py → _mst.py} +3 -65
  63. dataeval/utils/{plot.py → _plot.py} +6 -6
  64. dataeval/utils/data/__init__.py +26 -0
  65. dataeval/utils/data/_dataset.py +217 -0
  66. dataeval/utils/data/_embeddings.py +104 -0
  67. dataeval/utils/data/_images.py +68 -0
  68. dataeval/utils/data/_metadata.py +360 -0
  69. dataeval/utils/data/_selection.py +126 -0
  70. dataeval/utils/{dataset/split.py → data/_split.py} +12 -38
  71. dataeval/utils/data/_targets.py +85 -0
  72. dataeval/utils/data/collate.py +103 -0
  73. dataeval/utils/data/datasets/__init__.py +17 -0
  74. dataeval/utils/data/datasets/_base.py +254 -0
  75. dataeval/utils/data/datasets/_cifar10.py +134 -0
  76. dataeval/utils/data/datasets/_fileio.py +168 -0
  77. dataeval/utils/data/datasets/_milco.py +153 -0
  78. dataeval/utils/data/datasets/_mixin.py +56 -0
  79. dataeval/utils/data/datasets/_mnist.py +183 -0
  80. dataeval/utils/data/datasets/_ships.py +123 -0
  81. dataeval/utils/data/datasets/_types.py +52 -0
  82. dataeval/utils/data/datasets/_voc.py +352 -0
  83. dataeval/utils/data/selections/__init__.py +15 -0
  84. dataeval/utils/data/selections/_classfilter.py +57 -0
  85. dataeval/utils/data/selections/_indices.py +26 -0
  86. dataeval/utils/data/selections/_limit.py +26 -0
  87. dataeval/utils/data/selections/_reverse.py +18 -0
  88. dataeval/utils/data/selections/_shuffle.py +29 -0
  89. dataeval/utils/metadata.py +51 -376
  90. dataeval/utils/torch/{gmm.py → _gmm.py} +4 -2
  91. dataeval/utils/torch/{internal.py → _internal.py} +21 -51
  92. dataeval/utils/torch/models.py +43 -2
  93. dataeval/workflows/__init__.py +2 -1
  94. dataeval/workflows/sufficiency.py +11 -346
  95. {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/METADATA +5 -2
  96. dataeval-0.82.0.dist-info/RECORD +104 -0
  97. dataeval/detectors/linters/clusterer.py +0 -512
  98. dataeval/detectors/linters/merged_stats.py +0 -49
  99. dataeval/detectors/ood/metadata_ks_compare.py +0 -129
  100. dataeval/detectors/ood/metadata_least_likely.py +0 -119
  101. dataeval/interop.py +0 -69
  102. dataeval/metrics/bias/coverage.py +0 -194
  103. dataeval/metrics/stats/datasetstats.py +0 -202
  104. dataeval/metrics/stats/dimensionstats.py +0 -115
  105. dataeval/metrics/stats/labelstats.py +0 -210
  106. dataeval/utils/dataset/__init__.py +0 -7
  107. dataeval/utils/dataset/datasets.py +0 -412
  108. dataeval/utils/dataset/read.py +0 -63
  109. dataeval-0.76.1.dist-info/RECORD +0 -67
  110. /dataeval/{log.py → _log.py} +0 -0
  111. /dataeval/utils/torch/{blocks.py → _blocks.py} +0 -0
  112. {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/LICENSE.txt +0 -0
  113. {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/WHEEL +0 -0
@@ -1,129 +0,0 @@
1
- from __future__ import annotations
2
-
3
- __all__ = []
4
-
5
- import numbers
6
- import warnings
7
- from typing import Any, Mapping, NamedTuple
8
-
9
- import numpy as np
10
- from numpy.typing import NDArray
11
- from scipy.stats import iqr, ks_2samp
12
- from scipy.stats import wasserstein_distance as emd
13
-
14
- from dataeval.output import MappingOutput, set_metadata
15
-
16
-
17
- class MetadataKSResult(NamedTuple):
18
- statistic: float
19
- statistic_location: float
20
- shift_magnitude: float
21
- pvalue: float
22
-
23
-
24
- class KSOutput(MappingOutput[str, MetadataKSResult]):
25
- """
26
- Output dictionary class for results of ks_2samp featurewise comparisons of new metadata to reference metadata.
27
-
28
- Attributes
29
- ----------
30
- key: str
31
- Metadata feature names
32
- value: NamedTuple[float, float, float, float]
33
- Each value contains four floats, which are:
34
- - statistic: the KS statistic itself
35
- - statistic_location: its location within the range of the reference metadata
36
- - shift_magnitude: the shift of new metadata relative to reference
37
- - pvalue: the p-value from the KS two-sample test
38
- """
39
-
40
-
41
- @set_metadata
42
- def meta_distribution_compare(
43
- md0: Mapping[str, list[Any] | NDArray[Any]], md1: Mapping[str, list[Any] | NDArray[Any]]
44
- ) -> KSOutput:
45
- """
46
- Measures the featurewise distance between two metadata distributions, and computes a p-value to evaluate its
47
- significance.
48
-
49
- Uses the Earth Mover's Distance and the Kolmogorov-Smirnov two-sample test, featurewise.
50
-
51
- Parameters
52
- ----------
53
- md0 : Mapping[str, list[Any] | NDArray[Any]]
54
- A set of arrays of values, indexed by metadata feature names, with one value per data example per feature.
55
- md1 : Mapping[str, list[Any] | NDArray[Any]]
56
- Another set of arrays of values, indexed by metadata feature names, with one value per data example per
57
- feature.
58
-
59
- Returns
60
- -------
61
- dict[str, KstestResult]
62
- A dictionary with keys corresponding to metadata feature names, and values that are KstestResult objects, as
63
- defined by scipy.stats.ks_2samp. These values also have two additional attributes: shift_magnitude and
64
- statistic_location. The first is the Earth Mover's Distance normalized by the interquartile range (IQR) of
65
- the reference, while the second is the value at which the KS statistic has its maximum, measured in
66
- IQR-normalized units relative to the median of the reference distribution.
67
-
68
- Examples
69
- --------
70
- Imagine we have 3 data examples, and that the corresponding metadata contains 2 features called time and
71
- altitude.
72
-
73
- >>> md0 = {"time": [1.2, 3.4, 5.6], "altitude": [235, 6789, 101112]}
74
- >>> md1 = {"time": [7.8, 9.10, 11.12], "altitude": [532, 9876, 211101]}
75
- >>> md_out = meta_distribution_compare(md0, md1)
76
- >>> for k, v in md_out.items():
77
- ... print(f"{k}: { {kv: round(vv, 3) for kv, vv in v._asdict().items()} }")
78
- time: {'statistic': 1.0, 'statistic_location': 0.444, 'shift_magnitude': 2.7, 'pvalue': 0.0}
79
- altitude: {'statistic': 0.333, 'statistic_location': 0.478, 'shift_magnitude': 0.749, 'pvalue': 0.944}
80
- """
81
-
82
- if (metadata_keys := md0.keys()) != md1.keys():
83
- raise ValueError(f"Both sets of metadata keys must be identical: {list(md0)}, {list(md1)}")
84
-
85
- mdc = {} # output dict
86
- for k in metadata_keys:
87
- mdc.update({k: {}})
88
-
89
- x0, x1 = list(md0[k]), list(md1[k])
90
-
91
- allx = x0 + x1 # "+" sign concatenates lists.
92
-
93
- if not all(isinstance(allxi, numbers.Number) for allxi in allx): # NB: np.nan *is* a number in this context.
94
- continue # non-numeric features will return an empty dict for feature k
95
-
96
- # from Numerical Recipes in C, 3rd ed. p. 737. If too few points, warn and keep going.
97
- if np.sqrt(((N := len(x0)) * (M := len(x1))) / (N + M)) < 4:
98
- warnings.warn(
99
- f"Sample sizes of {N}, {M} for feature {k} will yield unreliable p-values from the KS test.",
100
- UserWarning,
101
- )
102
-
103
- xmin, xmax = min(allx), max(allx)
104
- if xmin == xmax: # only one value in this feature, so fill in the obvious results for feature k
105
- mdc[k] = MetadataKSResult(
106
- **{"statistic": 0.0, "statistic_location": 0.0, "shift_magnitude": 0.0, "pvalue": 1.0}
107
- )
108
- continue
109
-
110
- ks_result = ks_2samp(x0, x1, method="asymp")
111
- dev = ks_result.statistic_location - xmin # pyright: ignore (KSresult type)
112
- loc = dev / (xmax - xmin) if xmax > xmin else dev
113
-
114
- dX = iqr(x0) # preferred value of dX, which is the scale of the the md0 values for feature k
115
- dX = (max(x0) - min(x0)) / 2.0 if dX == 0 else dX # reasonable alternative value of dX, when iqr is zero.
116
- dX = 1.0 if dX == 0 else dX # if dX is *still* zero, just avoid division by zero this way
117
-
118
- drift = emd(x0, x1) / dX
119
-
120
- mdc[k] = MetadataKSResult(
121
- **{
122
- "statistic": ks_result.statistic, # pyright: ignore
123
- "statistic_location": loc,
124
- "shift_magnitude": drift,
125
- "pvalue": ks_result.pvalue, # pyright: ignore
126
- }
127
- )
128
-
129
- return KSOutput(mdc)
@@ -1,119 +0,0 @@
1
- from __future__ import annotations
2
-
3
- __all__ = []
4
-
5
- import numbers
6
- import warnings
7
- from typing import Any
8
-
9
- import numpy as np
10
- from numpy.typing import NDArray
11
-
12
-
13
- def get_least_likely_features(
14
- metadata: dict[str, list[Any] | NDArray[Any]],
15
- new_metadata: dict[str, list[Any] | NDArray[Any]],
16
- is_ood: NDArray[np.bool_],
17
- ) -> list[tuple[str, float]]:
18
- """Computes which metadata feature is most out-of-distribution (OOD) relative to a reference metadata set.
19
-
20
- Given a reference metadata dictionary `metadata` (where each key maps to one scalar metadata feature), a second
21
- metadata dictionary, and a corresponding boolean flag `is_ood` indicating whether each new example falls
22
- out-of-distribution (OOD) relative to the reference, this function finds which metadata feature is the most OOD,
23
- for each OOD example.
24
-
25
- Parameters
26
- ----------
27
- metadata: dict[str, list[Any] | NDArray[Any]]
28
- A reference set of arrays of values, indexed by metadata feature names, with one value per data example per
29
- feature.
30
- new_metadata: dict[str, list[Any] | NDArray[Any]]
31
- A second metadata set, to be tested against the reference metadata. It is ok if the two meta data objects
32
- hold different numbers of examples.
33
- is_ood: NDArray[np.bool_]
34
- A boolean array, with one value per new_metadata example, that indicates which examples are OOD.
35
-
36
- Returns
37
- -------
38
- list[tuple[str, float]]
39
- An array of names of the features of each OOD new_metadata example that were the most OOD.
40
-
41
- Examples
42
- --------
43
- Imagine we have 3 data examples, and that the corresponding metadata contains 2 features called time and
44
- altitude, as shown below.
45
-
46
- >>> metadata = {"time": [1.2, 3.4, 5.6], "altitude": [235, 6789, 101112]}
47
- >>> new_metadata = {"time": [7.8, 11.12], "altitude": [532, -211101]}
48
- >>> is_ood = np.array([True, True])
49
- >>> get_least_likely_features(metadata, new_metadata, is_ood)
50
- [('time', 2.0), ('altitude', 33.245346)]
51
- """
52
- # Raise errors for bad inputs...
53
-
54
- if metadata.keys() != new_metadata.keys():
55
- raise ValueError(f"Reference and test metadata keys must be identical: {list(metadata)}, {list(new_metadata)}")
56
-
57
- md_lengths = {len(np.atleast_1d(v)) for v in metadata.values()}
58
- new_md_lengths = {len(np.atleast_1d(v)) for v in new_metadata.values()}
59
- if len(md_lengths) > 1 or len(new_md_lengths) > 1:
60
- raise ValueError(f"All features must have same length, got lengths {md_lengths}, {new_md_lengths}")
61
-
62
- n_reference, n_new = md_lengths.pop(), new_md_lengths.pop() # possibly different numbers of metadata examples
63
-
64
- if n_new != len(is_ood):
65
- raise ValueError(f"is_ood flag must have same length as new metadata {n_new} but has length {len(is_ood)}.")
66
-
67
- if n_reference < 3: # too hard to define "in-distribution" with this few reference samples.
68
- warnings.warn(
69
- "We need at least 3 reference metadata examples to determine which "
70
- f"features are least likely, but only got {n_reference}",
71
- UserWarning,
72
- )
73
- return []
74
-
75
- if not any(is_ood):
76
- return []
77
-
78
- # ...inputs are good, look for most deviant standardized features.
79
-
80
- # largest standardized absolute deviation from the median observed so far for each example
81
- deviation = np.zeros_like(is_ood, dtype=np.float32)
82
-
83
- # name of feature that corresponds to `deviation` for each example
84
- kmax = np.empty(len(is_ood), dtype=object)
85
-
86
- for k, v in metadata.items():
87
- # exclude cases where random happens to be out on tails, not interesting.
88
- if k == "random":
89
- continue
90
-
91
- # Skip non-numerical features
92
- if not all(isinstance(vi, numbers.Number) for vi in v): # NB: np.nan *is* a number in this context.
93
- continue
94
-
95
- # Get standardization parameters from metadata
96
- loc = np.median(v) # ok, because we checked all were numeric
97
- dev = np.asarray(v) - loc # need to make array from v since it could be a list here.
98
- posdev, negdev = dev[dev > 0], dev[dev < 0]
99
- pos_scale = np.median(posdev) if posdev.any() else 1.0
100
- neg_scale = np.abs(np.median(negdev)) if negdev.any() else 1.0
101
-
102
- x, x0, dxp, dxn = np.atleast_1d(new_metadata[k]), loc, pos_scale, neg_scale # just abbreviations
103
- dxp = dxp if dxp > 0 else 1.0 # avoids dividing by zero below
104
- dxn = dxn if dxn > 0 else 1.0
105
-
106
- # xdev must be floating-point to avoid getting zero in an integer division.
107
- xdev = (x - x0).astype(np.float64)
108
- pos = xdev >= 0
109
-
110
- X = np.zeros_like(xdev)
111
- X[pos], X[~pos] = xdev[pos] / dxp, xdev[~pos] / dxn # keeping track of possible asymmetry of x, but...
112
- # ...below here, only need to think about absolute deviation.
113
-
114
- abig = np.abs(X) > deviation
115
- kmax[abig] = k
116
- deviation[abig] = np.abs(X[abig])
117
-
118
- unlikely_features = list(zip(kmax[is_ood], deviation[is_ood])) # feature names, along with how far out they are.
119
- return unlikely_features
dataeval/interop.py DELETED
@@ -1,69 +0,0 @@
1
- """Utility functions for interoperability with different array types."""
2
-
3
- from __future__ import annotations
4
-
5
- __all__ = []
6
-
7
- import logging
8
- from importlib import import_module
9
- from types import ModuleType
10
- from typing import Any, Iterable, Iterator
11
-
12
- import numpy as np
13
- from numpy.typing import ArrayLike, NDArray
14
-
15
- from dataeval.log import LogMessage
16
-
17
- _logger = logging.getLogger(__name__)
18
-
19
- _MODULE_CACHE = {}
20
-
21
-
22
- def _try_import(module_name) -> ModuleType | None:
23
- if module_name in _MODULE_CACHE:
24
- return _MODULE_CACHE[module_name]
25
-
26
- try:
27
- module = import_module(module_name)
28
- except ImportError: # pragma: no cover - covered by test_mindeps.py
29
- _logger.log(logging.INFO, f"Unable to import {module_name}.")
30
- module = None
31
-
32
- _MODULE_CACHE[module_name] = module
33
- return module
34
-
35
-
36
- def as_numpy(array: ArrayLike | None) -> NDArray[Any]:
37
- """Converts an ArrayLike to Numpy array without copying (if possible)"""
38
- return to_numpy(array, copy=False)
39
-
40
-
41
- def to_numpy(array: ArrayLike | None, copy: bool = True) -> NDArray[Any]:
42
- """Converts an ArrayLike to new Numpy array"""
43
- if array is None:
44
- return np.ndarray([])
45
-
46
- if isinstance(array, np.ndarray):
47
- return array.copy() if copy else array
48
-
49
- if array.__class__.__module__.startswith("tensorflow"): # pragma: no cover - removed tf from deps
50
- tf = _try_import("tensorflow")
51
- if tf and tf.is_tensor(array):
52
- _logger.log(logging.INFO, "Converting Tensorflow array to NumPy array.")
53
- return array.numpy().copy() if copy else array.numpy() # type: ignore
54
-
55
- if array.__class__.__module__.startswith("torch"):
56
- torch = _try_import("torch")
57
- if torch and isinstance(array, torch.Tensor):
58
- _logger.log(logging.INFO, "Converting PyTorch array to NumPy array.")
59
- numpy = array.detach().cpu().numpy().copy() if copy else array.detach().cpu().numpy() # type: ignore
60
- _logger.log(logging.DEBUG, LogMessage(lambda: f"{str(array)} -> {str(numpy)}"))
61
- return numpy
62
-
63
- return np.array(array) if copy else np.asarray(array)
64
-
65
-
66
- def to_numpy_iter(iterable: Iterable[ArrayLike]) -> Iterator[NDArray[Any]]:
67
- """Yields an iterator of numpy arrays from an ArrayLike"""
68
- for array in iterable:
69
- yield to_numpy(array)
@@ -1,194 +0,0 @@
1
- from __future__ import annotations
2
-
3
- __all__ = []
4
-
5
- import contextlib
6
- import math
7
- from dataclasses import dataclass
8
- from typing import Any, Literal
9
-
10
- import numpy as np
11
- from numpy.typing import ArrayLike, NDArray
12
- from scipy.spatial.distance import pdist, squareform
13
-
14
- from dataeval.interop import to_numpy
15
- from dataeval.output import Output, set_metadata
16
- from dataeval.utils.shared import flatten
17
-
18
- with contextlib.suppress(ImportError):
19
- from matplotlib.figure import Figure
20
-
21
-
22
- def _plot(images: NDArray[Any], num_images: int) -> Figure:
23
- """
24
- Creates a single plot of all of the provided images
25
-
26
- Parameters
27
- ----------
28
- images : NDArray
29
- Array containing only the desired images to plot
30
-
31
- Returns
32
- -------
33
- matplotlib.figure.Figure
34
- Plot of all provided images
35
- """
36
- import matplotlib.pyplot as plt
37
-
38
- num_images = min(num_images, len(images))
39
-
40
- if images.ndim == 4:
41
- images = np.moveaxis(images, 1, -1)
42
- elif images.ndim == 3:
43
- images = np.repeat(images[:, :, :, np.newaxis], 3, axis=-1)
44
- else:
45
- raise ValueError(
46
- f"Expected a (N,C,H,W) or a (N, H, W) set of images, but got a {images.ndim}-dimensional set of images."
47
- )
48
-
49
- rows = int(np.ceil(num_images / 3))
50
- fig, axs = plt.subplots(rows, 3, figsize=(9, 3 * rows))
51
-
52
- if rows == 1:
53
- for j in range(3):
54
- if j >= len(images):
55
- continue
56
- axs[j].imshow(images[j])
57
- axs[j].axis("off")
58
- else:
59
- for i in range(rows):
60
- for j in range(3):
61
- i_j = i * 3 + j
62
- if i_j >= len(images):
63
- continue
64
- axs[i, j].imshow(images[i_j])
65
- axs[i, j].axis("off")
66
-
67
- fig.tight_layout()
68
- return fig
69
-
70
-
71
- @dataclass(frozen=True)
72
- class CoverageOutput(Output):
73
- """
74
- Output class for :func:`coverage` :term:`bias<Bias>` metric.
75
-
76
- Attributes
77
- ----------
78
- indices : NDArray[np.intp]
79
- Array of uncovered indices
80
- radii : NDArray[np.float64]
81
- Array of critical value radii
82
- critical_value : float
83
- Radius for :term:`coverage<Coverage>`
84
- """
85
-
86
- indices: NDArray[np.intp]
87
- radii: NDArray[np.float64]
88
- critical_value: float
89
-
90
- def plot(self, images: ArrayLike, top_k: int = 6) -> Figure:
91
- """
92
- Plot the top k images together for visualization
93
-
94
- Parameters
95
- ----------
96
- images : ArrayLike
97
- Original images (not embeddings) in (N, C, H, W) or (N, H, W) format
98
- top_k : int, default 6
99
- Number of images to plot (plotting assumes groups of 3)
100
-
101
- Returns
102
- -------
103
- matplotlib.figure.Figure
104
- """
105
- # Determine which images to plot
106
- highest_uncovered_indices = self.indices[:top_k]
107
-
108
- # Grab the images
109
- images = to_numpy(images)
110
- selected_images = images[highest_uncovered_indices]
111
-
112
- # Plot the images
113
- fig = _plot(selected_images, top_k)
114
-
115
- return fig
116
-
117
-
118
- @set_metadata
119
- def coverage(
120
- embeddings: ArrayLike,
121
- radius_type: Literal["adaptive", "naive"] = "adaptive",
122
- k: int = 20,
123
- percent: float = 0.01,
124
- ) -> CoverageOutput:
125
- """
126
- Class for evaluating :term:`coverage<Coverage>` and identifying images/samples that are in undercovered regions.
127
-
128
- Parameters
129
- ----------
130
- embeddings : ArrayLike, shape - (N, P)
131
- A dataset in an ArrayLike format.
132
- Function expects the data to have 2 dimensions, N number of observations in a P-dimesionial space.
133
- radius_type : {"adaptive", "naive"}, default "adaptive"
134
- The function used to determine radius.
135
- k : int, default 20
136
- Number of observations required in order to be covered.
137
- [1] suggests that a minimum of 20-50 samples is necessary.
138
- percent : float, default 0.01
139
- Percent of observations to be considered uncovered. Only applies to adaptive radius.
140
-
141
- Returns
142
- -------
143
- CoverageOutput
144
- Array of uncovered indices, critical value radii, and the radius for coverage
145
-
146
- Raises
147
- ------
148
- ValueError
149
- If length of :term:`embeddings<Embeddings>` is less than or equal to k
150
- ValueError
151
- If radius_type is unknown
152
-
153
- Note
154
- ----
155
- Embeddings should be on the unit interval [0-1].
156
-
157
- Example
158
- -------
159
- >>> results = coverage(embeddings)
160
- >>> results.indices
161
- array([447, 412, 8, 32, 63])
162
- >>> results.critical_value
163
- 0.8459038956941765
164
-
165
- Reference
166
- ---------
167
- This implementation is based on https://dl.acm.org/doi/abs/10.1145/3448016.3457315.
168
-
169
- [1] Seymour Sudman. 1976. Applied sampling. Academic Press New York (1976).
170
- """
171
-
172
- # Calculate distance matrix, look at the (k+1)th farthest neighbor for each image.
173
- embeddings = to_numpy(embeddings)
174
- n = len(embeddings)
175
- if n <= k:
176
- raise ValueError(
177
- f"Number of observations n={n} is less than or equal to the specified number of neighbors k={k}."
178
- )
179
- mat = squareform(pdist(flatten(embeddings))).astype(np.float64)
180
- sorted_dists = np.sort(mat, axis=1)
181
- crit = sorted_dists[:, k + 1]
182
-
183
- d = embeddings.shape[1]
184
- if radius_type == "naive":
185
- rho = (1 / math.sqrt(math.pi)) * ((2 * k * math.gamma(d / 2 + 1)) / (n)) ** (1 / d)
186
- pvals = np.where(crit > rho)[0]
187
- elif radius_type == "adaptive":
188
- # Use data adaptive cutoff as rho
189
- selection = int(max(n * percent, 1))
190
- pvals = np.argsort(crit)[::-1][:selection]
191
- rho = float(np.mean(np.sort(crit)[::-1][selection - 1 : selection + 1]))
192
- else:
193
- raise ValueError(f"{radius_type} is an invalid radius type. Expected 'adaptive' or 'naive'")
194
- return CoverageOutput(pvals, crit, rho)