dataeval 0.72.0__py3-none-any.whl → 0.72.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. dataeval/__init__.py +4 -4
  2. dataeval/detectors/__init__.py +4 -3
  3. dataeval/detectors/drift/__init__.py +10 -11
  4. dataeval/{_internal/detectors → detectors}/drift/base.py +51 -102
  5. dataeval/{_internal/detectors → detectors}/drift/cvm.py +9 -8
  6. dataeval/{_internal/detectors → detectors}/drift/ks.py +11 -10
  7. dataeval/{_internal/detectors → detectors}/drift/mmd.py +33 -34
  8. dataeval/{_internal/detectors → detectors}/drift/torch.py +15 -13
  9. dataeval/{_internal/detectors → detectors}/drift/uncertainty.py +12 -9
  10. dataeval/detectors/drift/updates.py +61 -0
  11. dataeval/detectors/linters/__init__.py +3 -3
  12. dataeval/{_internal/detectors → detectors/linters}/clusterer.py +47 -45
  13. dataeval/{_internal/detectors → detectors/linters}/duplicates.py +20 -10
  14. dataeval/{_internal/detectors → detectors/linters}/merged_stats.py +3 -1
  15. dataeval/{_internal/detectors → detectors/linters}/outliers.py +19 -26
  16. dataeval/detectors/ood/__init__.py +8 -16
  17. dataeval/{_internal/detectors → detectors}/ood/ae.py +9 -9
  18. dataeval/{_internal/detectors → detectors}/ood/aegmm.py +10 -30
  19. dataeval/{_internal/detectors → detectors}/ood/base.py +27 -21
  20. dataeval/{_internal/detectors → detectors}/ood/llr.py +27 -23
  21. dataeval/detectors/ood/metadata_ks_compare.py +99 -0
  22. dataeval/detectors/ood/metadata_least_likely.py +119 -0
  23. dataeval/detectors/ood/metadata_ood_mi.py +92 -0
  24. dataeval/{_internal/detectors → detectors}/ood/vae.py +11 -13
  25. dataeval/{_internal/detectors → detectors}/ood/vaegmm.py +10 -32
  26. dataeval/{_internal/interop.py → interop.py} +12 -7
  27. dataeval/metrics/__init__.py +1 -1
  28. dataeval/metrics/bias/__init__.py +4 -4
  29. dataeval/{_internal/metrics → metrics/bias}/balance.py +70 -4
  30. dataeval/{_internal/metrics → metrics/bias}/coverage.py +10 -8
  31. dataeval/{_internal/metrics → metrics/bias}/diversity.py +54 -20
  32. dataeval/metrics/bias/metadata.py +275 -0
  33. dataeval/{_internal/metrics → metrics/bias}/parity.py +21 -17
  34. dataeval/metrics/estimators/__init__.py +3 -3
  35. dataeval/{_internal/metrics → metrics/estimators}/ber.py +31 -28
  36. dataeval/{_internal/metrics → metrics/estimators}/divergence.py +15 -16
  37. dataeval/{_internal/metrics → metrics/estimators}/uap.py +8 -6
  38. dataeval/metrics/stats/__init__.py +7 -7
  39. dataeval/{_internal/metrics → metrics}/stats/base.py +66 -40
  40. dataeval/{_internal/metrics → metrics}/stats/boxratiostats.py +19 -15
  41. dataeval/{_internal/metrics → metrics}/stats/datasetstats.py +19 -17
  42. dataeval/{_internal/metrics → metrics}/stats/dimensionstats.py +12 -10
  43. dataeval/metrics/stats/hashstats.py +156 -0
  44. dataeval/{_internal/metrics → metrics}/stats/labelstats.py +8 -6
  45. dataeval/{_internal/metrics → metrics}/stats/pixelstats.py +12 -11
  46. dataeval/{_internal/metrics → metrics}/stats/visualstats.py +14 -13
  47. dataeval/{_internal/output.py → output.py} +26 -6
  48. dataeval/utils/__init__.py +8 -4
  49. dataeval/utils/image.py +71 -0
  50. dataeval/utils/shared.py +151 -0
  51. dataeval/utils/split_dataset.py +486 -0
  52. dataeval/utils/tensorflow/__init__.py +9 -7
  53. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/autoencoder.py +64 -68
  54. dataeval/{_internal/models/tensorflow/losses.py → utils/tensorflow/_internal/loss.py} +10 -9
  55. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/pixelcnn.py +18 -22
  56. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/trainer.py +3 -1
  57. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/utils.py +18 -18
  58. dataeval/utils/tensorflow/loss/__init__.py +6 -2
  59. dataeval/utils/torch/__init__.py +7 -3
  60. dataeval/{_internal/models/pytorch → utils/torch}/blocks.py +19 -14
  61. dataeval/{_internal → utils/torch}/datasets.py +49 -43
  62. dataeval/utils/torch/models.py +138 -0
  63. dataeval/{_internal/models/pytorch/autoencoder.py → utils/torch/trainer.py} +12 -141
  64. dataeval/{_internal → utils/torch}/utils.py +3 -1
  65. dataeval/workflows/__init__.py +1 -1
  66. dataeval/{_internal/workflows → workflows}/sufficiency.py +42 -37
  67. {dataeval-0.72.0.dist-info → dataeval-0.72.2.dist-info}/METADATA +7 -5
  68. dataeval-0.72.2.dist-info/RECORD +72 -0
  69. dataeval/_internal/detectors/__init__.py +0 -0
  70. dataeval/_internal/detectors/drift/__init__.py +0 -0
  71. dataeval/_internal/detectors/ood/__init__.py +0 -0
  72. dataeval/_internal/metrics/__init__.py +0 -0
  73. dataeval/_internal/metrics/stats/hashstats.py +0 -75
  74. dataeval/_internal/metrics/utils.py +0 -447
  75. dataeval/_internal/models/__init__.py +0 -0
  76. dataeval/_internal/models/pytorch/__init__.py +0 -0
  77. dataeval/_internal/models/pytorch/utils.py +0 -67
  78. dataeval/_internal/models/tensorflow/__init__.py +0 -0
  79. dataeval/_internal/workflows/__init__.py +0 -0
  80. dataeval/detectors/drift/kernels/__init__.py +0 -10
  81. dataeval/detectors/drift/updates/__init__.py +0 -7
  82. dataeval/utils/tensorflow/models/__init__.py +0 -9
  83. dataeval/utils/tensorflow/recon/__init__.py +0 -3
  84. dataeval/utils/torch/datasets/__init__.py +0 -12
  85. dataeval/utils/torch/models/__init__.py +0 -11
  86. dataeval/utils/torch/trainer/__init__.py +0 -7
  87. dataeval-0.72.0.dist-info/RECORD +0 -80
  88. /dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/gmm.py +0 -0
  89. {dataeval-0.72.0.dist-info → dataeval-0.72.2.dist-info}/LICENSE.txt +0 -0
  90. {dataeval-0.72.0.dist-info → dataeval-0.72.2.dist-info}/WHEEL +0 -0
@@ -1,40 +1,77 @@
1
1
  from __future__ import annotations
2
2
 
3
+ __all__ = ["DiversityOutput", "diversity"]
4
+
3
5
  from dataclasses import dataclass
4
- from typing import Literal, Mapping
6
+ from typing import Any, Literal, Mapping
5
7
 
6
8
  import numpy as np
7
9
  from numpy.typing import ArrayLike, NDArray
8
10
 
9
- from dataeval._internal.metrics.utils import entropy, get_counts, get_method, get_num_bins, preprocess_metadata
10
- from dataeval._internal.output import OutputMetadata, set_metadata
11
+ from dataeval.metrics.bias.metadata import entropy, get_counts, get_num_bins, heatmap, preprocess_metadata
12
+ from dataeval.output import OutputMetadata, set_metadata
13
+ from dataeval.utils.shared import get_method
11
14
 
12
15
 
13
16
  @dataclass(frozen=True)
14
17
  class DiversityOutput(OutputMetadata):
15
18
  """
16
- Output class for :func:`diversity` bias metric
19
+ Output class for :func:`diversity` :term:`bias<Bias>` metric
17
20
 
18
21
  Attributes
19
22
  ----------
20
23
  diversity_index : NDArray[np.float64]
21
- Diversity index for classes and factors
24
+ :term:`Diversity` index for classes and factors
22
25
  classwise : NDArray[np.float64]
23
26
  Classwise diversity index [n_class x n_factor]
27
+ class_list: NDArray[np.int64]
28
+ Class labels for each value in the dataset
29
+ metadata_names: list[str]
30
+ Names of each metadata factor
24
31
  """
25
32
 
26
33
  diversity_index: NDArray[np.float64]
27
34
  classwise: NDArray[np.float64]
28
35
 
36
+ class_list: NDArray[np.int64]
37
+ metadata_names: list[str]
38
+
39
+ method: Literal["shannon", "simpson"]
40
+
41
+ def plot(self, row_labels: NDArray[Any] | None = None, col_labels: NDArray[Any] | None = None) -> None:
42
+ """
43
+ Plot a heatmap of diversity information
44
+
45
+ Parameters
46
+ ----------
47
+ row_labels: NDArray | None, default None
48
+ Array containing the labels for rows in the histogram
49
+ col_labels: NDArray | None, default None
50
+ Array containing the labels for columns in the histogram
51
+ """
52
+ if row_labels is None:
53
+ row_labels = np.unique(self.class_list)
54
+ if col_labels is None:
55
+ col_labels = np.array(self.metadata_names)
56
+
57
+ heatmap(
58
+ self.classwise,
59
+ row_labels,
60
+ col_labels,
61
+ xlabel="Factors",
62
+ ylabel="Class",
63
+ cbarlabel=f"Normalized {self.method.title()} Index",
64
+ )
65
+
29
66
 
30
67
  def diversity_shannon(
31
- data: NDArray,
68
+ data: NDArray[Any],
32
69
  names: list[str],
33
70
  is_categorical: list[bool],
34
71
  subset_mask: NDArray[np.bool_] | None = None,
35
- ) -> NDArray:
72
+ ) -> NDArray[np.float64]:
36
73
  """
37
- Compute diversity for discrete/categorical variables and, through standard
74
+ Compute :term:`diversity<Diversity>` for discrete/categorical variables and, through standard
38
75
  histogram binning, for continuous variables.
39
76
 
40
77
  We define diversity as a normalized form of the Shannon entropy.
@@ -79,13 +116,13 @@ def diversity_shannon(
79
116
 
80
117
 
81
118
  def diversity_simpson(
82
- data: NDArray,
119
+ data: NDArray[Any],
83
120
  names: list[str],
84
121
  is_categorical: list[bool],
85
122
  subset_mask: NDArray[np.bool_] | None = None,
86
- ) -> NDArray:
123
+ ) -> NDArray[np.float64]:
87
124
  """
88
- Compute diversity for discrete/categorical variables and, through standard
125
+ Compute :term:`diversity<Diversity>` for discrete/categorical variables and, through standard
89
126
  histogram binning, for continuous variables.
90
127
 
91
128
  We define diversity as the inverse Simpson diversity index linearly rescaled to the unit interval.
@@ -139,16 +176,13 @@ def diversity_simpson(
139
176
  return ev_index
140
177
 
141
178
 
142
- DIVERSITY_FN_MAP = {"simpson": diversity_simpson, "shannon": diversity_shannon}
143
-
144
-
145
- @set_metadata("dataeval.metrics")
179
+ @set_metadata()
146
180
  def diversity(
147
181
  class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], method: Literal["shannon", "simpson"] = "simpson"
148
182
  ) -> DiversityOutput:
149
183
  """
150
- Compute diversity and classwise diversity for discrete/categorical variables and, through standard
151
- histogram binning, for continuous variables.
184
+ Compute :term:`diversity<Diversity>` and classwise diversity for discrete/categorical variables and,
185
+ through standard histogram binning, for continuous variables.
152
186
 
153
187
  We define diversity as a normalized form of the inverse Simpson diversity index.
154
188
 
@@ -202,12 +236,12 @@ def diversity(
202
236
  --------
203
237
  numpy.histogram
204
238
  """
205
- diversity_fn = get_method(DIVERSITY_FN_MAP, method)
239
+ diversity_fn = get_method({"simpson": diversity_simpson, "shannon": diversity_shannon}, method)
206
240
  data, names, is_categorical = preprocess_metadata(class_labels, metadata)
207
241
  diversity_index = diversity_fn(data, names, is_categorical, None).astype(np.float64)
208
242
 
209
243
  class_idx = names.index("class_label")
210
- class_lbl = data[:, class_idx]
244
+ class_lbl = np.array(data[:, class_idx], dtype=int)
211
245
 
212
246
  u_classes = np.unique(class_lbl)
213
247
  num_factors = len(names)
@@ -218,4 +252,4 @@ def diversity(
218
252
  diversity[idx, :] = diversity_fn(data, names, is_categorical, subset_mask)
219
253
  div_no_class = np.concatenate((diversity[:, :class_idx], diversity[:, (class_idx + 1) :]), axis=1)
220
254
 
221
- return DiversityOutput(diversity_index, div_no_class)
255
+ return DiversityOutput(diversity_index, div_no_class, class_lbl, list(metadata.keys()), method)
@@ -0,0 +1,275 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ from typing import Any, Mapping
6
+
7
+ import numpy as np
8
+ from numpy.typing import ArrayLike, NDArray
9
+ from scipy.stats import entropy as sp_entropy
10
+
11
+ from dataeval.interop import to_numpy
12
+
13
+
14
+ def get_counts(
15
+ data: NDArray[np.int_], names: list[str], is_categorical: list[bool], subset_mask: NDArray[np.bool_] | None = None
16
+ ) -> tuple[dict[str, NDArray[np.int_]], dict[str, NDArray[np.int_]]]:
17
+ """
18
+ Initialize dictionary of histogram counts --- treat categorical values
19
+ as histogram bins.
20
+
21
+ Parameters
22
+ ----------
23
+ subset_mask: NDArray[np.bool_] | None
24
+ Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
25
+
26
+ Returns
27
+ -------
28
+ counts: Dict
29
+ histogram counts per metadata factor in `factors`. Each
30
+ factor will have a different number of bins. Counts get reused
31
+ across metrics, so hist_counts are cached but only if computed
32
+ globally, i.e. without masked samples.
33
+ """
34
+
35
+ hist_counts, hist_bins = {}, {}
36
+ # np.where needed to satisfy linter
37
+ mask = np.where(subset_mask if subset_mask is not None else np.ones(data.shape[0], dtype=bool))
38
+
39
+ for cdx, fn in enumerate(names):
40
+ # linter doesn't like double indexing
41
+ col_data = data[mask, cdx].squeeze()
42
+ if is_categorical[cdx]:
43
+ # if discrete, use unique values as bins
44
+ bins, cnts = np.unique(col_data, return_counts=True)
45
+ else:
46
+ bins = hist_bins.get(fn, "auto")
47
+ cnts, bins = np.histogram(col_data, bins=bins, density=True)
48
+
49
+ hist_counts[fn] = cnts
50
+ hist_bins[fn] = bins
51
+
52
+ return hist_counts, hist_bins
53
+
54
+
55
+ def entropy(
56
+ data: NDArray[Any],
57
+ names: list[str],
58
+ is_categorical: list[bool],
59
+ normalized: bool = False,
60
+ subset_mask: NDArray[np.bool_] | None = None,
61
+ ) -> NDArray[np.float64]:
62
+ """
63
+ Meant for use with :term:`bias<Bias>` metrics, :term:`balance<Balance>`, :term:`diversity<Diversity>`,
64
+ ClasswiseBalance, and Classwise Diversity.
65
+
66
+ Compute entropy for discrete/categorical variables and for continuous variables through standard
67
+ histogram binning.
68
+
69
+ Parameters
70
+ ----------
71
+ normalized: bool
72
+ Flag that determines whether or not to normalize entropy by log(num_bins)
73
+ subset_mask: NDArray[np.bool_] | None
74
+ Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
75
+
76
+ Note
77
+ ----
78
+ For continuous variables, histogram bins are chosen automatically. See
79
+ numpy.histogram for details.
80
+
81
+ Returns
82
+ -------
83
+ ent: NDArray[np.float64]
84
+ Entropy estimate per column of X
85
+
86
+ See Also
87
+ --------
88
+ numpy.histogram
89
+ scipy.stats.entropy
90
+ """
91
+
92
+ num_factors = len(names)
93
+ hist_counts, _ = get_counts(data, names, is_categorical, subset_mask)
94
+
95
+ ev_index = np.empty(num_factors)
96
+ for col, cnts in enumerate(hist_counts.values()):
97
+ # entropy in nats, normalizes counts
98
+ ev_index[col] = sp_entropy(cnts)
99
+ if normalized:
100
+ if len(cnts) == 1:
101
+ # log(0)
102
+ ev_index[col] = 0
103
+ else:
104
+ ev_index[col] /= np.log(len(cnts))
105
+ return ev_index
106
+
107
+
108
+ def get_num_bins(
109
+ data: NDArray[Any], names: list[str], is_categorical: list[bool], subset_mask: NDArray[np.bool_] | None = None
110
+ ) -> NDArray[np.float64]:
111
+ """
112
+ Number of bins or unique values for each metadata factor, used to
113
+ normalize entropy/:term:`diversity<Diversity>`.
114
+
115
+ Parameters
116
+ ----------
117
+ subset_mask: NDArray[np.bool_] | None
118
+ Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
119
+
120
+ Returns
121
+ -------
122
+ NDArray[np.float64]
123
+ """
124
+ # likely cached
125
+ hist_counts, _ = get_counts(data, names, is_categorical, subset_mask)
126
+ num_bins = np.empty(len(hist_counts))
127
+ for idx, cnts in enumerate(hist_counts.values()):
128
+ num_bins[idx] = len(cnts)
129
+
130
+ return num_bins
131
+
132
+
133
+ def infer_categorical(arr: NDArray[Any], threshold: float = 0.2) -> NDArray[Any]:
134
+ """
135
+ Compute fraction of feature values that are unique --- intended to be used
136
+ for inferring whether variables are categorical.
137
+ """
138
+ if arr.ndim == 1:
139
+ arr = np.expand_dims(arr, axis=1)
140
+ num_samples = arr.shape[0]
141
+ pct_unique = np.empty(arr.shape[1])
142
+ for col in range(arr.shape[1]): # type: ignore
143
+ uvals = np.unique(arr[:, col], axis=0)
144
+ pct_unique[col] = len(uvals) / num_samples
145
+ return pct_unique < threshold
146
+
147
+
148
+ def preprocess_metadata(
149
+ class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], cat_thresh: float = 0.2
150
+ ) -> tuple[NDArray[Any], list[str], list[bool]]:
151
+ # convert class_labels and dict of lists to matrix of metadata values
152
+ preprocessed_metadata = {"class_label": np.asarray(class_labels, dtype=int)}
153
+
154
+ # map columns of dict that are not numeric (e.g. string) to numeric values
155
+ # that mutual information and diversity functions can accommodate. Each
156
+ # unique string receives a unique integer value.
157
+ for k, v in metadata.items():
158
+ # if not numeric
159
+ v = to_numpy(v)
160
+ if not np.issubdtype(v.dtype, np.number):
161
+ _, mapped_vals = np.unique(v, return_inverse=True)
162
+ preprocessed_metadata[k] = mapped_vals
163
+ else:
164
+ preprocessed_metadata[k] = v
165
+
166
+ data = np.stack(list(preprocessed_metadata.values()), axis=-1)
167
+ names = list(preprocessed_metadata.keys())
168
+ is_categorical = [infer_categorical(preprocessed_metadata[var], cat_thresh)[0] for var in names]
169
+
170
+ return data, names, is_categorical
171
+
172
+
173
+ def heatmap(
174
+ data: NDArray[Any],
175
+ row_labels: NDArray[Any],
176
+ col_labels: NDArray[Any],
177
+ xlabel: str = "",
178
+ ylabel: str = "",
179
+ cbarlabel: str = "",
180
+ ) -> None:
181
+ """
182
+ Plots a formatted heatmap
183
+
184
+ Parameters
185
+ ----------
186
+ data: NDArray
187
+ Array containing numerical values for factors to plot
188
+ row_labels: NDArray
189
+ Array containing the labels for rows in the histogram
190
+ col_labels: NDArray
191
+ Array containing the labels for columns in the histogram
192
+ xlabel: str, default ""
193
+ X-axis label
194
+ ylabel: str, default ""
195
+ Y-axis label
196
+ cbarlabel: str, default ""
197
+ Label for the colorbar
198
+
199
+ """
200
+ import matplotlib
201
+ import matplotlib.pyplot as plt
202
+
203
+ fig, ax = plt.subplots(figsize=(10, 10))
204
+
205
+ # Plot the heatmap
206
+ im = ax.imshow(data, vmin=0, vmax=1.0)
207
+
208
+ # Create colorbar
209
+ cbar = fig.colorbar(im, shrink=0.5)
210
+ cbar.set_ticks([0.0, 0.25, 0.5, 0.75, 1.0])
211
+ cbar.set_ticklabels(["0.0", "0.25", "0.5", "0.75", "1.0"])
212
+ cbar.set_label(cbarlabel, loc="center")
213
+
214
+ # Show all ticks and label them with the respective list entries.
215
+ ax.set_xticks(np.arange(data.shape[1]), labels=col_labels)
216
+ ax.set_yticks(np.arange(data.shape[0]), labels=row_labels)
217
+
218
+ ax.tick_params(top=False, bottom=True, labeltop=False, labelbottom=True)
219
+ # Rotate the tick labels and set their alignment.
220
+ plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
221
+
222
+ # Turn spines off and create white grid.
223
+ ax.spines[:].set_visible(False)
224
+
225
+ ax.set_xticks(np.arange(data.shape[1] + 1) - 0.5, minor=True)
226
+ ax.set_yticks(np.arange(data.shape[0] + 1) - 0.5, minor=True)
227
+ ax.grid(which="minor", color="w", linestyle="-", linewidth=3)
228
+ ax.tick_params(which="minor", bottom=False, left=False)
229
+
230
+ if xlabel:
231
+ ax.set_xlabel(xlabel)
232
+ if ylabel:
233
+ ax.set_ylabel(ylabel)
234
+
235
+ valfmt = matplotlib.ticker.FuncFormatter(format_text) # type: ignore
236
+
237
+ # Normalize the threshold to the images color range.
238
+ threshold = im.norm(1.0) / 2.0
239
+
240
+ # Set default alignment to center, but allow it to be
241
+ # overwritten by textkw.
242
+ kw = {"horizontalalignment": "center", "verticalalignment": "center"}
243
+
244
+ # Loop over the data and create a `Text` for each "pixel".
245
+ # Change the text's color depending on the data.
246
+ textcolors = ("white", "black")
247
+ texts = []
248
+ for i in range(data.shape[0]):
249
+ for j in range(data.shape[1]):
250
+ kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])
251
+ text = im.axes.text(j, i, valfmt(data[i, j], None), **kw) # type: ignore
252
+ texts.append(text)
253
+
254
+ fig.tight_layout()
255
+ plt.show()
256
+
257
+
258
+ # Function to define how the text is displayed in the heatmap
259
+ def format_text(*args: str) -> str:
260
+ """
261
+ Helper function to format text for heatmap()
262
+
263
+ Parameters
264
+ ----------
265
+ *args: Tuple (str, str)
266
+ Text to be formatted. Second element is ignored, but is a
267
+ mandatory pass-through argument as per matplotlib.ticket.FuncFormatter
268
+
269
+ Returns
270
+ -------
271
+ str
272
+ Formatted text
273
+ """
274
+ x = args[0]
275
+ return f"{x:.2f}".replace("0.00", "0").replace("0.", ".").replace("nan", "")
@@ -1,15 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
+ __all__ = ["ParityOutput", "parity", "label_parity"]
4
+
3
5
  import warnings
4
6
  from dataclasses import dataclass
5
- from typing import Generic, Mapping, TypeVar
7
+ from typing import Any, Generic, Mapping, TypeVar
6
8
 
7
9
  import numpy as np
8
10
  from numpy.typing import ArrayLike, NDArray
9
11
  from scipy.stats import chi2_contingency, chisquare
10
12
 
11
- from dataeval._internal.interop import to_numpy
12
- from dataeval._internal.output import OutputMetadata, set_metadata
13
+ from dataeval.interop import to_numpy
14
+ from dataeval.output import OutputMetadata, set_metadata
13
15
 
14
16
  TData = TypeVar("TData", np.float64, NDArray[np.float64])
15
17
 
@@ -17,7 +19,7 @@ TData = TypeVar("TData", np.float64, NDArray[np.float64])
17
19
  @dataclass(frozen=True)
18
20
  class ParityOutput(Generic[TData], OutputMetadata):
19
21
  """
20
- Output class for :func:`parity` and :func:`label_parity` bias metrics
22
+ Output class for :func:`parity` and :func:`label_parity` :term:`bias<Bias>` metrics
21
23
 
22
24
  Attributes
23
25
  ----------
@@ -31,7 +33,7 @@ class ParityOutput(Generic[TData], OutputMetadata):
31
33
  p_value: TData
32
34
 
33
35
 
34
- def digitize_factor_bins(continuous_values: NDArray, bins: int, factor_name: str) -> NDArray:
36
+ def digitize_factor_bins(continuous_values: NDArray[Any], bins: int, factor_name: str) -> NDArray[np.intp]:
35
37
  """
36
38
  Digitizes a list of values into a given number of bins.
37
39
 
@@ -64,8 +66,8 @@ def digitize_factor_bins(continuous_values: NDArray, bins: int, factor_name: str
64
66
 
65
67
 
66
68
  def format_discretize_factors(
67
- data_factors: Mapping[str, NDArray], continuous_factor_bincounts: Mapping[str, int]
68
- ) -> dict[str, NDArray]:
69
+ data_factors: Mapping[str, NDArray[Any]], continuous_factor_bincounts: Mapping[str, int]
70
+ ) -> dict[str, NDArray[Any]]:
69
71
  """
70
72
  Sets up the internal list of metadata factors.
71
73
 
@@ -115,7 +117,7 @@ def format_discretize_factors(
115
117
  return metadata_factors
116
118
 
117
119
 
118
- def normalize_expected_dist(expected_dist: NDArray, observed_dist: NDArray) -> NDArray:
120
+ def normalize_expected_dist(expected_dist: NDArray[Any], observed_dist: NDArray[Any]) -> NDArray[Any]:
119
121
  """
120
122
  Normalize the expected label distribution to match the total number of labels in the observed distribution.
121
123
 
@@ -162,7 +164,7 @@ def normalize_expected_dist(expected_dist: NDArray, observed_dist: NDArray) -> N
162
164
  return expected_dist
163
165
 
164
166
 
165
- def validate_dist(label_dist: NDArray, label_name: str):
167
+ def validate_dist(label_dist: NDArray[Any], label_name: str) -> None:
166
168
  """
167
169
  Verifies that the given label distribution has labels and checks if
168
170
  any labels have frequencies less than 5.
@@ -191,14 +193,15 @@ def validate_dist(label_dist: NDArray, label_name: str):
191
193
  )
192
194
 
193
195
 
194
- @set_metadata("dataeval.metrics")
196
+ @set_metadata()
195
197
  def label_parity(
196
198
  expected_labels: ArrayLike,
197
199
  observed_labels: ArrayLike,
198
200
  num_classes: int | None = None,
199
201
  ) -> ParityOutput[np.float64]:
200
202
  """
201
- Calculate the chi-square statistic to assess the parity between expected and observed label distributions.
203
+ Calculate the chi-square statistic to assess the :term:`parity<Parity>` between expected and
204
+ observed label distributions.
202
205
 
203
206
  This function computes the frequency distribution of classes in both expected and observed labels, normalizes
204
207
  the expected distribution to match the total number of observed labels, and then calculates the chi-square
@@ -217,7 +220,7 @@ def label_parity(
217
220
  Returns
218
221
  -------
219
222
  ParityOutput[np.float64]
220
- chi-squared score and p-value of the test
223
+ chi-squared score and :term`P-Value` of the test
221
224
 
222
225
  Raises
223
226
  ------
@@ -231,8 +234,8 @@ def label_parity(
231
234
  - Providing ``num_classes`` can be helpful if there are classes with zero instances in one of the distributions.
232
235
  - The function first validates the observed distribution and normalizes the expected distribution so that it
233
236
  has the same total number of labels as the observed distribution.
234
- - It then performs a chi-square test to determine if there is a statistically significant difference between
235
- the observed and expected label distributions.
237
+ - It then performs a :term:`Chi-Square Test of Independence` to determine if there is a statistically significant
238
+ difference between the observed and expected label distributions.
236
239
  - This function acts as an interface to the scipy.stats.chisquare method, which is documented at
237
240
  https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.chisquare.html
238
241
 
@@ -278,14 +281,15 @@ def label_parity(
278
281
  return ParityOutput(cs, p)
279
282
 
280
283
 
281
- @set_metadata("dataeval.metrics")
284
+ @set_metadata()
282
285
  def parity(
283
286
  class_labels: ArrayLike,
284
287
  data_factors: Mapping[str, ArrayLike],
285
288
  continuous_factor_bincounts: Mapping[str, int] | None = None,
286
289
  ) -> ParityOutput[NDArray[np.float64]]:
287
290
  """
288
- Calculate chi-square statistics to assess the relationship between multiple factors and class labels.
291
+ Calculate chi-square statistics to assess the relationship between multiple factors
292
+ and class labels.
289
293
 
290
294
  This function computes the chi-square statistic for each metadata factor to determine if there is
291
295
  a significant relationship between the factor values and class labels. The function handles both categorical
@@ -308,7 +312,7 @@ def parity(
308
312
  -------
309
313
  ParityOutput[NDArray[np.float64]]
310
314
  Arrays of length (num_factors) whose (i)th element corresponds to the
311
- chi-square score and p-value for the relationship between factor i and
315
+ chi-square score and :term:`p-value<P-Value>` for the relationship between factor i and
312
316
  the class labels in the dataset.
313
317
 
314
318
  Raises
@@ -2,8 +2,8 @@
2
2
  Estimators calculate performance bounds and the statistical distance between datasets.
3
3
  """
4
4
 
5
- from dataeval._internal.metrics.ber import BEROutput, ber
6
- from dataeval._internal.metrics.divergence import DivergenceOutput, divergence
7
- from dataeval._internal.metrics.uap import UAPOutput, uap
5
+ from dataeval.metrics.estimators.ber import BEROutput, ber
6
+ from dataeval.metrics.estimators.divergence import DivergenceOutput, divergence
7
+ from dataeval.metrics.estimators.uap import UAPOutput, uap
8
8
 
9
9
  __all__ = ["ber", "divergence", "uap", "BEROutput", "DivergenceOutput", "UAPOutput"]