dataeval 0.76.1__py3-none-any.whl → 0.81.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. dataeval/__init__.py +3 -3
  2. dataeval/{output.py → _output.py} +14 -0
  3. dataeval/config.py +77 -0
  4. dataeval/detectors/__init__.py +1 -1
  5. dataeval/detectors/drift/__init__.py +6 -6
  6. dataeval/detectors/drift/{base.py → _base.py} +41 -30
  7. dataeval/detectors/drift/{cvm.py → _cvm.py} +21 -28
  8. dataeval/detectors/drift/{ks.py → _ks.py} +20 -26
  9. dataeval/detectors/drift/{mmd.py → _mmd.py} +33 -19
  10. dataeval/detectors/drift/{torch.py → _torch.py} +2 -1
  11. dataeval/detectors/drift/{uncertainty.py → _uncertainty.py} +23 -7
  12. dataeval/detectors/drift/updates.py +1 -1
  13. dataeval/detectors/linters/__init__.py +0 -3
  14. dataeval/detectors/linters/duplicates.py +17 -8
  15. dataeval/detectors/linters/outliers.py +23 -14
  16. dataeval/detectors/ood/ae.py +29 -8
  17. dataeval/detectors/ood/base.py +5 -4
  18. dataeval/detectors/ood/metadata_ks_compare.py +1 -1
  19. dataeval/detectors/ood/mixin.py +20 -5
  20. dataeval/detectors/ood/output.py +1 -1
  21. dataeval/detectors/ood/vae.py +73 -0
  22. dataeval/metadata/__init__.py +5 -0
  23. dataeval/metadata/_ood.py +238 -0
  24. dataeval/metrics/__init__.py +1 -1
  25. dataeval/metrics/bias/__init__.py +5 -4
  26. dataeval/metrics/bias/{balance.py → _balance.py} +67 -17
  27. dataeval/metrics/bias/{coverage.py → _coverage.py} +41 -35
  28. dataeval/metrics/bias/{diversity.py → _diversity.py} +17 -12
  29. dataeval/metrics/bias/{parity.py → _parity.py} +89 -61
  30. dataeval/metrics/estimators/__init__.py +14 -4
  31. dataeval/metrics/estimators/{ber.py → _ber.py} +42 -11
  32. dataeval/metrics/estimators/_clusterer.py +104 -0
  33. dataeval/metrics/estimators/{divergence.py → _divergence.py} +18 -13
  34. dataeval/metrics/estimators/{uap.py → _uap.py} +4 -4
  35. dataeval/metrics/stats/__init__.py +7 -7
  36. dataeval/metrics/stats/{base.py → _base.py} +52 -16
  37. dataeval/metrics/stats/{boxratiostats.py → _boxratiostats.py} +6 -9
  38. dataeval/metrics/stats/{datasetstats.py → _datasetstats.py} +10 -14
  39. dataeval/metrics/stats/{dimensionstats.py → _dimensionstats.py} +6 -5
  40. dataeval/metrics/stats/{hashstats.py → _hashstats.py} +6 -6
  41. dataeval/metrics/stats/{labelstats.py → _labelstats.py} +4 -4
  42. dataeval/metrics/stats/{pixelstats.py → _pixelstats.py} +5 -4
  43. dataeval/metrics/stats/{visualstats.py → _visualstats.py} +9 -8
  44. dataeval/typing.py +54 -0
  45. dataeval/utils/__init__.py +2 -2
  46. dataeval/utils/_array.py +169 -0
  47. dataeval/utils/_bin.py +199 -0
  48. dataeval/utils/_clusterer.py +144 -0
  49. dataeval/utils/_fast_mst.py +189 -0
  50. dataeval/utils/{image.py → _image.py} +6 -4
  51. dataeval/utils/_method.py +18 -0
  52. dataeval/utils/{shared.py → _mst.py} +3 -65
  53. dataeval/utils/{plot.py → _plot.py} +4 -4
  54. dataeval/utils/data/__init__.py +22 -0
  55. dataeval/utils/data/_embeddings.py +105 -0
  56. dataeval/utils/data/_images.py +65 -0
  57. dataeval/utils/data/_metadata.py +352 -0
  58. dataeval/utils/data/_selection.py +119 -0
  59. dataeval/utils/{dataset/split.py → data/_split.py} +13 -14
  60. dataeval/utils/data/_targets.py +73 -0
  61. dataeval/utils/data/_types.py +58 -0
  62. dataeval/utils/data/collate.py +103 -0
  63. dataeval/utils/data/datasets/__init__.py +17 -0
  64. dataeval/utils/data/datasets/_base.py +254 -0
  65. dataeval/utils/data/datasets/_cifar10.py +134 -0
  66. dataeval/utils/data/datasets/_fileio.py +168 -0
  67. dataeval/utils/data/datasets/_milco.py +153 -0
  68. dataeval/utils/data/datasets/_mixin.py +56 -0
  69. dataeval/utils/data/datasets/_mnist.py +183 -0
  70. dataeval/utils/data/datasets/_ships.py +123 -0
  71. dataeval/utils/data/datasets/_voc.py +352 -0
  72. dataeval/utils/data/selections/__init__.py +15 -0
  73. dataeval/utils/data/selections/_classfilter.py +60 -0
  74. dataeval/utils/data/selections/_indices.py +26 -0
  75. dataeval/utils/data/selections/_limit.py +26 -0
  76. dataeval/utils/data/selections/_reverse.py +18 -0
  77. dataeval/utils/data/selections/_shuffle.py +29 -0
  78. dataeval/utils/metadata.py +51 -376
  79. dataeval/utils/torch/{gmm.py → _gmm.py} +4 -2
  80. dataeval/utils/torch/{internal.py → _internal.py} +21 -51
  81. dataeval/utils/torch/models.py +43 -2
  82. dataeval/workflows/sufficiency.py +10 -9
  83. {dataeval-0.76.1.dist-info → dataeval-0.81.0.dist-info}/METADATA +4 -1
  84. dataeval-0.81.0.dist-info/RECORD +94 -0
  85. dataeval/detectors/linters/clusterer.py +0 -512
  86. dataeval/detectors/linters/merged_stats.py +0 -49
  87. dataeval/detectors/ood/metadata_least_likely.py +0 -119
  88. dataeval/interop.py +0 -69
  89. dataeval/utils/dataset/__init__.py +0 -7
  90. dataeval/utils/dataset/datasets.py +0 -412
  91. dataeval/utils/dataset/read.py +0 -63
  92. dataeval-0.76.1.dist-info/RECORD +0 -67
  93. /dataeval/{log.py → _log.py} +0 -0
  94. /dataeval/utils/torch/{blocks.py → _blocks.py} +0 -0
  95. {dataeval-0.76.1.dist-info → dataeval-0.81.0.dist-info}/LICENSE.txt +0 -0
  96. {dataeval-0.76.1.dist-info → dataeval-0.81.0.dist-info}/WHEEL +0 -0
@@ -5,16 +5,17 @@ __all__ = []
5
5
  import contextlib
6
6
  import warnings
7
7
  from dataclasses import dataclass
8
- from typing import Any
8
+ from typing import Any, Literal, overload
9
9
 
10
10
  import numpy as np
11
11
  import scipy as sp
12
12
  from numpy.typing import NDArray
13
13
  from sklearn.feature_selection import mutual_info_classif, mutual_info_regression
14
14
 
15
- from dataeval.output import Output, set_metadata
16
- from dataeval.utils.metadata import Metadata, get_counts
17
- from dataeval.utils.plot import heatmap
15
+ from dataeval._output import Output, set_metadata
16
+ from dataeval.utils._bin import get_counts
17
+ from dataeval.utils._plot import heatmap
18
+ from dataeval.utils.data import Metadata
18
19
 
19
20
  with contextlib.suppress(ImportError):
20
21
  from matplotlib.figure import Figure
@@ -23,8 +24,8 @@ with contextlib.suppress(ImportError):
23
24
  @dataclass(frozen=True)
24
25
  class BalanceOutput(Output):
25
26
  """
26
- Output class for :func:`balance` :term:`bias<Bias>` metric.
27
-
27
+ Output class for :func:`.balance` :term:`bias<Bias>` metric.
28
+
28
29
  Attributes
29
30
  ----------
30
31
  balance : NDArray[np.float64]
@@ -35,21 +36,62 @@ class BalanceOutput(Output):
35
36
  Estimate of mutual information between metadata factors and individual class labels
36
37
  factor_names : list[str]
37
38
  Names of each metadata factor
38
- class_list : NDArray
39
- Array of the class labels present in the dataset
39
+ class_names : list[str]
40
+ List of the class labels present in the dataset
40
41
  """
41
42
 
42
43
  balance: NDArray[np.float64]
43
44
  factors: NDArray[np.float64]
44
45
  classwise: NDArray[np.float64]
45
46
  factor_names: list[str]
46
- class_list: NDArray[Any]
47
+ class_names: list[str]
48
+
49
+ @overload
50
+ def _by_factor_type(
51
+ self,
52
+ attr: Literal["factor_names"],
53
+ factor_type: Literal["discrete", "continuous", "both"],
54
+ ) -> list[str]: ...
55
+
56
+ @overload
57
+ def _by_factor_type(
58
+ self,
59
+ attr: Literal["balance", "factors", "classwise"],
60
+ factor_type: Literal["discrete", "continuous", "both"],
61
+ ) -> NDArray[np.float64]: ...
62
+
63
+ def _by_factor_type(
64
+ self,
65
+ attr: Literal["balance", "factors", "classwise", "factor_names"],
66
+ factor_type: Literal["discrete", "continuous", "both"],
67
+ ) -> NDArray[np.float64] | list[str]:
68
+ # if not filtering by factor_type then just return the requested attribute without mask
69
+ if factor_type == "both":
70
+ return getattr(self, attr)
71
+
72
+ # create the mask for the selected factor_type
73
+ mask_lambda = (
74
+ (lambda x: "-continuous" not in x) if factor_type == "discrete" else (lambda x: "-discrete" not in x)
75
+ )
76
+
77
+ # return the masked attribute
78
+ if attr == "factor_names":
79
+ return [x.replace(f"-{factor_type}", "") for x in self.factor_names if mask_lambda(x)]
80
+ else:
81
+ factor_type_mask = [mask_lambda(x) for x in self.factor_names]
82
+ if attr == "factors":
83
+ return self.factors[factor_type_mask[1:]][:, factor_type_mask[1:]]
84
+ elif attr == "balance":
85
+ return self.balance[factor_type_mask]
86
+ elif attr == "classwise":
87
+ return self.classwise[:, factor_type_mask]
47
88
 
48
89
  def plot(
49
90
  self,
50
91
  row_labels: list[Any] | NDArray[Any] | None = None,
51
92
  col_labels: list[Any] | NDArray[Any] | None = None,
52
93
  plot_classwise: bool = False,
94
+ factor_type: Literal["discrete", "continuous", "both"] = "discrete",
53
95
  ) -> Figure:
54
96
  """
55
97
  Plot a heatmap of balance information
@@ -62,15 +104,17 @@ class BalanceOutput(Output):
62
104
  List/Array containing the labels for columns in the histogram
63
105
  plot_classwise : bool, default False
64
106
  Whether to plot per-class balance instead of global balance
107
+ factor_type : "discrete", "continuous", or "both", default "discrete"
108
+ Whether to plot discretized values, continuous values, or to include both
65
109
  """
66
110
  if plot_classwise:
67
111
  if row_labels is None:
68
- row_labels = self.class_list
112
+ row_labels = self.class_names
69
113
  if col_labels is None:
70
- col_labels = self.factor_names
114
+ col_labels = self._by_factor_type("factor_names", factor_type)
71
115
 
72
116
  fig = heatmap(
73
- self.classwise,
117
+ self._by_factor_type("classwise", factor_type),
74
118
  row_labels,
75
119
  col_labels,
76
120
  xlabel="Factors",
@@ -79,13 +123,19 @@ class BalanceOutput(Output):
79
123
  )
80
124
  else:
81
125
  # Combine balance and factors results
82
- data = np.concatenate([self.balance[np.newaxis, 1:], self.factors], axis=0)
126
+ data = np.concatenate(
127
+ [
128
+ self._by_factor_type("balance", factor_type)[np.newaxis, 1:],
129
+ self._by_factor_type("factors", factor_type),
130
+ ],
131
+ axis=0,
132
+ )
83
133
  # Create a mask for the upper triangle of the symmetrical array, ignoring the diagonal
84
134
  mask = np.triu(data + 1, k=0) < 1
85
135
  # Finalize the data for the plot, last row is last factor x last factor so it gets dropped
86
136
  heat_data = np.where(mask, np.nan, data)[:-1]
87
137
  # Creating label array for heat map axes
88
- heat_labels = self.factor_names
138
+ heat_labels = self._by_factor_type("factor_names", factor_type)
89
139
 
90
140
  if row_labels is None:
91
141
  row_labels = heat_labels[:-1]
@@ -128,7 +178,7 @@ def balance(
128
178
  Parameters
129
179
  ----------
130
180
  metadata : Metadata
131
- Preprocessed metadata from :func:`dataeval.utils.metadata.preprocess`
181
+ Preprocessed metadata
132
182
  num_neighbors : int, default 5
133
183
  Number of points to consider as neighbors
134
184
 
@@ -184,7 +234,7 @@ def balance(
184
234
  mi = np.full((num_factors, num_factors), np.nan, dtype=np.float32)
185
235
  data = np.hstack((metadata.class_labels[:, np.newaxis], metadata.discrete_data))
186
236
  discretized_data = data
187
- if metadata.continuous_data is not None:
237
+ if len(metadata.continuous_data):
188
238
  data = np.hstack((data, metadata.continuous_data))
189
239
  discrete_idx = [metadata.discrete_factor_names.index(name) for name in metadata.continuous_factor_names]
190
240
  discretized_data = np.hstack((discretized_data, metadata.discrete_data[:, discrete_idx]))
@@ -218,7 +268,7 @@ def balance(
218
268
  factors = nmi[1:, 1:]
219
269
 
220
270
  # assume class is a factor
221
- num_classes = metadata.class_names.size
271
+ num_classes = len(metadata.class_names)
222
272
  classwise_mi = np.full((num_classes, num_factors), np.nan, dtype=np.float32)
223
273
 
224
274
  # classwise targets
@@ -8,12 +8,12 @@ from dataclasses import dataclass
8
8
  from typing import Any, Literal
9
9
 
10
10
  import numpy as np
11
- from numpy.typing import ArrayLike, NDArray
11
+ from numpy.typing import NDArray
12
12
  from scipy.spatial.distance import pdist, squareform
13
13
 
14
- from dataeval.interop import to_numpy
15
- from dataeval.output import Output, set_metadata
16
- from dataeval.utils.shared import flatten
14
+ from dataeval._output import Output, set_metadata
15
+ from dataeval.typing import ArrayLike
16
+ from dataeval.utils._array import ensure_embeddings, flatten, to_numpy
17
17
 
18
18
  with contextlib.suppress(ImportError):
19
19
  from matplotlib.figure import Figure
@@ -71,21 +71,21 @@ def _plot(images: NDArray[Any], num_images: int) -> Figure:
71
71
  @dataclass(frozen=True)
72
72
  class CoverageOutput(Output):
73
73
  """
74
- Output class for :func:`coverage` :term:`bias<Bias>` metric.
74
+ Output class for :func:`.coverage` :term:`bias<Bias>` metric.
75
75
 
76
76
  Attributes
77
77
  ----------
78
- indices : NDArray[np.intp]
78
+ uncovered_indices : NDArray[np.intp]
79
79
  Array of uncovered indices
80
- radii : NDArray[np.float64]
80
+ critical_value_radii : NDArray[np.float64]
81
81
  Array of critical value radii
82
- critical_value : float
82
+ coverage_radius : float
83
83
  Radius for :term:`coverage<Coverage>`
84
84
  """
85
85
 
86
- indices: NDArray[np.intp]
87
- radii: NDArray[np.float64]
88
- critical_value: float
86
+ uncovered_indices: NDArray[np.intp]
87
+ critical_value_radii: NDArray[np.float64]
88
+ coverage_radius: float
89
89
 
90
90
  def plot(self, images: ArrayLike, top_k: int = 6) -> Figure:
91
91
  """
@@ -102,8 +102,9 @@ class CoverageOutput(Output):
102
102
  -------
103
103
  matplotlib.figure.Figure
104
104
  """
105
+
105
106
  # Determine which images to plot
106
- highest_uncovered_indices = self.indices[:top_k]
107
+ highest_uncovered_indices = self.uncovered_indices[:top_k]
107
108
 
108
109
  # Grab the images
109
110
  images = to_numpy(images)
@@ -119,7 +120,7 @@ class CoverageOutput(Output):
119
120
  def coverage(
120
121
  embeddings: ArrayLike,
121
122
  radius_type: Literal["adaptive", "naive"] = "adaptive",
122
- k: int = 20,
123
+ num_observations: int = 20,
123
124
  percent: float = 0.01,
124
125
  ) -> CoverageOutput:
125
126
  """
@@ -128,11 +129,11 @@ def coverage(
128
129
  Parameters
129
130
  ----------
130
131
  embeddings : ArrayLike, shape - (N, P)
131
- A dataset in an ArrayLike format.
132
- Function expects the data to have 2 dimensions, N number of observations in a P-dimesionial space.
132
+ Dataset embeddings as unit interval [0, 1].
133
+ Function expects the data to have 2 dimensions, N number of observations in a P-dimensional space.
133
134
  radius_type : {"adaptive", "naive"}, default "adaptive"
134
135
  The function used to determine radius.
135
- k : int, default 20
136
+ num_observations : int, default 20
136
137
  Number of observations required in order to be covered.
137
138
  [1] suggests that a minimum of 20-50 samples is necessary.
138
139
  percent : float, default 0.01
@@ -146,7 +147,9 @@ def coverage(
146
147
  Raises
147
148
  ------
148
149
  ValueError
149
- If length of :term:`embeddings<Embeddings>` is less than or equal to k
150
+ If embeddings are not unit interval [0-1]
151
+ ValueError
152
+ If length of :term:`embeddings<Embeddings>` is less than or equal to num_observations
150
153
  ValueError
151
154
  If radius_type is unknown
152
155
 
@@ -157,10 +160,10 @@ def coverage(
157
160
  Example
158
161
  -------
159
162
  >>> results = coverage(embeddings)
160
- >>> results.indices
163
+ >>> results.uncovered_indices
161
164
  array([447, 412, 8, 32, 63])
162
- >>> results.critical_value
163
- 0.8459038956941765
165
+ >>> results.coverage_radius
166
+ 0.17592147193757596
164
167
 
165
168
  Reference
166
169
  ---------
@@ -169,26 +172,29 @@ def coverage(
169
172
  [1] Seymour Sudman. 1976. Applied sampling. Academic Press New York (1976).
170
173
  """
171
174
 
172
- # Calculate distance matrix, look at the (k+1)th farthest neighbor for each image.
173
- embeddings = to_numpy(embeddings)
174
- n = len(embeddings)
175
- if n <= k:
175
+ # Calculate distance matrix, look at the (num_observations + 1)th farthest neighbor for each image.
176
+ embeddings = ensure_embeddings(embeddings, dtype=np.float64, unit_interval=True)
177
+ len_embeddings = len(embeddings)
178
+ if len_embeddings <= num_observations:
176
179
  raise ValueError(
177
- f"Number of observations n={n} is less than or equal to the specified number of neighbors k={k}."
180
+ f"Length of embeddings ({len_embeddings}) is less than or equal to the specified number of \
181
+ observations ({num_observations})."
178
182
  )
179
- mat = squareform(pdist(flatten(embeddings))).astype(np.float64)
180
- sorted_dists = np.sort(mat, axis=1)
181
- crit = sorted_dists[:, k + 1]
183
+ embeddings_matrix = squareform(pdist(flatten(embeddings))).astype(np.float64)
184
+ sorted_dists = np.sort(embeddings_matrix, axis=1)
185
+ critical_value_radii = sorted_dists[:, num_observations + 1]
182
186
 
183
187
  d = embeddings.shape[1]
184
188
  if radius_type == "naive":
185
- rho = (1 / math.sqrt(math.pi)) * ((2 * k * math.gamma(d / 2 + 1)) / (n)) ** (1 / d)
186
- pvals = np.where(crit > rho)[0]
189
+ coverage_radius = (1 / math.sqrt(math.pi)) * (
190
+ (2 * num_observations * math.gamma(d / 2 + 1)) / (len_embeddings)
191
+ ) ** (1 / d)
192
+ uncovered_indices = np.where(critical_value_radii > coverage_radius)[0]
187
193
  elif radius_type == "adaptive":
188
- # Use data adaptive cutoff as rho
189
- selection = int(max(n * percent, 1))
190
- pvals = np.argsort(crit)[::-1][:selection]
191
- rho = float(np.mean(np.sort(crit)[::-1][selection - 1 : selection + 1]))
194
+ # Use data adaptive cutoff as coverage_radius
195
+ selection = int(max(len_embeddings * percent, 1))
196
+ uncovered_indices = np.argsort(critical_value_radii)[::-1][:selection]
197
+ coverage_radius = float(np.mean(np.sort(critical_value_radii)[::-1][selection - 1 : selection + 1]))
192
198
  else:
193
199
  raise ValueError(f"{radius_type} is an invalid radius type. Expected 'adaptive' or 'naive'")
194
- return CoverageOutput(pvals, crit, rho)
200
+ return CoverageOutput(uncovered_indices, critical_value_radii, coverage_radius)
@@ -8,12 +8,14 @@ from typing import Any, Literal
8
8
 
9
9
  import numpy as np
10
10
  import scipy as sp
11
- from numpy.typing import ArrayLike, NDArray
11
+ from numpy.typing import NDArray
12
12
 
13
- from dataeval.output import Output, set_metadata
14
- from dataeval.utils.metadata import Metadata, get_counts
15
- from dataeval.utils.plot import heatmap
16
- from dataeval.utils.shared import get_method
13
+ from dataeval._output import Output, set_metadata
14
+ from dataeval.typing import ArrayLike
15
+ from dataeval.utils._bin import get_counts
16
+ from dataeval.utils._method import get_method
17
+ from dataeval.utils._plot import heatmap
18
+ from dataeval.utils.data import Metadata
17
19
 
18
20
  with contextlib.suppress(ImportError):
19
21
  from matplotlib.figure import Figure
@@ -37,7 +39,7 @@ def _plot(labels: NDArray[Any], bar_heights: NDArray[Any]) -> Figure:
37
39
  """
38
40
  import matplotlib.pyplot as plt
39
41
 
40
- fig, ax = plt.subplots(figsize=(10, 10))
42
+ fig, ax = plt.subplots(figsize=(8, 8))
41
43
 
42
44
  ax.bar(labels, bar_heights)
43
45
  ax.set_xlabel("Factors")
@@ -51,7 +53,7 @@ def _plot(labels: NDArray[Any], bar_heights: NDArray[Any]) -> Figure:
51
53
  @dataclass(frozen=True)
52
54
  class DiversityOutput(Output):
53
55
  """
54
- Output class for :func:`diversity` :term:`bias<Bias>` metric.
56
+ Output class for :func:`.diversity` :term:`bias<Bias>` metric.
55
57
 
56
58
  Attributes
57
59
  ----------
@@ -61,14 +63,14 @@ class DiversityOutput(Output):
61
63
  Classwise diversity index [n_class x n_factor]
62
64
  factor_names : list[str]
63
65
  Names of each metadata factor
64
- class_list : NDArray[Any]
66
+ class_names : list[str]
65
67
  Class labels for each value in the dataset
66
68
  """
67
69
 
68
70
  diversity_index: NDArray[np.double]
69
71
  classwise: NDArray[np.double]
70
72
  factor_names: list[str]
71
- class_list: NDArray[Any]
73
+ class_names: list[str]
72
74
 
73
75
  def plot(
74
76
  self,
@@ -90,7 +92,7 @@ class DiversityOutput(Output):
90
92
  """
91
93
  if plot_classwise:
92
94
  if row_labels is None:
93
- row_labels = self.class_list
95
+ row_labels = self.class_names
94
96
  if col_labels is None:
95
97
  col_labels = self.factor_names
96
98
 
@@ -191,6 +193,9 @@ def diversity_simpson(
191
193
  return ev_index
192
194
 
193
195
 
196
+ _DIVERSITY_FN_MAP = {"simpson": diversity_simpson, "shannon": diversity_shannon}
197
+
198
+
194
199
  @set_metadata
195
200
  def diversity(
196
201
  metadata: Metadata,
@@ -210,7 +215,7 @@ def diversity(
210
215
  Parameters
211
216
  ----------
212
217
  metadata : Metadata
213
- Preprocessed metadata from :func:`dataeval.utils.metadata.preprocess`
218
+ Preprocessed metadata
214
219
  method : "simpson" or "shannon", default "simpson"
215
220
  The methodology used for defining diversity
216
221
 
@@ -251,7 +256,7 @@ def diversity(
251
256
  --------
252
257
  scipy.stats.entropy
253
258
  """
254
- diversity_fn = get_method({"simpson": diversity_simpson, "shannon": diversity_shannon}, method)
259
+ diversity_fn = get_method(_DIVERSITY_FN_MAP, method)
255
260
  discretized_data = np.hstack((metadata.class_labels[:, np.newaxis], metadata.discrete_data))
256
261
  cnts = get_counts(discretized_data)
257
262
  num_bins = np.bincount(np.nonzero(cnts)[1])
@@ -2,40 +2,86 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
+ import contextlib
5
6
  import warnings
6
7
  from dataclasses import dataclass
7
8
  from typing import Any, Generic, TypeVar
8
9
 
9
10
  import numpy as np
10
- from numpy.typing import ArrayLike, NDArray
11
+ from numpy.typing import NDArray
11
12
  from scipy.stats import chisquare
12
13
  from scipy.stats.contingency import chi2_contingency, crosstab
13
14
 
14
- from dataeval.interop import as_numpy, to_numpy
15
- from dataeval.output import Output, set_metadata
16
- from dataeval.utils.metadata import Metadata
15
+ from dataeval._output import Output, set_metadata
16
+ from dataeval.typing import ArrayLike
17
+ from dataeval.utils._array import as_numpy
18
+ from dataeval.utils.data import Metadata
19
+
20
+ with contextlib.suppress(ImportError):
21
+ import pandas as pd
17
22
 
18
23
  TData = TypeVar("TData", np.float64, NDArray[np.float64])
19
24
 
20
25
 
21
26
  @dataclass(frozen=True)
22
- class ParityOutput(Generic[TData], Output):
27
+ class BaseParityOutput(Generic[TData], Output):
28
+ score: TData
29
+ p_value: TData
30
+
31
+ def to_dataframe(self) -> pd.DataFrame:
32
+ """
33
+ Exports the parity output results to a pandas DataFrame.
34
+
35
+ Returns
36
+ -------
37
+ pd.DataFrame
38
+ """
39
+ import pandas as pd
40
+
41
+ return pd.DataFrame(
42
+ index=self.factor_names, # type: ignore - list[str] is documented as acceptable index type
43
+ data={
44
+ "score": self.score.round(2),
45
+ "p-value": self.p_value.round(2),
46
+ },
47
+ )
48
+
49
+
50
+ @dataclass(frozen=True)
51
+ class LabelParityOutput(BaseParityOutput[np.float64]):
52
+ """
53
+ Output class for :func:`.label_parity` :term:`bias<Bias>` metrics.
54
+
55
+ Attributes
56
+ ----------
57
+ score : np.float64
58
+ chi-squared score(s) of the test
59
+ p_value : np.float64
60
+ p-value(s) of the test
61
+ """
62
+
63
+
64
+ @dataclass(frozen=True)
65
+ class ParityOutput(BaseParityOutput[NDArray[np.float64]]):
23
66
  """
24
- Output class for :func:`parity` and :func:`label_parity` :term:`bias<Bias>` metrics.
67
+ Output class for :func:`.parity` :term:`bias<Bias>` metrics.
25
68
 
26
69
  Attributes
27
70
  ----------
28
- score : np.float64 | NDArray[np.float64]
71
+ score : NDArray[np.float64]
29
72
  chi-squared score(s) of the test
30
- p_value : np.float64 | NDArray[np.float64]
73
+ p_value : NDArray[np.float64]
31
74
  p-value(s) of the test
32
- metadata_names : list[str] | None
75
+ factor_names : list[str]
33
76
  Names of each metadata factor
77
+ insufficient_data: dict
78
+ Dictionary of metadata factors with less than 5 class occurrences per value
34
79
  """
35
80
 
36
- score: TData
37
- p_value: TData
38
- metadata_names: list[str] | None
81
+ # score: NDArray[np.float64]
82
+ # p_value: NDArray[np.float64]
83
+ factor_names: list[str]
84
+ insufficient_data: dict[str, dict[int, dict[str, int]]]
39
85
 
40
86
 
41
87
  def normalize_expected_dist(expected_dist: NDArray[Any], observed_dist: NDArray[Any]) -> NDArray[Any]:
@@ -109,7 +155,7 @@ def validate_dist(label_dist: NDArray[Any], label_name: str) -> None:
109
155
  raise ValueError(f"No labels found in the {label_name} dataset")
110
156
  if np.any(label_dist < 5):
111
157
  warnings.warn(
112
- f"Labels {np.where(label_dist<5)[0]} in {label_name}"
158
+ f"Labels {np.where(label_dist < 5)[0]} in {label_name}"
113
159
  " dataset have frequencies less than 5. This may lead"
114
160
  " to invalid chi-squared evaluation.",
115
161
  UserWarning,
@@ -121,7 +167,7 @@ def label_parity(
121
167
  expected_labels: ArrayLike,
122
168
  observed_labels: ArrayLike,
123
169
  num_classes: int | None = None,
124
- ) -> ParityOutput[np.float64]:
170
+ ) -> LabelParityOutput:
125
171
  """
126
172
  Calculate the chi-square statistic to assess the :term:`parity<Parity>` \
127
173
  between expected and observed label distributions.
@@ -142,7 +188,7 @@ def label_parity(
142
188
 
143
189
  Returns
144
190
  -------
145
- ParityOutput[np.float64]
191
+ LabelParityOutput
146
192
  chi-squared score and :term`P-Value` of the test
147
193
 
148
194
  Raises
@@ -171,7 +217,7 @@ def label_parity(
171
217
  >>> expected_labels = rng.choice([0, 1, 2, 3, 4], (100))
172
218
  >>> observed_labels = rng.choice([2, 3, 0, 4, 1], (100))
173
219
  >>> label_parity(expected_labels, observed_labels)
174
- ParityOutput(score=14.007374204742625, p_value=0.0072715574616218, metadata_names=None)
220
+ LabelParityOutput(score=14.007374204742625, p_value=0.0072715574616218)
175
221
  """
176
222
 
177
223
  # Calculate
@@ -179,8 +225,8 @@ def label_parity(
179
225
  num_classes = 0
180
226
 
181
227
  # Calculate the class frequencies associated with the datasets
182
- observed_dist = np.bincount(to_numpy(observed_labels), minlength=num_classes)
183
- expected_dist = np.bincount(to_numpy(expected_labels), minlength=num_classes)
228
+ observed_dist = np.bincount(as_numpy(observed_labels), minlength=num_classes)
229
+ expected_dist = np.bincount(as_numpy(expected_labels), minlength=num_classes)
184
230
 
185
231
  # Validate
186
232
  validate_dist(observed_dist, "observed")
@@ -202,11 +248,11 @@ def label_parity(
202
248
  )
203
249
 
204
250
  cs, p = chisquare(f_obs=observed_dist, f_exp=expected_dist)
205
- return ParityOutput(cs, p, None)
251
+ return LabelParityOutput(cs, p)
206
252
 
207
253
 
208
254
  @set_metadata
209
- def parity(metadata: Metadata) -> ParityOutput[NDArray[np.float64]]:
255
+ def parity(metadata: Metadata) -> ParityOutput:
210
256
  """
211
257
  Calculate chi-square statistics to assess the linear relationship \
212
258
  between multiple factors and class labels.
@@ -218,7 +264,7 @@ def parity(metadata: Metadata) -> ParityOutput[NDArray[np.float64]]:
218
264
  Parameters
219
265
  ----------
220
266
  metadata : Metadata
221
- Preprocessed metadata from :func:`dataeval.utils.metadata.preprocess`
267
+ Preprocessed metadata
222
268
 
223
269
  Returns
224
270
  -------
@@ -250,22 +296,21 @@ def parity(metadata: Metadata) -> ParityOutput[NDArray[np.float64]]:
250
296
  --------
251
297
  Randomly creating some "continuous" and categorical variables using ``np.random.default_rng``
252
298
 
253
- >>> from dataeval.utils.metadata import preprocess
254
- >>> rng = np.random.default_rng(175)
255
- >>> labels = rng.choice([0, 1, 2], (100))
256
- >>> metadata_dict = {
257
- ... "age": list(rng.choice([25, 30, 35, 45], (100))),
258
- ... "income": list(rng.choice([50000, 65000, 80000], (100))),
259
- ... "gender": list(rng.choice(["M", "F"], (100))),
260
- ... }
261
- >>> continuous_factor_bincounts = {"age": 4, "income": 3}
262
- >>> metadata = preprocess(metadata_dict, labels, continuous_factor_bincounts)
299
+ >>> metadata = generate_random_metadata(
300
+ ... labels=["doctor", "artist", "teacher"],
301
+ ... factors={
302
+ ... "age": [25, 30, 35, 45],
303
+ ... "income": [50000, 65000, 80000],
304
+ ... "gender": ["M", "F"]},
305
+ ... length=100,
306
+ ... random_seed=175)
307
+ >>> metadata.continuous_factor_bins = {"age": 4, "income": 3}
263
308
  >>> parity(metadata)
264
- ParityOutput(score=array([7.35731943, 5.46711299, 0.51506212]), p_value=array([0.28906231, 0.24263543, 0.77295762]), metadata_names=['age', 'income', 'gender'])
309
+ ParityOutput(score=array([7.35731943, 5.46711299, 0.51506212]), p_value=array([0.28906231, 0.24263543, 0.77295762]), factor_names=['age', 'income', 'gender'], insufficient_data={'age': {3: {'artist': 4}, 4: {'artist': 4, 'teacher': 3}}, 'income': {1: {'artist': 3}}})
265
310
  """ # noqa: E501
266
311
  chi_scores = np.zeros(metadata.discrete_data.shape[1])
267
312
  p_values = np.zeros_like(chi_scores)
268
- not_enough_data = {}
313
+ insufficient_data = {}
269
314
  for i, col_data in enumerate(metadata.discrete_data.T):
270
315
  # Builds a contingency matrix where entry at index (r,c) represents
271
316
  # the frequency of current_factor_name achieving value unique_factor_values[r]
@@ -279,14 +324,14 @@ def parity(metadata: Metadata) -> ParityOutput[NDArray[np.float64]]:
279
324
  current_factor_name = metadata.discrete_factor_names[i]
280
325
  for int_factor, int_class in zip(counts[0], counts[1]):
281
326
  if contingency_matrix[int_factor, int_class] > 0:
282
- factor_category = unique_factor_values[int_factor]
283
- if current_factor_name not in not_enough_data:
284
- not_enough_data[current_factor_name] = {}
285
- if factor_category not in not_enough_data[current_factor_name]:
286
- not_enough_data[current_factor_name][factor_category] = []
287
- not_enough_data[current_factor_name][factor_category].append(
288
- (metadata.class_names[int_class], int(contingency_matrix[int_factor, int_class]))
289
- )
327
+ factor_category = unique_factor_values[int_factor].item()
328
+ if current_factor_name not in insufficient_data:
329
+ insufficient_data[current_factor_name] = {}
330
+ if factor_category not in insufficient_data[current_factor_name]:
331
+ insufficient_data[current_factor_name][factor_category] = {}
332
+ class_name = metadata.class_names[int_class]
333
+ class_count = contingency_matrix[int_factor, int_class].item()
334
+ insufficient_data[current_factor_name][factor_category][class_name] = class_count
290
335
 
291
336
  # This deletes rows containing only zeros,
292
337
  # because scipy.stats.chi2_contingency fails when there are rows containing only zeros.
@@ -299,24 +344,7 @@ def parity(metadata: Metadata) -> ParityOutput[NDArray[np.float64]]:
299
344
  chi_scores[i] = chi2
300
345
  p_values[i] = p
301
346
 
302
- if not_enough_data:
303
- factor_msg = []
304
- for factor, fact_dict in not_enough_data.items():
305
- stacked_msg = []
306
- for key, value in fact_dict.items():
307
- msg = []
308
- for item in value:
309
- msg.append(f"label {item[0]}: {item[1]} occurrences")
310
- flat_msg = "\n\t\t".join(msg)
311
- stacked_msg.append(f"value {key} - {flat_msg}\n\t")
312
- factor_msg.append(factor + " - " + "".join(stacked_msg))
313
-
314
- message = "\n".join(factor_msg)
315
-
316
- warnings.warn(
317
- f"The following factors did not meet the recommended 5 occurrences for each value-label combination. \n\
318
- Recommend rerunning parity after adjusting the following factor-value-label combinations: \n{message}",
319
- UserWarning,
320
- )
347
+ if insufficient_data:
348
+ warnings.warn("Some factors did not meet the recommended 5 occurrences for each value-label combination.")
321
349
 
322
- return ParityOutput(chi_scores, p_values, metadata.discrete_factor_names)
350
+ return ParityOutput(chi_scores, p_values, metadata.discrete_factor_names, insufficient_data)