dataeval 0.74.2__py3-none-any.whl → 0.76.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. dataeval/__init__.py +27 -23
  2. dataeval/detectors/__init__.py +2 -2
  3. dataeval/detectors/drift/__init__.py +14 -12
  4. dataeval/detectors/drift/base.py +3 -3
  5. dataeval/detectors/drift/cvm.py +1 -1
  6. dataeval/detectors/drift/ks.py +3 -2
  7. dataeval/detectors/drift/mmd.py +9 -7
  8. dataeval/detectors/drift/torch.py +12 -12
  9. dataeval/detectors/drift/uncertainty.py +5 -4
  10. dataeval/detectors/drift/updates.py +1 -1
  11. dataeval/detectors/linters/__init__.py +4 -4
  12. dataeval/detectors/linters/clusterer.py +5 -9
  13. dataeval/detectors/linters/duplicates.py +10 -14
  14. dataeval/detectors/linters/outliers.py +100 -5
  15. dataeval/detectors/ood/__init__.py +4 -11
  16. dataeval/detectors/ood/{ae_torch.py → ae.py} +6 -4
  17. dataeval/detectors/ood/base.py +47 -160
  18. dataeval/detectors/ood/metadata_ks_compare.py +34 -42
  19. dataeval/detectors/ood/metadata_least_likely.py +3 -3
  20. dataeval/detectors/ood/metadata_ood_mi.py +6 -5
  21. dataeval/detectors/ood/mixin.py +146 -0
  22. dataeval/detectors/ood/output.py +63 -0
  23. dataeval/interop.py +7 -6
  24. dataeval/{logging.py → log.py} +2 -0
  25. dataeval/metrics/__init__.py +3 -3
  26. dataeval/metrics/bias/__init__.py +10 -13
  27. dataeval/metrics/bias/balance.py +13 -11
  28. dataeval/metrics/bias/coverage.py +53 -5
  29. dataeval/metrics/bias/diversity.py +56 -24
  30. dataeval/metrics/bias/parity.py +20 -17
  31. dataeval/metrics/estimators/__init__.py +2 -2
  32. dataeval/metrics/estimators/ber.py +7 -4
  33. dataeval/metrics/estimators/divergence.py +4 -4
  34. dataeval/metrics/estimators/uap.py +4 -4
  35. dataeval/metrics/stats/__init__.py +19 -19
  36. dataeval/metrics/stats/base.py +28 -12
  37. dataeval/metrics/stats/boxratiostats.py +13 -14
  38. dataeval/metrics/stats/datasetstats.py +49 -20
  39. dataeval/metrics/stats/dimensionstats.py +8 -8
  40. dataeval/metrics/stats/hashstats.py +14 -10
  41. dataeval/metrics/stats/labelstats.py +94 -11
  42. dataeval/metrics/stats/pixelstats.py +11 -14
  43. dataeval/metrics/stats/visualstats.py +10 -13
  44. dataeval/output.py +23 -14
  45. dataeval/utils/__init__.py +5 -14
  46. dataeval/utils/dataset/__init__.py +7 -0
  47. dataeval/utils/{torch → dataset}/datasets.py +2 -0
  48. dataeval/utils/dataset/read.py +63 -0
  49. dataeval/utils/{split_dataset.py → dataset/split.py} +38 -30
  50. dataeval/utils/image.py +2 -2
  51. dataeval/utils/metadata.py +317 -14
  52. dataeval/{metrics/bias/metadata_utils.py → utils/plot.py} +91 -71
  53. dataeval/utils/torch/__init__.py +2 -17
  54. dataeval/utils/torch/gmm.py +29 -6
  55. dataeval/utils/torch/{utils.py → internal.py} +82 -58
  56. dataeval/utils/torch/models.py +10 -8
  57. dataeval/utils/torch/trainer.py +6 -85
  58. dataeval/workflows/__init__.py +2 -5
  59. dataeval/workflows/sufficiency.py +18 -8
  60. {dataeval-0.74.2.dist-info → dataeval-0.76.0.dist-info}/LICENSE.txt +2 -2
  61. dataeval-0.76.0.dist-info/METADATA +137 -0
  62. dataeval-0.76.0.dist-info/RECORD +67 -0
  63. dataeval/detectors/ood/base_torch.py +0 -109
  64. dataeval/metrics/bias/metadata_preprocessing.py +0 -285
  65. dataeval/utils/gmm.py +0 -26
  66. dataeval-0.74.2.dist-info/METADATA +0 -120
  67. dataeval-0.74.2.dist-info/RECORD +0 -66
  68. {dataeval-0.74.2.dist-info → dataeval-0.76.0.dist-info}/WHEEL +0 -0
@@ -1,12 +1,26 @@
1
+ """
2
+ Metadata related utility functions that help organize raw metadata into \
3
+ :class:`Metadata` objects for use within `DataEval`.
4
+ """
5
+
1
6
  from __future__ import annotations
2
7
 
3
- __all__ = ["merge_metadata"]
8
+ __all__ = ["Metadata", "preprocess", "merge", "flatten"]
4
9
 
5
10
  import warnings
6
- from typing import Any, Iterable, Mapping, TypeVar, overload
11
+ from dataclasses import dataclass
12
+ from typing import Any, Iterable, Literal, Mapping, TypeVar, overload
7
13
 
8
14
  import numpy as np
9
- from numpy.typing import NDArray
15
+ from numpy.typing import ArrayLike, NDArray
16
+ from scipy.stats import wasserstein_distance as wd
17
+
18
+ from dataeval.interop import as_numpy, to_numpy
19
+ from dataeval.output import Output, set_metadata
20
+
21
+ DISCRETE_MIN_WD = 0.054
22
+ CONTINUOUS_MIN_SAMPLE_SIZE = 20
23
+
10
24
 
11
25
  T = TypeVar("T")
12
26
 
@@ -131,9 +145,7 @@ def _flatten_dict_inner(
131
145
  return items, size
132
146
 
133
147
 
134
- def _flatten_dict(
135
- d: Mapping[str, Any], sep: str, ignore_lists: bool, fully_qualified: bool
136
- ) -> tuple[dict[str, Any], int]:
148
+ def flatten(d: Mapping[str, Any], sep: str, ignore_lists: bool, fully_qualified: bool) -> tuple[dict[str, Any], int]:
137
149
  """
138
150
  Flattens a dictionary and converts values to numeric values when possible.
139
151
 
@@ -146,12 +158,12 @@ def _flatten_dict(
146
158
  ignore_lists : bool
147
159
  Option to skip expanding lists within metadata
148
160
  fully_qualified : bool
149
- Option to return dictionary keys full qualified instead of minimized
161
+ Option to return dictionary keys full qualified instead of reduced
150
162
 
151
163
  Returns
152
164
  -------
153
- dict[str, Any]
154
- A flattened dictionary
165
+ tuple[dict[str, Any], int]
166
+ A tuple of the flattened dictionary and the length of detected lists in metadata
155
167
  """
156
168
  expanded, size = _flatten_dict_inner(d, parent_keys=(), nested=ignore_lists)
157
169
 
@@ -185,7 +197,7 @@ def _is_metadata_dict_of_dicts(metadata: Mapping) -> bool:
185
197
  return set(metadata[keys[0]]) == set(metadata[keys[1]])
186
198
 
187
199
 
188
- def merge_metadata(
200
+ def merge(
189
201
  metadata: Iterable[Mapping[str, Any]],
190
202
  ignore_lists: bool = False,
191
203
  fully_qualified: bool = False,
@@ -222,7 +234,7 @@ def merge_metadata(
222
234
  Example
223
235
  -------
224
236
  >>> list_metadata = [{"common": 1, "target": [{"a": 1, "b": 3, "c": 5}, {"a": 2, "b": 4}], "source": "example"}]
225
- >>> reorganized_metadata, image_indicies = merge_metadata(list_metadata)
237
+ >>> reorganized_metadata, image_indicies = merge(list_metadata)
226
238
  >>> reorganized_metadata
227
239
  {'common': [1, 1], 'a': [1, 2], 'b': [3, 4], 'source': ['example', 'example']}
228
240
  >>> image_indicies
@@ -245,9 +257,7 @@ def merge_metadata(
245
257
 
246
258
  image_repeats = np.zeros(len(dicts))
247
259
  for i, d in enumerate(dicts):
248
- flattened, image_repeats[i] = _flatten_dict(
249
- d, sep="_", ignore_lists=ignore_lists, fully_qualified=fully_qualified
250
- )
260
+ flattened, image_repeats[i] = flatten(d, sep="_", ignore_lists=ignore_lists, fully_qualified=fully_qualified)
251
261
  isect = isect.intersection(flattened.keys()) if isect else set(flattened.keys())
252
262
  union = union.union(flattened.keys())
253
263
  for k, v in flattened.items():
@@ -276,3 +286,296 @@ def merge_metadata(
276
286
  output[k] = np.array(cv) if as_numpy else cv
277
287
 
278
288
  return output, image_indicies
289
+
290
+
291
+ @dataclass(frozen=True)
292
+ class Metadata(Output):
293
+ """
294
+ Dataclass containing binned metadata from the :func:`preprocess` function.
295
+
296
+ Attributes
297
+ ----------
298
+ discrete_factor_names : list[str]
299
+ List containing factor names for the original data that was discrete and the binned continuous data
300
+ discrete_data : NDArray[np.int]
301
+ Array containing values for the original data that was discrete and the binned continuous data
302
+ continuous_factor_names : list[str]
303
+ List containing factor names for the original continuous data
304
+ continuous_data : NDArray[np.int or np.double] | None
305
+ Array containing values for the original continuous data or None if there was no continuous data
306
+ class_labels : NDArray[np.int]
307
+ Numerical class labels for the images/objects
308
+ class_names : NDArray[Any]
309
+ Array of unique class names (for use with plotting)
310
+ total_num_factors : int
311
+ Sum of discrete_factor_names and continuous_factor_names plus 1 for class
312
+ """
313
+
314
+ discrete_factor_names: list[str]
315
+ discrete_data: NDArray[np.int_]
316
+ continuous_factor_names: list[str]
317
+ continuous_data: NDArray[np.int_ | np.double] | None
318
+ class_labels: NDArray[np.int_]
319
+ class_names: NDArray[Any]
320
+ total_num_factors: int
321
+
322
+
323
+ @set_metadata
324
+ def preprocess(
325
+ raw_metadata: Iterable[Mapping[str, Any]],
326
+ class_labels: ArrayLike | str,
327
+ continuous_factor_bins: Mapping[str, int | Iterable[float]] | None = None,
328
+ auto_bin_method: Literal["uniform_width", "uniform_count", "clusters"] = "uniform_width",
329
+ exclude: Iterable[str] | None = None,
330
+ ) -> Metadata:
331
+ """
332
+ Restructures the metadata to be in the correct format for the bias functions.
333
+
334
+ This identifies whether the incoming metadata is discrete or continuous,
335
+ and whether the data is already binned or still needs binning.
336
+ It accepts a list of dictionaries containing the per image metadata and
337
+ automatically adjusts for multiple targets in an image.
338
+
339
+ Parameters
340
+ ----------
341
+ raw_metadata : Iterable[Mapping[str, Any]]
342
+ Iterable collection of metadata dictionaries to flatten and merge.
343
+ class_labels : ArrayLike or string
344
+ If arraylike, expects the labels for each image (image classification) or each object (object detection).
345
+ If the labels are included in the metadata dictionary, pass in the key value.
346
+ continuous_factor_bins : Mapping[str, int or Iterable[float]] or None, default None
347
+ User provided dictionary specifying how to bin the continuous metadata factors where the value is either
348
+ an int to represent the number of bins, or a list of floats representing the edges for each bin.
349
+ auto_bin_method : "uniform_width" or "uniform_count" or "clusters", default "uniform_width"
350
+ Method by which the function will automatically bin continuous metadata factors. It is recommended
351
+ that the user provide the bins through the `continuous_factor_bins`.
352
+ exclude : Iterable[str] or None, default None
353
+ User provided collection of metadata keys to exclude when processing metadata.
354
+
355
+ Returns
356
+ -------
357
+ Metadata
358
+ Output class containing the binned metadata
359
+ """
360
+ # Transform metadata into single, flattened dictionary
361
+ metadata, image_repeats = merge(raw_metadata)
362
+
363
+ continuous_factor_bins = dict(continuous_factor_bins) if continuous_factor_bins else None
364
+
365
+ # Drop any excluded metadata keys
366
+ for k in exclude or ():
367
+ metadata.pop(k, None)
368
+ if continuous_factor_bins:
369
+ continuous_factor_bins.pop(k, None)
370
+
371
+ # Get the class label array in numeric form
372
+ class_array = as_numpy(metadata.pop(class_labels)) if isinstance(class_labels, str) else as_numpy(class_labels)
373
+ if class_array.ndim > 1:
374
+ raise ValueError(
375
+ f"Got class labels with {class_array.ndim}-dimensional "
376
+ f"shape {class_array.shape}, but expected a 1-dimensional array."
377
+ )
378
+ if not np.issubdtype(class_array.dtype, np.int_):
379
+ unique_classes, numerical_labels = np.unique(class_array, return_inverse=True)
380
+ else:
381
+ numerical_labels = class_array
382
+ unique_classes = np.unique(class_array)
383
+
384
+ # Bin according to user supplied bins
385
+ continuous_metadata = {}
386
+ discrete_metadata = {}
387
+ if continuous_factor_bins is not None and continuous_factor_bins != {}:
388
+ invalid_keys = set(continuous_factor_bins.keys()) - set(metadata.keys())
389
+ if invalid_keys:
390
+ raise KeyError(
391
+ f"The keys - {invalid_keys} - are present in the `continuous_factor_bins` dictionary "
392
+ "but are not keys in the `metadata` dictionary. Delete these keys from `continuous_factor_bins` "
393
+ "or add corresponding entries to the `metadata` dictionary."
394
+ )
395
+ for factor, bins in continuous_factor_bins.items():
396
+ discrete_metadata[factor] = _digitize_data(metadata[factor], bins)
397
+ continuous_metadata[factor] = metadata[factor]
398
+
399
+ # Determine category of the rest of the keys
400
+ remaining_keys = set(metadata.keys()) - set(continuous_metadata.keys())
401
+ for key in remaining_keys:
402
+ data = to_numpy(metadata[key])
403
+ if np.issubdtype(data.dtype, np.number):
404
+ result = _is_continuous(data, image_repeats)
405
+ if result:
406
+ continuous_metadata[key] = data
407
+ unique_samples, ordinal_data = np.unique(data, return_inverse=True)
408
+ if unique_samples.size <= np.max([20, data.size * 0.01]):
409
+ discrete_metadata[key] = ordinal_data
410
+ else:
411
+ warnings.warn(
412
+ f"A user defined binning was not provided for {key}. "
413
+ f"Using the {auto_bin_method} method to discretize the data. "
414
+ "It is recommended that the user rerun and supply the desired "
415
+ "bins using the continuous_factor_bins parameter.",
416
+ UserWarning,
417
+ )
418
+ discrete_metadata[key] = _bin_data(data, auto_bin_method)
419
+ else:
420
+ _, discrete_metadata[key] = np.unique(data, return_inverse=True)
421
+
422
+ # splitting out the dictionaries into the keys and values
423
+ discrete_factor_names = list(discrete_metadata.keys())
424
+ discrete_data = np.stack(list(discrete_metadata.values()), axis=-1)
425
+ continuous_factor_names = list(continuous_metadata.keys())
426
+ continuous_data = np.stack(list(continuous_metadata.values()), axis=-1) if continuous_metadata else None
427
+ total_num_factors = len(discrete_factor_names + continuous_factor_names) + 1
428
+
429
+ return Metadata(
430
+ discrete_factor_names,
431
+ discrete_data,
432
+ continuous_factor_names,
433
+ continuous_data,
434
+ numerical_labels,
435
+ unique_classes,
436
+ total_num_factors,
437
+ )
438
+
439
+
440
+ def _digitize_data(data: list[Any] | NDArray[Any], bins: int | Iterable[float]) -> NDArray[np.intp]:
441
+ """
442
+ Digitizes a list of values into a given number of bins.
443
+
444
+ Parameters
445
+ ----------
446
+ data : list | NDArray
447
+ The values to be digitized.
448
+ bins : int | Iterable[float]
449
+ The number of bins or list of bin edges for the discrete values that data will be digitized into.
450
+
451
+ Returns
452
+ -------
453
+ NDArray[np.intp]
454
+ The digitized values
455
+ """
456
+
457
+ if not np.all([np.issubdtype(type(n), np.number) for n in data]):
458
+ raise TypeError(
459
+ "Encountered a data value with non-numeric type when digitizing a factor. "
460
+ "Ensure all occurrences of continuous factors are numeric types."
461
+ )
462
+ if isinstance(bins, int):
463
+ _, bin_edges = np.histogram(data, bins=bins)
464
+ bin_edges[-1] = np.inf
465
+ bin_edges[0] = -np.inf
466
+ else:
467
+ bin_edges = list(bins)
468
+ return np.digitize(data, bin_edges)
469
+
470
+
471
+ def _bin_data(data: NDArray[Any], bin_method: str) -> NDArray[np.int_]:
472
+ """
473
+ Bins continuous data through either equal width bins, equal amounts in each bin, or by clusters.
474
+ """
475
+ if bin_method == "clusters":
476
+ # bin_edges = _binning_by_clusters(data)
477
+ warnings.warn(
478
+ "Binning by clusters is currently unavailable until changes to the clustering function go through.",
479
+ UserWarning,
480
+ )
481
+ bin_method = "uniform_width"
482
+
483
+ # if bin_method != "clusters": # restore this when clusters bin_method is available
484
+ counts, bin_edges = np.histogram(data, bins="auto")
485
+ n_bins = counts.size
486
+ if counts[counts > 0].min() < 10:
487
+ counter = 20
488
+ while counts[counts > 0].min() < 10 and n_bins >= 2 and counter > 0:
489
+ counter -= 1
490
+ n_bins -= 1
491
+ counts, bin_edges = np.histogram(data, bins=n_bins)
492
+
493
+ if bin_method == "uniform_count":
494
+ quantiles = np.linspace(0, 100, n_bins + 1)
495
+ bin_edges = np.asarray(np.percentile(data, quantiles))
496
+
497
+ bin_edges[0] = -np.inf # type: ignore # until the clusters speed up is merged
498
+ bin_edges[-1] = np.inf # type: ignore # and the _binning_by_clusters can be uncommented
499
+ return np.digitize(data, bin_edges) # type: ignore
500
+
501
+
502
+ def _is_continuous(data: NDArray[np.number], image_indicies: NDArray[np.number]) -> bool:
503
+ """
504
+ Determines whether the data is continuous or discrete using the Wasserstein distance.
505
+
506
+ Given a 1D sample, we consider the intervals between adjacent points. For a continuous distribution,
507
+ a point is equally likely to lie anywhere in the interval bounded by its two neighbors. Furthermore,
508
+ we can put all "between neighbor" locations on the same scale of 0 to 1 by subtracting the smaller
509
+ neighbor and dividing out the length of the interval. (Duplicates are either assigned to zero or
510
+ ignored, depending on context). These normalized locations will be much more uniformly distributed
511
+ for continuous data than for discrete, and this gives us a way to distinguish them. Call this the
512
+ Normalized Near Neighbor distribution (NNN), defined on the interval [0,1].
513
+
514
+ The Wasserstein distance is available in scipy.stats.wasserstein_distance. We can use it to measure
515
+ how close the NNN is to a uniform distribution over [0,1]. We found that as long as a sample has at
516
+ least 20 points, and furthermore at least half as many points as there are discrete values, we can
517
+ reliably distinguish discrete from continuous samples by testing that the Wasserstein distance
518
+ measured from a uniform distribution is greater or less than 0.054, respectively.
519
+ """
520
+ # Check if the metadata is image specific
521
+ _, data_indicies_unsorted = np.unique(data, return_index=True)
522
+ if data_indicies_unsorted.size == image_indicies.size:
523
+ data_indicies = np.sort(data_indicies_unsorted)
524
+ if (data_indicies == image_indicies).all():
525
+ data = data[data_indicies]
526
+
527
+ # OLD METHOD
528
+ # uvals = np.unique(data)
529
+ # pct_unique = uvals.size / data.size
530
+ # return pct_unique < threshold
531
+
532
+ n_examples = len(data)
533
+
534
+ if n_examples < CONTINUOUS_MIN_SAMPLE_SIZE:
535
+ warnings.warn(
536
+ f"All samples look discrete with so few data points (< {CONTINUOUS_MIN_SAMPLE_SIZE})", UserWarning
537
+ )
538
+ return False
539
+
540
+ # Require at least 3 unique values before bothering with NNN
541
+ xu = np.unique(data, axis=None)
542
+ if xu.size < 3:
543
+ return False
544
+
545
+ Xs = np.sort(data)
546
+
547
+ X0, X1 = Xs[0:-2], Xs[2:] # left and right neighbors
548
+
549
+ dx = np.zeros(n_examples - 2) # no dx at end points
550
+ gtz = (X1 - X0) > 0 # check for dups; dx will be zero for them
551
+ dx[np.logical_not(gtz)] = 0.0
552
+
553
+ dx[gtz] = (Xs[1:-1] - X0)[gtz] / (X1 - X0)[gtz] # the core idea: dx is NNN samples.
554
+
555
+ shift = wd(dx, np.linspace(0, 1, dx.size)) # how far is dx from uniform, for this feature?
556
+
557
+ return shift < DISCRETE_MIN_WD # if NNN is close enough to uniform, consider the sample continuous.
558
+
559
+
560
+ def get_counts(data: NDArray[np.int_], min_num_bins: int | None = None) -> NDArray[np.int_]:
561
+ """
562
+ Returns columnwise unique counts for discrete data.
563
+
564
+ Parameters
565
+ ----------
566
+ data : NDArray
567
+ Array containing integer values for metadata factors
568
+ min_num_bins : int | None, default None
569
+ Minimum number of bins for bincount, helps force consistency across runs
570
+
571
+ Returns
572
+ -------
573
+ NDArray[np.int_]
574
+ Bin counts per column of data.
575
+ """
576
+ max_value = data.max() + 1 if min_num_bins is None else min_num_bins
577
+ cnt_array = np.zeros((max_value, data.shape[1]), dtype=np.int_)
578
+ for idx in range(data.shape[1]):
579
+ cnt_array[:, idx] = np.bincount(data[:, idx], minlength=max_value)
580
+
581
+ return cnt_array
@@ -6,7 +6,7 @@ import contextlib
6
6
  from typing import Any
7
7
 
8
8
  import numpy as np
9
- from numpy.typing import ArrayLike, NDArray
9
+ from numpy.typing import ArrayLike
10
10
 
11
11
  from dataeval.interop import to_numpy
12
12
 
@@ -14,30 +14,6 @@ with contextlib.suppress(ImportError):
14
14
  from matplotlib.figure import Figure
15
15
 
16
16
 
17
- def get_counts(data: NDArray[np.int_], min_num_bins: int | None = None) -> NDArray[np.int_]:
18
- """
19
- Returns columnwise unique counts for discrete data.
20
-
21
- Parameters
22
- ----------
23
- data : NDArray
24
- Array containing integer values for metadata factors
25
- min_num_bins : int | None, default None
26
- Minimum number of bins for bincount, helps force consistency across runs
27
-
28
- Returns
29
- -------
30
- NDArray[np.int_]
31
- Bin counts per column of data.
32
- """
33
- max_value = data.max() + 1 if min_num_bins is None else min_num_bins
34
- cnt_array = np.zeros((max_value, data.shape[1]), dtype=np.int_)
35
- for idx in range(data.shape[1]):
36
- cnt_array[:, idx] = np.bincount(data[:, idx], minlength=max_value)
37
-
38
- return cnt_array
39
-
40
-
41
17
  def heatmap(
42
18
  data: ArrayLike,
43
19
  row_labels: list[str] | ArrayLike,
@@ -95,12 +71,17 @@ def heatmap(
95
71
  # Rotate the tick labels and set their alignment.
96
72
  plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
97
73
 
98
- # Turn spines off and create white grid.
99
- ax.spines[:].set_visible(False)
100
-
101
- ax.set_xticks(np.arange(np_data.shape[1] + 1) - 0.5, minor=True)
102
- ax.set_yticks(np.arange(np_data.shape[0] + 1) - 0.5, minor=True)
103
- ax.grid(which="minor", color="w", linestyle="-", linewidth=3)
74
+ light_gray = "0.9"
75
+ # Turn spines on and create light gray easily visible grid.
76
+ for spine in ax.spines.values():
77
+ spine.set_visible(True)
78
+ spine.set_color(light_gray)
79
+
80
+ xticks = np.arange(np_data.shape[1] + 1) - 0.5
81
+ yticks = np.arange(np_data.shape[0] + 1) - 0.5
82
+ ax.set_xticks(xticks, minor=True)
83
+ ax.set_yticks(yticks, minor=True)
84
+ ax.grid(which="minor", color=light_gray, linestyle="-", linewidth=3)
104
85
  ax.tick_params(which="minor", bottom=False, left=False)
105
86
 
106
87
  if xlabel:
@@ -151,79 +132,118 @@ def format_text(*args: str) -> str:
151
132
  return f"{x:.2f}".replace("0.00", "0").replace("0.", ".").replace("nan", "")
152
133
 
153
134
 
154
- def diversity_bar_plot(labels: NDArray[Any], bar_heights: NDArray[Any]) -> Figure:
135
+ def histogram_plot(
136
+ data_dict: dict[str, Any],
137
+ log: bool = True,
138
+ xlabel: str = "values",
139
+ ylabel: str = "counts",
140
+ ) -> Figure:
155
141
  """
156
- Plots a formatted bar plot
142
+ Plots a formatted histogram
157
143
 
158
144
  Parameters
159
145
  ----------
160
- labels : NDArray
161
- Array containing the labels for each bar
162
- bar_heights : NDArray
163
- Array containing the values for each bar
146
+ data_dict : dict
147
+ Dictionary containing the metrics and their value arrays
148
+ log : bool, default True
149
+ If True, plots the histogram on a semi-log scale (y axis)
150
+ xlabel : str, default "values"
151
+ X-axis label
152
+ ylabel : str, default "counts"
153
+ Y-axis label
164
154
 
165
155
  Returns
166
156
  -------
167
157
  matplotlib.figure.Figure
168
- Bar plot figure
158
+ Formatted plot of histograms
169
159
  """
170
160
  import matplotlib.pyplot as plt
171
161
 
172
- fig, ax = plt.subplots(figsize=(10, 10))
162
+ num_metrics = len(data_dict)
163
+ if num_metrics > 2:
164
+ rows = int(len(data_dict) / 3)
165
+ fig, axs = plt.subplots(rows, 3, figsize=(10, rows * 2.5))
166
+ else:
167
+ fig, axs = plt.subplots(1, num_metrics, figsize=(4 * num_metrics, 4))
173
168
 
174
- ax.bar(labels, bar_heights)
175
- ax.set_xlabel("Factors")
169
+ for ax, metric in zip(
170
+ axs.flat,
171
+ data_dict,
172
+ ):
173
+ # Plot the histogram for the chosen metric
174
+ ax.hist(data_dict[metric], bins=20, log=log)
176
175
 
177
- plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
176
+ # Add labels to the histogram
177
+ ax.set_title(metric)
178
+ ax.set_ylabel(ylabel)
179
+ ax.set_xlabel(xlabel)
178
180
 
179
181
  fig.tight_layout()
180
182
  return fig
181
183
 
182
184
 
183
- def coverage_plot(images: NDArray[Any], num_images: int) -> Figure:
185
+ def channel_histogram_plot(
186
+ data_dict: dict[str, Any],
187
+ log: bool = True,
188
+ max_channels: int = 3,
189
+ ch_mask: list[bool] | None = None,
190
+ xlabel: str = "values",
191
+ ylabel: str = "counts",
192
+ ) -> Figure:
184
193
  """
185
- Creates a single plot of all of the provided images
194
+ Plots a formatted heatmap
186
195
 
187
196
  Parameters
188
197
  ----------
189
- images : NDArray
190
- Array containing only the desired images to plot
198
+ data_dict : dict
199
+ Dictionary containing the metrics and their value arrays
200
+ log : bool, default True
201
+ If True, plots the histogram on a semi-log scale (y axis)
202
+ xlabel : str, default "values"
203
+ X-axis label
204
+ ylabel : str, default "counts"
205
+ Y-axis label
191
206
 
192
207
  Returns
193
208
  -------
194
209
  matplotlib.figure.Figure
195
- Plot of all provided images
210
+ Formatted plot of histograms
196
211
  """
197
212
  import matplotlib.pyplot as plt
198
213
 
199
- num_images = min(num_images, len(images))
214
+ channelwise_metrics = ["mean", "std", "var", "skew", "zeros", "brightness", "contrast", "darkness", "entropy"]
215
+ data_keys = [key for key in data_dict if key in channelwise_metrics]
216
+ label_kwargs = {"label": [f"Channel {i}" for i in range(max_channels)]}
200
217
 
201
- if images.ndim == 4:
202
- images = np.moveaxis(images, 1, -1)
203
- elif images.ndim == 3:
204
- images = np.repeat(images[:, :, :, np.newaxis], 3, axis=-1)
218
+ num_metrics = len(data_keys)
219
+ if num_metrics > 2:
220
+ rows = int(len(data_keys) / 3)
221
+ fig, axs = plt.subplots(rows, 3, figsize=(10, rows * 2.5))
205
222
  else:
206
- raise ValueError(
207
- f"Expected a (N,C,H,W) or a (N, H, W) set of images, but got a {images.ndim}-dimensional set of images."
223
+ fig, axs = plt.subplots(1, num_metrics, figsize=(4 * num_metrics, 4))
224
+
225
+ for ax, metric in zip(
226
+ axs.flat,
227
+ data_keys,
228
+ ):
229
+ # Plot the histogram for the chosen metric
230
+ data = data_dict[metric][ch_mask].reshape(-1, max_channels)
231
+ ax.hist(
232
+ data,
233
+ bins=20,
234
+ density=True,
235
+ log=log,
236
+ **label_kwargs,
208
237
  )
238
+ # Only plot the labels once for channels
239
+ if label_kwargs:
240
+ ax.legend()
241
+ label_kwargs = {}
209
242
 
210
- rows = int(np.ceil(num_images / 3))
211
- fig, axs = plt.subplots(rows, 3, figsize=(9, 3 * rows))
212
-
213
- if rows == 1:
214
- for j in range(3):
215
- if j >= len(images):
216
- continue
217
- axs[j].imshow(images[j])
218
- axs[j].axis("off")
219
- else:
220
- for i in range(rows):
221
- for j in range(3):
222
- i_j = i * 3 + j
223
- if i_j >= len(images):
224
- continue
225
- axs[i, j].imshow(images[i_j])
226
- axs[i, j].axis("off")
243
+ # Add labels to the histogram
244
+ ax.set_title(metric)
245
+ ax.set_ylabel(ylabel)
246
+ ax.set_xlabel(xlabel)
227
247
 
228
248
  fig.tight_layout()
229
249
  return fig
@@ -5,21 +5,6 @@ While these metrics can take in custom models, DataEval provides utility classes
5
5
  to create a seamless integration between custom models and DataEval's metrics.
6
6
  """
7
7
 
8
- from dataeval import _IS_TORCH_AVAILABLE, _IS_TORCHVISION_AVAILABLE
8
+ __all__ = ["models", "trainer"]
9
9
 
10
- __all__ = []
11
-
12
- if _IS_TORCH_AVAILABLE:
13
- from dataeval.utils.torch import models, trainer
14
- from dataeval.utils.torch.utils import read_dataset
15
-
16
- __all__ += ["read_dataset", "models", "trainer"]
17
-
18
- if _IS_TORCHVISION_AVAILABLE:
19
- from dataeval.utils.torch import datasets
20
-
21
- __all__ += ["datasets"]
22
-
23
-
24
- del _IS_TORCH_AVAILABLE
25
- del _IS_TORCHVISION_AVAILABLE
10
+ from dataeval.utils.torch import models, trainer
@@ -1,7 +1,5 @@
1
1
  """
2
- Adapted for Pytorch from:
3
-
4
- Source code derived from Alibi-Detect 0.11.4
2
+ Adapted for Pytorch from Alibi-Detect 0.11.4
5
3
  https://github.com/SeldonIO/alibi-detect/tree/v0.11.4
6
4
 
7
5
  Original code Copyright (c) 2023 Seldon Technologies Ltd
@@ -10,13 +8,38 @@ Licensed under Apache Software License (Apache 2.0)
10
8
 
11
9
  from __future__ import annotations
12
10
 
11
+ __all__ = []
12
+
13
+ from typing import NamedTuple, TypeVar
14
+
13
15
  import numpy as np
14
16
  import torch
15
17
 
16
- from dataeval.utils.gmm import GaussianMixtureModelParams
18
+ TGMMData = TypeVar("TGMMData")
19
+
20
+
21
+ class GaussianMixtureModelParams(NamedTuple):
22
+ """
23
+ phi : torch.Tensor
24
+ Mixture component distribution weights.
25
+ mu : torch.Tensor
26
+ Mixture means.
27
+ cov : torch.Tensor
28
+ Mixture covariance.
29
+ L : torch.Tensor
30
+ Cholesky decomposition of `cov`.
31
+ log_det_cov : torch.Tensor
32
+ Log of the determinant of `cov`.
33
+ """
34
+
35
+ phi: torch.Tensor
36
+ mu: torch.Tensor
37
+ cov: torch.Tensor
38
+ L: torch.Tensor
39
+ log_det_cov: torch.Tensor
17
40
 
18
41
 
19
- def gmm_params(z: torch.Tensor, gamma: torch.Tensor) -> GaussianMixtureModelParams[torch.Tensor]:
42
+ def gmm_params(z: torch.Tensor, gamma: torch.Tensor) -> GaussianMixtureModelParams:
20
43
  """
21
44
  Compute parameters of Gaussian Mixture Model.
22
45
 
@@ -58,7 +81,7 @@ def gmm_params(z: torch.Tensor, gamma: torch.Tensor) -> GaussianMixtureModelPara
58
81
 
59
82
  def gmm_energy(
60
83
  z: torch.Tensor,
61
- params: GaussianMixtureModelParams[torch.Tensor],
84
+ params: GaussianMixtureModelParams,
62
85
  return_mean: bool = True,
63
86
  ) -> tuple[torch.Tensor, torch.Tensor]:
64
87
  """