dataeval 0.72.1__py3-none-any.whl → 0.73.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. dataeval/__init__.py +4 -4
  2. dataeval/detectors/__init__.py +4 -3
  3. dataeval/detectors/drift/__init__.py +9 -10
  4. dataeval/{_internal/detectors → detectors}/drift/base.py +39 -91
  5. dataeval/{_internal/detectors → detectors}/drift/cvm.py +4 -3
  6. dataeval/{_internal/detectors → detectors}/drift/ks.py +4 -3
  7. dataeval/{_internal/detectors → detectors}/drift/mmd.py +23 -25
  8. dataeval/{_internal/detectors → detectors}/drift/torch.py +13 -11
  9. dataeval/{_internal/detectors → detectors}/drift/uncertainty.py +7 -5
  10. dataeval/detectors/drift/updates.py +61 -0
  11. dataeval/detectors/linters/__init__.py +3 -3
  12. dataeval/{_internal/detectors → detectors/linters}/clusterer.py +41 -39
  13. dataeval/{_internal/detectors → detectors/linters}/duplicates.py +19 -9
  14. dataeval/{_internal/detectors → detectors/linters}/merged_stats.py +3 -1
  15. dataeval/{_internal/detectors → detectors/linters}/outliers.py +14 -21
  16. dataeval/detectors/ood/__init__.py +6 -6
  17. dataeval/{_internal/detectors → detectors}/ood/ae.py +20 -12
  18. dataeval/detectors/ood/aegmm.py +66 -0
  19. dataeval/{_internal/detectors → detectors}/ood/base.py +33 -21
  20. dataeval/{_internal/detectors → detectors}/ood/llr.py +43 -33
  21. dataeval/detectors/ood/metadata_ks_compare.py +99 -0
  22. dataeval/detectors/ood/metadata_least_likely.py +119 -0
  23. dataeval/detectors/ood/metadata_ood_mi.py +92 -0
  24. dataeval/{_internal/detectors → detectors}/ood/vae.py +23 -17
  25. dataeval/detectors/ood/vaegmm.py +75 -0
  26. dataeval/interop.py +56 -0
  27. dataeval/metrics/__init__.py +1 -1
  28. dataeval/metrics/bias/__init__.py +4 -4
  29. dataeval/{_internal/metrics → metrics/bias}/balance.py +75 -13
  30. dataeval/{_internal/metrics → metrics/bias}/coverage.py +41 -7
  31. dataeval/{_internal/metrics → metrics/bias}/diversity.py +75 -18
  32. dataeval/metrics/bias/metadata.py +358 -0
  33. dataeval/{_internal/metrics → metrics/bias}/parity.py +54 -44
  34. dataeval/metrics/estimators/__init__.py +3 -3
  35. dataeval/{_internal/metrics → metrics/estimators}/ber.py +25 -22
  36. dataeval/{_internal/metrics → metrics/estimators}/divergence.py +11 -12
  37. dataeval/{_internal/metrics → metrics/estimators}/uap.py +5 -3
  38. dataeval/metrics/stats/__init__.py +7 -7
  39. dataeval/{_internal/metrics → metrics}/stats/base.py +59 -35
  40. dataeval/{_internal/metrics → metrics}/stats/boxratiostats.py +18 -14
  41. dataeval/{_internal/metrics → metrics}/stats/datasetstats.py +18 -16
  42. dataeval/{_internal/metrics → metrics}/stats/dimensionstats.py +9 -7
  43. dataeval/metrics/stats/hashstats.py +156 -0
  44. dataeval/{_internal/metrics → metrics}/stats/labelstats.py +5 -3
  45. dataeval/{_internal/metrics → metrics}/stats/pixelstats.py +9 -8
  46. dataeval/{_internal/metrics → metrics}/stats/visualstats.py +10 -9
  47. dataeval/{_internal/output.py → output.py} +26 -6
  48. dataeval/utils/__init__.py +8 -3
  49. dataeval/utils/image.py +71 -0
  50. dataeval/utils/lazy.py +26 -0
  51. dataeval/utils/metadata.py +258 -0
  52. dataeval/utils/shared.py +151 -0
  53. dataeval/{_internal → utils}/split_dataset.py +98 -33
  54. dataeval/utils/tensorflow/__init__.py +7 -6
  55. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/gmm.py +8 -2
  56. dataeval/{_internal/models/tensorflow/losses.py → utils/tensorflow/_internal/loss.py} +28 -18
  57. dataeval/{_internal/models/tensorflow/pixelcnn.py → utils/tensorflow/_internal/models.py} +387 -97
  58. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/trainer.py +15 -6
  59. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/utils.py +84 -85
  60. dataeval/utils/tensorflow/loss/__init__.py +6 -2
  61. dataeval/utils/torch/__init__.py +7 -3
  62. dataeval/{_internal/models/pytorch → utils/torch}/blocks.py +19 -14
  63. dataeval/{_internal → utils/torch}/datasets.py +48 -42
  64. dataeval/utils/torch/models.py +138 -0
  65. dataeval/{_internal/models/pytorch/autoencoder.py → utils/torch/trainer.py} +7 -136
  66. dataeval/{_internal → utils/torch}/utils.py +3 -1
  67. dataeval/workflows/__init__.py +1 -1
  68. dataeval/{_internal/workflows → workflows}/sufficiency.py +39 -34
  69. {dataeval-0.72.1.dist-info → dataeval-0.73.0.dist-info}/METADATA +4 -3
  70. dataeval-0.73.0.dist-info/RECORD +73 -0
  71. dataeval/_internal/detectors/__init__.py +0 -0
  72. dataeval/_internal/detectors/drift/__init__.py +0 -0
  73. dataeval/_internal/detectors/ood/__init__.py +0 -0
  74. dataeval/_internal/detectors/ood/aegmm.py +0 -78
  75. dataeval/_internal/detectors/ood/vaegmm.py +0 -89
  76. dataeval/_internal/interop.py +0 -49
  77. dataeval/_internal/metrics/__init__.py +0 -0
  78. dataeval/_internal/metrics/stats/hashstats.py +0 -75
  79. dataeval/_internal/metrics/utils.py +0 -447
  80. dataeval/_internal/models/__init__.py +0 -0
  81. dataeval/_internal/models/pytorch/__init__.py +0 -0
  82. dataeval/_internal/models/pytorch/utils.py +0 -67
  83. dataeval/_internal/models/tensorflow/__init__.py +0 -0
  84. dataeval/_internal/models/tensorflow/autoencoder.py +0 -320
  85. dataeval/_internal/workflows/__init__.py +0 -0
  86. dataeval/detectors/drift/kernels/__init__.py +0 -10
  87. dataeval/detectors/drift/updates/__init__.py +0 -8
  88. dataeval/utils/tensorflow/models/__init__.py +0 -9
  89. dataeval/utils/tensorflow/recon/__init__.py +0 -3
  90. dataeval/utils/torch/datasets/__init__.py +0 -12
  91. dataeval/utils/torch/models/__init__.py +0 -11
  92. dataeval/utils/torch/trainer/__init__.py +0 -7
  93. dataeval-0.72.1.dist-info/RECORD +0 -81
  94. {dataeval-0.72.1.dist-info → dataeval-0.73.0.dist-info}/LICENSE.txt +0 -0
  95. {dataeval-0.72.1.dist-info → dataeval-0.73.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,358 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ import contextlib
6
+ from typing import Any, Mapping
7
+
8
+ import numpy as np
9
+ from numpy.typing import ArrayLike, NDArray
10
+ from scipy.stats import entropy as sp_entropy
11
+
12
+ from dataeval.interop import to_numpy
13
+
14
+ with contextlib.suppress(ImportError):
15
+ from matplotlib.figure import Figure
16
+
17
+ CLASS_LABEL = "class_label"
18
+
19
+
20
+ def get_counts(
21
+ data: NDArray[np.int_], names: list[str], is_categorical: list[bool], subset_mask: NDArray[np.bool_] | None = None
22
+ ) -> tuple[dict[str, NDArray[np.int_]], dict[str, NDArray[np.int_]]]:
23
+ """
24
+ Initialize dictionary of histogram counts --- treat categorical values
25
+ as histogram bins.
26
+
27
+ Parameters
28
+ ----------
29
+ subset_mask: NDArray[np.bool_] | None
30
+ Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
31
+
32
+ Returns
33
+ -------
34
+ counts: Dict
35
+ histogram counts per metadata factor in `factors`. Each
36
+ factor will have a different number of bins. Counts get reused
37
+ across metrics, so hist_counts are cached but only if computed
38
+ globally, i.e. without masked samples.
39
+ """
40
+
41
+ hist_counts, hist_bins = {}, {}
42
+ # np.where needed to satisfy linter
43
+ mask = np.where(subset_mask if subset_mask is not None else np.ones(data.shape[0], dtype=bool))
44
+
45
+ for cdx, fn in enumerate(names):
46
+ # linter doesn't like double indexing
47
+ col_data = data[mask, cdx].squeeze()
48
+ if is_categorical[cdx]:
49
+ # if discrete, use unique values as bins
50
+ bins, cnts = np.unique(col_data, return_counts=True)
51
+ else:
52
+ bins = hist_bins.get(fn, "auto")
53
+ cnts, bins = np.histogram(col_data, bins=bins, density=True)
54
+
55
+ hist_counts[fn] = cnts
56
+ hist_bins[fn] = bins
57
+
58
+ return hist_counts, hist_bins
59
+
60
+
61
+ def entropy(
62
+ data: NDArray[Any],
63
+ names: list[str],
64
+ is_categorical: list[bool],
65
+ normalized: bool = False,
66
+ subset_mask: NDArray[np.bool_] | None = None,
67
+ ) -> NDArray[np.float64]:
68
+ """
69
+ Meant for use with :term:`bias<Bias>` metrics, :term:`balance<Balance>`, :term:`diversity<Diversity>`,
70
+ ClasswiseBalance, and Classwise Diversity.
71
+
72
+ Compute entropy for discrete/categorical variables and for continuous variables through standard
73
+ histogram binning.
74
+
75
+ Parameters
76
+ ----------
77
+ normalized: bool
78
+ Flag that determines whether or not to normalize entropy by log(num_bins)
79
+ subset_mask: NDArray[np.bool_] | None
80
+ Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
81
+
82
+ Note
83
+ ----
84
+ For continuous variables, histogram bins are chosen automatically. See
85
+ numpy.histogram for details.
86
+
87
+ Returns
88
+ -------
89
+ ent: NDArray[np.float64]
90
+ Entropy estimate per column of X
91
+
92
+ See Also
93
+ --------
94
+ numpy.histogram
95
+ scipy.stats.entropy
96
+ """
97
+
98
+ num_factors = len(names)
99
+ hist_counts, _ = get_counts(data, names, is_categorical, subset_mask)
100
+
101
+ ev_index = np.empty(num_factors)
102
+ for col, cnts in enumerate(hist_counts.values()):
103
+ # entropy in nats, normalizes counts
104
+ ev_index[col] = sp_entropy(cnts)
105
+ if normalized:
106
+ if len(cnts) == 1:
107
+ # log(0)
108
+ ev_index[col] = 0
109
+ else:
110
+ ev_index[col] /= np.log(len(cnts))
111
+ return ev_index
112
+
113
+
114
+ def get_num_bins(
115
+ data: NDArray[Any], names: list[str], is_categorical: list[bool], subset_mask: NDArray[np.bool_] | None = None
116
+ ) -> NDArray[np.float64]:
117
+ """
118
+ Number of bins or unique values for each metadata factor, used to
119
+ normalize entropy/:term:`diversity<Diversity>`.
120
+
121
+ Parameters
122
+ ----------
123
+ subset_mask: NDArray[np.bool_] | None
124
+ Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
125
+
126
+ Returns
127
+ -------
128
+ NDArray[np.float64]
129
+ """
130
+ # likely cached
131
+ hist_counts, _ = get_counts(data, names, is_categorical, subset_mask)
132
+ num_bins = np.empty(len(hist_counts))
133
+ for idx, cnts in enumerate(hist_counts.values()):
134
+ num_bins[idx] = len(cnts)
135
+
136
+ return num_bins
137
+
138
+
139
+ def infer_categorical(arr: NDArray[Any], threshold: float = 0.2) -> NDArray[Any]:
140
+ """
141
+ Compute fraction of feature values that are unique --- intended to be used
142
+ for inferring whether variables are categorical.
143
+ """
144
+ if arr.ndim == 1:
145
+ arr = np.expand_dims(arr, axis=1)
146
+ num_samples = arr.shape[0]
147
+ pct_unique = np.empty(arr.shape[1])
148
+ for col in range(arr.shape[1]): # type: ignore
149
+ uvals = np.unique(arr[:, col], axis=0)
150
+ pct_unique[col] = len(uvals) / num_samples
151
+ return pct_unique < threshold
152
+
153
+
154
+ def preprocess_metadata(
155
+ class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], cat_thresh: float = 0.2
156
+ ) -> tuple[NDArray[Any], list[str], list[bool], NDArray[np.str_]]:
157
+ # if class_labels is not numeric
158
+ class_array = to_numpy(class_labels)
159
+ if not np.issubdtype(class_array.dtype, np.number):
160
+ unique_classes, numerical_labels = np.unique(class_array, return_inverse=True)
161
+ else:
162
+ numerical_labels = np.asarray(class_array, dtype=int)
163
+ unique_classes = np.unique(class_array)
164
+
165
+ # convert class_labels and dict of lists to matrix of metadata values
166
+ preprocessed_metadata = {CLASS_LABEL: numerical_labels}
167
+
168
+ # map columns of dict that are not numeric (e.g. string) to numeric values
169
+ # that mutual information and diversity functions can accommodate. Each
170
+ # unique string receives a unique integer value.
171
+ for k, v in metadata.items():
172
+ if k == CLASS_LABEL:
173
+ k = "label_class"
174
+ # if not numeric
175
+ v = to_numpy(v)
176
+ if not np.issubdtype(v.dtype, np.number):
177
+ _, mapped_vals = np.unique(v, return_inverse=True)
178
+ preprocessed_metadata[k] = mapped_vals
179
+ else:
180
+ preprocessed_metadata[k] = v
181
+
182
+ data = np.stack(list(preprocessed_metadata.values()), axis=-1)
183
+ names = list(preprocessed_metadata.keys())
184
+ is_categorical = [infer_categorical(preprocessed_metadata[var], cat_thresh)[0] for var in names]
185
+
186
+ return data, names, is_categorical, unique_classes
187
+
188
+
189
+ def heatmap(
190
+ data: NDArray[Any],
191
+ row_labels: list[str] | NDArray[Any],
192
+ col_labels: list[str] | NDArray[Any],
193
+ xlabel: str = "",
194
+ ylabel: str = "",
195
+ cbarlabel: str = "",
196
+ ) -> Figure:
197
+ """
198
+ Plots a formatted heatmap
199
+
200
+ Parameters
201
+ ----------
202
+ data : NDArray
203
+ Array containing numerical values for factors to plot
204
+ row_labels : ArrayLike
205
+ List/Array containing the labels for rows in the histogram
206
+ col_labels : ArrayLike
207
+ List/Array containing the labels for columns in the histogram
208
+ xlabel : str, default ""
209
+ X-axis label
210
+ ylabel : str, default ""
211
+ Y-axis label
212
+ cbarlabel : str, default ""
213
+ Label for the colorbar
214
+ """
215
+ import matplotlib
216
+ import matplotlib.pyplot as plt
217
+
218
+ fig, ax = plt.subplots(figsize=(10, 10))
219
+
220
+ # Plot the heatmap
221
+ im = ax.imshow(data, vmin=0, vmax=1.0)
222
+
223
+ # Create colorbar
224
+ cbar = fig.colorbar(im, shrink=0.5)
225
+ cbar.set_ticks([0.0, 0.25, 0.5, 0.75, 1.0])
226
+ cbar.set_ticklabels(["0.0", "0.25", "0.5", "0.75", "1.0"])
227
+ cbar.set_label(cbarlabel, loc="center")
228
+
229
+ # Show all ticks and label them with the respective list entries.
230
+ ax.set_xticks(np.arange(data.shape[1]), labels=col_labels)
231
+ ax.set_yticks(np.arange(data.shape[0]), labels=row_labels)
232
+
233
+ ax.tick_params(top=False, bottom=True, labeltop=False, labelbottom=True)
234
+ # Rotate the tick labels and set their alignment.
235
+ plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
236
+
237
+ # Turn spines off and create white grid.
238
+ ax.spines[:].set_visible(False)
239
+
240
+ ax.set_xticks(np.arange(data.shape[1] + 1) - 0.5, minor=True)
241
+ ax.set_yticks(np.arange(data.shape[0] + 1) - 0.5, minor=True)
242
+ ax.grid(which="minor", color="w", linestyle="-", linewidth=3)
243
+ ax.tick_params(which="minor", bottom=False, left=False)
244
+
245
+ if xlabel:
246
+ ax.set_xlabel(xlabel)
247
+ if ylabel:
248
+ ax.set_ylabel(ylabel)
249
+
250
+ valfmt = matplotlib.ticker.FuncFormatter(format_text) # type: ignore
251
+
252
+ # Normalize the threshold to the images color range.
253
+ threshold = im.norm(1.0) / 2.0
254
+
255
+ # Set default alignment to center, but allow it to be
256
+ # overwritten by textkw.
257
+ kw = {"horizontalalignment": "center", "verticalalignment": "center"}
258
+
259
+ # Loop over the data and create a `Text` for each "pixel".
260
+ # Change the text's color depending on the data.
261
+ textcolors = ("white", "black")
262
+ texts = []
263
+ for i in range(data.shape[0]):
264
+ for j in range(data.shape[1]):
265
+ kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])
266
+ text = im.axes.text(j, i, valfmt(data[i, j], None), **kw) # type: ignore
267
+ texts.append(text)
268
+
269
+ fig.tight_layout()
270
+ return fig
271
+
272
+
273
+ # Function to define how the text is displayed in the heatmap
274
+ def format_text(*args: str) -> str:
275
+ """
276
+ Helper function to format text for heatmap()
277
+
278
+ Parameters
279
+ ----------
280
+ *args: Tuple (str, str)
281
+ Text to be formatted. Second element is ignored, but is a
282
+ mandatory pass-through argument as per matplotlib.ticket.FuncFormatter
283
+
284
+ Returns
285
+ -------
286
+ str
287
+ Formatted text
288
+ """
289
+ x = args[0]
290
+ return f"{x:.2f}".replace("0.00", "0").replace("0.", ".").replace("nan", "")
291
+
292
+
293
+ def diversity_bar_plot(labels: NDArray[Any], bar_heights: NDArray[Any]) -> Figure:
294
+ """
295
+ Plots a formatted bar plot
296
+
297
+ Parameters
298
+ ----------
299
+ labels : NDArray
300
+ Array containing the labels for each bar
301
+ bar_heights : NDArray
302
+ Array containing the values for each bar
303
+ """
304
+ import matplotlib.pyplot as plt
305
+
306
+ fig, ax = plt.subplots(figsize=(10, 10))
307
+
308
+ ax.bar(labels, bar_heights)
309
+ ax.set_xlabel("Factors")
310
+
311
+ plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
312
+
313
+ fig.tight_layout()
314
+ return fig
315
+
316
+
317
+ def coverage_plot(images: NDArray[Any], num_images: int) -> Figure:
318
+ """
319
+ Creates a single plot of all of the provided images
320
+
321
+ Parameters
322
+ ----------
323
+ images : NDArray
324
+ Array containing only the desired images to plot
325
+ """
326
+ import matplotlib.pyplot as plt
327
+
328
+ num_images = min(num_images, len(images))
329
+
330
+ if images.ndim == 4:
331
+ images = np.moveaxis(images, 1, -1)
332
+ elif images.ndim == 3:
333
+ images = np.repeat(images[:, :, :, np.newaxis], 3, axis=-1)
334
+ else:
335
+ raise ValueError(
336
+ f"Expected a (N,C,H,W) or a (N, H, W) set of images, but got a {images.ndim}-dimensional set of images."
337
+ )
338
+
339
+ rows = np.ceil(num_images / 3).astype(int)
340
+ fig, axs = plt.subplots(rows, 3, figsize=(9, 3 * rows))
341
+
342
+ if rows == 1:
343
+ for j in range(3):
344
+ if j >= len(images):
345
+ continue
346
+ axs[j].imshow(images[j])
347
+ axs[j].axis("off")
348
+ else:
349
+ for i in range(rows):
350
+ for j in range(3):
351
+ i_j = i * 3 + j
352
+ if i_j >= len(images):
353
+ continue
354
+ axs[i, j].imshow(images[i_j])
355
+ axs[i, j].axis("off")
356
+
357
+ fig.tight_layout()
358
+ return fig
@@ -1,15 +1,18 @@
1
1
  from __future__ import annotations
2
2
 
3
+ __all__ = ["ParityOutput", "parity", "label_parity"]
4
+
3
5
  import warnings
4
6
  from dataclasses import dataclass
5
- from typing import Generic, Mapping, TypeVar
7
+ from typing import Any, Generic, Mapping, TypeVar
6
8
 
7
9
  import numpy as np
8
10
  from numpy.typing import ArrayLike, NDArray
9
11
  from scipy.stats import chi2_contingency, chisquare
10
12
 
11
- from dataeval._internal.interop import to_numpy
12
- from dataeval._internal.output import OutputMetadata, set_metadata
13
+ from dataeval.interop import to_numpy
14
+ from dataeval.metrics.bias.metadata import CLASS_LABEL, preprocess_metadata
15
+ from dataeval.output import OutputMetadata, set_metadata
13
16
 
14
17
  TData = TypeVar("TData", np.float64, NDArray[np.float64])
15
18
 
@@ -25,13 +28,16 @@ class ParityOutput(Generic[TData], OutputMetadata):
25
28
  chi-squared score(s) of the test
26
29
  p_value : np.float64 | NDArray[np.float64]
27
30
  p-value(s) of the test
31
+ metadata_names: list[str] | None
32
+ Names of each metadata factor
28
33
  """
29
34
 
30
35
  score: TData
31
36
  p_value: TData
37
+ metadata_names: list[str] | None
32
38
 
33
39
 
34
- def digitize_factor_bins(continuous_values: NDArray, bins: int, factor_name: str) -> NDArray:
40
+ def digitize_factor_bins(continuous_values: NDArray[Any], bins: int, factor_name: str) -> NDArray[np.intp]:
35
41
  """
36
42
  Digitizes a list of values into a given number of bins.
37
43
 
@@ -64,8 +70,8 @@ def digitize_factor_bins(continuous_values: NDArray, bins: int, factor_name: str
64
70
 
65
71
 
66
72
  def format_discretize_factors(
67
- data_factors: Mapping[str, NDArray], continuous_factor_bincounts: Mapping[str, int]
68
- ) -> dict[str, NDArray]:
73
+ data: NDArray[Any], names: list[str], is_categorical: list[bool], continuous_factor_bincounts: Mapping[str, int]
74
+ ) -> dict[str, NDArray[Any]]:
69
75
  """
70
76
  Sets up the internal list of metadata factors.
71
77
 
@@ -87,35 +93,37 @@ def format_discretize_factors(
87
93
  Each key is a metadata factor, whose value is the discrete per-image factor values.
88
94
  """
89
95
 
90
- invalid_keys = set(continuous_factor_bincounts.keys()) - set(data_factors.keys())
96
+ invalid_keys = set(continuous_factor_bincounts.keys()) - set(names)
91
97
  if invalid_keys:
92
98
  raise KeyError(
93
99
  f"The continuous factor(s) {invalid_keys} do not exist in data_factors. Delete these "
94
100
  "keys from `continuous_factor_names` or add corresponding entries to `data_factors`."
95
101
  )
96
102
 
103
+ warn = []
97
104
  metadata_factors = {}
98
-
99
- # make sure each factor has the same number of entries
100
- lengths = []
101
- for arr in data_factors.values():
102
- lengths.append(arr.shape)
103
-
104
- if lengths[1:] != lengths[:-1]:
105
- raise ValueError("The lengths of each entry in the dictionary are not equal." f" Found lengths {lengths}")
106
-
107
- metadata_factors = {
108
- name: val
109
- if name not in continuous_factor_bincounts
110
- else digitize_factor_bins(val, continuous_factor_bincounts[name], name)
111
- for name, val in data_factors.items()
112
- if name != "class"
113
- }
105
+ for i, name in enumerate(names):
106
+ if name == CLASS_LABEL:
107
+ continue
108
+ if name in continuous_factor_bincounts:
109
+ metadata_factors[name] = digitize_factor_bins(data[:, i], continuous_factor_bincounts[name], name)
110
+ elif not is_categorical[i]:
111
+ warn.append(name)
112
+ metadata_factors[name] = data[:, i]
113
+ else:
114
+ metadata_factors[name] = data[:, i]
115
+
116
+ if warn:
117
+ warnings.warn(
118
+ f"The following factors appear to be continuous but did not have the desired number of bins specified: \n\
119
+ {warn}",
120
+ UserWarning,
121
+ )
114
122
 
115
123
  return metadata_factors
116
124
 
117
125
 
118
- def normalize_expected_dist(expected_dist: NDArray, observed_dist: NDArray) -> NDArray:
126
+ def normalize_expected_dist(expected_dist: NDArray[Any], observed_dist: NDArray[Any]) -> NDArray[Any]:
119
127
  """
120
128
  Normalize the expected label distribution to match the total number of labels in the observed distribution.
121
129
 
@@ -162,7 +170,7 @@ def normalize_expected_dist(expected_dist: NDArray, observed_dist: NDArray) -> N
162
170
  return expected_dist
163
171
 
164
172
 
165
- def validate_dist(label_dist: NDArray, label_name: str):
173
+ def validate_dist(label_dist: NDArray[Any], label_name: str) -> None:
166
174
  """
167
175
  Verifies that the given label distribution has labels and checks if
168
176
  any labels have frequencies less than 5.
@@ -191,7 +199,7 @@ def validate_dist(label_dist: NDArray, label_name: str):
191
199
  )
192
200
 
193
201
 
194
- @set_metadata("dataeval.metrics")
202
+ @set_metadata()
195
203
  def label_parity(
196
204
  expected_labels: ArrayLike,
197
205
  observed_labels: ArrayLike,
@@ -245,7 +253,7 @@ def label_parity(
245
253
  >>> expected_labels = np_random_gen.choice([0, 1, 2, 3, 4], (100))
246
254
  >>> observed_labels = np_random_gen.choice([2, 3, 0, 4, 1], (100))
247
255
  >>> label_parity(expected_labels, observed_labels)
248
- ParityOutput(score=14.007374204742625, p_value=0.0072715574616218)
256
+ ParityOutput(score=14.007374204742625, p_value=0.0072715574616218, metadata_names=None)
249
257
  """
250
258
 
251
259
  # Calculate
@@ -276,13 +284,13 @@ def label_parity(
276
284
  )
277
285
 
278
286
  cs, p = chisquare(f_obs=observed_dist, f_exp=expected_dist)
279
- return ParityOutput(cs, p)
287
+ return ParityOutput(cs, p, None)
280
288
 
281
289
 
282
- @set_metadata("dataeval.metrics")
290
+ @set_metadata()
283
291
  def parity(
284
292
  class_labels: ArrayLike,
285
- data_factors: Mapping[str, ArrayLike],
293
+ metadata: Mapping[str, ArrayLike],
286
294
  continuous_factor_bincounts: Mapping[str, int] | None = None,
287
295
  ) -> ParityOutput[NDArray[np.float64]]:
288
296
  """
@@ -297,12 +305,12 @@ def parity(
297
305
  ----------
298
306
  class_labels: ArrayLike
299
307
  List of class labels for each image
300
- data_factors: Mapping[str, ArrayLike]
308
+ metadata: Mapping[str, ArrayLike]
301
309
  The dataset factors, which are per-image metadata attributes.
302
310
  Each key of dataset_factors is a factor, whose value is the per-image factor values.
303
311
  continuous_factor_bincounts : Mapping[str, int] | None, default None
304
312
  A dictionary specifying the number of bins for discretizing the continuous factors.
305
- The keys should correspond to the names of continuous factors in `data_factors`,
313
+ The keys should correspond to the names of continuous factors in `metadata`,
306
314
  and the values should be the number of bins to use for discretization.
307
315
  If not provided, no discretization is applied.
308
316
 
@@ -335,42 +343,44 @@ def parity(
335
343
  Randomly creating some "continuous" and categorical variables using ``np.random.default_rng``
336
344
 
337
345
  >>> labels = np_random_gen.choice([0, 1, 2], (100))
338
- >>> data_factors = {
346
+ >>> metadata = {
339
347
  ... "age": np_random_gen.choice([25, 30, 35, 45], (100)),
340
348
  ... "income": np_random_gen.choice([50000, 65000, 80000], (100)),
341
349
  ... "gender": np_random_gen.choice(["M", "F"], (100)),
342
350
  ... }
343
351
  >>> continuous_factor_bincounts = {"age": 4, "income": 3}
344
- >>> parity(labels, data_factors, continuous_factor_bincounts)
345
- ParityOutput(score=array([7.35731943, 5.46711299, 0.51506212]), p_value=array([0.28906231, 0.24263543, 0.77295762]))
346
- """
352
+ >>> parity(labels, metadata, continuous_factor_bincounts)
353
+ ParityOutput(score=array([7.35731943, 5.46711299, 0.51506212]), p_value=array([0.28906231, 0.24263543, 0.77295762]), metadata_names=['age', 'income', 'gender'])
354
+ """ # noqa: E501
347
355
  if len(np.shape(class_labels)) > 1:
348
356
  raise ValueError(
349
357
  f"Got class labels with {len(np.shape(class_labels))}-dimensional",
350
358
  f" shape {np.shape(class_labels)}, but expected a 1-dimensional array.",
351
359
  )
352
360
 
353
- data_factors_np = {k: to_numpy(v) for k, v in data_factors.items()}
361
+ data, names, is_categorical, _ = preprocess_metadata(class_labels, metadata)
354
362
  continuous_factor_bincounts = continuous_factor_bincounts if continuous_factor_bincounts else {}
355
363
 
356
- labels = to_numpy(class_labels)
357
- factors = format_discretize_factors(data_factors_np, continuous_factor_bincounts)
364
+ factors = format_discretize_factors(data, names, is_categorical, continuous_factor_bincounts)
365
+
366
+ # unique class labels
367
+ class_idx = names.index(CLASS_LABEL)
368
+ u_cls = np.unique(data[:, class_idx])
358
369
 
359
370
  chi_scores = np.zeros(len(factors))
360
371
  p_values = np.zeros(len(factors))
361
- n_cls = len(np.unique(labels))
362
372
  not_enough_data = {}
363
373
  for i, (current_factor_name, factor_values) in enumerate(factors.items()):
364
374
  unique_factor_values = np.unique(factor_values)
365
- contingency_matrix = np.zeros((len(unique_factor_values), n_cls))
375
+ contingency_matrix = np.zeros((len(unique_factor_values), u_cls.size))
366
376
  # Builds a contingency matrix where entry at index (r,c) represents
367
377
  # the frequency of current_factor_name achieving value unique_factor_values[r]
368
378
  # at a data point with class c.
369
379
 
370
380
  # TODO: Vectorize this nested for loop
371
381
  for fi, factor_value in enumerate(unique_factor_values):
372
- for label in range(n_cls):
373
- with_both = np.bitwise_and((labels == label), factor_values == factor_value)
382
+ for label in u_cls:
383
+ with_both = np.bitwise_and((data[:, class_idx] == label), factor_values == factor_value)
374
384
  contingency_matrix[fi, label] = np.sum(with_both)
375
385
  if 0 < contingency_matrix[fi, label] < 5:
376
386
  if current_factor_name not in not_enough_data:
@@ -412,4 +422,4 @@ def parity(
412
422
  UserWarning,
413
423
  )
414
424
 
415
- return ParityOutput(chi_scores, p_values)
425
+ return ParityOutput(chi_scores, p_values, list(metadata.keys()))
@@ -2,8 +2,8 @@
2
2
  Estimators calculate performance bounds and the statistical distance between datasets.
3
3
  """
4
4
 
5
- from dataeval._internal.metrics.ber import BEROutput, ber
6
- from dataeval._internal.metrics.divergence import DivergenceOutput, divergence
7
- from dataeval._internal.metrics.uap import UAPOutput, uap
5
+ from dataeval.metrics.estimators.ber import BEROutput, ber
6
+ from dataeval.metrics.estimators.divergence import DivergenceOutput, divergence
7
+ from dataeval.metrics.estimators.uap import UAPOutput, uap
8
8
 
9
9
  __all__ = ["ber", "divergence", "uap", "BEROutput", "DivergenceOutput", "UAPOutput"]