dataeval 0.72.1__py3-none-any.whl → 0.73.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +4 -4
- dataeval/detectors/__init__.py +4 -3
- dataeval/detectors/drift/__init__.py +9 -10
- dataeval/{_internal/detectors → detectors}/drift/base.py +39 -91
- dataeval/{_internal/detectors → detectors}/drift/cvm.py +4 -3
- dataeval/{_internal/detectors → detectors}/drift/ks.py +4 -3
- dataeval/{_internal/detectors → detectors}/drift/mmd.py +23 -25
- dataeval/{_internal/detectors → detectors}/drift/torch.py +13 -11
- dataeval/{_internal/detectors → detectors}/drift/uncertainty.py +7 -5
- dataeval/detectors/drift/updates.py +61 -0
- dataeval/detectors/linters/__init__.py +3 -3
- dataeval/{_internal/detectors → detectors/linters}/clusterer.py +41 -39
- dataeval/{_internal/detectors → detectors/linters}/duplicates.py +19 -9
- dataeval/{_internal/detectors → detectors/linters}/merged_stats.py +3 -1
- dataeval/{_internal/detectors → detectors/linters}/outliers.py +14 -21
- dataeval/detectors/ood/__init__.py +6 -6
- dataeval/{_internal/detectors → detectors}/ood/ae.py +20 -12
- dataeval/detectors/ood/aegmm.py +66 -0
- dataeval/{_internal/detectors → detectors}/ood/base.py +33 -21
- dataeval/{_internal/detectors → detectors}/ood/llr.py +43 -33
- dataeval/detectors/ood/metadata_ks_compare.py +99 -0
- dataeval/detectors/ood/metadata_least_likely.py +119 -0
- dataeval/detectors/ood/metadata_ood_mi.py +92 -0
- dataeval/{_internal/detectors → detectors}/ood/vae.py +23 -17
- dataeval/detectors/ood/vaegmm.py +75 -0
- dataeval/interop.py +56 -0
- dataeval/metrics/__init__.py +1 -1
- dataeval/metrics/bias/__init__.py +4 -4
- dataeval/{_internal/metrics → metrics/bias}/balance.py +75 -13
- dataeval/{_internal/metrics → metrics/bias}/coverage.py +41 -7
- dataeval/{_internal/metrics → metrics/bias}/diversity.py +75 -18
- dataeval/metrics/bias/metadata.py +358 -0
- dataeval/{_internal/metrics → metrics/bias}/parity.py +54 -44
- dataeval/metrics/estimators/__init__.py +3 -3
- dataeval/{_internal/metrics → metrics/estimators}/ber.py +25 -22
- dataeval/{_internal/metrics → metrics/estimators}/divergence.py +11 -12
- dataeval/{_internal/metrics → metrics/estimators}/uap.py +5 -3
- dataeval/metrics/stats/__init__.py +7 -7
- dataeval/{_internal/metrics → metrics}/stats/base.py +59 -35
- dataeval/{_internal/metrics → metrics}/stats/boxratiostats.py +18 -14
- dataeval/{_internal/metrics → metrics}/stats/datasetstats.py +18 -16
- dataeval/{_internal/metrics → metrics}/stats/dimensionstats.py +9 -7
- dataeval/metrics/stats/hashstats.py +156 -0
- dataeval/{_internal/metrics → metrics}/stats/labelstats.py +5 -3
- dataeval/{_internal/metrics → metrics}/stats/pixelstats.py +9 -8
- dataeval/{_internal/metrics → metrics}/stats/visualstats.py +10 -9
- dataeval/{_internal/output.py → output.py} +26 -6
- dataeval/utils/__init__.py +8 -3
- dataeval/utils/image.py +71 -0
- dataeval/utils/lazy.py +26 -0
- dataeval/utils/metadata.py +258 -0
- dataeval/utils/shared.py +151 -0
- dataeval/{_internal → utils}/split_dataset.py +98 -33
- dataeval/utils/tensorflow/__init__.py +7 -6
- dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/gmm.py +8 -2
- dataeval/{_internal/models/tensorflow/losses.py → utils/tensorflow/_internal/loss.py} +28 -18
- dataeval/{_internal/models/tensorflow/pixelcnn.py → utils/tensorflow/_internal/models.py} +387 -97
- dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/trainer.py +15 -6
- dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/utils.py +84 -85
- dataeval/utils/tensorflow/loss/__init__.py +6 -2
- dataeval/utils/torch/__init__.py +7 -3
- dataeval/{_internal/models/pytorch → utils/torch}/blocks.py +19 -14
- dataeval/{_internal → utils/torch}/datasets.py +48 -42
- dataeval/utils/torch/models.py +138 -0
- dataeval/{_internal/models/pytorch/autoencoder.py → utils/torch/trainer.py} +7 -136
- dataeval/{_internal → utils/torch}/utils.py +3 -1
- dataeval/workflows/__init__.py +1 -1
- dataeval/{_internal/workflows → workflows}/sufficiency.py +39 -34
- {dataeval-0.72.1.dist-info → dataeval-0.73.0.dist-info}/METADATA +4 -3
- dataeval-0.73.0.dist-info/RECORD +73 -0
- dataeval/_internal/detectors/__init__.py +0 -0
- dataeval/_internal/detectors/drift/__init__.py +0 -0
- dataeval/_internal/detectors/ood/__init__.py +0 -0
- dataeval/_internal/detectors/ood/aegmm.py +0 -78
- dataeval/_internal/detectors/ood/vaegmm.py +0 -89
- dataeval/_internal/interop.py +0 -49
- dataeval/_internal/metrics/__init__.py +0 -0
- dataeval/_internal/metrics/stats/hashstats.py +0 -75
- dataeval/_internal/metrics/utils.py +0 -447
- dataeval/_internal/models/__init__.py +0 -0
- dataeval/_internal/models/pytorch/__init__.py +0 -0
- dataeval/_internal/models/pytorch/utils.py +0 -67
- dataeval/_internal/models/tensorflow/__init__.py +0 -0
- dataeval/_internal/models/tensorflow/autoencoder.py +0 -320
- dataeval/_internal/workflows/__init__.py +0 -0
- dataeval/detectors/drift/kernels/__init__.py +0 -10
- dataeval/detectors/drift/updates/__init__.py +0 -8
- dataeval/utils/tensorflow/models/__init__.py +0 -9
- dataeval/utils/tensorflow/recon/__init__.py +0 -3
- dataeval/utils/torch/datasets/__init__.py +0 -12
- dataeval/utils/torch/models/__init__.py +0 -11
- dataeval/utils/torch/trainer/__init__.py +0 -7
- dataeval-0.72.1.dist-info/RECORD +0 -81
- {dataeval-0.72.1.dist-info → dataeval-0.73.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.72.1.dist-info → dataeval-0.73.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,358 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
import contextlib
|
6
|
+
from typing import Any, Mapping
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
from numpy.typing import ArrayLike, NDArray
|
10
|
+
from scipy.stats import entropy as sp_entropy
|
11
|
+
|
12
|
+
from dataeval.interop import to_numpy
|
13
|
+
|
14
|
+
with contextlib.suppress(ImportError):
|
15
|
+
from matplotlib.figure import Figure
|
16
|
+
|
17
|
+
CLASS_LABEL = "class_label"
|
18
|
+
|
19
|
+
|
20
|
+
def get_counts(
|
21
|
+
data: NDArray[np.int_], names: list[str], is_categorical: list[bool], subset_mask: NDArray[np.bool_] | None = None
|
22
|
+
) -> tuple[dict[str, NDArray[np.int_]], dict[str, NDArray[np.int_]]]:
|
23
|
+
"""
|
24
|
+
Initialize dictionary of histogram counts --- treat categorical values
|
25
|
+
as histogram bins.
|
26
|
+
|
27
|
+
Parameters
|
28
|
+
----------
|
29
|
+
subset_mask: NDArray[np.bool_] | None
|
30
|
+
Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
|
31
|
+
|
32
|
+
Returns
|
33
|
+
-------
|
34
|
+
counts: Dict
|
35
|
+
histogram counts per metadata factor in `factors`. Each
|
36
|
+
factor will have a different number of bins. Counts get reused
|
37
|
+
across metrics, so hist_counts are cached but only if computed
|
38
|
+
globally, i.e. without masked samples.
|
39
|
+
"""
|
40
|
+
|
41
|
+
hist_counts, hist_bins = {}, {}
|
42
|
+
# np.where needed to satisfy linter
|
43
|
+
mask = np.where(subset_mask if subset_mask is not None else np.ones(data.shape[0], dtype=bool))
|
44
|
+
|
45
|
+
for cdx, fn in enumerate(names):
|
46
|
+
# linter doesn't like double indexing
|
47
|
+
col_data = data[mask, cdx].squeeze()
|
48
|
+
if is_categorical[cdx]:
|
49
|
+
# if discrete, use unique values as bins
|
50
|
+
bins, cnts = np.unique(col_data, return_counts=True)
|
51
|
+
else:
|
52
|
+
bins = hist_bins.get(fn, "auto")
|
53
|
+
cnts, bins = np.histogram(col_data, bins=bins, density=True)
|
54
|
+
|
55
|
+
hist_counts[fn] = cnts
|
56
|
+
hist_bins[fn] = bins
|
57
|
+
|
58
|
+
return hist_counts, hist_bins
|
59
|
+
|
60
|
+
|
61
|
+
def entropy(
|
62
|
+
data: NDArray[Any],
|
63
|
+
names: list[str],
|
64
|
+
is_categorical: list[bool],
|
65
|
+
normalized: bool = False,
|
66
|
+
subset_mask: NDArray[np.bool_] | None = None,
|
67
|
+
) -> NDArray[np.float64]:
|
68
|
+
"""
|
69
|
+
Meant for use with :term:`bias<Bias>` metrics, :term:`balance<Balance>`, :term:`diversity<Diversity>`,
|
70
|
+
ClasswiseBalance, and Classwise Diversity.
|
71
|
+
|
72
|
+
Compute entropy for discrete/categorical variables and for continuous variables through standard
|
73
|
+
histogram binning.
|
74
|
+
|
75
|
+
Parameters
|
76
|
+
----------
|
77
|
+
normalized: bool
|
78
|
+
Flag that determines whether or not to normalize entropy by log(num_bins)
|
79
|
+
subset_mask: NDArray[np.bool_] | None
|
80
|
+
Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
|
81
|
+
|
82
|
+
Note
|
83
|
+
----
|
84
|
+
For continuous variables, histogram bins are chosen automatically. See
|
85
|
+
numpy.histogram for details.
|
86
|
+
|
87
|
+
Returns
|
88
|
+
-------
|
89
|
+
ent: NDArray[np.float64]
|
90
|
+
Entropy estimate per column of X
|
91
|
+
|
92
|
+
See Also
|
93
|
+
--------
|
94
|
+
numpy.histogram
|
95
|
+
scipy.stats.entropy
|
96
|
+
"""
|
97
|
+
|
98
|
+
num_factors = len(names)
|
99
|
+
hist_counts, _ = get_counts(data, names, is_categorical, subset_mask)
|
100
|
+
|
101
|
+
ev_index = np.empty(num_factors)
|
102
|
+
for col, cnts in enumerate(hist_counts.values()):
|
103
|
+
# entropy in nats, normalizes counts
|
104
|
+
ev_index[col] = sp_entropy(cnts)
|
105
|
+
if normalized:
|
106
|
+
if len(cnts) == 1:
|
107
|
+
# log(0)
|
108
|
+
ev_index[col] = 0
|
109
|
+
else:
|
110
|
+
ev_index[col] /= np.log(len(cnts))
|
111
|
+
return ev_index
|
112
|
+
|
113
|
+
|
114
|
+
def get_num_bins(
|
115
|
+
data: NDArray[Any], names: list[str], is_categorical: list[bool], subset_mask: NDArray[np.bool_] | None = None
|
116
|
+
) -> NDArray[np.float64]:
|
117
|
+
"""
|
118
|
+
Number of bins or unique values for each metadata factor, used to
|
119
|
+
normalize entropy/:term:`diversity<Diversity>`.
|
120
|
+
|
121
|
+
Parameters
|
122
|
+
----------
|
123
|
+
subset_mask: NDArray[np.bool_] | None
|
124
|
+
Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
|
125
|
+
|
126
|
+
Returns
|
127
|
+
-------
|
128
|
+
NDArray[np.float64]
|
129
|
+
"""
|
130
|
+
# likely cached
|
131
|
+
hist_counts, _ = get_counts(data, names, is_categorical, subset_mask)
|
132
|
+
num_bins = np.empty(len(hist_counts))
|
133
|
+
for idx, cnts in enumerate(hist_counts.values()):
|
134
|
+
num_bins[idx] = len(cnts)
|
135
|
+
|
136
|
+
return num_bins
|
137
|
+
|
138
|
+
|
139
|
+
def infer_categorical(arr: NDArray[Any], threshold: float = 0.2) -> NDArray[Any]:
|
140
|
+
"""
|
141
|
+
Compute fraction of feature values that are unique --- intended to be used
|
142
|
+
for inferring whether variables are categorical.
|
143
|
+
"""
|
144
|
+
if arr.ndim == 1:
|
145
|
+
arr = np.expand_dims(arr, axis=1)
|
146
|
+
num_samples = arr.shape[0]
|
147
|
+
pct_unique = np.empty(arr.shape[1])
|
148
|
+
for col in range(arr.shape[1]): # type: ignore
|
149
|
+
uvals = np.unique(arr[:, col], axis=0)
|
150
|
+
pct_unique[col] = len(uvals) / num_samples
|
151
|
+
return pct_unique < threshold
|
152
|
+
|
153
|
+
|
154
|
+
def preprocess_metadata(
|
155
|
+
class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], cat_thresh: float = 0.2
|
156
|
+
) -> tuple[NDArray[Any], list[str], list[bool], NDArray[np.str_]]:
|
157
|
+
# if class_labels is not numeric
|
158
|
+
class_array = to_numpy(class_labels)
|
159
|
+
if not np.issubdtype(class_array.dtype, np.number):
|
160
|
+
unique_classes, numerical_labels = np.unique(class_array, return_inverse=True)
|
161
|
+
else:
|
162
|
+
numerical_labels = np.asarray(class_array, dtype=int)
|
163
|
+
unique_classes = np.unique(class_array)
|
164
|
+
|
165
|
+
# convert class_labels and dict of lists to matrix of metadata values
|
166
|
+
preprocessed_metadata = {CLASS_LABEL: numerical_labels}
|
167
|
+
|
168
|
+
# map columns of dict that are not numeric (e.g. string) to numeric values
|
169
|
+
# that mutual information and diversity functions can accommodate. Each
|
170
|
+
# unique string receives a unique integer value.
|
171
|
+
for k, v in metadata.items():
|
172
|
+
if k == CLASS_LABEL:
|
173
|
+
k = "label_class"
|
174
|
+
# if not numeric
|
175
|
+
v = to_numpy(v)
|
176
|
+
if not np.issubdtype(v.dtype, np.number):
|
177
|
+
_, mapped_vals = np.unique(v, return_inverse=True)
|
178
|
+
preprocessed_metadata[k] = mapped_vals
|
179
|
+
else:
|
180
|
+
preprocessed_metadata[k] = v
|
181
|
+
|
182
|
+
data = np.stack(list(preprocessed_metadata.values()), axis=-1)
|
183
|
+
names = list(preprocessed_metadata.keys())
|
184
|
+
is_categorical = [infer_categorical(preprocessed_metadata[var], cat_thresh)[0] for var in names]
|
185
|
+
|
186
|
+
return data, names, is_categorical, unique_classes
|
187
|
+
|
188
|
+
|
189
|
+
def heatmap(
|
190
|
+
data: NDArray[Any],
|
191
|
+
row_labels: list[str] | NDArray[Any],
|
192
|
+
col_labels: list[str] | NDArray[Any],
|
193
|
+
xlabel: str = "",
|
194
|
+
ylabel: str = "",
|
195
|
+
cbarlabel: str = "",
|
196
|
+
) -> Figure:
|
197
|
+
"""
|
198
|
+
Plots a formatted heatmap
|
199
|
+
|
200
|
+
Parameters
|
201
|
+
----------
|
202
|
+
data : NDArray
|
203
|
+
Array containing numerical values for factors to plot
|
204
|
+
row_labels : ArrayLike
|
205
|
+
List/Array containing the labels for rows in the histogram
|
206
|
+
col_labels : ArrayLike
|
207
|
+
List/Array containing the labels for columns in the histogram
|
208
|
+
xlabel : str, default ""
|
209
|
+
X-axis label
|
210
|
+
ylabel : str, default ""
|
211
|
+
Y-axis label
|
212
|
+
cbarlabel : str, default ""
|
213
|
+
Label for the colorbar
|
214
|
+
"""
|
215
|
+
import matplotlib
|
216
|
+
import matplotlib.pyplot as plt
|
217
|
+
|
218
|
+
fig, ax = plt.subplots(figsize=(10, 10))
|
219
|
+
|
220
|
+
# Plot the heatmap
|
221
|
+
im = ax.imshow(data, vmin=0, vmax=1.0)
|
222
|
+
|
223
|
+
# Create colorbar
|
224
|
+
cbar = fig.colorbar(im, shrink=0.5)
|
225
|
+
cbar.set_ticks([0.0, 0.25, 0.5, 0.75, 1.0])
|
226
|
+
cbar.set_ticklabels(["0.0", "0.25", "0.5", "0.75", "1.0"])
|
227
|
+
cbar.set_label(cbarlabel, loc="center")
|
228
|
+
|
229
|
+
# Show all ticks and label them with the respective list entries.
|
230
|
+
ax.set_xticks(np.arange(data.shape[1]), labels=col_labels)
|
231
|
+
ax.set_yticks(np.arange(data.shape[0]), labels=row_labels)
|
232
|
+
|
233
|
+
ax.tick_params(top=False, bottom=True, labeltop=False, labelbottom=True)
|
234
|
+
# Rotate the tick labels and set their alignment.
|
235
|
+
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
|
236
|
+
|
237
|
+
# Turn spines off and create white grid.
|
238
|
+
ax.spines[:].set_visible(False)
|
239
|
+
|
240
|
+
ax.set_xticks(np.arange(data.shape[1] + 1) - 0.5, minor=True)
|
241
|
+
ax.set_yticks(np.arange(data.shape[0] + 1) - 0.5, minor=True)
|
242
|
+
ax.grid(which="minor", color="w", linestyle="-", linewidth=3)
|
243
|
+
ax.tick_params(which="minor", bottom=False, left=False)
|
244
|
+
|
245
|
+
if xlabel:
|
246
|
+
ax.set_xlabel(xlabel)
|
247
|
+
if ylabel:
|
248
|
+
ax.set_ylabel(ylabel)
|
249
|
+
|
250
|
+
valfmt = matplotlib.ticker.FuncFormatter(format_text) # type: ignore
|
251
|
+
|
252
|
+
# Normalize the threshold to the images color range.
|
253
|
+
threshold = im.norm(1.0) / 2.0
|
254
|
+
|
255
|
+
# Set default alignment to center, but allow it to be
|
256
|
+
# overwritten by textkw.
|
257
|
+
kw = {"horizontalalignment": "center", "verticalalignment": "center"}
|
258
|
+
|
259
|
+
# Loop over the data and create a `Text` for each "pixel".
|
260
|
+
# Change the text's color depending on the data.
|
261
|
+
textcolors = ("white", "black")
|
262
|
+
texts = []
|
263
|
+
for i in range(data.shape[0]):
|
264
|
+
for j in range(data.shape[1]):
|
265
|
+
kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])
|
266
|
+
text = im.axes.text(j, i, valfmt(data[i, j], None), **kw) # type: ignore
|
267
|
+
texts.append(text)
|
268
|
+
|
269
|
+
fig.tight_layout()
|
270
|
+
return fig
|
271
|
+
|
272
|
+
|
273
|
+
# Function to define how the text is displayed in the heatmap
|
274
|
+
def format_text(*args: str) -> str:
|
275
|
+
"""
|
276
|
+
Helper function to format text for heatmap()
|
277
|
+
|
278
|
+
Parameters
|
279
|
+
----------
|
280
|
+
*args: Tuple (str, str)
|
281
|
+
Text to be formatted. Second element is ignored, but is a
|
282
|
+
mandatory pass-through argument as per matplotlib.ticket.FuncFormatter
|
283
|
+
|
284
|
+
Returns
|
285
|
+
-------
|
286
|
+
str
|
287
|
+
Formatted text
|
288
|
+
"""
|
289
|
+
x = args[0]
|
290
|
+
return f"{x:.2f}".replace("0.00", "0").replace("0.", ".").replace("nan", "")
|
291
|
+
|
292
|
+
|
293
|
+
def diversity_bar_plot(labels: NDArray[Any], bar_heights: NDArray[Any]) -> Figure:
|
294
|
+
"""
|
295
|
+
Plots a formatted bar plot
|
296
|
+
|
297
|
+
Parameters
|
298
|
+
----------
|
299
|
+
labels : NDArray
|
300
|
+
Array containing the labels for each bar
|
301
|
+
bar_heights : NDArray
|
302
|
+
Array containing the values for each bar
|
303
|
+
"""
|
304
|
+
import matplotlib.pyplot as plt
|
305
|
+
|
306
|
+
fig, ax = plt.subplots(figsize=(10, 10))
|
307
|
+
|
308
|
+
ax.bar(labels, bar_heights)
|
309
|
+
ax.set_xlabel("Factors")
|
310
|
+
|
311
|
+
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
|
312
|
+
|
313
|
+
fig.tight_layout()
|
314
|
+
return fig
|
315
|
+
|
316
|
+
|
317
|
+
def coverage_plot(images: NDArray[Any], num_images: int) -> Figure:
|
318
|
+
"""
|
319
|
+
Creates a single plot of all of the provided images
|
320
|
+
|
321
|
+
Parameters
|
322
|
+
----------
|
323
|
+
images : NDArray
|
324
|
+
Array containing only the desired images to plot
|
325
|
+
"""
|
326
|
+
import matplotlib.pyplot as plt
|
327
|
+
|
328
|
+
num_images = min(num_images, len(images))
|
329
|
+
|
330
|
+
if images.ndim == 4:
|
331
|
+
images = np.moveaxis(images, 1, -1)
|
332
|
+
elif images.ndim == 3:
|
333
|
+
images = np.repeat(images[:, :, :, np.newaxis], 3, axis=-1)
|
334
|
+
else:
|
335
|
+
raise ValueError(
|
336
|
+
f"Expected a (N,C,H,W) or a (N, H, W) set of images, but got a {images.ndim}-dimensional set of images."
|
337
|
+
)
|
338
|
+
|
339
|
+
rows = np.ceil(num_images / 3).astype(int)
|
340
|
+
fig, axs = plt.subplots(rows, 3, figsize=(9, 3 * rows))
|
341
|
+
|
342
|
+
if rows == 1:
|
343
|
+
for j in range(3):
|
344
|
+
if j >= len(images):
|
345
|
+
continue
|
346
|
+
axs[j].imshow(images[j])
|
347
|
+
axs[j].axis("off")
|
348
|
+
else:
|
349
|
+
for i in range(rows):
|
350
|
+
for j in range(3):
|
351
|
+
i_j = i * 3 + j
|
352
|
+
if i_j >= len(images):
|
353
|
+
continue
|
354
|
+
axs[i, j].imshow(images[i_j])
|
355
|
+
axs[i, j].axis("off")
|
356
|
+
|
357
|
+
fig.tight_layout()
|
358
|
+
return fig
|
@@ -1,15 +1,18 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
__all__ = ["ParityOutput", "parity", "label_parity"]
|
4
|
+
|
3
5
|
import warnings
|
4
6
|
from dataclasses import dataclass
|
5
|
-
from typing import Generic, Mapping, TypeVar
|
7
|
+
from typing import Any, Generic, Mapping, TypeVar
|
6
8
|
|
7
9
|
import numpy as np
|
8
10
|
from numpy.typing import ArrayLike, NDArray
|
9
11
|
from scipy.stats import chi2_contingency, chisquare
|
10
12
|
|
11
|
-
from dataeval.
|
12
|
-
from dataeval.
|
13
|
+
from dataeval.interop import to_numpy
|
14
|
+
from dataeval.metrics.bias.metadata import CLASS_LABEL, preprocess_metadata
|
15
|
+
from dataeval.output import OutputMetadata, set_metadata
|
13
16
|
|
14
17
|
TData = TypeVar("TData", np.float64, NDArray[np.float64])
|
15
18
|
|
@@ -25,13 +28,16 @@ class ParityOutput(Generic[TData], OutputMetadata):
|
|
25
28
|
chi-squared score(s) of the test
|
26
29
|
p_value : np.float64 | NDArray[np.float64]
|
27
30
|
p-value(s) of the test
|
31
|
+
metadata_names: list[str] | None
|
32
|
+
Names of each metadata factor
|
28
33
|
"""
|
29
34
|
|
30
35
|
score: TData
|
31
36
|
p_value: TData
|
37
|
+
metadata_names: list[str] | None
|
32
38
|
|
33
39
|
|
34
|
-
def digitize_factor_bins(continuous_values: NDArray, bins: int, factor_name: str) -> NDArray:
|
40
|
+
def digitize_factor_bins(continuous_values: NDArray[Any], bins: int, factor_name: str) -> NDArray[np.intp]:
|
35
41
|
"""
|
36
42
|
Digitizes a list of values into a given number of bins.
|
37
43
|
|
@@ -64,8 +70,8 @@ def digitize_factor_bins(continuous_values: NDArray, bins: int, factor_name: str
|
|
64
70
|
|
65
71
|
|
66
72
|
def format_discretize_factors(
|
67
|
-
|
68
|
-
) -> dict[str, NDArray]:
|
73
|
+
data: NDArray[Any], names: list[str], is_categorical: list[bool], continuous_factor_bincounts: Mapping[str, int]
|
74
|
+
) -> dict[str, NDArray[Any]]:
|
69
75
|
"""
|
70
76
|
Sets up the internal list of metadata factors.
|
71
77
|
|
@@ -87,35 +93,37 @@ def format_discretize_factors(
|
|
87
93
|
Each key is a metadata factor, whose value is the discrete per-image factor values.
|
88
94
|
"""
|
89
95
|
|
90
|
-
invalid_keys = set(continuous_factor_bincounts.keys()) - set(
|
96
|
+
invalid_keys = set(continuous_factor_bincounts.keys()) - set(names)
|
91
97
|
if invalid_keys:
|
92
98
|
raise KeyError(
|
93
99
|
f"The continuous factor(s) {invalid_keys} do not exist in data_factors. Delete these "
|
94
100
|
"keys from `continuous_factor_names` or add corresponding entries to `data_factors`."
|
95
101
|
)
|
96
102
|
|
103
|
+
warn = []
|
97
104
|
metadata_factors = {}
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
105
|
+
for i, name in enumerate(names):
|
106
|
+
if name == CLASS_LABEL:
|
107
|
+
continue
|
108
|
+
if name in continuous_factor_bincounts:
|
109
|
+
metadata_factors[name] = digitize_factor_bins(data[:, i], continuous_factor_bincounts[name], name)
|
110
|
+
elif not is_categorical[i]:
|
111
|
+
warn.append(name)
|
112
|
+
metadata_factors[name] = data[:, i]
|
113
|
+
else:
|
114
|
+
metadata_factors[name] = data[:, i]
|
115
|
+
|
116
|
+
if warn:
|
117
|
+
warnings.warn(
|
118
|
+
f"The following factors appear to be continuous but did not have the desired number of bins specified: \n\
|
119
|
+
{warn}",
|
120
|
+
UserWarning,
|
121
|
+
)
|
114
122
|
|
115
123
|
return metadata_factors
|
116
124
|
|
117
125
|
|
118
|
-
def normalize_expected_dist(expected_dist: NDArray, observed_dist: NDArray) -> NDArray:
|
126
|
+
def normalize_expected_dist(expected_dist: NDArray[Any], observed_dist: NDArray[Any]) -> NDArray[Any]:
|
119
127
|
"""
|
120
128
|
Normalize the expected label distribution to match the total number of labels in the observed distribution.
|
121
129
|
|
@@ -162,7 +170,7 @@ def normalize_expected_dist(expected_dist: NDArray, observed_dist: NDArray) -> N
|
|
162
170
|
return expected_dist
|
163
171
|
|
164
172
|
|
165
|
-
def validate_dist(label_dist: NDArray, label_name: str):
|
173
|
+
def validate_dist(label_dist: NDArray[Any], label_name: str) -> None:
|
166
174
|
"""
|
167
175
|
Verifies that the given label distribution has labels and checks if
|
168
176
|
any labels have frequencies less than 5.
|
@@ -191,7 +199,7 @@ def validate_dist(label_dist: NDArray, label_name: str):
|
|
191
199
|
)
|
192
200
|
|
193
201
|
|
194
|
-
@set_metadata(
|
202
|
+
@set_metadata()
|
195
203
|
def label_parity(
|
196
204
|
expected_labels: ArrayLike,
|
197
205
|
observed_labels: ArrayLike,
|
@@ -245,7 +253,7 @@ def label_parity(
|
|
245
253
|
>>> expected_labels = np_random_gen.choice([0, 1, 2, 3, 4], (100))
|
246
254
|
>>> observed_labels = np_random_gen.choice([2, 3, 0, 4, 1], (100))
|
247
255
|
>>> label_parity(expected_labels, observed_labels)
|
248
|
-
ParityOutput(score=14.007374204742625, p_value=0.0072715574616218)
|
256
|
+
ParityOutput(score=14.007374204742625, p_value=0.0072715574616218, metadata_names=None)
|
249
257
|
"""
|
250
258
|
|
251
259
|
# Calculate
|
@@ -276,13 +284,13 @@ def label_parity(
|
|
276
284
|
)
|
277
285
|
|
278
286
|
cs, p = chisquare(f_obs=observed_dist, f_exp=expected_dist)
|
279
|
-
return ParityOutput(cs, p)
|
287
|
+
return ParityOutput(cs, p, None)
|
280
288
|
|
281
289
|
|
282
|
-
@set_metadata(
|
290
|
+
@set_metadata()
|
283
291
|
def parity(
|
284
292
|
class_labels: ArrayLike,
|
285
|
-
|
293
|
+
metadata: Mapping[str, ArrayLike],
|
286
294
|
continuous_factor_bincounts: Mapping[str, int] | None = None,
|
287
295
|
) -> ParityOutput[NDArray[np.float64]]:
|
288
296
|
"""
|
@@ -297,12 +305,12 @@ def parity(
|
|
297
305
|
----------
|
298
306
|
class_labels: ArrayLike
|
299
307
|
List of class labels for each image
|
300
|
-
|
308
|
+
metadata: Mapping[str, ArrayLike]
|
301
309
|
The dataset factors, which are per-image metadata attributes.
|
302
310
|
Each key of dataset_factors is a factor, whose value is the per-image factor values.
|
303
311
|
continuous_factor_bincounts : Mapping[str, int] | None, default None
|
304
312
|
A dictionary specifying the number of bins for discretizing the continuous factors.
|
305
|
-
The keys should correspond to the names of continuous factors in `
|
313
|
+
The keys should correspond to the names of continuous factors in `metadata`,
|
306
314
|
and the values should be the number of bins to use for discretization.
|
307
315
|
If not provided, no discretization is applied.
|
308
316
|
|
@@ -335,42 +343,44 @@ def parity(
|
|
335
343
|
Randomly creating some "continuous" and categorical variables using ``np.random.default_rng``
|
336
344
|
|
337
345
|
>>> labels = np_random_gen.choice([0, 1, 2], (100))
|
338
|
-
>>>
|
346
|
+
>>> metadata = {
|
339
347
|
... "age": np_random_gen.choice([25, 30, 35, 45], (100)),
|
340
348
|
... "income": np_random_gen.choice([50000, 65000, 80000], (100)),
|
341
349
|
... "gender": np_random_gen.choice(["M", "F"], (100)),
|
342
350
|
... }
|
343
351
|
>>> continuous_factor_bincounts = {"age": 4, "income": 3}
|
344
|
-
>>> parity(labels,
|
345
|
-
ParityOutput(score=array([7.35731943, 5.46711299, 0.51506212]), p_value=array([0.28906231, 0.24263543, 0.77295762]))
|
346
|
-
"""
|
352
|
+
>>> parity(labels, metadata, continuous_factor_bincounts)
|
353
|
+
ParityOutput(score=array([7.35731943, 5.46711299, 0.51506212]), p_value=array([0.28906231, 0.24263543, 0.77295762]), metadata_names=['age', 'income', 'gender'])
|
354
|
+
""" # noqa: E501
|
347
355
|
if len(np.shape(class_labels)) > 1:
|
348
356
|
raise ValueError(
|
349
357
|
f"Got class labels with {len(np.shape(class_labels))}-dimensional",
|
350
358
|
f" shape {np.shape(class_labels)}, but expected a 1-dimensional array.",
|
351
359
|
)
|
352
360
|
|
353
|
-
|
361
|
+
data, names, is_categorical, _ = preprocess_metadata(class_labels, metadata)
|
354
362
|
continuous_factor_bincounts = continuous_factor_bincounts if continuous_factor_bincounts else {}
|
355
363
|
|
356
|
-
|
357
|
-
|
364
|
+
factors = format_discretize_factors(data, names, is_categorical, continuous_factor_bincounts)
|
365
|
+
|
366
|
+
# unique class labels
|
367
|
+
class_idx = names.index(CLASS_LABEL)
|
368
|
+
u_cls = np.unique(data[:, class_idx])
|
358
369
|
|
359
370
|
chi_scores = np.zeros(len(factors))
|
360
371
|
p_values = np.zeros(len(factors))
|
361
|
-
n_cls = len(np.unique(labels))
|
362
372
|
not_enough_data = {}
|
363
373
|
for i, (current_factor_name, factor_values) in enumerate(factors.items()):
|
364
374
|
unique_factor_values = np.unique(factor_values)
|
365
|
-
contingency_matrix = np.zeros((len(unique_factor_values),
|
375
|
+
contingency_matrix = np.zeros((len(unique_factor_values), u_cls.size))
|
366
376
|
# Builds a contingency matrix where entry at index (r,c) represents
|
367
377
|
# the frequency of current_factor_name achieving value unique_factor_values[r]
|
368
378
|
# at a data point with class c.
|
369
379
|
|
370
380
|
# TODO: Vectorize this nested for loop
|
371
381
|
for fi, factor_value in enumerate(unique_factor_values):
|
372
|
-
for label in
|
373
|
-
with_both = np.bitwise_and((
|
382
|
+
for label in u_cls:
|
383
|
+
with_both = np.bitwise_and((data[:, class_idx] == label), factor_values == factor_value)
|
374
384
|
contingency_matrix[fi, label] = np.sum(with_both)
|
375
385
|
if 0 < contingency_matrix[fi, label] < 5:
|
376
386
|
if current_factor_name not in not_enough_data:
|
@@ -412,4 +422,4 @@ def parity(
|
|
412
422
|
UserWarning,
|
413
423
|
)
|
414
424
|
|
415
|
-
return ParityOutput(chi_scores, p_values)
|
425
|
+
return ParityOutput(chi_scores, p_values, list(metadata.keys()))
|
@@ -2,8 +2,8 @@
|
|
2
2
|
Estimators calculate performance bounds and the statistical distance between datasets.
|
3
3
|
"""
|
4
4
|
|
5
|
-
from dataeval.
|
6
|
-
from dataeval.
|
7
|
-
from dataeval.
|
5
|
+
from dataeval.metrics.estimators.ber import BEROutput, ber
|
6
|
+
from dataeval.metrics.estimators.divergence import DivergenceOutput, divergence
|
7
|
+
from dataeval.metrics.estimators.uap import UAPOutput, uap
|
8
8
|
|
9
9
|
__all__ = ["ber", "divergence", "uap", "BEROutput", "DivergenceOutput", "UAPOutput"]
|