dataeval 0.76.1__py3-none-any.whl → 0.82.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. dataeval/__init__.py +3 -3
  2. dataeval/config.py +77 -0
  3. dataeval/detectors/__init__.py +1 -1
  4. dataeval/detectors/drift/__init__.py +6 -6
  5. dataeval/detectors/drift/{base.py → _base.py} +40 -85
  6. dataeval/detectors/drift/{cvm.py → _cvm.py} +21 -28
  7. dataeval/detectors/drift/{ks.py → _ks.py} +20 -26
  8. dataeval/detectors/drift/{mmd.py → _mmd.py} +31 -43
  9. dataeval/detectors/drift/{torch.py → _torch.py} +2 -1
  10. dataeval/detectors/drift/{uncertainty.py → _uncertainty.py} +24 -7
  11. dataeval/detectors/drift/updates.py +20 -3
  12. dataeval/detectors/linters/__init__.py +3 -5
  13. dataeval/detectors/linters/duplicates.py +13 -36
  14. dataeval/detectors/linters/outliers.py +23 -148
  15. dataeval/detectors/ood/__init__.py +1 -1
  16. dataeval/detectors/ood/ae.py +30 -9
  17. dataeval/detectors/ood/base.py +5 -4
  18. dataeval/detectors/ood/mixin.py +21 -7
  19. dataeval/detectors/ood/vae.py +73 -0
  20. dataeval/metadata/__init__.py +6 -0
  21. dataeval/metadata/_distance.py +167 -0
  22. dataeval/metadata/_ood.py +217 -0
  23. dataeval/metadata/_utils.py +44 -0
  24. dataeval/metrics/__init__.py +1 -1
  25. dataeval/metrics/bias/__init__.py +6 -4
  26. dataeval/metrics/bias/{balance.py → _balance.py} +15 -101
  27. dataeval/metrics/bias/_coverage.py +98 -0
  28. dataeval/metrics/bias/{diversity.py → _diversity.py} +18 -111
  29. dataeval/metrics/bias/{parity.py → _parity.py} +39 -77
  30. dataeval/metrics/estimators/__init__.py +15 -4
  31. dataeval/metrics/estimators/{ber.py → _ber.py} +42 -29
  32. dataeval/metrics/estimators/_clusterer.py +44 -0
  33. dataeval/metrics/estimators/{divergence.py → _divergence.py} +18 -30
  34. dataeval/metrics/estimators/{uap.py → _uap.py} +4 -18
  35. dataeval/metrics/stats/__init__.py +16 -13
  36. dataeval/metrics/stats/{base.py → _base.py} +82 -133
  37. dataeval/metrics/stats/{boxratiostats.py → _boxratiostats.py} +15 -18
  38. dataeval/metrics/stats/_dimensionstats.py +75 -0
  39. dataeval/metrics/stats/{hashstats.py → _hashstats.py} +21 -37
  40. dataeval/metrics/stats/_imagestats.py +94 -0
  41. dataeval/metrics/stats/_labelstats.py +131 -0
  42. dataeval/metrics/stats/{pixelstats.py → _pixelstats.py} +19 -50
  43. dataeval/metrics/stats/{visualstats.py → _visualstats.py} +23 -54
  44. dataeval/outputs/__init__.py +53 -0
  45. dataeval/{output.py → outputs/_base.py} +55 -25
  46. dataeval/outputs/_bias.py +381 -0
  47. dataeval/outputs/_drift.py +83 -0
  48. dataeval/outputs/_estimators.py +114 -0
  49. dataeval/outputs/_linters.py +184 -0
  50. dataeval/{detectors/ood/output.py → outputs/_ood.py} +22 -22
  51. dataeval/outputs/_stats.py +387 -0
  52. dataeval/outputs/_utils.py +44 -0
  53. dataeval/outputs/_workflows.py +364 -0
  54. dataeval/typing.py +234 -0
  55. dataeval/utils/__init__.py +2 -2
  56. dataeval/utils/_array.py +169 -0
  57. dataeval/utils/_bin.py +199 -0
  58. dataeval/utils/_clusterer.py +144 -0
  59. dataeval/utils/_fast_mst.py +189 -0
  60. dataeval/utils/{image.py → _image.py} +6 -4
  61. dataeval/utils/_method.py +14 -0
  62. dataeval/utils/{shared.py → _mst.py} +3 -65
  63. dataeval/utils/{plot.py → _plot.py} +6 -6
  64. dataeval/utils/data/__init__.py +26 -0
  65. dataeval/utils/data/_dataset.py +217 -0
  66. dataeval/utils/data/_embeddings.py +104 -0
  67. dataeval/utils/data/_images.py +68 -0
  68. dataeval/utils/data/_metadata.py +360 -0
  69. dataeval/utils/data/_selection.py +126 -0
  70. dataeval/utils/{dataset/split.py → data/_split.py} +12 -38
  71. dataeval/utils/data/_targets.py +85 -0
  72. dataeval/utils/data/collate.py +103 -0
  73. dataeval/utils/data/datasets/__init__.py +17 -0
  74. dataeval/utils/data/datasets/_base.py +254 -0
  75. dataeval/utils/data/datasets/_cifar10.py +134 -0
  76. dataeval/utils/data/datasets/_fileio.py +168 -0
  77. dataeval/utils/data/datasets/_milco.py +153 -0
  78. dataeval/utils/data/datasets/_mixin.py +56 -0
  79. dataeval/utils/data/datasets/_mnist.py +183 -0
  80. dataeval/utils/data/datasets/_ships.py +123 -0
  81. dataeval/utils/data/datasets/_types.py +52 -0
  82. dataeval/utils/data/datasets/_voc.py +352 -0
  83. dataeval/utils/data/selections/__init__.py +15 -0
  84. dataeval/utils/data/selections/_classfilter.py +57 -0
  85. dataeval/utils/data/selections/_indices.py +26 -0
  86. dataeval/utils/data/selections/_limit.py +26 -0
  87. dataeval/utils/data/selections/_reverse.py +18 -0
  88. dataeval/utils/data/selections/_shuffle.py +29 -0
  89. dataeval/utils/metadata.py +51 -376
  90. dataeval/utils/torch/{gmm.py → _gmm.py} +4 -2
  91. dataeval/utils/torch/{internal.py → _internal.py} +21 -51
  92. dataeval/utils/torch/models.py +43 -2
  93. dataeval/workflows/__init__.py +2 -1
  94. dataeval/workflows/sufficiency.py +11 -346
  95. {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/METADATA +5 -2
  96. dataeval-0.82.0.dist-info/RECORD +104 -0
  97. dataeval/detectors/linters/clusterer.py +0 -512
  98. dataeval/detectors/linters/merged_stats.py +0 -49
  99. dataeval/detectors/ood/metadata_ks_compare.py +0 -129
  100. dataeval/detectors/ood/metadata_least_likely.py +0 -119
  101. dataeval/interop.py +0 -69
  102. dataeval/metrics/bias/coverage.py +0 -194
  103. dataeval/metrics/stats/datasetstats.py +0 -202
  104. dataeval/metrics/stats/dimensionstats.py +0 -115
  105. dataeval/metrics/stats/labelstats.py +0 -210
  106. dataeval/utils/dataset/__init__.py +0 -7
  107. dataeval/utils/dataset/datasets.py +0 -412
  108. dataeval/utils/dataset/read.py +0 -63
  109. dataeval-0.76.1.dist-info/RECORD +0 -67
  110. /dataeval/{log.py → _log.py} +0 -0
  111. /dataeval/utils/torch/{blocks.py → _blocks.py} +0 -0
  112. {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/LICENSE.txt +0 -0
  113. {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,98 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ import math
6
+ from typing import Literal
7
+
8
+ import numpy as np
9
+ from scipy.spatial.distance import pdist, squareform
10
+
11
+ from dataeval.outputs import CoverageOutput
12
+ from dataeval.outputs._base import set_metadata
13
+ from dataeval.typing import ArrayLike
14
+ from dataeval.utils._array import ensure_embeddings, flatten
15
+
16
+
17
+ @set_metadata
18
+ def coverage(
19
+ embeddings: ArrayLike,
20
+ radius_type: Literal["adaptive", "naive"] = "adaptive",
21
+ num_observations: int = 20,
22
+ percent: float = 0.01,
23
+ ) -> CoverageOutput:
24
+ """
25
+ Class for evaluating :term:`coverage<Coverage>` and identifying images/samples that are in undercovered regions.
26
+
27
+ Parameters
28
+ ----------
29
+ embeddings : ArrayLike, shape - (N, P)
30
+ Dataset embeddings as unit interval [0, 1].
31
+ Function expects the data to have 2 dimensions, N number of observations in a P-dimensional space.
32
+ radius_type : {"adaptive", "naive"}, default "adaptive"
33
+ The function used to determine radius.
34
+ num_observations : int, default 20
35
+ Number of observations required in order to be covered.
36
+ [1] suggests that a minimum of 20-50 samples is necessary.
37
+ percent : float, default 0.01
38
+ Percent of observations to be considered uncovered. Only applies to adaptive radius.
39
+
40
+ Returns
41
+ -------
42
+ CoverageOutput
43
+ Array of uncovered indices, critical value radii, and the radius for coverage
44
+
45
+ Raises
46
+ ------
47
+ ValueError
48
+ If embeddings are not unit interval [0-1]
49
+ ValueError
50
+ If length of :term:`embeddings<Embeddings>` is less than or equal to num_observations
51
+ ValueError
52
+ If radius_type is unknown
53
+
54
+ Note
55
+ ----
56
+ Embeddings should be on the unit interval [0-1].
57
+
58
+ Example
59
+ -------
60
+ >>> results = coverage(embeddings)
61
+ >>> results.uncovered_indices
62
+ array([447, 412, 8, 32, 63])
63
+ >>> results.coverage_radius
64
+ 0.17592147193757596
65
+
66
+ Reference
67
+ ---------
68
+ This implementation is based on https://dl.acm.org/doi/abs/10.1145/3448016.3457315.
69
+
70
+ [1] Seymour Sudman. 1976. Applied sampling. Academic Press New York (1976).
71
+ """
72
+
73
+ # Calculate distance matrix, look at the (num_observations + 1)th farthest neighbor for each image.
74
+ embeddings = ensure_embeddings(embeddings, dtype=np.float64, unit_interval=True)
75
+ len_embeddings = len(embeddings)
76
+ if len_embeddings <= num_observations:
77
+ raise ValueError(
78
+ f"Length of embeddings ({len_embeddings}) is less than or equal to the specified number of \
79
+ observations ({num_observations})."
80
+ )
81
+ embeddings_matrix = squareform(pdist(flatten(embeddings))).astype(np.float64)
82
+ sorted_dists = np.sort(embeddings_matrix, axis=1)
83
+ critical_value_radii = sorted_dists[:, num_observations + 1]
84
+
85
+ d = embeddings.shape[1]
86
+ if radius_type == "naive":
87
+ coverage_radius = (1 / math.sqrt(math.pi)) * (
88
+ (2 * num_observations * math.gamma(d / 2 + 1)) / (len_embeddings)
89
+ ) ** (1 / d)
90
+ uncovered_indices = np.where(critical_value_radii > coverage_radius)[0]
91
+ elif radius_type == "adaptive":
92
+ # Use data adaptive cutoff as coverage_radius
93
+ selection = int(max(len_embeddings * percent, 1))
94
+ uncovered_indices = np.argsort(critical_value_radii)[::-1][:selection]
95
+ coverage_radius = float(np.mean(np.sort(critical_value_radii)[::-1][selection - 1 : selection + 1]))
96
+ else:
97
+ raise ValueError(f"{radius_type} is an invalid radius type. Expected 'adaptive' or 'naive'")
98
+ return CoverageOutput(uncovered_indices, critical_value_radii, coverage_radius)
@@ -2,113 +2,17 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- import contextlib
6
- from dataclasses import dataclass
7
- from typing import Any, Literal
5
+ from typing import Literal
8
6
 
9
7
  import numpy as np
10
8
  import scipy as sp
11
- from numpy.typing import ArrayLike, NDArray
9
+ from numpy.typing import NDArray
12
10
 
13
- from dataeval.output import Output, set_metadata
14
- from dataeval.utils.metadata import Metadata, get_counts
15
- from dataeval.utils.plot import heatmap
16
- from dataeval.utils.shared import get_method
17
-
18
- with contextlib.suppress(ImportError):
19
- from matplotlib.figure import Figure
20
-
21
-
22
- def _plot(labels: NDArray[Any], bar_heights: NDArray[Any]) -> Figure:
23
- """
24
- Plots a formatted bar plot
25
-
26
- Parameters
27
- ----------
28
- labels : NDArray
29
- Array containing the labels for each bar
30
- bar_heights : NDArray
31
- Array containing the values for each bar
32
-
33
- Returns
34
- -------
35
- matplotlib.figure.Figure
36
- Bar plot figure
37
- """
38
- import matplotlib.pyplot as plt
39
-
40
- fig, ax = plt.subplots(figsize=(10, 10))
41
-
42
- ax.bar(labels, bar_heights)
43
- ax.set_xlabel("Factors")
44
-
45
- plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
46
-
47
- fig.tight_layout()
48
- return fig
49
-
50
-
51
- @dataclass(frozen=True)
52
- class DiversityOutput(Output):
53
- """
54
- Output class for :func:`diversity` :term:`bias<Bias>` metric.
55
-
56
- Attributes
57
- ----------
58
- diversity_index : NDArray[np.double]
59
- :term:`Diversity` index for classes and factors
60
- classwise : NDArray[np.double]
61
- Classwise diversity index [n_class x n_factor]
62
- factor_names : list[str]
63
- Names of each metadata factor
64
- class_list : NDArray[Any]
65
- Class labels for each value in the dataset
66
- """
67
-
68
- diversity_index: NDArray[np.double]
69
- classwise: NDArray[np.double]
70
- factor_names: list[str]
71
- class_list: NDArray[Any]
72
-
73
- def plot(
74
- self,
75
- row_labels: ArrayLike | None = None,
76
- col_labels: ArrayLike | None = None,
77
- plot_classwise: bool = False,
78
- ) -> Figure:
79
- """
80
- Plot a heatmap of diversity information
81
-
82
- Parameters
83
- ----------
84
- row_labels : ArrayLike or None, default None
85
- List/Array containing the labels for rows in the histogram
86
- col_labels : ArrayLike or None, default None
87
- List/Array containing the labels for columns in the histogram
88
- plot_classwise : bool, default False
89
- Whether to plot per-class balance instead of global balance
90
- """
91
- if plot_classwise:
92
- if row_labels is None:
93
- row_labels = self.class_list
94
- if col_labels is None:
95
- col_labels = self.factor_names
96
-
97
- fig = heatmap(
98
- self.classwise,
99
- row_labels,
100
- col_labels,
101
- xlabel="Factors",
102
- ylabel="Class",
103
- cbarlabel=f"Normalized {self.meta()['arguments']['method'].title()} Index",
104
- )
105
-
106
- else:
107
- # Creating label array for heat map axes
108
- heat_labels = np.concatenate((["class"], self.factor_names))
109
- fig = _plot(heat_labels, self.diversity_index)
110
-
111
- return fig
11
+ from dataeval.outputs import DiversityOutput
12
+ from dataeval.outputs._base import set_metadata
13
+ from dataeval.utils._bin import get_counts
14
+ from dataeval.utils._method import get_method
15
+ from dataeval.utils.data import Metadata
112
16
 
113
17
 
114
18
  def diversity_shannon(
@@ -191,6 +95,9 @@ def diversity_simpson(
191
95
  return ev_index
192
96
 
193
97
 
98
+ _DIVERSITY_FN_MAP = {"simpson": diversity_simpson, "shannon": diversity_shannon}
99
+
100
+
194
101
  @set_metadata
195
102
  def diversity(
196
103
  metadata: Metadata,
@@ -210,7 +117,7 @@ def diversity(
210
117
  Parameters
211
118
  ----------
212
119
  metadata : Metadata
213
- Preprocessed metadata from :func:`dataeval.utils.metadata.preprocess`
120
+ Preprocessed metadata
214
121
  method : "simpson" or "shannon", default "simpson"
215
122
  The methodology used for defining diversity
216
123
 
@@ -231,27 +138,27 @@ def diversity(
231
138
 
232
139
  >>> div_simp = diversity(metadata, method="simpson")
233
140
  >>> div_simp.diversity_index
234
- array([0.6 , 0.80882353, 1. , 0.8 ])
141
+ array([0.6 , 0.809, 1. , 0.8 ])
235
142
 
236
143
  >>> div_simp.classwise
237
- array([[0.5 , 0.8 , 0.8 ],
238
- [0.63043478, 0.97560976, 0.52830189]])
144
+ array([[0.5 , 0.8 , 0.8 ],
145
+ [0.63 , 0.976, 0.528]])
239
146
 
240
147
  Compute Shannon diversity index of metadata and class labels
241
148
 
242
149
  >>> div_shan = diversity(metadata, method="shannon")
243
150
  >>> div_shan.diversity_index
244
- array([0.81127812, 0.9426312 , 1. , 0.91829583])
151
+ array([0.811, 0.943, 1. , 0.918])
245
152
 
246
153
  >>> div_shan.classwise
247
- array([[0.68260619, 0.91829583, 0.91829583],
248
- [0.81443569, 0.99107606, 0.76420451]])
154
+ array([[0.683, 0.918, 0.918],
155
+ [0.814, 0.991, 0.764]])
249
156
 
250
157
  See Also
251
158
  --------
252
159
  scipy.stats.entropy
253
160
  """
254
- diversity_fn = get_method({"simpson": diversity_simpson, "shannon": diversity_shannon}, method)
161
+ diversity_fn = get_method(_DIVERSITY_FN_MAP, method)
255
162
  discretized_data = np.hstack((metadata.class_labels[:, np.newaxis], metadata.discrete_data))
256
163
  cnts = get_counts(discretized_data)
257
164
  num_bins = np.bincount(np.nonzero(cnts)[1])
@@ -3,39 +3,18 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  import warnings
6
- from dataclasses import dataclass
7
- from typing import Any, Generic, TypeVar
6
+ from typing import Any
8
7
 
9
8
  import numpy as np
10
- from numpy.typing import ArrayLike, NDArray
9
+ from numpy.typing import NDArray
11
10
  from scipy.stats import chisquare
12
11
  from scipy.stats.contingency import chi2_contingency, crosstab
13
12
 
14
- from dataeval.interop import as_numpy, to_numpy
15
- from dataeval.output import Output, set_metadata
16
- from dataeval.utils.metadata import Metadata
17
-
18
- TData = TypeVar("TData", np.float64, NDArray[np.float64])
19
-
20
-
21
- @dataclass(frozen=True)
22
- class ParityOutput(Generic[TData], Output):
23
- """
24
- Output class for :func:`parity` and :func:`label_parity` :term:`bias<Bias>` metrics.
25
-
26
- Attributes
27
- ----------
28
- score : np.float64 | NDArray[np.float64]
29
- chi-squared score(s) of the test
30
- p_value : np.float64 | NDArray[np.float64]
31
- p-value(s) of the test
32
- metadata_names : list[str] | None
33
- Names of each metadata factor
34
- """
35
-
36
- score: TData
37
- p_value: TData
38
- metadata_names: list[str] | None
13
+ from dataeval.outputs import LabelParityOutput, ParityOutput
14
+ from dataeval.outputs._base import set_metadata
15
+ from dataeval.typing import ArrayLike
16
+ from dataeval.utils._array import as_numpy
17
+ from dataeval.utils.data import Metadata
39
18
 
40
19
 
41
20
  def normalize_expected_dist(expected_dist: NDArray[Any], observed_dist: NDArray[Any]) -> NDArray[Any]:
@@ -109,7 +88,7 @@ def validate_dist(label_dist: NDArray[Any], label_name: str) -> None:
109
88
  raise ValueError(f"No labels found in the {label_name} dataset")
110
89
  if np.any(label_dist < 5):
111
90
  warnings.warn(
112
- f"Labels {np.where(label_dist<5)[0]} in {label_name}"
91
+ f"Labels {np.where(label_dist < 5)[0]} in {label_name}"
113
92
  " dataset have frequencies less than 5. This may lead"
114
93
  " to invalid chi-squared evaluation.",
115
94
  UserWarning,
@@ -121,7 +100,7 @@ def label_parity(
121
100
  expected_labels: ArrayLike,
122
101
  observed_labels: ArrayLike,
123
102
  num_classes: int | None = None,
124
- ) -> ParityOutput[np.float64]:
103
+ ) -> LabelParityOutput:
125
104
  """
126
105
  Calculate the chi-square statistic to assess the :term:`parity<Parity>` \
127
106
  between expected and observed label distributions.
@@ -142,7 +121,7 @@ def label_parity(
142
121
 
143
122
  Returns
144
123
  -------
145
- ParityOutput[np.float64]
124
+ LabelParityOutput
146
125
  chi-squared score and :term`P-Value` of the test
147
126
 
148
127
  Raises
@@ -171,7 +150,7 @@ def label_parity(
171
150
  >>> expected_labels = rng.choice([0, 1, 2, 3, 4], (100))
172
151
  >>> observed_labels = rng.choice([2, 3, 0, 4, 1], (100))
173
152
  >>> label_parity(expected_labels, observed_labels)
174
- ParityOutput(score=14.007374204742625, p_value=0.0072715574616218, metadata_names=None)
153
+ LabelParityOutput(score=14.007374204742625, p_value=0.0072715574616218)
175
154
  """
176
155
 
177
156
  # Calculate
@@ -179,8 +158,8 @@ def label_parity(
179
158
  num_classes = 0
180
159
 
181
160
  # Calculate the class frequencies associated with the datasets
182
- observed_dist = np.bincount(to_numpy(observed_labels), minlength=num_classes)
183
- expected_dist = np.bincount(to_numpy(expected_labels), minlength=num_classes)
161
+ observed_dist = np.bincount(as_numpy(observed_labels), minlength=num_classes)
162
+ expected_dist = np.bincount(as_numpy(expected_labels), minlength=num_classes)
184
163
 
185
164
  # Validate
186
165
  validate_dist(observed_dist, "observed")
@@ -202,11 +181,11 @@ def label_parity(
202
181
  )
203
182
 
204
183
  cs, p = chisquare(f_obs=observed_dist, f_exp=expected_dist)
205
- return ParityOutput(cs, p, None)
184
+ return LabelParityOutput(cs, p)
206
185
 
207
186
 
208
187
  @set_metadata
209
- def parity(metadata: Metadata) -> ParityOutput[NDArray[np.float64]]:
188
+ def parity(metadata: Metadata) -> ParityOutput:
210
189
  """
211
190
  Calculate chi-square statistics to assess the linear relationship \
212
191
  between multiple factors and class labels.
@@ -218,7 +197,7 @@ def parity(metadata: Metadata) -> ParityOutput[NDArray[np.float64]]:
218
197
  Parameters
219
198
  ----------
220
199
  metadata : Metadata
221
- Preprocessed metadata from :func:`dataeval.utils.metadata.preprocess`
200
+ Preprocessed metadata
222
201
 
223
202
  Returns
224
203
  -------
@@ -250,22 +229,22 @@ def parity(metadata: Metadata) -> ParityOutput[NDArray[np.float64]]:
250
229
  --------
251
230
  Randomly creating some "continuous" and categorical variables using ``np.random.default_rng``
252
231
 
253
- >>> from dataeval.utils.metadata import preprocess
254
- >>> rng = np.random.default_rng(175)
255
- >>> labels = rng.choice([0, 1, 2], (100))
256
- >>> metadata_dict = {
257
- ... "age": list(rng.choice([25, 30, 35, 45], (100))),
258
- ... "income": list(rng.choice([50000, 65000, 80000], (100))),
259
- ... "gender": list(rng.choice(["M", "F"], (100))),
260
- ... }
261
- >>> continuous_factor_bincounts = {"age": 4, "income": 3}
262
- >>> metadata = preprocess(metadata_dict, labels, continuous_factor_bincounts)
232
+ >>> metadata = generate_random_metadata(
233
+ ... labels=["doctor", "artist", "teacher"],
234
+ ... factors={
235
+ ... "age": [25, 30, 35, 45],
236
+ ... "income": [50000, 65000, 80000],
237
+ ... "gender": ["M", "F"]},
238
+ ... length=100,
239
+ ... random_seed=175)
240
+ >>> metadata.continuous_factor_bins = {"age": 4, "income": 3}
263
241
  >>> parity(metadata)
264
- ParityOutput(score=array([7.35731943, 5.46711299, 0.51506212]), p_value=array([0.28906231, 0.24263543, 0.77295762]), metadata_names=['age', 'income', 'gender'])
242
+ ParityOutput(score=array([7.357, 5.467, 0.515]), p_value=array([0.289, 0.243, 0.773]), factor_names=['age', 'income', 'gender'], insufficient_data={'age': {3: {'artist': 4}, 4: {'artist': 4, 'teacher': 3}}, 'income': {1: {'artist': 3}}})
265
243
  """ # noqa: E501
244
+
266
245
  chi_scores = np.zeros(metadata.discrete_data.shape[1])
267
246
  p_values = np.zeros_like(chi_scores)
268
- not_enough_data = {}
247
+ insufficient_data = {}
269
248
  for i, col_data in enumerate(metadata.discrete_data.T):
270
249
  # Builds a contingency matrix where entry at index (r,c) represents
271
250
  # the frequency of current_factor_name achieving value unique_factor_values[r]
@@ -279,14 +258,14 @@ def parity(metadata: Metadata) -> ParityOutput[NDArray[np.float64]]:
279
258
  current_factor_name = metadata.discrete_factor_names[i]
280
259
  for int_factor, int_class in zip(counts[0], counts[1]):
281
260
  if contingency_matrix[int_factor, int_class] > 0:
282
- factor_category = unique_factor_values[int_factor]
283
- if current_factor_name not in not_enough_data:
284
- not_enough_data[current_factor_name] = {}
285
- if factor_category not in not_enough_data[current_factor_name]:
286
- not_enough_data[current_factor_name][factor_category] = []
287
- not_enough_data[current_factor_name][factor_category].append(
288
- (metadata.class_names[int_class], int(contingency_matrix[int_factor, int_class]))
289
- )
261
+ factor_category = unique_factor_values[int_factor].item()
262
+ if current_factor_name not in insufficient_data:
263
+ insufficient_data[current_factor_name] = {}
264
+ if factor_category not in insufficient_data[current_factor_name]:
265
+ insufficient_data[current_factor_name][factor_category] = {}
266
+ class_name = metadata.class_names[int_class]
267
+ class_count = contingency_matrix[int_factor, int_class].item()
268
+ insufficient_data[current_factor_name][factor_category][class_name] = class_count
290
269
 
291
270
  # This deletes rows containing only zeros,
292
271
  # because scipy.stats.chi2_contingency fails when there are rows containing only zeros.
@@ -299,24 +278,7 @@ def parity(metadata: Metadata) -> ParityOutput[NDArray[np.float64]]:
299
278
  chi_scores[i] = chi2
300
279
  p_values[i] = p
301
280
 
302
- if not_enough_data:
303
- factor_msg = []
304
- for factor, fact_dict in not_enough_data.items():
305
- stacked_msg = []
306
- for key, value in fact_dict.items():
307
- msg = []
308
- for item in value:
309
- msg.append(f"label {item[0]}: {item[1]} occurrences")
310
- flat_msg = "\n\t\t".join(msg)
311
- stacked_msg.append(f"value {key} - {flat_msg}\n\t")
312
- factor_msg.append(factor + " - " + "".join(stacked_msg))
313
-
314
- message = "\n".join(factor_msg)
315
-
316
- warnings.warn(
317
- f"The following factors did not meet the recommended 5 occurrences for each value-label combination. \n\
318
- Recommend rerunning parity after adjusting the following factor-value-label combinations: \n{message}",
319
- UserWarning,
320
- )
281
+ if insufficient_data:
282
+ warnings.warn("Some factors did not meet the recommended 5 occurrences for each value-label combination.")
321
283
 
322
- return ParityOutput(chi_scores, p_values, metadata.discrete_factor_names)
284
+ return ParityOutput(chi_scores, p_values, metadata.discrete_factor_names, insufficient_data)
@@ -2,8 +2,19 @@
2
2
  Estimators calculate performance bounds and the statistical distance between datasets.
3
3
  """
4
4
 
5
- __all__ = ["ber", "divergence", "uap", "BEROutput", "DivergenceOutput", "UAPOutput"]
5
+ __all__ = [
6
+ "ber",
7
+ "clusterer",
8
+ "divergence",
9
+ "uap",
10
+ "BEROutput",
11
+ "ClustererOutput",
12
+ "DivergenceOutput",
13
+ "UAPOutput",
14
+ ]
6
15
 
7
- from dataeval.metrics.estimators.ber import BEROutput, ber
8
- from dataeval.metrics.estimators.divergence import DivergenceOutput, divergence
9
- from dataeval.metrics.estimators.uap import UAPOutput, uap
16
+ from dataeval.metrics.estimators._ber import ber
17
+ from dataeval.metrics.estimators._clusterer import clusterer
18
+ from dataeval.metrics.estimators._divergence import divergence
19
+ from dataeval.metrics.estimators._uap import uap
20
+ from dataeval.outputs._estimators import BEROutput, ClustererOutput, DivergenceOutput, UAPOutput
@@ -12,35 +12,19 @@ from __future__ import annotations
12
12
 
13
13
  __all__ = []
14
14
 
15
- from dataclasses import dataclass
16
15
  from typing import Literal
17
16
 
18
17
  import numpy as np
19
- from numpy.typing import ArrayLike, NDArray
18
+ from numpy.typing import NDArray
20
19
  from scipy.sparse import coo_matrix
21
20
  from scipy.stats import mode
22
21
 
23
- from dataeval.interop import as_numpy
24
- from dataeval.output import Output, set_metadata
25
- from dataeval.utils.shared import compute_neighbors, get_classes_counts, get_method, minimum_spanning_tree
26
-
27
-
28
- @dataclass(frozen=True)
29
- class BEROutput(Output):
30
- """
31
- Output class for :func:`ber` estimator metric.
32
-
33
- Attributes
34
- ----------
35
- ber : float
36
- The upper bounds of the :term:`Bayes error rate<Bayes Error Rate (BER)>`
37
- ber_lower : float
38
- The lower bounds of the Bayes Error Rate
39
- """
40
-
41
- ber: float
42
-
43
- ber_lower: float
22
+ from dataeval.outputs import BEROutput
23
+ from dataeval.outputs._base import set_metadata
24
+ from dataeval.typing import ArrayLike
25
+ from dataeval.utils._array import as_numpy, ensure_embeddings
26
+ from dataeval.utils._method import get_method
27
+ from dataeval.utils._mst import compute_neighbors, minimum_spanning_tree
44
28
 
45
29
 
46
30
  def ber_mst(images: NDArray[np.float64], labels: NDArray[np.int_], k: int = 1) -> tuple[float, float]:
@@ -116,18 +100,21 @@ def knn_lowerbound(value: float, classes: int, k: int) -> float:
116
100
  return ((classes - 1) / classes) * (1 - np.sqrt(max(0, 1 - ((classes / (classes - 1)) * value))))
117
101
 
118
102
 
103
+ _BER_FN_MAP = {"KNN": ber_knn, "MST": ber_mst}
104
+
105
+
119
106
  @set_metadata
120
- def ber(images: ArrayLike, labels: ArrayLike, k: int = 1, method: Literal["KNN", "MST"] = "KNN") -> BEROutput:
107
+ def ber(embeddings: ArrayLike, labels: ArrayLike, k: int = 1, method: Literal["KNN", "MST"] = "KNN") -> BEROutput:
121
108
  """
122
109
  An estimator for Multi-class :term:`Bayes error rate<Bayes Error Rate (BER)>` \
123
110
  using FR or KNN test statistic basis.
124
111
 
125
112
  Parameters
126
113
  ----------
127
- images : ArrayLike (N, ... )
128
- Array of images or image :term:`embeddings<Embeddings>`
114
+ embeddings : ArrayLike (N, ... )
115
+ Array of image :term:`embeddings<Embeddings>`
129
116
  labels : ArrayLike (N, 1)
130
- Array of labels for each image or image embedding
117
+ Array of labels for each image
131
118
  k : int, default 1
132
119
  Number of nearest neighbors for KNN estimator -- ignored by MST estimator
133
120
  method : Literal["KNN", "MST"], default "KNN"
@@ -152,8 +139,34 @@ def ber(images: ArrayLike, labels: ArrayLike, k: int = 1, method: Literal["KNN",
152
139
  >>> ber(images, labels)
153
140
  BEROutput(ber=0.04, ber_lower=0.020416847668728033)
154
141
  """
155
- ber_fn = get_method({"KNN": ber_knn, "MST": ber_mst}, method)
156
- X = as_numpy(images)
142
+ ber_fn = get_method(_BER_FN_MAP, method)
143
+ X = ensure_embeddings(embeddings, dtype=np.float64)
157
144
  y = as_numpy(labels)
158
145
  upper, lower = ber_fn(X, y, k)
159
146
  return BEROutput(upper, lower)
147
+
148
+
149
+ def get_classes_counts(labels: NDArray[np.int_]) -> tuple[int, int]:
150
+ """
151
+ Returns the classes and counts of from an array of labels
152
+
153
+ Parameters
154
+ ----------
155
+ label : NDArray
156
+ Numpy labels array
157
+
158
+ Returns
159
+ -------
160
+ Classes and counts
161
+
162
+ Raises
163
+ ------
164
+ ValueError
165
+ If the number of unique classes is less than 2
166
+ """
167
+ classes, counts = np.unique(labels, return_counts=True)
168
+ M = len(classes)
169
+ if M < 2:
170
+ raise ValueError("Label vector contains less than 2 classes!")
171
+ N = int(np.sum(counts))
172
+ return M, N
@@ -0,0 +1,44 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+
6
+ from dataeval.outputs import ClustererOutput
7
+ from dataeval.typing import ArrayLike
8
+ from dataeval.utils._array import as_numpy
9
+
10
+
11
+ def clusterer(data: ArrayLike) -> ClustererOutput:
12
+ """
13
+ Uses hierarchical clustering on the flattened data and returns clustering
14
+ information.
15
+
16
+ Parameters
17
+ ----------
18
+ data : ArrayLike, shape - (N, ...)
19
+ A dataset in an ArrayLike format. Function expects the data to have 2
20
+ or more dimensions which will flatten to (N, P) where N number of
21
+ observations in a P-dimensional space.
22
+
23
+ Returns
24
+ -------
25
+ :class:`.ClustererOutput`
26
+
27
+ Note
28
+ ----
29
+ The clusterer works best when the length of the feature dimension, P, is
30
+ less than 500. If flattening a CxHxW image results in a dimension larger
31
+ than 500, then it is recommended to reduce the dimensions.
32
+
33
+ Example
34
+ -------
35
+ >>> clusterer(clusterer_images).clusters
36
+ array([ 2, 0, 0, 0, 0, 0, 4, 0, 3, 1, 1, 0, 2, 0, 0, 0, 0,
37
+ 4, 2, 0, 0, 1, 2, 0, 1, 3, 0, 3, 3, 4, 0, 0, 3, 0,
38
+ 3, -1, 0, 0, 2, 4, 3, 4, 0, 1, 0, -1, 3, 0, 0, 0])
39
+ """
40
+ # Delay load numba compiled functions
41
+ from dataeval.utils._clusterer import cluster
42
+
43
+ c = cluster(data)
44
+ return ClustererOutput(c.clusters, c.mst, c.linkage_tree, as_numpy(c.condensed_tree), c.membership_strengths)