dataeval 0.72.2__py3-none-any.whl → 0.73.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. dataeval/__init__.py +3 -3
  2. dataeval/detectors/__init__.py +1 -1
  3. dataeval/detectors/drift/__init__.py +1 -1
  4. dataeval/detectors/drift/base.py +2 -2
  5. dataeval/detectors/linters/clusterer.py +1 -1
  6. dataeval/detectors/ood/__init__.py +1 -1
  7. dataeval/detectors/ood/ae.py +14 -6
  8. dataeval/detectors/ood/aegmm.py +14 -6
  9. dataeval/detectors/ood/base.py +9 -3
  10. dataeval/detectors/ood/llr.py +22 -16
  11. dataeval/detectors/ood/vae.py +14 -6
  12. dataeval/detectors/ood/vaegmm.py +14 -6
  13. dataeval/interop.py +9 -7
  14. dataeval/metrics/bias/balance.py +50 -44
  15. dataeval/metrics/bias/coverage.py +38 -6
  16. dataeval/metrics/bias/diversity.py +117 -65
  17. dataeval/metrics/bias/metadata.py +225 -60
  18. dataeval/metrics/bias/parity.py +68 -54
  19. dataeval/utils/__init__.py +4 -3
  20. dataeval/utils/lazy.py +26 -0
  21. dataeval/utils/metadata.py +258 -0
  22. dataeval/utils/shared.py +1 -1
  23. dataeval/utils/split_dataset.py +12 -6
  24. dataeval/utils/tensorflow/_internal/gmm.py +8 -2
  25. dataeval/utils/tensorflow/_internal/loss.py +20 -11
  26. dataeval/utils/tensorflow/_internal/{pixelcnn.py → models.py} +371 -77
  27. dataeval/utils/tensorflow/_internal/trainer.py +12 -5
  28. dataeval/utils/tensorflow/_internal/utils.py +70 -71
  29. dataeval/utils/torch/datasets.py +2 -2
  30. dataeval/workflows/__init__.py +1 -1
  31. {dataeval-0.72.2.dist-info → dataeval-0.73.1.dist-info}/METADATA +3 -3
  32. {dataeval-0.72.2.dist-info → dataeval-0.73.1.dist-info}/RECORD +34 -33
  33. dataeval/utils/tensorflow/_internal/autoencoder.py +0 -316
  34. {dataeval-0.72.2.dist-info → dataeval-0.73.1.dist-info}/LICENSE.txt +0 -0
  35. {dataeval-0.72.2.dist-info → dataeval-0.73.1.dist-info}/WHEEL +0 -0
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
+ import contextlib
5
6
  from typing import Any, Mapping
6
7
 
7
8
  import numpy as np
@@ -10,54 +11,87 @@ from scipy.stats import entropy as sp_entropy
10
11
 
11
12
  from dataeval.interop import to_numpy
12
13
 
14
+ with contextlib.suppress(ImportError):
15
+ from matplotlib.figure import Figure
16
+
17
+ CLASS_LABEL = "class_label"
18
+
13
19
 
14
20
  def get_counts(
15
- data: NDArray[np.int_], names: list[str], is_categorical: list[bool], subset_mask: NDArray[np.bool_] | None = None
16
- ) -> tuple[dict[str, NDArray[np.int_]], dict[str, NDArray[np.int_]]]:
21
+ data: NDArray[Any],
22
+ names: list[str],
23
+ continuous_factor_bincounts: Mapping[str, int] | None = None,
24
+ subset_mask: NDArray[np.bool_] | None = None,
25
+ hist_cache: dict[str, NDArray[np.intp]] | None = None,
26
+ ) -> dict[str, NDArray[np.intp]]:
17
27
  """
18
28
  Initialize dictionary of histogram counts --- treat categorical values
19
29
  as histogram bins.
20
30
 
21
31
  Parameters
22
32
  ----------
23
- subset_mask: NDArray[np.bool_] | None
33
+ data : NDArray
34
+ Array containing numerical values for metadata factors
35
+ names : list[str]
36
+ Names of metadata factors -- keys of the metadata dictionary
37
+ continuous_factor_bincounts : Mapping[str, int] or None, default None
38
+ The factors in names that have continuous values and the array of bin counts to
39
+ discretize values into. All factors are treated as having discrete values unless they
40
+ are specified as keys in this dictionary. Each element of this array must occur as a key
41
+ in names.
42
+ Names of metadata factors -- keys of the metadata dictionary
43
+ subset_mask : NDArray[np.bool_] or None, default None
24
44
  Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
45
+ hist_cache : dict[str, NDArray[np.intp]] or None, default None
46
+ Optional cache to store histogram counts
25
47
 
26
48
  Returns
27
49
  -------
28
- counts: Dict
50
+ dict[str, NDArray[np.intp]]
29
51
  histogram counts per metadata factor in `factors`. Each
30
52
  factor will have a different number of bins. Counts get reused
31
53
  across metrics, so hist_counts are cached but only if computed
32
54
  globally, i.e. without masked samples.
33
55
  """
34
56
 
35
- hist_counts, hist_bins = {}, {}
36
- # np.where needed to satisfy linter
37
- mask = np.where(subset_mask if subset_mask is not None else np.ones(data.shape[0], dtype=bool))
57
+ hist_counts = {}
58
+
59
+ mask = subset_mask if subset_mask is not None else np.ones(data.shape[0], dtype=np.bool_)
38
60
 
39
61
  for cdx, fn in enumerate(names):
40
- # linter doesn't like double indexing
41
- col_data = data[mask, cdx].squeeze()
42
- if is_categorical[cdx]:
43
- # if discrete, use unique values as bins
44
- bins, cnts = np.unique(col_data, return_counts=True)
62
+ if hist_cache is not None and fn in hist_cache:
63
+ cnts = hist_cache[fn]
45
64
  else:
46
- bins = hist_bins.get(fn, "auto")
47
- cnts, bins = np.histogram(col_data, bins=bins, density=True)
65
+ hist_edges = np.array([-np.inf, np.inf])
66
+ cnts = np.array([len(data[:, cdx].squeeze())])
67
+ # linter doesn't like double indexing
68
+ col_data = np.array(data[mask, cdx].squeeze(), dtype=np.float64)
69
+
70
+ if continuous_factor_bincounts and fn in continuous_factor_bincounts:
71
+ num_bins = continuous_factor_bincounts[fn]
72
+ _, hist_edges = np.histogram(data[:, cdx].squeeze(), bins=num_bins, density=True)
73
+ hist_edges[-1] = np.inf
74
+ hist_edges[0] = -np.inf
75
+ disc_col_data = np.digitize(col_data, np.array(hist_edges))
76
+ _, cnts = np.unique(disc_col_data, return_counts=True)
77
+ else:
78
+ _, cnts = np.unique(col_data, return_counts=True)
79
+
80
+ if hist_cache is not None:
81
+ hist_cache[fn] = cnts
48
82
 
49
83
  hist_counts[fn] = cnts
50
- hist_bins[fn] = bins
51
84
 
52
- return hist_counts, hist_bins
85
+ return hist_counts
53
86
 
54
87
 
55
88
  def entropy(
56
89
  data: NDArray[Any],
57
90
  names: list[str],
58
- is_categorical: list[bool],
91
+ continuous_factor_bincounts: Mapping[str, int] | None = None,
59
92
  normalized: bool = False,
60
93
  subset_mask: NDArray[np.bool_] | None = None,
94
+ hist_cache: dict[str, NDArray[np.intp]] | None = None,
61
95
  ) -> NDArray[np.float64]:
62
96
  """
63
97
  Meant for use with :term:`bias<Bias>` metrics, :term:`balance<Balance>`, :term:`diversity<Diversity>`,
@@ -68,19 +102,30 @@ def entropy(
68
102
 
69
103
  Parameters
70
104
  ----------
71
- normalized: bool
105
+ data : NDArray
106
+ Array containing numerical values for metadata factors
107
+ names : list[str]
108
+ Names of metadata factors -- keys of the metadata dictionary
109
+ continuous_factor_bincounts : Mapping[str, int] or None, default None
110
+ The factors in names that have continuous values and the array of bin counts to
111
+ discretize values into. All factors are treated as having discrete values unless they
112
+ are specified as keys in this dictionary. Each element of this array must occur as a key
113
+ in names.
114
+ normalized : bool, default False
72
115
  Flag that determines whether or not to normalize entropy by log(num_bins)
73
- subset_mask: NDArray[np.bool_] | None
116
+ subset_mask : NDArray[np.bool_] or None, default None
74
117
  Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
118
+ hist_cache : dict[str, NDArray[np.intp]] or None, default None
119
+ Optional cache to store histogram counts
75
120
 
76
- Note
77
- ----
121
+ Notes
122
+ -----
78
123
  For continuous variables, histogram bins are chosen automatically. See
79
124
  numpy.histogram for details.
80
125
 
81
126
  Returns
82
127
  -------
83
- ent: NDArray[np.float64]
128
+ NDArray[np.float64]
84
129
  Entropy estimate per column of X
85
130
 
86
131
  See Also
@@ -90,47 +135,64 @@ def entropy(
90
135
  """
91
136
 
92
137
  num_factors = len(names)
93
- hist_counts, _ = get_counts(data, names, is_categorical, subset_mask)
138
+ hist_counts = get_counts(data, names, continuous_factor_bincounts, subset_mask, hist_cache)
94
139
 
95
140
  ev_index = np.empty(num_factors)
96
141
  for col, cnts in enumerate(hist_counts.values()):
97
142
  # entropy in nats, normalizes counts
98
143
  ev_index[col] = sp_entropy(cnts)
99
144
  if normalized:
100
- if len(cnts) == 1:
145
+ cnt_len = np.size(cnts, 0)
146
+ if cnt_len == 1:
101
147
  # log(0)
102
148
  ev_index[col] = 0
103
149
  else:
104
- ev_index[col] /= np.log(len(cnts))
150
+ ev_index[col] /= np.log(cnt_len)
105
151
  return ev_index
106
152
 
107
153
 
108
154
  def get_num_bins(
109
- data: NDArray[Any], names: list[str], is_categorical: list[bool], subset_mask: NDArray[np.bool_] | None = None
155
+ data: NDArray[Any],
156
+ names: list[str],
157
+ continuous_factor_bincounts: Mapping[str, int] | None = None,
158
+ subset_mask: NDArray[np.bool_] | None = None,
159
+ hist_cache: dict[str, NDArray[np.intp]] | None = None,
110
160
  ) -> NDArray[np.float64]:
111
161
  """
112
162
  Number of bins or unique values for each metadata factor, used to
113
- normalize entropy/:term:`diversity<Diversity>`.
163
+ normalize entropy/diversity.
114
164
 
115
165
  Parameters
116
166
  ----------
117
- subset_mask: NDArray[np.bool_] | None
167
+ data : NDArray
168
+ Array containing numerical values for metadata factors
169
+ names : list[str]
170
+ Names of metadata factors -- keys of the metadata dictionary
171
+ continuous_factor_bincounts : Mapping[str, int] or None, default None
172
+ The factors in names that have continuous values and the array of bin counts to
173
+ discretize values into. All factors are treated as having discrete values unless they
174
+ are specified as keys in this dictionary. Each element of this array must occur as a key
175
+ in names.
176
+ subset_mask : NDArray[np.bool_] or None, default None
118
177
  Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
178
+ hist_cache : dict[str, NDArray[np.intp]] or None, default None
179
+ Optional cache to store histogram counts
119
180
 
120
181
  Returns
121
182
  -------
122
183
  NDArray[np.float64]
184
+ Number of bins used in the discretization for each value in names.
123
185
  """
124
186
  # likely cached
125
- hist_counts, _ = get_counts(data, names, is_categorical, subset_mask)
187
+ hist_counts = get_counts(data, names, continuous_factor_bincounts, subset_mask, hist_cache)
126
188
  num_bins = np.empty(len(hist_counts))
127
189
  for idx, cnts in enumerate(hist_counts.values()):
128
- num_bins[idx] = len(cnts)
190
+ num_bins[idx] = np.size(cnts, 0)
129
191
 
130
192
  return num_bins
131
193
 
132
194
 
133
- def infer_categorical(arr: NDArray[Any], threshold: float = 0.2) -> NDArray[Any]:
195
+ def infer_categorical(arr: NDArray[np.float64], threshold: float = 0.2) -> NDArray[np.bool_]:
134
196
  """
135
197
  Compute fraction of feature values that are unique --- intended to be used
136
198
  for inferring whether variables are categorical.
@@ -147,14 +209,28 @@ def infer_categorical(arr: NDArray[Any], threshold: float = 0.2) -> NDArray[Any]
147
209
 
148
210
  def preprocess_metadata(
149
211
  class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], cat_thresh: float = 0.2
150
- ) -> tuple[NDArray[Any], list[str], list[bool]]:
212
+ ) -> tuple[NDArray[Any], list[str], list[bool], NDArray[np.str_]]:
213
+ """
214
+ Formats metadata by organizing factor names, converting labels to numeric values,
215
+ adds class labels to the dataset structure, and marks which factors are categorical.
216
+ """
217
+ # if class_labels is not numeric
218
+ class_array = to_numpy(class_labels)
219
+ if not np.issubdtype(class_array.dtype, np.integer):
220
+ unique_classes, numerical_labels = np.unique(class_array, return_inverse=True)
221
+ else:
222
+ numerical_labels = np.asarray(class_array, dtype=np.intp)
223
+ unique_classes = np.unique(class_array)
224
+
151
225
  # convert class_labels and dict of lists to matrix of metadata values
152
- preprocessed_metadata = {"class_label": np.asarray(class_labels, dtype=int)}
226
+ preprocessed_metadata = {CLASS_LABEL: numerical_labels}
153
227
 
154
228
  # map columns of dict that are not numeric (e.g. string) to numeric values
155
229
  # that mutual information and diversity functions can accommodate. Each
156
230
  # unique string receives a unique integer value.
157
231
  for k, v in metadata.items():
232
+ if k == CLASS_LABEL:
233
+ continue
158
234
  # if not numeric
159
235
  v = to_numpy(v)
160
236
  if not np.issubdtype(v.dtype, np.number):
@@ -165,45 +241,56 @@ def preprocess_metadata(
165
241
 
166
242
  data = np.stack(list(preprocessed_metadata.values()), axis=-1)
167
243
  names = list(preprocessed_metadata.keys())
168
- is_categorical = [infer_categorical(preprocessed_metadata[var], cat_thresh)[0] for var in names]
244
+ is_categorical = [
245
+ var == CLASS_LABEL or infer_categorical(preprocessed_metadata[var].astype(np.float64), cat_thresh)[0]
246
+ for var in names
247
+ ]
169
248
 
170
- return data, names, is_categorical
249
+ return data, names, is_categorical, unique_classes
171
250
 
172
251
 
173
252
  def heatmap(
174
- data: NDArray[Any],
175
- row_labels: NDArray[Any],
176
- col_labels: NDArray[Any],
253
+ data: ArrayLike,
254
+ row_labels: list[str] | ArrayLike,
255
+ col_labels: list[str] | ArrayLike,
177
256
  xlabel: str = "",
178
257
  ylabel: str = "",
179
258
  cbarlabel: str = "",
180
- ) -> None:
259
+ ) -> Figure:
181
260
  """
182
261
  Plots a formatted heatmap
183
262
 
184
263
  Parameters
185
264
  ----------
186
- data: NDArray
265
+ data : NDArray
187
266
  Array containing numerical values for factors to plot
188
- row_labels: NDArray
189
- Array containing the labels for rows in the histogram
190
- col_labels: NDArray
191
- Array containing the labels for columns in the histogram
192
- xlabel: str, default ""
267
+ row_labels : ArrayLike
268
+ List/Array containing the labels for rows in the histogram
269
+ col_labels : ArrayLike
270
+ List/Array containing the labels for columns in the histogram
271
+ xlabel : str, default ""
193
272
  X-axis label
194
- ylabel: str, default ""
273
+ ylabel : str, default ""
195
274
  Y-axis label
196
- cbarlabel: str, default ""
275
+ cbarlabel : str, default ""
197
276
  Label for the colorbar
198
277
 
278
+ Returns
279
+ -------
280
+ matplotlib.figure.Figure
281
+ Formatted heatmap
199
282
  """
200
- import matplotlib
201
283
  import matplotlib.pyplot as plt
284
+ from matplotlib.ticker import FuncFormatter
285
+
286
+ np_data = to_numpy(data)
287
+ rows = row_labels if isinstance(row_labels, list) else to_numpy(row_labels)
288
+ cols = col_labels if isinstance(col_labels, list) else to_numpy(col_labels)
202
289
 
203
290
  fig, ax = plt.subplots(figsize=(10, 10))
204
291
 
205
292
  # Plot the heatmap
206
- im = ax.imshow(data, vmin=0, vmax=1.0)
293
+ im = ax.imshow(np_data, vmin=0, vmax=1.0)
207
294
 
208
295
  # Create colorbar
209
296
  cbar = fig.colorbar(im, shrink=0.5)
@@ -212,8 +299,8 @@ def heatmap(
212
299
  cbar.set_label(cbarlabel, loc="center")
213
300
 
214
301
  # Show all ticks and label them with the respective list entries.
215
- ax.set_xticks(np.arange(data.shape[1]), labels=col_labels)
216
- ax.set_yticks(np.arange(data.shape[0]), labels=row_labels)
302
+ ax.set_xticks(np.arange(np_data.shape[1]), labels=cols)
303
+ ax.set_yticks(np.arange(np_data.shape[0]), labels=rows)
217
304
 
218
305
  ax.tick_params(top=False, bottom=True, labeltop=False, labelbottom=True)
219
306
  # Rotate the tick labels and set their alignment.
@@ -222,8 +309,8 @@ def heatmap(
222
309
  # Turn spines off and create white grid.
223
310
  ax.spines[:].set_visible(False)
224
311
 
225
- ax.set_xticks(np.arange(data.shape[1] + 1) - 0.5, minor=True)
226
- ax.set_yticks(np.arange(data.shape[0] + 1) - 0.5, minor=True)
312
+ ax.set_xticks(np.arange(np_data.shape[1] + 1) - 0.5, minor=True)
313
+ ax.set_yticks(np.arange(np_data.shape[0] + 1) - 0.5, minor=True)
227
314
  ax.grid(which="minor", color="w", linestyle="-", linewidth=3)
228
315
  ax.tick_params(which="minor", bottom=False, left=False)
229
316
 
@@ -232,7 +319,7 @@ def heatmap(
232
319
  if ylabel:
233
320
  ax.set_ylabel(ylabel)
234
321
 
235
- valfmt = matplotlib.ticker.FuncFormatter(format_text) # type: ignore
322
+ valfmt = FuncFormatter(format_text)
236
323
 
237
324
  # Normalize the threshold to the images color range.
238
325
  threshold = im.norm(1.0) / 2.0
@@ -245,14 +332,14 @@ def heatmap(
245
332
  # Change the text's color depending on the data.
246
333
  textcolors = ("white", "black")
247
334
  texts = []
248
- for i in range(data.shape[0]):
249
- for j in range(data.shape[1]):
250
- kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])
251
- text = im.axes.text(j, i, valfmt(data[i, j], None), **kw) # type: ignore
335
+ for i in range(np_data.shape[0]):
336
+ for j in range(np_data.shape[1]):
337
+ kw.update(color=textcolors[int(im.norm(np_data[i, j]) > threshold)])
338
+ text = im.axes.text(j, i, valfmt(np_data[i, j], None), **kw) # type: ignore
252
339
  texts.append(text)
253
340
 
254
341
  fig.tight_layout()
255
- plt.show()
342
+ return fig
256
343
 
257
344
 
258
345
  # Function to define how the text is displayed in the heatmap
@@ -262,7 +349,7 @@ def format_text(*args: str) -> str:
262
349
 
263
350
  Parameters
264
351
  ----------
265
- *args: Tuple (str, str)
352
+ *args : tuple[str, str]
266
353
  Text to be formatted. Second element is ignored, but is a
267
354
  mandatory pass-through argument as per matplotlib.ticket.FuncFormatter
268
355
 
@@ -273,3 +360,81 @@ def format_text(*args: str) -> str:
273
360
  """
274
361
  x = args[0]
275
362
  return f"{x:.2f}".replace("0.00", "0").replace("0.", ".").replace("nan", "")
363
+
364
+
365
+ def diversity_bar_plot(labels: NDArray[Any], bar_heights: NDArray[Any]) -> Figure:
366
+ """
367
+ Plots a formatted bar plot
368
+
369
+ Parameters
370
+ ----------
371
+ labels : NDArray
372
+ Array containing the labels for each bar
373
+ bar_heights : NDArray
374
+ Array containing the values for each bar
375
+
376
+ Returns
377
+ -------
378
+ matplotlib.figure.Figure
379
+ Bar plot figure
380
+ """
381
+ import matplotlib.pyplot as plt
382
+
383
+ fig, ax = plt.subplots(figsize=(10, 10))
384
+
385
+ ax.bar(labels, bar_heights)
386
+ ax.set_xlabel("Factors")
387
+
388
+ plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
389
+
390
+ fig.tight_layout()
391
+ return fig
392
+
393
+
394
+ def coverage_plot(images: NDArray[Any], num_images: int) -> Figure:
395
+ """
396
+ Creates a single plot of all of the provided images
397
+
398
+ Parameters
399
+ ----------
400
+ images : NDArray
401
+ Array containing only the desired images to plot
402
+
403
+ Returns
404
+ -------
405
+ matplotlib.figure.Figure
406
+ Plot of all provided images
407
+ """
408
+ import matplotlib.pyplot as plt
409
+
410
+ num_images = min(num_images, len(images))
411
+
412
+ if images.ndim == 4:
413
+ images = np.moveaxis(images, 1, -1)
414
+ elif images.ndim == 3:
415
+ images = np.repeat(images[:, :, :, np.newaxis], 3, axis=-1)
416
+ else:
417
+ raise ValueError(
418
+ f"Expected a (N,C,H,W) or a (N, H, W) set of images, but got a {images.ndim}-dimensional set of images."
419
+ )
420
+
421
+ rows = int(np.ceil(num_images / 3))
422
+ fig, axs = plt.subplots(rows, 3, figsize=(9, 3 * rows))
423
+
424
+ if rows == 1:
425
+ for j in range(3):
426
+ if j >= len(images):
427
+ continue
428
+ axs[j].imshow(images[j])
429
+ axs[j].axis("off")
430
+ else:
431
+ for i in range(rows):
432
+ for j in range(3):
433
+ i_j = i * 3 + j
434
+ if i_j >= len(images):
435
+ continue
436
+ axs[i, j].imshow(images[i_j])
437
+ axs[i, j].axis("off")
438
+
439
+ fig.tight_layout()
440
+ return fig