dataeval 0.72.1__py3-none-any.whl → 0.72.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. dataeval/__init__.py +4 -4
  2. dataeval/detectors/__init__.py +4 -3
  3. dataeval/detectors/drift/__init__.py +9 -10
  4. dataeval/{_internal/detectors → detectors}/drift/base.py +39 -91
  5. dataeval/{_internal/detectors → detectors}/drift/cvm.py +4 -3
  6. dataeval/{_internal/detectors → detectors}/drift/ks.py +4 -3
  7. dataeval/{_internal/detectors → detectors}/drift/mmd.py +23 -25
  8. dataeval/{_internal/detectors → detectors}/drift/torch.py +13 -11
  9. dataeval/{_internal/detectors → detectors}/drift/uncertainty.py +7 -5
  10. dataeval/detectors/drift/updates.py +61 -0
  11. dataeval/detectors/linters/__init__.py +3 -3
  12. dataeval/{_internal/detectors → detectors/linters}/clusterer.py +41 -39
  13. dataeval/{_internal/detectors → detectors/linters}/duplicates.py +19 -9
  14. dataeval/{_internal/detectors → detectors/linters}/merged_stats.py +3 -1
  15. dataeval/{_internal/detectors → detectors/linters}/outliers.py +14 -21
  16. dataeval/detectors/ood/__init__.py +6 -6
  17. dataeval/{_internal/detectors → detectors}/ood/ae.py +7 -7
  18. dataeval/{_internal/detectors → detectors}/ood/aegmm.py +9 -29
  19. dataeval/{_internal/detectors → detectors}/ood/base.py +24 -18
  20. dataeval/{_internal/detectors → detectors}/ood/llr.py +24 -20
  21. dataeval/detectors/ood/metadata_ks_compare.py +99 -0
  22. dataeval/detectors/ood/metadata_least_likely.py +119 -0
  23. dataeval/detectors/ood/metadata_ood_mi.py +92 -0
  24. dataeval/{_internal/detectors → detectors}/ood/vae.py +10 -12
  25. dataeval/{_internal/detectors → detectors}/ood/vaegmm.py +10 -32
  26. dataeval/{_internal/interop.py → interop.py} +12 -7
  27. dataeval/metrics/__init__.py +1 -1
  28. dataeval/metrics/bias/__init__.py +4 -4
  29. dataeval/{_internal/metrics → metrics/bias}/balance.py +75 -9
  30. dataeval/{_internal/metrics → metrics/bias}/coverage.py +6 -4
  31. dataeval/{_internal/metrics → metrics/bias}/diversity.py +48 -14
  32. dataeval/metrics/bias/metadata.py +275 -0
  33. dataeval/{_internal/metrics → metrics/bias}/parity.py +12 -10
  34. dataeval/metrics/estimators/__init__.py +3 -3
  35. dataeval/{_internal/metrics → metrics/estimators}/ber.py +25 -22
  36. dataeval/{_internal/metrics → metrics/estimators}/divergence.py +11 -12
  37. dataeval/{_internal/metrics → metrics/estimators}/uap.py +5 -3
  38. dataeval/metrics/stats/__init__.py +7 -7
  39. dataeval/{_internal/metrics → metrics}/stats/base.py +59 -35
  40. dataeval/{_internal/metrics → metrics}/stats/boxratiostats.py +18 -14
  41. dataeval/{_internal/metrics → metrics}/stats/datasetstats.py +18 -16
  42. dataeval/{_internal/metrics → metrics}/stats/dimensionstats.py +9 -7
  43. dataeval/metrics/stats/hashstats.py +156 -0
  44. dataeval/{_internal/metrics → metrics}/stats/labelstats.py +5 -3
  45. dataeval/{_internal/metrics → metrics}/stats/pixelstats.py +9 -8
  46. dataeval/{_internal/metrics → metrics}/stats/visualstats.py +10 -9
  47. dataeval/{_internal/output.py → output.py} +26 -6
  48. dataeval/utils/__init__.py +7 -3
  49. dataeval/utils/image.py +71 -0
  50. dataeval/utils/shared.py +151 -0
  51. dataeval/{_internal → utils}/split_dataset.py +98 -33
  52. dataeval/utils/tensorflow/__init__.py +7 -6
  53. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/autoencoder.py +60 -64
  54. dataeval/{_internal/models/tensorflow/losses.py → utils/tensorflow/_internal/loss.py} +9 -8
  55. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/pixelcnn.py +16 -20
  56. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/trainer.py +3 -1
  57. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/utils.py +17 -17
  58. dataeval/utils/tensorflow/loss/__init__.py +6 -2
  59. dataeval/utils/torch/__init__.py +7 -3
  60. dataeval/{_internal/models/pytorch → utils/torch}/blocks.py +19 -14
  61. dataeval/{_internal → utils/torch}/datasets.py +48 -42
  62. dataeval/utils/torch/models.py +138 -0
  63. dataeval/{_internal/models/pytorch/autoencoder.py → utils/torch/trainer.py} +7 -136
  64. dataeval/{_internal → utils/torch}/utils.py +3 -1
  65. dataeval/workflows/__init__.py +1 -1
  66. dataeval/{_internal/workflows → workflows}/sufficiency.py +39 -34
  67. {dataeval-0.72.1.dist-info → dataeval-0.72.2.dist-info}/METADATA +2 -1
  68. dataeval-0.72.2.dist-info/RECORD +72 -0
  69. dataeval/_internal/detectors/__init__.py +0 -0
  70. dataeval/_internal/detectors/drift/__init__.py +0 -0
  71. dataeval/_internal/detectors/ood/__init__.py +0 -0
  72. dataeval/_internal/metrics/__init__.py +0 -0
  73. dataeval/_internal/metrics/stats/hashstats.py +0 -75
  74. dataeval/_internal/metrics/utils.py +0 -447
  75. dataeval/_internal/models/__init__.py +0 -0
  76. dataeval/_internal/models/pytorch/__init__.py +0 -0
  77. dataeval/_internal/models/pytorch/utils.py +0 -67
  78. dataeval/_internal/models/tensorflow/__init__.py +0 -0
  79. dataeval/_internal/workflows/__init__.py +0 -0
  80. dataeval/detectors/drift/kernels/__init__.py +0 -10
  81. dataeval/detectors/drift/updates/__init__.py +0 -8
  82. dataeval/utils/tensorflow/models/__init__.py +0 -9
  83. dataeval/utils/tensorflow/recon/__init__.py +0 -3
  84. dataeval/utils/torch/datasets/__init__.py +0 -12
  85. dataeval/utils/torch/models/__init__.py +0 -11
  86. dataeval/utils/torch/trainer/__init__.py +0 -7
  87. dataeval-0.72.1.dist-info/RECORD +0 -81
  88. /dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/gmm.py +0 -0
  89. {dataeval-0.72.1.dist-info → dataeval-0.72.2.dist-info}/LICENSE.txt +0 -0
  90. {dataeval-0.72.1.dist-info → dataeval-0.72.2.dist-info}/WHEEL +0 -0
@@ -1,36 +1,99 @@
1
1
  from __future__ import annotations
2
2
 
3
+ __all__ = ["BalanceOutput", "balance"]
4
+
3
5
  import warnings
4
6
  from dataclasses import dataclass
5
- from typing import Mapping
7
+ from typing import Any, Mapping
6
8
 
7
9
  import numpy as np
8
10
  from numpy.typing import ArrayLike, NDArray
9
11
  from sklearn.feature_selection import mutual_info_classif, mutual_info_regression
10
12
 
11
- from dataeval._internal.metrics.utils import entropy, preprocess_metadata
12
- from dataeval._internal.output import OutputMetadata, set_metadata
13
+ from dataeval.metrics.bias.metadata import entropy, heatmap, preprocess_metadata
14
+ from dataeval.output import OutputMetadata, set_metadata
13
15
 
14
16
 
15
17
  @dataclass(frozen=True)
16
18
  class BalanceOutput(OutputMetadata):
17
19
  """
18
- Output class for :func:`balance` :term:`Bias` metric
20
+ Output class for :func:`balance` bias metric
19
21
 
20
22
  Attributes
21
23
  ----------
22
24
  balance : NDArray[np.float64]
23
- Estimate of :term:`mutual information<Mutual Information (MI)>` between metadata factors and class label
25
+ Estimate of mutual information between metadata factors and class label
24
26
  factors : NDArray[np.float64]
25
27
  Estimate of inter/intra-factor mutual information
26
28
  classwise : NDArray[np.float64]
27
29
  Estimate of mutual information between metadata factors and individual class labels
30
+ class_list: NDArray[np.int64]
31
+ Class labels for each value in the dataset
32
+ metadata_names: list[str]
33
+ Names of each metadata factor
28
34
  """
29
35
 
30
36
  balance: NDArray[np.float64]
31
37
  factors: NDArray[np.float64]
32
38
  classwise: NDArray[np.float64]
33
39
 
40
+ class_list: NDArray[np.int64]
41
+ metadata_names: list[str]
42
+
43
+ def plot(
44
+ self,
45
+ row_labels: NDArray[Any] | None = None,
46
+ col_labels: NDArray[Any] | None = None,
47
+ plot_classwise: bool = False,
48
+ ) -> None:
49
+ """
50
+ Plot a heatmap of balance information
51
+
52
+ Parameters
53
+ ----------
54
+ row_labels: NDArray | None, default None
55
+ Array containing the labels for rows in the histogram
56
+ col_labels: NDArray | None, default None
57
+ Array containing the labels for columns in the histogram
58
+ plot_classwise: bool, default False
59
+ Whether to plot per-class balance instead of global balance
60
+
61
+ """
62
+ if plot_classwise:
63
+ if row_labels is None:
64
+ row_labels = np.unique(self.class_list)
65
+ if col_labels is None:
66
+ col_labels = np.concatenate((["class"], self.metadata_names))
67
+
68
+ heatmap(
69
+ self.classwise,
70
+ row_labels,
71
+ col_labels,
72
+ xlabel="Factors",
73
+ ylabel="Class",
74
+ cbarlabel="Normalized Mutual Information",
75
+ )
76
+ else:
77
+ data = np.concatenate([self.balance[np.newaxis, 1:], self.factors], axis=0)
78
+ # Create a mask for the upper triangle of the symmetrical array, ignoring the diagonal
79
+ mask = np.triu(data + 1, k=0) < 1
80
+ # Finalize the data for the plot, last row is last factor x last factor so it gets dropped
81
+ heat_data = np.where(mask, np.nan, data)[:-1]
82
+ # Creating label array for heat map axes
83
+ heat_labels = np.concatenate((["class"], self.metadata_names))
84
+
85
+ if row_labels is None:
86
+ row_labels = heat_labels[:-1]
87
+ if col_labels is None:
88
+ col_labels = heat_labels[1:]
89
+
90
+ heatmap(
91
+ heat_data,
92
+ row_labels,
93
+ col_labels,
94
+ cbarlabel="Normalized Mutual Information",
95
+ )
96
+
34
97
 
35
98
  def validate_num_neighbors(num_neighbors: int) -> int:
36
99
  if not isinstance(num_neighbors, (int, float)):
@@ -55,7 +118,7 @@ def validate_num_neighbors(num_neighbors: int) -> int:
55
118
  @set_metadata("dataeval.metrics")
56
119
  def balance(class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], num_neighbors: int = 5) -> BalanceOutput:
57
120
  """
58
- :term:`Mutual information (MI)` between factors (class label, metadata, label/image properties)
121
+ Mutual information (MI) between factors (class label, metadata, label/image properties)
59
122
 
60
123
  Parameters
61
124
  ----------
@@ -70,7 +133,7 @@ def balance(class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], num_neig
70
133
  Returns
71
134
  -------
72
135
  BalanceOutput
73
- (num_factors+1) x (num_factors+1) estimate of :term:`mutual information<Mutual Information (MI)>`
136
+ (num_factors+1) x (num_factors+1) estimate of mutual information
74
137
  between num_factors metadata factors and class label. Symmetry is enforced.
75
138
 
76
139
  Note
@@ -83,7 +146,7 @@ def balance(class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], num_neig
83
146
 
84
147
  Example
85
148
  -------
86
- Return :term:`balance<Balance>` (:term:`mutual information<Mutual Information (MI)>`) of factors with class_labels
149
+ Return balance (mutual information) of factors with class_labels
87
150
 
88
151
  >>> bal = balance(class_labels, metadata)
89
152
  >>> bal.balance
@@ -114,6 +177,9 @@ def balance(class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], num_neig
114
177
  mi = np.empty((num_factors, num_factors))
115
178
  mi[:] = np.nan
116
179
 
180
+ class_idx = names.index("class_label")
181
+ class_lbl = np.array(data[:, class_idx], dtype=int)
182
+
117
183
  for idx in range(num_factors):
118
184
  tgt = data[:, idx].astype(int)
119
185
 
@@ -174,4 +240,4 @@ def balance(class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], num_neig
174
240
  norm_factor = 0.5 * np.add.outer(ent_tgt_bin, ent_all) + 1e-6
175
241
  classwise = classwise_mi / norm_factor
176
242
 
177
- return BalanceOutput(balance, factors, classwise)
243
+ return BalanceOutput(balance, factors, classwise, class_lbl, list(metadata.keys()))
@@ -1,5 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
+ __all__ = ["CoverageOutput", "coverage"]
4
+
3
5
  import math
4
6
  from dataclasses import dataclass
5
7
  from typing import Literal
@@ -8,9 +10,9 @@ import numpy as np
8
10
  from numpy.typing import ArrayLike, NDArray
9
11
  from scipy.spatial.distance import pdist, squareform
10
12
 
11
- from dataeval._internal.interop import to_numpy
12
- from dataeval._internal.metrics.utils import flatten
13
- from dataeval._internal.output import OutputMetadata, set_metadata
13
+ from dataeval.interop import to_numpy
14
+ from dataeval.output import OutputMetadata, set_metadata
15
+ from dataeval.utils.shared import flatten
14
16
 
15
17
 
16
18
  @dataclass(frozen=True)
@@ -33,7 +35,7 @@ class CoverageOutput(OutputMetadata):
33
35
  critical_value: float
34
36
 
35
37
 
36
- @set_metadata("dataeval.metrics")
38
+ @set_metadata()
37
39
  def coverage(
38
40
  embeddings: ArrayLike,
39
41
  radius_type: Literal["adaptive", "naive"] = "adaptive",
@@ -1,13 +1,16 @@
1
1
  from __future__ import annotations
2
2
 
3
+ __all__ = ["DiversityOutput", "diversity"]
4
+
3
5
  from dataclasses import dataclass
4
- from typing import Literal, Mapping
6
+ from typing import Any, Literal, Mapping
5
7
 
6
8
  import numpy as np
7
9
  from numpy.typing import ArrayLike, NDArray
8
10
 
9
- from dataeval._internal.metrics.utils import entropy, get_counts, get_method, get_num_bins, preprocess_metadata
10
- from dataeval._internal.output import OutputMetadata, set_metadata
11
+ from dataeval.metrics.bias.metadata import entropy, get_counts, get_num_bins, heatmap, preprocess_metadata
12
+ from dataeval.output import OutputMetadata, set_metadata
13
+ from dataeval.utils.shared import get_method
11
14
 
12
15
 
13
16
  @dataclass(frozen=True)
@@ -21,18 +24,52 @@ class DiversityOutput(OutputMetadata):
21
24
  :term:`Diversity` index for classes and factors
22
25
  classwise : NDArray[np.float64]
23
26
  Classwise diversity index [n_class x n_factor]
27
+ class_list: NDArray[np.int64]
28
+ Class labels for each value in the dataset
29
+ metadata_names: list[str]
30
+ Names of each metadata factor
24
31
  """
25
32
 
26
33
  diversity_index: NDArray[np.float64]
27
34
  classwise: NDArray[np.float64]
28
35
 
36
+ class_list: NDArray[np.int64]
37
+ metadata_names: list[str]
38
+
39
+ method: Literal["shannon", "simpson"]
40
+
41
+ def plot(self, row_labels: NDArray[Any] | None = None, col_labels: NDArray[Any] | None = None) -> None:
42
+ """
43
+ Plot a heatmap of diversity information
44
+
45
+ Parameters
46
+ ----------
47
+ row_labels: NDArray | None, default None
48
+ Array containing the labels for rows in the histogram
49
+ col_labels: NDArray | None, default None
50
+ Array containing the labels for columns in the histogram
51
+ """
52
+ if row_labels is None:
53
+ row_labels = np.unique(self.class_list)
54
+ if col_labels is None:
55
+ col_labels = np.array(self.metadata_names)
56
+
57
+ heatmap(
58
+ self.classwise,
59
+ row_labels,
60
+ col_labels,
61
+ xlabel="Factors",
62
+ ylabel="Class",
63
+ cbarlabel=f"Normalized {self.method.title()} Index",
64
+ )
65
+
29
66
 
30
67
  def diversity_shannon(
31
- data: NDArray,
68
+ data: NDArray[Any],
32
69
  names: list[str],
33
70
  is_categorical: list[bool],
34
71
  subset_mask: NDArray[np.bool_] | None = None,
35
- ) -> NDArray:
72
+ ) -> NDArray[np.float64]:
36
73
  """
37
74
  Compute :term:`diversity<Diversity>` for discrete/categorical variables and, through standard
38
75
  histogram binning, for continuous variables.
@@ -79,11 +116,11 @@ def diversity_shannon(
79
116
 
80
117
 
81
118
  def diversity_simpson(
82
- data: NDArray,
119
+ data: NDArray[Any],
83
120
  names: list[str],
84
121
  is_categorical: list[bool],
85
122
  subset_mask: NDArray[np.bool_] | None = None,
86
- ) -> NDArray:
123
+ ) -> NDArray[np.float64]:
87
124
  """
88
125
  Compute :term:`diversity<Diversity>` for discrete/categorical variables and, through standard
89
126
  histogram binning, for continuous variables.
@@ -139,10 +176,7 @@ def diversity_simpson(
139
176
  return ev_index
140
177
 
141
178
 
142
- DIVERSITY_FN_MAP = {"simpson": diversity_simpson, "shannon": diversity_shannon}
143
-
144
-
145
- @set_metadata("dataeval.metrics")
179
+ @set_metadata()
146
180
  def diversity(
147
181
  class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], method: Literal["shannon", "simpson"] = "simpson"
148
182
  ) -> DiversityOutput:
@@ -202,12 +236,12 @@ def diversity(
202
236
  --------
203
237
  numpy.histogram
204
238
  """
205
- diversity_fn = get_method(DIVERSITY_FN_MAP, method)
239
+ diversity_fn = get_method({"simpson": diversity_simpson, "shannon": diversity_shannon}, method)
206
240
  data, names, is_categorical = preprocess_metadata(class_labels, metadata)
207
241
  diversity_index = diversity_fn(data, names, is_categorical, None).astype(np.float64)
208
242
 
209
243
  class_idx = names.index("class_label")
210
- class_lbl = data[:, class_idx]
244
+ class_lbl = np.array(data[:, class_idx], dtype=int)
211
245
 
212
246
  u_classes = np.unique(class_lbl)
213
247
  num_factors = len(names)
@@ -218,4 +252,4 @@ def diversity(
218
252
  diversity[idx, :] = diversity_fn(data, names, is_categorical, subset_mask)
219
253
  div_no_class = np.concatenate((diversity[:, :class_idx], diversity[:, (class_idx + 1) :]), axis=1)
220
254
 
221
- return DiversityOutput(diversity_index, div_no_class)
255
+ return DiversityOutput(diversity_index, div_no_class, class_lbl, list(metadata.keys()), method)
@@ -0,0 +1,275 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ from typing import Any, Mapping
6
+
7
+ import numpy as np
8
+ from numpy.typing import ArrayLike, NDArray
9
+ from scipy.stats import entropy as sp_entropy
10
+
11
+ from dataeval.interop import to_numpy
12
+
13
+
14
+ def get_counts(
15
+ data: NDArray[np.int_], names: list[str], is_categorical: list[bool], subset_mask: NDArray[np.bool_] | None = None
16
+ ) -> tuple[dict[str, NDArray[np.int_]], dict[str, NDArray[np.int_]]]:
17
+ """
18
+ Initialize dictionary of histogram counts --- treat categorical values
19
+ as histogram bins.
20
+
21
+ Parameters
22
+ ----------
23
+ subset_mask: NDArray[np.bool_] | None
24
+ Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
25
+
26
+ Returns
27
+ -------
28
+ counts: Dict
29
+ histogram counts per metadata factor in `factors`. Each
30
+ factor will have a different number of bins. Counts get reused
31
+ across metrics, so hist_counts are cached but only if computed
32
+ globally, i.e. without masked samples.
33
+ """
34
+
35
+ hist_counts, hist_bins = {}, {}
36
+ # np.where needed to satisfy linter
37
+ mask = np.where(subset_mask if subset_mask is not None else np.ones(data.shape[0], dtype=bool))
38
+
39
+ for cdx, fn in enumerate(names):
40
+ # linter doesn't like double indexing
41
+ col_data = data[mask, cdx].squeeze()
42
+ if is_categorical[cdx]:
43
+ # if discrete, use unique values as bins
44
+ bins, cnts = np.unique(col_data, return_counts=True)
45
+ else:
46
+ bins = hist_bins.get(fn, "auto")
47
+ cnts, bins = np.histogram(col_data, bins=bins, density=True)
48
+
49
+ hist_counts[fn] = cnts
50
+ hist_bins[fn] = bins
51
+
52
+ return hist_counts, hist_bins
53
+
54
+
55
+ def entropy(
56
+ data: NDArray[Any],
57
+ names: list[str],
58
+ is_categorical: list[bool],
59
+ normalized: bool = False,
60
+ subset_mask: NDArray[np.bool_] | None = None,
61
+ ) -> NDArray[np.float64]:
62
+ """
63
+ Meant for use with :term:`bias<Bias>` metrics, :term:`balance<Balance>`, :term:`diversity<Diversity>`,
64
+ ClasswiseBalance, and Classwise Diversity.
65
+
66
+ Compute entropy for discrete/categorical variables and for continuous variables through standard
67
+ histogram binning.
68
+
69
+ Parameters
70
+ ----------
71
+ normalized: bool
72
+ Flag that determines whether or not to normalize entropy by log(num_bins)
73
+ subset_mask: NDArray[np.bool_] | None
74
+ Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
75
+
76
+ Note
77
+ ----
78
+ For continuous variables, histogram bins are chosen automatically. See
79
+ numpy.histogram for details.
80
+
81
+ Returns
82
+ -------
83
+ ent: NDArray[np.float64]
84
+ Entropy estimate per column of X
85
+
86
+ See Also
87
+ --------
88
+ numpy.histogram
89
+ scipy.stats.entropy
90
+ """
91
+
92
+ num_factors = len(names)
93
+ hist_counts, _ = get_counts(data, names, is_categorical, subset_mask)
94
+
95
+ ev_index = np.empty(num_factors)
96
+ for col, cnts in enumerate(hist_counts.values()):
97
+ # entropy in nats, normalizes counts
98
+ ev_index[col] = sp_entropy(cnts)
99
+ if normalized:
100
+ if len(cnts) == 1:
101
+ # log(0)
102
+ ev_index[col] = 0
103
+ else:
104
+ ev_index[col] /= np.log(len(cnts))
105
+ return ev_index
106
+
107
+
108
+ def get_num_bins(
109
+ data: NDArray[Any], names: list[str], is_categorical: list[bool], subset_mask: NDArray[np.bool_] | None = None
110
+ ) -> NDArray[np.float64]:
111
+ """
112
+ Number of bins or unique values for each metadata factor, used to
113
+ normalize entropy/:term:`diversity<Diversity>`.
114
+
115
+ Parameters
116
+ ----------
117
+ subset_mask: NDArray[np.bool_] | None
118
+ Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
119
+
120
+ Returns
121
+ -------
122
+ NDArray[np.float64]
123
+ """
124
+ # likely cached
125
+ hist_counts, _ = get_counts(data, names, is_categorical, subset_mask)
126
+ num_bins = np.empty(len(hist_counts))
127
+ for idx, cnts in enumerate(hist_counts.values()):
128
+ num_bins[idx] = len(cnts)
129
+
130
+ return num_bins
131
+
132
+
133
+ def infer_categorical(arr: NDArray[Any], threshold: float = 0.2) -> NDArray[Any]:
134
+ """
135
+ Compute fraction of feature values that are unique --- intended to be used
136
+ for inferring whether variables are categorical.
137
+ """
138
+ if arr.ndim == 1:
139
+ arr = np.expand_dims(arr, axis=1)
140
+ num_samples = arr.shape[0]
141
+ pct_unique = np.empty(arr.shape[1])
142
+ for col in range(arr.shape[1]): # type: ignore
143
+ uvals = np.unique(arr[:, col], axis=0)
144
+ pct_unique[col] = len(uvals) / num_samples
145
+ return pct_unique < threshold
146
+
147
+
148
+ def preprocess_metadata(
149
+ class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], cat_thresh: float = 0.2
150
+ ) -> tuple[NDArray[Any], list[str], list[bool]]:
151
+ # convert class_labels and dict of lists to matrix of metadata values
152
+ preprocessed_metadata = {"class_label": np.asarray(class_labels, dtype=int)}
153
+
154
+ # map columns of dict that are not numeric (e.g. string) to numeric values
155
+ # that mutual information and diversity functions can accommodate. Each
156
+ # unique string receives a unique integer value.
157
+ for k, v in metadata.items():
158
+ # if not numeric
159
+ v = to_numpy(v)
160
+ if not np.issubdtype(v.dtype, np.number):
161
+ _, mapped_vals = np.unique(v, return_inverse=True)
162
+ preprocessed_metadata[k] = mapped_vals
163
+ else:
164
+ preprocessed_metadata[k] = v
165
+
166
+ data = np.stack(list(preprocessed_metadata.values()), axis=-1)
167
+ names = list(preprocessed_metadata.keys())
168
+ is_categorical = [infer_categorical(preprocessed_metadata[var], cat_thresh)[0] for var in names]
169
+
170
+ return data, names, is_categorical
171
+
172
+
173
+ def heatmap(
174
+ data: NDArray[Any],
175
+ row_labels: NDArray[Any],
176
+ col_labels: NDArray[Any],
177
+ xlabel: str = "",
178
+ ylabel: str = "",
179
+ cbarlabel: str = "",
180
+ ) -> None:
181
+ """
182
+ Plots a formatted heatmap
183
+
184
+ Parameters
185
+ ----------
186
+ data: NDArray
187
+ Array containing numerical values for factors to plot
188
+ row_labels: NDArray
189
+ Array containing the labels for rows in the histogram
190
+ col_labels: NDArray
191
+ Array containing the labels for columns in the histogram
192
+ xlabel: str, default ""
193
+ X-axis label
194
+ ylabel: str, default ""
195
+ Y-axis label
196
+ cbarlabel: str, default ""
197
+ Label for the colorbar
198
+
199
+ """
200
+ import matplotlib
201
+ import matplotlib.pyplot as plt
202
+
203
+ fig, ax = plt.subplots(figsize=(10, 10))
204
+
205
+ # Plot the heatmap
206
+ im = ax.imshow(data, vmin=0, vmax=1.0)
207
+
208
+ # Create colorbar
209
+ cbar = fig.colorbar(im, shrink=0.5)
210
+ cbar.set_ticks([0.0, 0.25, 0.5, 0.75, 1.0])
211
+ cbar.set_ticklabels(["0.0", "0.25", "0.5", "0.75", "1.0"])
212
+ cbar.set_label(cbarlabel, loc="center")
213
+
214
+ # Show all ticks and label them with the respective list entries.
215
+ ax.set_xticks(np.arange(data.shape[1]), labels=col_labels)
216
+ ax.set_yticks(np.arange(data.shape[0]), labels=row_labels)
217
+
218
+ ax.tick_params(top=False, bottom=True, labeltop=False, labelbottom=True)
219
+ # Rotate the tick labels and set their alignment.
220
+ plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
221
+
222
+ # Turn spines off and create white grid.
223
+ ax.spines[:].set_visible(False)
224
+
225
+ ax.set_xticks(np.arange(data.shape[1] + 1) - 0.5, minor=True)
226
+ ax.set_yticks(np.arange(data.shape[0] + 1) - 0.5, minor=True)
227
+ ax.grid(which="minor", color="w", linestyle="-", linewidth=3)
228
+ ax.tick_params(which="minor", bottom=False, left=False)
229
+
230
+ if xlabel:
231
+ ax.set_xlabel(xlabel)
232
+ if ylabel:
233
+ ax.set_ylabel(ylabel)
234
+
235
+ valfmt = matplotlib.ticker.FuncFormatter(format_text) # type: ignore
236
+
237
+ # Normalize the threshold to the images color range.
238
+ threshold = im.norm(1.0) / 2.0
239
+
240
+ # Set default alignment to center, but allow it to be
241
+ # overwritten by textkw.
242
+ kw = {"horizontalalignment": "center", "verticalalignment": "center"}
243
+
244
+ # Loop over the data and create a `Text` for each "pixel".
245
+ # Change the text's color depending on the data.
246
+ textcolors = ("white", "black")
247
+ texts = []
248
+ for i in range(data.shape[0]):
249
+ for j in range(data.shape[1]):
250
+ kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])
251
+ text = im.axes.text(j, i, valfmt(data[i, j], None), **kw) # type: ignore
252
+ texts.append(text)
253
+
254
+ fig.tight_layout()
255
+ plt.show()
256
+
257
+
258
+ # Function to define how the text is displayed in the heatmap
259
+ def format_text(*args: str) -> str:
260
+ """
261
+ Helper function to format text for heatmap()
262
+
263
+ Parameters
264
+ ----------
265
+ *args: Tuple (str, str)
266
+ Text to be formatted. Second element is ignored, but is a
267
+ mandatory pass-through argument as per matplotlib.ticket.FuncFormatter
268
+
269
+ Returns
270
+ -------
271
+ str
272
+ Formatted text
273
+ """
274
+ x = args[0]
275
+ return f"{x:.2f}".replace("0.00", "0").replace("0.", ".").replace("nan", "")
@@ -1,15 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
+ __all__ = ["ParityOutput", "parity", "label_parity"]
4
+
3
5
  import warnings
4
6
  from dataclasses import dataclass
5
- from typing import Generic, Mapping, TypeVar
7
+ from typing import Any, Generic, Mapping, TypeVar
6
8
 
7
9
  import numpy as np
8
10
  from numpy.typing import ArrayLike, NDArray
9
11
  from scipy.stats import chi2_contingency, chisquare
10
12
 
11
- from dataeval._internal.interop import to_numpy
12
- from dataeval._internal.output import OutputMetadata, set_metadata
13
+ from dataeval.interop import to_numpy
14
+ from dataeval.output import OutputMetadata, set_metadata
13
15
 
14
16
  TData = TypeVar("TData", np.float64, NDArray[np.float64])
15
17
 
@@ -31,7 +33,7 @@ class ParityOutput(Generic[TData], OutputMetadata):
31
33
  p_value: TData
32
34
 
33
35
 
34
- def digitize_factor_bins(continuous_values: NDArray, bins: int, factor_name: str) -> NDArray:
36
+ def digitize_factor_bins(continuous_values: NDArray[Any], bins: int, factor_name: str) -> NDArray[np.intp]:
35
37
  """
36
38
  Digitizes a list of values into a given number of bins.
37
39
 
@@ -64,8 +66,8 @@ def digitize_factor_bins(continuous_values: NDArray, bins: int, factor_name: str
64
66
 
65
67
 
66
68
  def format_discretize_factors(
67
- data_factors: Mapping[str, NDArray], continuous_factor_bincounts: Mapping[str, int]
68
- ) -> dict[str, NDArray]:
69
+ data_factors: Mapping[str, NDArray[Any]], continuous_factor_bincounts: Mapping[str, int]
70
+ ) -> dict[str, NDArray[Any]]:
69
71
  """
70
72
  Sets up the internal list of metadata factors.
71
73
 
@@ -115,7 +117,7 @@ def format_discretize_factors(
115
117
  return metadata_factors
116
118
 
117
119
 
118
- def normalize_expected_dist(expected_dist: NDArray, observed_dist: NDArray) -> NDArray:
120
+ def normalize_expected_dist(expected_dist: NDArray[Any], observed_dist: NDArray[Any]) -> NDArray[Any]:
119
121
  """
120
122
  Normalize the expected label distribution to match the total number of labels in the observed distribution.
121
123
 
@@ -162,7 +164,7 @@ def normalize_expected_dist(expected_dist: NDArray, observed_dist: NDArray) -> N
162
164
  return expected_dist
163
165
 
164
166
 
165
- def validate_dist(label_dist: NDArray, label_name: str):
167
+ def validate_dist(label_dist: NDArray[Any], label_name: str) -> None:
166
168
  """
167
169
  Verifies that the given label distribution has labels and checks if
168
170
  any labels have frequencies less than 5.
@@ -191,7 +193,7 @@ def validate_dist(label_dist: NDArray, label_name: str):
191
193
  )
192
194
 
193
195
 
194
- @set_metadata("dataeval.metrics")
196
+ @set_metadata()
195
197
  def label_parity(
196
198
  expected_labels: ArrayLike,
197
199
  observed_labels: ArrayLike,
@@ -279,7 +281,7 @@ def label_parity(
279
281
  return ParityOutput(cs, p)
280
282
 
281
283
 
282
- @set_metadata("dataeval.metrics")
284
+ @set_metadata()
283
285
  def parity(
284
286
  class_labels: ArrayLike,
285
287
  data_factors: Mapping[str, ArrayLike],
@@ -2,8 +2,8 @@
2
2
  Estimators calculate performance bounds and the statistical distance between datasets.
3
3
  """
4
4
 
5
- from dataeval._internal.metrics.ber import BEROutput, ber
6
- from dataeval._internal.metrics.divergence import DivergenceOutput, divergence
7
- from dataeval._internal.metrics.uap import UAPOutput, uap
5
+ from dataeval.metrics.estimators.ber import BEROutput, ber
6
+ from dataeval.metrics.estimators.divergence import DivergenceOutput, divergence
7
+ from dataeval.metrics.estimators.uap import UAPOutput, uap
8
8
 
9
9
  __all__ = ["ber", "divergence", "uap", "BEROutput", "DivergenceOutput", "UAPOutput"]