dataeval 0.76.1__py3-none-any.whl → 0.82.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. dataeval/__init__.py +3 -3
  2. dataeval/config.py +77 -0
  3. dataeval/detectors/__init__.py +1 -1
  4. dataeval/detectors/drift/__init__.py +6 -6
  5. dataeval/detectors/drift/{base.py → _base.py} +40 -85
  6. dataeval/detectors/drift/{cvm.py → _cvm.py} +21 -28
  7. dataeval/detectors/drift/{ks.py → _ks.py} +20 -26
  8. dataeval/detectors/drift/{mmd.py → _mmd.py} +31 -43
  9. dataeval/detectors/drift/{torch.py → _torch.py} +2 -1
  10. dataeval/detectors/drift/{uncertainty.py → _uncertainty.py} +24 -7
  11. dataeval/detectors/drift/updates.py +20 -3
  12. dataeval/detectors/linters/__init__.py +3 -5
  13. dataeval/detectors/linters/duplicates.py +13 -36
  14. dataeval/detectors/linters/outliers.py +23 -148
  15. dataeval/detectors/ood/__init__.py +1 -1
  16. dataeval/detectors/ood/ae.py +30 -9
  17. dataeval/detectors/ood/base.py +5 -4
  18. dataeval/detectors/ood/mixin.py +21 -7
  19. dataeval/detectors/ood/vae.py +73 -0
  20. dataeval/metadata/__init__.py +6 -0
  21. dataeval/metadata/_distance.py +167 -0
  22. dataeval/metadata/_ood.py +217 -0
  23. dataeval/metadata/_utils.py +44 -0
  24. dataeval/metrics/__init__.py +1 -1
  25. dataeval/metrics/bias/__init__.py +6 -4
  26. dataeval/metrics/bias/{balance.py → _balance.py} +15 -101
  27. dataeval/metrics/bias/_coverage.py +98 -0
  28. dataeval/metrics/bias/{diversity.py → _diversity.py} +18 -111
  29. dataeval/metrics/bias/{parity.py → _parity.py} +39 -77
  30. dataeval/metrics/estimators/__init__.py +15 -4
  31. dataeval/metrics/estimators/{ber.py → _ber.py} +42 -29
  32. dataeval/metrics/estimators/_clusterer.py +44 -0
  33. dataeval/metrics/estimators/{divergence.py → _divergence.py} +18 -30
  34. dataeval/metrics/estimators/{uap.py → _uap.py} +4 -18
  35. dataeval/metrics/stats/__init__.py +16 -13
  36. dataeval/metrics/stats/{base.py → _base.py} +82 -133
  37. dataeval/metrics/stats/{boxratiostats.py → _boxratiostats.py} +15 -18
  38. dataeval/metrics/stats/_dimensionstats.py +75 -0
  39. dataeval/metrics/stats/{hashstats.py → _hashstats.py} +21 -37
  40. dataeval/metrics/stats/_imagestats.py +94 -0
  41. dataeval/metrics/stats/_labelstats.py +131 -0
  42. dataeval/metrics/stats/{pixelstats.py → _pixelstats.py} +19 -50
  43. dataeval/metrics/stats/{visualstats.py → _visualstats.py} +23 -54
  44. dataeval/outputs/__init__.py +53 -0
  45. dataeval/{output.py → outputs/_base.py} +55 -25
  46. dataeval/outputs/_bias.py +381 -0
  47. dataeval/outputs/_drift.py +83 -0
  48. dataeval/outputs/_estimators.py +114 -0
  49. dataeval/outputs/_linters.py +184 -0
  50. dataeval/{detectors/ood/output.py → outputs/_ood.py} +22 -22
  51. dataeval/outputs/_stats.py +387 -0
  52. dataeval/outputs/_utils.py +44 -0
  53. dataeval/outputs/_workflows.py +364 -0
  54. dataeval/typing.py +234 -0
  55. dataeval/utils/__init__.py +2 -2
  56. dataeval/utils/_array.py +169 -0
  57. dataeval/utils/_bin.py +199 -0
  58. dataeval/utils/_clusterer.py +144 -0
  59. dataeval/utils/_fast_mst.py +189 -0
  60. dataeval/utils/{image.py → _image.py} +6 -4
  61. dataeval/utils/_method.py +14 -0
  62. dataeval/utils/{shared.py → _mst.py} +3 -65
  63. dataeval/utils/{plot.py → _plot.py} +6 -6
  64. dataeval/utils/data/__init__.py +26 -0
  65. dataeval/utils/data/_dataset.py +217 -0
  66. dataeval/utils/data/_embeddings.py +104 -0
  67. dataeval/utils/data/_images.py +68 -0
  68. dataeval/utils/data/_metadata.py +360 -0
  69. dataeval/utils/data/_selection.py +126 -0
  70. dataeval/utils/{dataset/split.py → data/_split.py} +12 -38
  71. dataeval/utils/data/_targets.py +85 -0
  72. dataeval/utils/data/collate.py +103 -0
  73. dataeval/utils/data/datasets/__init__.py +17 -0
  74. dataeval/utils/data/datasets/_base.py +254 -0
  75. dataeval/utils/data/datasets/_cifar10.py +134 -0
  76. dataeval/utils/data/datasets/_fileio.py +168 -0
  77. dataeval/utils/data/datasets/_milco.py +153 -0
  78. dataeval/utils/data/datasets/_mixin.py +56 -0
  79. dataeval/utils/data/datasets/_mnist.py +183 -0
  80. dataeval/utils/data/datasets/_ships.py +123 -0
  81. dataeval/utils/data/datasets/_types.py +52 -0
  82. dataeval/utils/data/datasets/_voc.py +352 -0
  83. dataeval/utils/data/selections/__init__.py +15 -0
  84. dataeval/utils/data/selections/_classfilter.py +57 -0
  85. dataeval/utils/data/selections/_indices.py +26 -0
  86. dataeval/utils/data/selections/_limit.py +26 -0
  87. dataeval/utils/data/selections/_reverse.py +18 -0
  88. dataeval/utils/data/selections/_shuffle.py +29 -0
  89. dataeval/utils/metadata.py +51 -376
  90. dataeval/utils/torch/{gmm.py → _gmm.py} +4 -2
  91. dataeval/utils/torch/{internal.py → _internal.py} +21 -51
  92. dataeval/utils/torch/models.py +43 -2
  93. dataeval/workflows/__init__.py +2 -1
  94. dataeval/workflows/sufficiency.py +11 -346
  95. {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/METADATA +5 -2
  96. dataeval-0.82.0.dist-info/RECORD +104 -0
  97. dataeval/detectors/linters/clusterer.py +0 -512
  98. dataeval/detectors/linters/merged_stats.py +0 -49
  99. dataeval/detectors/ood/metadata_ks_compare.py +0 -129
  100. dataeval/detectors/ood/metadata_least_likely.py +0 -119
  101. dataeval/interop.py +0 -69
  102. dataeval/metrics/bias/coverage.py +0 -194
  103. dataeval/metrics/stats/datasetstats.py +0 -202
  104. dataeval/metrics/stats/dimensionstats.py +0 -115
  105. dataeval/metrics/stats/labelstats.py +0 -210
  106. dataeval/utils/dataset/__init__.py +0 -7
  107. dataeval/utils/dataset/datasets.py +0 -412
  108. dataeval/utils/dataset/read.py +0 -63
  109. dataeval-0.76.1.dist-info/RECORD +0 -67
  110. /dataeval/{log.py → _log.py} +0 -0
  111. /dataeval/utils/torch/{blocks.py → _blocks.py} +0 -0
  112. {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/LICENSE.txt +0 -0
  113. {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,167 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ import warnings
6
+ from typing import NamedTuple, cast
7
+
8
+ import numpy as np
9
+ from scipy.stats import iqr, ks_2samp
10
+ from scipy.stats import wasserstein_distance as emd
11
+
12
+ from dataeval.metadata._utils import _compare_keys, _validate_factors_and_data
13
+ from dataeval.outputs._base import MappingOutput
14
+ from dataeval.typing import ArrayLike
15
+ from dataeval.utils.data import Metadata
16
+
17
+
18
+ class KSType(NamedTuple):
19
+ """Used to typehint scipy's internal hidden ks_2samp output"""
20
+
21
+ statistic: float
22
+ statistic_location: float
23
+ pvalue: float
24
+
25
+
26
+ class MetadataKSResult(NamedTuple):
27
+ """
28
+ Attributes
29
+ ----------
30
+ statistic : float
31
+ the KS statistic
32
+ location : float
33
+ The value at which the KS statistic has its maximum, measured in IQR-normalized units relative
34
+ to the median of the reference distribution.
35
+ dist : float
36
+ The Earth Mover's Distance normalized by the interquartile range (IQR) of the reference
37
+ pvalue : float
38
+ The p-value from the KS two-sample test
39
+ """
40
+
41
+ statistic: float
42
+ location: float
43
+ dist: float
44
+ pvalue: float
45
+
46
+
47
+ class KSOutput(MappingOutput[str, MetadataKSResult]):
48
+ """
49
+ Output class for results of ks_2samp featurewise comparisons of new metadata to reference metadata.
50
+
51
+
52
+ Attributes
53
+ ----------
54
+ key: str
55
+ Metadata feature names
56
+ value: :class:`MetadataKSResult`
57
+ Output per feature name containing the statistic, statistic location, distance, and pvalue.
58
+ """
59
+
60
+
61
+ def _calculate_drift(x1: ArrayLike, x2: ArrayLike) -> float:
62
+ """Calculates the shift magnitude between x1 and x2 scaled by x1"""
63
+
64
+ distance = emd(x1, x2)
65
+
66
+ X = iqr(x1)
67
+
68
+ # Preferred scaling of x1
69
+ if X:
70
+ return distance / X
71
+
72
+ # Return if single-valued, else scale
73
+ xmin, xmax = np.min(x1), np.max(x1)
74
+ return distance if xmin == xmax else distance / (xmax - xmin)
75
+
76
+
77
+ def metadata_distance(metadata1: Metadata, metadata2: Metadata) -> KSOutput:
78
+ """
79
+ Measures the feature-wise distance between two continuous metadata distributions and
80
+ computes a p-value to evaluate its significance.
81
+
82
+ Uses the Earth Mover's Distance and the Kolmogorov-Smirnov two-sample test, featurewise.
83
+
84
+ Parameters
85
+ ----------
86
+ metadata1 : Metadata
87
+ Class containing continuous factor names and values to be used as reference
88
+ metadata2 : Metadata
89
+ Class containing continuous factor names and values to be compare with the reference
90
+
91
+ Returns
92
+ -------
93
+ dict[str, KstestResult]
94
+ A dictionary with keys corresponding to metadata feature names, and values that are KstestResult objects, as
95
+ defined by scipy.stats.ks_2samp.
96
+
97
+ See Also
98
+ --------
99
+ Earth mover's distance
100
+
101
+ Kolmogorov-Smirnov two-sample test
102
+
103
+ Note
104
+ ----
105
+ This function only applies to the continuous data
106
+
107
+ Examples
108
+ --------
109
+ >>> output = metadata_distance(metadata1, metadata2)
110
+ >>> list(output)
111
+ ['time', 'altitude']
112
+ >>> output["time"]
113
+ MetadataKSResult(statistic=1.0, location=0.44354838709677413, dist=2.7, pvalue=0.0)
114
+ """
115
+
116
+ _compare_keys(metadata1.continuous_factor_names, metadata2.continuous_factor_names)
117
+ fnames = metadata1.continuous_factor_names
118
+
119
+ cont1 = np.atleast_2d(metadata1.continuous_data) # (S, F)
120
+ cont2 = np.atleast_2d(metadata2.continuous_data) # (S, F)
121
+
122
+ _validate_factors_and_data(fnames, cont1)
123
+ _validate_factors_and_data(fnames, cont2)
124
+
125
+ N = len(cont1)
126
+ M = len(cont2)
127
+
128
+ # This is a simplified version of sqrt(N*M / N+M) < 4
129
+ if (N - 16) * (M - 16) < 256:
130
+ warnings.warn(
131
+ f"Sample sizes of {N}, {M} will yield unreliable p-values from the KS test. "
132
+ f"Recommended 32 samples per factor or at least 16 if one set has many more.",
133
+ UserWarning,
134
+ )
135
+
136
+ # Set default for statistic, location, and magnitude to zero and pvalue to one
137
+ results: dict[str, MetadataKSResult] = {}
138
+
139
+ # Per factor
140
+ for i, fname in enumerate(fnames):
141
+ fdata1 = cont1[:, i] # (S, 1)
142
+ fdata2 = cont2[:, i] # (S, 1)
143
+
144
+ # Min and max over both distributions
145
+ xmin = min(np.min(fdata1), np.min(fdata2))
146
+ xmax = max(np.max(fdata1), np.max(fdata2))
147
+
148
+ # Default case
149
+ if xmin == xmax:
150
+ results[fname] = MetadataKSResult(statistic=0.0, location=0.0, dist=0.0, pvalue=1.0)
151
+ continue
152
+
153
+ ks_result = cast(KSType, ks_2samp(fdata1, fdata2, method="asymp"))
154
+
155
+ # Normalized location
156
+ loc = float((ks_result.statistic_location - xmin) / (xmax - xmin))
157
+
158
+ drift = _calculate_drift(fdata1, fdata2)
159
+
160
+ results[fname] = MetadataKSResult(
161
+ statistic=ks_result.statistic,
162
+ location=loc,
163
+ dist=drift,
164
+ pvalue=ks_result.pvalue,
165
+ )
166
+
167
+ return KSOutput(results)
@@ -0,0 +1,217 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ import warnings
6
+
7
+ import numpy as np
8
+ from numpy.typing import NDArray
9
+
10
+ from dataeval.metadata._utils import _compare_keys, _validate_factors_and_data
11
+ from dataeval.outputs import OODOutput
12
+ from dataeval.utils.data import Metadata
13
+
14
+
15
+ def _combine_metadata(metadata_1: Metadata, metadata_2: Metadata) -> tuple[list[str], list[NDArray], list[NDArray]]:
16
+ """
17
+ Combines the factor names and data arrays of metadata_1 and metadata_2 when the names
18
+ match exactly and data has the same number of columns (factors).
19
+
20
+ Parameters
21
+ ----------
22
+ metadata_1 : Metadata
23
+ The set of factor names used as reference to determine the correct factor names and length of data
24
+ metadata_2 : Metadata
25
+ The compared set of factor names and data that must match metadata_1
26
+
27
+ Returns
28
+ -------
29
+ list[str]
30
+ The combined discrete and continuous factor names in that order.
31
+ list[NDArray]
32
+ Combined discrete and continuous data of metadata_1
33
+ list[NDArray]
34
+ Combined discrete and continuous data of metadata_2
35
+
36
+ Raises
37
+ ------
38
+ ValueError
39
+ If keys do not match in metadata_1 and metadata_2
40
+ ValueError
41
+ If the length of keys do not match the length of the data
42
+ """
43
+ factor_names: list[str] = []
44
+ m1_data: list[NDArray] = []
45
+ m2_data: list[NDArray] = []
46
+
47
+ # Both metadata must have the same number of factors (cols), but not necessarily samples (row)
48
+ if metadata_1.total_num_factors != metadata_2.total_num_factors:
49
+ raise ValueError(
50
+ f"Number of factors differs between metadata_1 ({metadata_1.total_num_factors}) "
51
+ f"and metadata_2 ({metadata_2.total_num_factors})"
52
+ )
53
+
54
+ # Validate and attach discrete data
55
+ if metadata_1.discrete_factor_names:
56
+ _compare_keys(metadata_1.discrete_factor_names, metadata_2.discrete_factor_names)
57
+ _validate_factors_and_data(metadata_1.discrete_factor_names, metadata_1.discrete_data)
58
+
59
+ factor_names.extend(metadata_1.discrete_factor_names)
60
+ m1_data.append(metadata_1.discrete_data)
61
+ m2_data.append(metadata_2.discrete_data)
62
+
63
+ # Validate and attach continuous data
64
+ if metadata_1.continuous_factor_names:
65
+ _compare_keys(metadata_1.continuous_factor_names, metadata_2.continuous_factor_names)
66
+ _validate_factors_and_data(metadata_1.continuous_factor_names, metadata_1.continuous_data)
67
+
68
+ factor_names.extend(metadata_1.continuous_factor_names)
69
+ m1_data.append(metadata_1.continuous_data)
70
+ m2_data.append(metadata_2.continuous_data)
71
+
72
+ # Turns list of discrete and continuous into one array
73
+ return factor_names, m1_data, m2_data
74
+
75
+
76
+ def _calc_median_deviations(reference: NDArray, test: NDArray) -> NDArray:
77
+ """
78
+ Calculates deviations of the test data from the median of the reference data
79
+
80
+ Parameters
81
+ ----------
82
+ reference : NDArray
83
+ Reference values of shape (samples, factors)
84
+ test : NDArray
85
+ Incoming values where each sample's factors will be compared to the median of
86
+ the reference set corresponding factors
87
+
88
+ Returns
89
+ -------
90
+ NDArray
91
+ Scaled positive and negative deviations of the test data from the reference.
92
+
93
+ Note
94
+ ----
95
+ All return values are in the range [0, pos_inf]
96
+ """
97
+
98
+ # Take median over samples (rows)
99
+ ref_median = np.median(reference, axis=0) # (F, )
100
+
101
+ # Shift reference and test distributions by reference
102
+ ref_dev = reference - ref_median # (S, F) - F
103
+ test_dev = test - ref_median # (S_t, F) - F
104
+
105
+ # Separate positive and negative distributions
106
+ # Fills with nans to keep shape in both 1-D and N-D matrices
107
+ pdev = np.where(ref_dev > 0, ref_dev, np.nan) # (S, F)
108
+ ndev = np.where(ref_dev < 0, ref_dev, np.nan) # (S, F)
109
+
110
+ # Calculate middle of positive and negative distributions per feature
111
+ pscale = np.nanmedian(pdev, axis=0) # (F, )
112
+ nscale = np.abs(np.nanmedian(ndev, axis=0)) # (F, )
113
+
114
+ # Replace 0's for division. Negatives should not happen
115
+ pscale = np.where(pscale > 0, pscale, 1.0) # (F, )
116
+ nscale = np.where(nscale > 0, nscale, 1.0) # (F, )
117
+
118
+ # Scales positive values by positive scale and negative values by negative
119
+ return np.abs(np.where(test_dev >= 0, test_dev / pscale, test_dev / nscale)) # (S_t, F)
120
+
121
+
122
+ def most_deviated_factors(
123
+ metadata_1: Metadata,
124
+ metadata_2: Metadata,
125
+ ood: OODOutput,
126
+ ) -> list[tuple[str, float]]:
127
+ """
128
+ Determines greatest deviation in metadata features per out of distribution sample in metadata_2.
129
+
130
+ Parameters
131
+ ----------
132
+ metadata_1 : Metadata
133
+ A reference set of Metadata containing factor names and samples
134
+ with discrete and/or continuous values per factor
135
+ metadata_2 : Metadata
136
+ The set of Metadata that is tested against the reference metadata.
137
+ This set must have the same number of features but does not require the same number of samples.
138
+ ood : OODOutput
139
+ A class output by the DataEval's OOD functions that contains which examples are OOD.
140
+
141
+ Returns
142
+ -------
143
+ list[tuple[str, float]]
144
+ An array of the factor name and deviation of the highest metadata deviation for each OOD example in metadata_2.
145
+
146
+ Notes
147
+ -----
148
+ 1. Both :class:`.Metadata` inputs must have discrete and continuous data in the shape (samples, factors)
149
+ and have equivalent factor names and lengths
150
+ 2. The flag at index `i` in :attr:`.OODOutput.is_ood` must correspond
151
+ directly to sample `i` of `metadata_2` being out-of-distribution from `metadata_1`
152
+
153
+ Examples
154
+ --------
155
+
156
+ >>> from dataeval.detectors.ood import OODOutput
157
+
158
+ All samples are out-of-distribution
159
+
160
+ >>> is_ood = OODOutput(np.array([True, True, True]), np.array([]), np.array([]))
161
+ >>> most_deviated_factors(metadata1, metadata2, is_ood)
162
+ [('time', 2.0), ('time', 2.592), ('time', 3.51)]
163
+
164
+ If there are no out-of-distribution samples, a list is returned
165
+
166
+ >>> is_ood = OODOutput(np.array([False, False, False]), np.array([]), np.array([]))
167
+ >>> most_deviated_factors(metadata1, metadata2, is_ood)
168
+ []
169
+ """
170
+
171
+ ood_mask: NDArray[np.bool] = ood.is_ood
172
+
173
+ # No metadata correlated with out of distribution data
174
+ if not any(ood_mask):
175
+ return []
176
+
177
+ # Combines reference and test factor names and data if exists and match exactly
178
+ # shape -> (samples, factors)
179
+ factor_names, md_1, md_2 = _combine_metadata(
180
+ metadata_1=metadata_1,
181
+ metadata_2=metadata_2,
182
+ )
183
+
184
+ # Stack discrete and continuous factors as separate factors. Must have equal sample counts
185
+ metadata_ref = np.hstack(md_1) if md_1 else np.array([])
186
+ metadata_tst = np.hstack(md_2) if md_2 else np.array([])
187
+
188
+ if len(metadata_ref) < 3:
189
+ warnings.warn(
190
+ f"At least 3 reference metadata samples are needed, got {len(metadata_ref)}",
191
+ UserWarning,
192
+ )
193
+ return []
194
+
195
+ if len(metadata_tst) != len(ood_mask):
196
+ raise ValueError(
197
+ f"ood and test metadata must have the same length, "
198
+ f"got {len(ood_mask)} and {len(metadata_tst)} respectively."
199
+ )
200
+
201
+ # Calculates deviations of all samples in m2_data
202
+ # from the median values of the corresponding index in m1_data
203
+ # Guaranteed for inputs to not be empty
204
+ deviations = _calc_median_deviations(metadata_ref, metadata_tst)
205
+
206
+ # Get most impactful factor deviation of each sample for ood samples only
207
+ deviation = np.max(deviations, axis=1)[ood_mask].astype(np.float16)
208
+
209
+ # Get indices of most impactful factors for ood samples only
210
+ max_factors = np.argmax(deviations, axis=1)[ood_mask]
211
+
212
+ # Get names of most impactful factors TODO: Find better way than np.dtype(<U4)
213
+ most_ood_factors = np.array(factor_names)[max_factors].tolist()
214
+
215
+ # List of tuples matching the factor name with its deviation
216
+
217
+ return [(factor, dev) for factor, dev in zip(most_ood_factors, deviation)]
@@ -0,0 +1,44 @@
1
+ __all__ = []
2
+
3
+ from numpy.typing import NDArray
4
+
5
+
6
+ def _compare_keys(keys1: list[str], keys2: list[str]) -> None:
7
+ """
8
+ Raises error when two lists are not equivalent including ordering
9
+
10
+ Parameters
11
+ ----------
12
+ keys1 : list of strings
13
+ List of strings to compare
14
+ keys2 : list of strings
15
+ List of strings to compare
16
+
17
+ Raises
18
+ ------
19
+ ValueError
20
+ If lists do not have the same values, value counts, or ordering
21
+ """
22
+
23
+ if keys1 != keys2:
24
+ raise ValueError(f"Metadata keys must be identical, got {keys1} and {keys2}")
25
+
26
+
27
+ def _validate_factors_and_data(factors: list[str], data: NDArray) -> None:
28
+ """
29
+ Raises error when the number of factors and number of rows do not match
30
+
31
+ Parameters
32
+ ----------
33
+ factors : list of strings
34
+ List of factor names of size N
35
+ data : NDArray
36
+ Array of values with shape (M, N)
37
+
38
+ Raises
39
+ ------
40
+ ValueError
41
+ If the length of factors does not equal the length of the transposed data
42
+ """
43
+ if len(factors) != len(data.T):
44
+ raise ValueError(f"Factors and data have mismatched lengths. Got {len(factors)} and {len(data.T)}")
@@ -5,4 +5,4 @@ can then be analyzed in the context of a given problem.
5
5
 
6
6
  __all__ = ["bias", "estimators", "stats"]
7
7
 
8
- from dataeval.metrics import bias, estimators, stats
8
+ from . import bias, estimators, stats
@@ -7,6 +7,7 @@ __all__ = [
7
7
  "BalanceOutput",
8
8
  "CoverageOutput",
9
9
  "DiversityOutput",
10
+ "LabelParityOutput",
10
11
  "ParityOutput",
11
12
  "balance",
12
13
  "coverage",
@@ -15,7 +16,8 @@ __all__ = [
15
16
  "parity",
16
17
  ]
17
18
 
18
- from dataeval.metrics.bias.balance import BalanceOutput, balance
19
- from dataeval.metrics.bias.coverage import CoverageOutput, coverage
20
- from dataeval.metrics.bias.diversity import DiversityOutput, diversity
21
- from dataeval.metrics.bias.parity import ParityOutput, label_parity, parity
19
+ from dataeval.metrics.bias._balance import balance
20
+ from dataeval.metrics.bias._coverage import coverage
21
+ from dataeval.metrics.bias._diversity import diversity
22
+ from dataeval.metrics.bias._parity import label_parity, parity
23
+ from dataeval.outputs._bias import BalanceOutput, CoverageOutput, DiversityOutput, LabelParityOutput, ParityOutput
@@ -2,99 +2,16 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- import contextlib
6
5
  import warnings
7
- from dataclasses import dataclass
8
- from typing import Any
9
6
 
10
7
  import numpy as np
11
8
  import scipy as sp
12
- from numpy.typing import NDArray
13
9
  from sklearn.feature_selection import mutual_info_classif, mutual_info_regression
14
10
 
15
- from dataeval.output import Output, set_metadata
16
- from dataeval.utils.metadata import Metadata, get_counts
17
- from dataeval.utils.plot import heatmap
18
-
19
- with contextlib.suppress(ImportError):
20
- from matplotlib.figure import Figure
21
-
22
-
23
- @dataclass(frozen=True)
24
- class BalanceOutput(Output):
25
- """
26
- Output class for :func:`balance` :term:`bias<Bias>` metric.
27
-
28
- Attributes
29
- ----------
30
- balance : NDArray[np.float64]
31
- Estimate of mutual information between metadata factors and class label
32
- factors : NDArray[np.float64]
33
- Estimate of inter/intra-factor mutual information
34
- classwise : NDArray[np.float64]
35
- Estimate of mutual information between metadata factors and individual class labels
36
- factor_names : list[str]
37
- Names of each metadata factor
38
- class_list : NDArray
39
- Array of the class labels present in the dataset
40
- """
41
-
42
- balance: NDArray[np.float64]
43
- factors: NDArray[np.float64]
44
- classwise: NDArray[np.float64]
45
- factor_names: list[str]
46
- class_list: NDArray[Any]
47
-
48
- def plot(
49
- self,
50
- row_labels: list[Any] | NDArray[Any] | None = None,
51
- col_labels: list[Any] | NDArray[Any] | None = None,
52
- plot_classwise: bool = False,
53
- ) -> Figure:
54
- """
55
- Plot a heatmap of balance information
56
-
57
- Parameters
58
- ----------
59
- row_labels : ArrayLike or None, default None
60
- List/Array containing the labels for rows in the histogram
61
- col_labels : ArrayLike or None, default None
62
- List/Array containing the labels for columns in the histogram
63
- plot_classwise : bool, default False
64
- Whether to plot per-class balance instead of global balance
65
- """
66
- if plot_classwise:
67
- if row_labels is None:
68
- row_labels = self.class_list
69
- if col_labels is None:
70
- col_labels = self.factor_names
71
-
72
- fig = heatmap(
73
- self.classwise,
74
- row_labels,
75
- col_labels,
76
- xlabel="Factors",
77
- ylabel="Class",
78
- cbarlabel="Normalized Mutual Information",
79
- )
80
- else:
81
- # Combine balance and factors results
82
- data = np.concatenate([self.balance[np.newaxis, 1:], self.factors], axis=0)
83
- # Create a mask for the upper triangle of the symmetrical array, ignoring the diagonal
84
- mask = np.triu(data + 1, k=0) < 1
85
- # Finalize the data for the plot, last row is last factor x last factor so it gets dropped
86
- heat_data = np.where(mask, np.nan, data)[:-1]
87
- # Creating label array for heat map axes
88
- heat_labels = self.factor_names
89
-
90
- if row_labels is None:
91
- row_labels = heat_labels[:-1]
92
- if col_labels is None:
93
- col_labels = heat_labels[1:]
94
-
95
- fig = heatmap(heat_data, row_labels, col_labels, cbarlabel="Normalized Mutual Information")
96
-
97
- return fig
11
+ from dataeval.outputs import BalanceOutput
12
+ from dataeval.outputs._base import set_metadata
13
+ from dataeval.utils._bin import get_counts
14
+ from dataeval.utils.data import Metadata
98
15
 
99
16
 
100
17
  def _validate_num_neighbors(num_neighbors: int) -> int:
@@ -128,7 +45,7 @@ def balance(
128
45
  Parameters
129
46
  ----------
130
47
  metadata : Metadata
131
- Preprocessed metadata from :func:`dataeval.utils.metadata.preprocess`
48
+ Preprocessed metadata
132
49
  num_neighbors : int, default 5
133
50
  Number of points to consider as neighbors
134
51
 
@@ -150,25 +67,22 @@ def balance(
150
67
 
151
68
  >>> bal = balance(metadata)
152
69
  >>> bal.balance
153
- array([0.9999982 , 0.2494567 , 0.02994455, 0.13363788, 0. ,
154
- 0. ])
70
+ array([1. , 0.249, 0.03 , 0.134, 0. , 0. ])
155
71
 
156
72
  Return intra/interfactor balance (mutual information)
157
73
 
158
74
  >>> bal.factors
159
- array([[0.99999935, 0.31360499, 0.26925848, 0.85201924, 0.36653548],
160
- [0.31360499, 0.99999856, 0.09725766, 0.15836905, 1.98031993],
161
- [0.26925848, 0.09725766, 0.99999846, 0.03713108, 0.01544656],
162
- [0.85201924, 0.15836905, 0.03713108, 0.47450653, 0.25509664],
163
- [0.36653548, 1.98031993, 0.01544656, 0.25509664, 1.06260686]])
75
+ array([[1. , 0.314, 0.269, 0.852, 0.367],
76
+ [0.314, 1. , 0.097, 0.158, 1.98 ],
77
+ [0.269, 0.097, 1. , 0.037, 0.015],
78
+ [0.852, 0.158, 0.037, 0.475, 0.255],
79
+ [0.367, 1.98 , 0.015, 0.255, 1.063]])
164
80
 
165
81
  Return classwise balance (mutual information) of factors with individual class_labels
166
82
 
167
83
  >>> bal.classwise
168
- array([[0.9999982 , 0.2494567 , 0.02994455, 0.13363788, 0. ,
169
- 0. ],
170
- [0.9999982 , 0.2494567 , 0.02994455, 0.13363788, 0. ,
171
- 0. ]])
84
+ array([[1. , 0.249, 0.03 , 0.134, 0. , 0. ],
85
+ [1. , 0.249, 0.03 , 0.134, 0. , 0. ]])
172
86
 
173
87
 
174
88
  See Also
@@ -184,7 +98,7 @@ def balance(
184
98
  mi = np.full((num_factors, num_factors), np.nan, dtype=np.float32)
185
99
  data = np.hstack((metadata.class_labels[:, np.newaxis], metadata.discrete_data))
186
100
  discretized_data = data
187
- if metadata.continuous_data is not None:
101
+ if len(metadata.continuous_data):
188
102
  data = np.hstack((data, metadata.continuous_data))
189
103
  discrete_idx = [metadata.discrete_factor_names.index(name) for name in metadata.continuous_factor_names]
190
104
  discretized_data = np.hstack((discretized_data, metadata.discrete_data[:, discrete_idx]))
@@ -218,7 +132,7 @@ def balance(
218
132
  factors = nmi[1:, 1:]
219
133
 
220
134
  # assume class is a factor
221
- num_classes = metadata.class_names.size
135
+ num_classes = len(metadata.class_names)
222
136
  classwise_mi = np.full((num_classes, num_factors), np.nan, dtype=np.float32)
223
137
 
224
138
  # classwise targets