dataeval 0.76.1__py3-none-any.whl → 0.81.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +3 -3
- dataeval/{output.py → _output.py} +14 -0
- dataeval/config.py +77 -0
- dataeval/detectors/__init__.py +1 -1
- dataeval/detectors/drift/__init__.py +6 -6
- dataeval/detectors/drift/{base.py → _base.py} +41 -30
- dataeval/detectors/drift/{cvm.py → _cvm.py} +21 -28
- dataeval/detectors/drift/{ks.py → _ks.py} +20 -26
- dataeval/detectors/drift/{mmd.py → _mmd.py} +33 -19
- dataeval/detectors/drift/{torch.py → _torch.py} +2 -1
- dataeval/detectors/drift/{uncertainty.py → _uncertainty.py} +23 -7
- dataeval/detectors/drift/updates.py +1 -1
- dataeval/detectors/linters/__init__.py +0 -3
- dataeval/detectors/linters/duplicates.py +17 -8
- dataeval/detectors/linters/outliers.py +23 -14
- dataeval/detectors/ood/ae.py +29 -8
- dataeval/detectors/ood/base.py +5 -4
- dataeval/detectors/ood/metadata_ks_compare.py +1 -1
- dataeval/detectors/ood/mixin.py +20 -5
- dataeval/detectors/ood/output.py +1 -1
- dataeval/detectors/ood/vae.py +73 -0
- dataeval/metadata/__init__.py +5 -0
- dataeval/metadata/_ood.py +238 -0
- dataeval/metrics/__init__.py +1 -1
- dataeval/metrics/bias/__init__.py +5 -4
- dataeval/metrics/bias/{balance.py → _balance.py} +67 -17
- dataeval/metrics/bias/{coverage.py → _coverage.py} +41 -35
- dataeval/metrics/bias/{diversity.py → _diversity.py} +17 -12
- dataeval/metrics/bias/{parity.py → _parity.py} +89 -61
- dataeval/metrics/estimators/__init__.py +14 -4
- dataeval/metrics/estimators/{ber.py → _ber.py} +42 -11
- dataeval/metrics/estimators/_clusterer.py +104 -0
- dataeval/metrics/estimators/{divergence.py → _divergence.py} +18 -13
- dataeval/metrics/estimators/{uap.py → _uap.py} +4 -4
- dataeval/metrics/stats/__init__.py +7 -7
- dataeval/metrics/stats/{base.py → _base.py} +52 -16
- dataeval/metrics/stats/{boxratiostats.py → _boxratiostats.py} +6 -9
- dataeval/metrics/stats/{datasetstats.py → _datasetstats.py} +10 -14
- dataeval/metrics/stats/{dimensionstats.py → _dimensionstats.py} +6 -5
- dataeval/metrics/stats/{hashstats.py → _hashstats.py} +6 -6
- dataeval/metrics/stats/{labelstats.py → _labelstats.py} +4 -4
- dataeval/metrics/stats/{pixelstats.py → _pixelstats.py} +5 -4
- dataeval/metrics/stats/{visualstats.py → _visualstats.py} +9 -8
- dataeval/typing.py +54 -0
- dataeval/utils/__init__.py +2 -2
- dataeval/utils/_array.py +169 -0
- dataeval/utils/_bin.py +199 -0
- dataeval/utils/_clusterer.py +144 -0
- dataeval/utils/_fast_mst.py +189 -0
- dataeval/utils/{image.py → _image.py} +6 -4
- dataeval/utils/_method.py +18 -0
- dataeval/utils/{shared.py → _mst.py} +3 -65
- dataeval/utils/{plot.py → _plot.py} +4 -4
- dataeval/utils/data/__init__.py +22 -0
- dataeval/utils/data/_embeddings.py +105 -0
- dataeval/utils/data/_images.py +65 -0
- dataeval/utils/data/_metadata.py +352 -0
- dataeval/utils/data/_selection.py +119 -0
- dataeval/utils/{dataset/split.py → data/_split.py} +13 -14
- dataeval/utils/data/_targets.py +73 -0
- dataeval/utils/data/_types.py +58 -0
- dataeval/utils/data/collate.py +103 -0
- dataeval/utils/data/datasets/__init__.py +17 -0
- dataeval/utils/data/datasets/_base.py +254 -0
- dataeval/utils/data/datasets/_cifar10.py +134 -0
- dataeval/utils/data/datasets/_fileio.py +168 -0
- dataeval/utils/data/datasets/_milco.py +153 -0
- dataeval/utils/data/datasets/_mixin.py +56 -0
- dataeval/utils/data/datasets/_mnist.py +183 -0
- dataeval/utils/data/datasets/_ships.py +123 -0
- dataeval/utils/data/datasets/_voc.py +352 -0
- dataeval/utils/data/selections/__init__.py +15 -0
- dataeval/utils/data/selections/_classfilter.py +60 -0
- dataeval/utils/data/selections/_indices.py +26 -0
- dataeval/utils/data/selections/_limit.py +26 -0
- dataeval/utils/data/selections/_reverse.py +18 -0
- dataeval/utils/data/selections/_shuffle.py +29 -0
- dataeval/utils/metadata.py +51 -376
- dataeval/utils/torch/{gmm.py → _gmm.py} +4 -2
- dataeval/utils/torch/{internal.py → _internal.py} +21 -51
- dataeval/utils/torch/models.py +43 -2
- dataeval/workflows/sufficiency.py +10 -9
- {dataeval-0.76.1.dist-info → dataeval-0.81.0.dist-info}/METADATA +4 -1
- dataeval-0.81.0.dist-info/RECORD +94 -0
- dataeval/detectors/linters/clusterer.py +0 -512
- dataeval/detectors/linters/merged_stats.py +0 -49
- dataeval/detectors/ood/metadata_least_likely.py +0 -119
- dataeval/interop.py +0 -69
- dataeval/utils/dataset/__init__.py +0 -7
- dataeval/utils/dataset/datasets.py +0 -412
- dataeval/utils/dataset/read.py +0 -63
- dataeval-0.76.1.dist-info/RECORD +0 -67
- /dataeval/{log.py → _log.py} +0 -0
- /dataeval/utils/torch/{blocks.py → _blocks.py} +0 -0
- {dataeval-0.76.1.dist-info → dataeval-0.81.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.76.1.dist-info → dataeval-0.81.0.dist-info}/WHEEL +0 -0
@@ -5,16 +5,17 @@ __all__ = []
|
|
5
5
|
import contextlib
|
6
6
|
import warnings
|
7
7
|
from dataclasses import dataclass
|
8
|
-
from typing import Any
|
8
|
+
from typing import Any, Literal, overload
|
9
9
|
|
10
10
|
import numpy as np
|
11
11
|
import scipy as sp
|
12
12
|
from numpy.typing import NDArray
|
13
13
|
from sklearn.feature_selection import mutual_info_classif, mutual_info_regression
|
14
14
|
|
15
|
-
from dataeval.
|
16
|
-
from dataeval.utils.
|
17
|
-
from dataeval.utils.
|
15
|
+
from dataeval._output import Output, set_metadata
|
16
|
+
from dataeval.utils._bin import get_counts
|
17
|
+
from dataeval.utils._plot import heatmap
|
18
|
+
from dataeval.utils.data import Metadata
|
18
19
|
|
19
20
|
with contextlib.suppress(ImportError):
|
20
21
|
from matplotlib.figure import Figure
|
@@ -23,8 +24,8 @@ with contextlib.suppress(ImportError):
|
|
23
24
|
@dataclass(frozen=True)
|
24
25
|
class BalanceOutput(Output):
|
25
26
|
"""
|
26
|
-
Output class for :func
|
27
|
-
|
27
|
+
Output class for :func:`.balance` :term:`bias<Bias>` metric.
|
28
|
+
|
28
29
|
Attributes
|
29
30
|
----------
|
30
31
|
balance : NDArray[np.float64]
|
@@ -35,21 +36,62 @@ class BalanceOutput(Output):
|
|
35
36
|
Estimate of mutual information between metadata factors and individual class labels
|
36
37
|
factor_names : list[str]
|
37
38
|
Names of each metadata factor
|
38
|
-
|
39
|
-
|
39
|
+
class_names : list[str]
|
40
|
+
List of the class labels present in the dataset
|
40
41
|
"""
|
41
42
|
|
42
43
|
balance: NDArray[np.float64]
|
43
44
|
factors: NDArray[np.float64]
|
44
45
|
classwise: NDArray[np.float64]
|
45
46
|
factor_names: list[str]
|
46
|
-
|
47
|
+
class_names: list[str]
|
48
|
+
|
49
|
+
@overload
|
50
|
+
def _by_factor_type(
|
51
|
+
self,
|
52
|
+
attr: Literal["factor_names"],
|
53
|
+
factor_type: Literal["discrete", "continuous", "both"],
|
54
|
+
) -> list[str]: ...
|
55
|
+
|
56
|
+
@overload
|
57
|
+
def _by_factor_type(
|
58
|
+
self,
|
59
|
+
attr: Literal["balance", "factors", "classwise"],
|
60
|
+
factor_type: Literal["discrete", "continuous", "both"],
|
61
|
+
) -> NDArray[np.float64]: ...
|
62
|
+
|
63
|
+
def _by_factor_type(
|
64
|
+
self,
|
65
|
+
attr: Literal["balance", "factors", "classwise", "factor_names"],
|
66
|
+
factor_type: Literal["discrete", "continuous", "both"],
|
67
|
+
) -> NDArray[np.float64] | list[str]:
|
68
|
+
# if not filtering by factor_type then just return the requested attribute without mask
|
69
|
+
if factor_type == "both":
|
70
|
+
return getattr(self, attr)
|
71
|
+
|
72
|
+
# create the mask for the selected factor_type
|
73
|
+
mask_lambda = (
|
74
|
+
(lambda x: "-continuous" not in x) if factor_type == "discrete" else (lambda x: "-discrete" not in x)
|
75
|
+
)
|
76
|
+
|
77
|
+
# return the masked attribute
|
78
|
+
if attr == "factor_names":
|
79
|
+
return [x.replace(f"-{factor_type}", "") for x in self.factor_names if mask_lambda(x)]
|
80
|
+
else:
|
81
|
+
factor_type_mask = [mask_lambda(x) for x in self.factor_names]
|
82
|
+
if attr == "factors":
|
83
|
+
return self.factors[factor_type_mask[1:]][:, factor_type_mask[1:]]
|
84
|
+
elif attr == "balance":
|
85
|
+
return self.balance[factor_type_mask]
|
86
|
+
elif attr == "classwise":
|
87
|
+
return self.classwise[:, factor_type_mask]
|
47
88
|
|
48
89
|
def plot(
|
49
90
|
self,
|
50
91
|
row_labels: list[Any] | NDArray[Any] | None = None,
|
51
92
|
col_labels: list[Any] | NDArray[Any] | None = None,
|
52
93
|
plot_classwise: bool = False,
|
94
|
+
factor_type: Literal["discrete", "continuous", "both"] = "discrete",
|
53
95
|
) -> Figure:
|
54
96
|
"""
|
55
97
|
Plot a heatmap of balance information
|
@@ -62,15 +104,17 @@ class BalanceOutput(Output):
|
|
62
104
|
List/Array containing the labels for columns in the histogram
|
63
105
|
plot_classwise : bool, default False
|
64
106
|
Whether to plot per-class balance instead of global balance
|
107
|
+
factor_type : "discrete", "continuous", or "both", default "discrete"
|
108
|
+
Whether to plot discretized values, continuous values, or to include both
|
65
109
|
"""
|
66
110
|
if plot_classwise:
|
67
111
|
if row_labels is None:
|
68
|
-
row_labels = self.
|
112
|
+
row_labels = self.class_names
|
69
113
|
if col_labels is None:
|
70
|
-
col_labels = self.factor_names
|
114
|
+
col_labels = self._by_factor_type("factor_names", factor_type)
|
71
115
|
|
72
116
|
fig = heatmap(
|
73
|
-
self.classwise,
|
117
|
+
self._by_factor_type("classwise", factor_type),
|
74
118
|
row_labels,
|
75
119
|
col_labels,
|
76
120
|
xlabel="Factors",
|
@@ -79,13 +123,19 @@ class BalanceOutput(Output):
|
|
79
123
|
)
|
80
124
|
else:
|
81
125
|
# Combine balance and factors results
|
82
|
-
data = np.concatenate(
|
126
|
+
data = np.concatenate(
|
127
|
+
[
|
128
|
+
self._by_factor_type("balance", factor_type)[np.newaxis, 1:],
|
129
|
+
self._by_factor_type("factors", factor_type),
|
130
|
+
],
|
131
|
+
axis=0,
|
132
|
+
)
|
83
133
|
# Create a mask for the upper triangle of the symmetrical array, ignoring the diagonal
|
84
134
|
mask = np.triu(data + 1, k=0) < 1
|
85
135
|
# Finalize the data for the plot, last row is last factor x last factor so it gets dropped
|
86
136
|
heat_data = np.where(mask, np.nan, data)[:-1]
|
87
137
|
# Creating label array for heat map axes
|
88
|
-
heat_labels = self.factor_names
|
138
|
+
heat_labels = self._by_factor_type("factor_names", factor_type)
|
89
139
|
|
90
140
|
if row_labels is None:
|
91
141
|
row_labels = heat_labels[:-1]
|
@@ -128,7 +178,7 @@ def balance(
|
|
128
178
|
Parameters
|
129
179
|
----------
|
130
180
|
metadata : Metadata
|
131
|
-
Preprocessed metadata
|
181
|
+
Preprocessed metadata
|
132
182
|
num_neighbors : int, default 5
|
133
183
|
Number of points to consider as neighbors
|
134
184
|
|
@@ -184,7 +234,7 @@ def balance(
|
|
184
234
|
mi = np.full((num_factors, num_factors), np.nan, dtype=np.float32)
|
185
235
|
data = np.hstack((metadata.class_labels[:, np.newaxis], metadata.discrete_data))
|
186
236
|
discretized_data = data
|
187
|
-
if metadata.continuous_data
|
237
|
+
if len(metadata.continuous_data):
|
188
238
|
data = np.hstack((data, metadata.continuous_data))
|
189
239
|
discrete_idx = [metadata.discrete_factor_names.index(name) for name in metadata.continuous_factor_names]
|
190
240
|
discretized_data = np.hstack((discretized_data, metadata.discrete_data[:, discrete_idx]))
|
@@ -218,7 +268,7 @@ def balance(
|
|
218
268
|
factors = nmi[1:, 1:]
|
219
269
|
|
220
270
|
# assume class is a factor
|
221
|
-
num_classes = metadata.class_names
|
271
|
+
num_classes = len(metadata.class_names)
|
222
272
|
classwise_mi = np.full((num_classes, num_factors), np.nan, dtype=np.float32)
|
223
273
|
|
224
274
|
# classwise targets
|
@@ -8,12 +8,12 @@ from dataclasses import dataclass
|
|
8
8
|
from typing import Any, Literal
|
9
9
|
|
10
10
|
import numpy as np
|
11
|
-
from numpy.typing import
|
11
|
+
from numpy.typing import NDArray
|
12
12
|
from scipy.spatial.distance import pdist, squareform
|
13
13
|
|
14
|
-
from dataeval.
|
15
|
-
from dataeval.
|
16
|
-
from dataeval.utils.
|
14
|
+
from dataeval._output import Output, set_metadata
|
15
|
+
from dataeval.typing import ArrayLike
|
16
|
+
from dataeval.utils._array import ensure_embeddings, flatten, to_numpy
|
17
17
|
|
18
18
|
with contextlib.suppress(ImportError):
|
19
19
|
from matplotlib.figure import Figure
|
@@ -71,21 +71,21 @@ def _plot(images: NDArray[Any], num_images: int) -> Figure:
|
|
71
71
|
@dataclass(frozen=True)
|
72
72
|
class CoverageOutput(Output):
|
73
73
|
"""
|
74
|
-
Output class for :func
|
74
|
+
Output class for :func:`.coverage` :term:`bias<Bias>` metric.
|
75
75
|
|
76
76
|
Attributes
|
77
77
|
----------
|
78
|
-
|
78
|
+
uncovered_indices : NDArray[np.intp]
|
79
79
|
Array of uncovered indices
|
80
|
-
|
80
|
+
critical_value_radii : NDArray[np.float64]
|
81
81
|
Array of critical value radii
|
82
|
-
|
82
|
+
coverage_radius : float
|
83
83
|
Radius for :term:`coverage<Coverage>`
|
84
84
|
"""
|
85
85
|
|
86
|
-
|
87
|
-
|
88
|
-
|
86
|
+
uncovered_indices: NDArray[np.intp]
|
87
|
+
critical_value_radii: NDArray[np.float64]
|
88
|
+
coverage_radius: float
|
89
89
|
|
90
90
|
def plot(self, images: ArrayLike, top_k: int = 6) -> Figure:
|
91
91
|
"""
|
@@ -102,8 +102,9 @@ class CoverageOutput(Output):
|
|
102
102
|
-------
|
103
103
|
matplotlib.figure.Figure
|
104
104
|
"""
|
105
|
+
|
105
106
|
# Determine which images to plot
|
106
|
-
highest_uncovered_indices = self.
|
107
|
+
highest_uncovered_indices = self.uncovered_indices[:top_k]
|
107
108
|
|
108
109
|
# Grab the images
|
109
110
|
images = to_numpy(images)
|
@@ -119,7 +120,7 @@ class CoverageOutput(Output):
|
|
119
120
|
def coverage(
|
120
121
|
embeddings: ArrayLike,
|
121
122
|
radius_type: Literal["adaptive", "naive"] = "adaptive",
|
122
|
-
|
123
|
+
num_observations: int = 20,
|
123
124
|
percent: float = 0.01,
|
124
125
|
) -> CoverageOutput:
|
125
126
|
"""
|
@@ -128,11 +129,11 @@ def coverage(
|
|
128
129
|
Parameters
|
129
130
|
----------
|
130
131
|
embeddings : ArrayLike, shape - (N, P)
|
131
|
-
|
132
|
-
Function expects the data to have 2 dimensions, N number of observations in a P-
|
132
|
+
Dataset embeddings as unit interval [0, 1].
|
133
|
+
Function expects the data to have 2 dimensions, N number of observations in a P-dimensional space.
|
133
134
|
radius_type : {"adaptive", "naive"}, default "adaptive"
|
134
135
|
The function used to determine radius.
|
135
|
-
|
136
|
+
num_observations : int, default 20
|
136
137
|
Number of observations required in order to be covered.
|
137
138
|
[1] suggests that a minimum of 20-50 samples is necessary.
|
138
139
|
percent : float, default 0.01
|
@@ -146,7 +147,9 @@ def coverage(
|
|
146
147
|
Raises
|
147
148
|
------
|
148
149
|
ValueError
|
149
|
-
If
|
150
|
+
If embeddings are not unit interval [0-1]
|
151
|
+
ValueError
|
152
|
+
If length of :term:`embeddings<Embeddings>` is less than or equal to num_observations
|
150
153
|
ValueError
|
151
154
|
If radius_type is unknown
|
152
155
|
|
@@ -157,10 +160,10 @@ def coverage(
|
|
157
160
|
Example
|
158
161
|
-------
|
159
162
|
>>> results = coverage(embeddings)
|
160
|
-
>>> results.
|
163
|
+
>>> results.uncovered_indices
|
161
164
|
array([447, 412, 8, 32, 63])
|
162
|
-
>>> results.
|
163
|
-
0.
|
165
|
+
>>> results.coverage_radius
|
166
|
+
0.17592147193757596
|
164
167
|
|
165
168
|
Reference
|
166
169
|
---------
|
@@ -169,26 +172,29 @@ def coverage(
|
|
169
172
|
[1] Seymour Sudman. 1976. Applied sampling. Academic Press New York (1976).
|
170
173
|
"""
|
171
174
|
|
172
|
-
# Calculate distance matrix, look at the (
|
173
|
-
embeddings =
|
174
|
-
|
175
|
-
if
|
175
|
+
# Calculate distance matrix, look at the (num_observations + 1)th farthest neighbor for each image.
|
176
|
+
embeddings = ensure_embeddings(embeddings, dtype=np.float64, unit_interval=True)
|
177
|
+
len_embeddings = len(embeddings)
|
178
|
+
if len_embeddings <= num_observations:
|
176
179
|
raise ValueError(
|
177
|
-
f"
|
180
|
+
f"Length of embeddings ({len_embeddings}) is less than or equal to the specified number of \
|
181
|
+
observations ({num_observations})."
|
178
182
|
)
|
179
|
-
|
180
|
-
sorted_dists = np.sort(
|
181
|
-
|
183
|
+
embeddings_matrix = squareform(pdist(flatten(embeddings))).astype(np.float64)
|
184
|
+
sorted_dists = np.sort(embeddings_matrix, axis=1)
|
185
|
+
critical_value_radii = sorted_dists[:, num_observations + 1]
|
182
186
|
|
183
187
|
d = embeddings.shape[1]
|
184
188
|
if radius_type == "naive":
|
185
|
-
|
186
|
-
|
189
|
+
coverage_radius = (1 / math.sqrt(math.pi)) * (
|
190
|
+
(2 * num_observations * math.gamma(d / 2 + 1)) / (len_embeddings)
|
191
|
+
) ** (1 / d)
|
192
|
+
uncovered_indices = np.where(critical_value_radii > coverage_radius)[0]
|
187
193
|
elif radius_type == "adaptive":
|
188
|
-
# Use data adaptive cutoff as
|
189
|
-
selection = int(max(
|
190
|
-
|
191
|
-
|
194
|
+
# Use data adaptive cutoff as coverage_radius
|
195
|
+
selection = int(max(len_embeddings * percent, 1))
|
196
|
+
uncovered_indices = np.argsort(critical_value_radii)[::-1][:selection]
|
197
|
+
coverage_radius = float(np.mean(np.sort(critical_value_radii)[::-1][selection - 1 : selection + 1]))
|
192
198
|
else:
|
193
199
|
raise ValueError(f"{radius_type} is an invalid radius type. Expected 'adaptive' or 'naive'")
|
194
|
-
return CoverageOutput(
|
200
|
+
return CoverageOutput(uncovered_indices, critical_value_radii, coverage_radius)
|
@@ -8,12 +8,14 @@ from typing import Any, Literal
|
|
8
8
|
|
9
9
|
import numpy as np
|
10
10
|
import scipy as sp
|
11
|
-
from numpy.typing import
|
11
|
+
from numpy.typing import NDArray
|
12
12
|
|
13
|
-
from dataeval.
|
14
|
-
from dataeval.
|
15
|
-
from dataeval.utils.
|
16
|
-
from dataeval.utils.
|
13
|
+
from dataeval._output import Output, set_metadata
|
14
|
+
from dataeval.typing import ArrayLike
|
15
|
+
from dataeval.utils._bin import get_counts
|
16
|
+
from dataeval.utils._method import get_method
|
17
|
+
from dataeval.utils._plot import heatmap
|
18
|
+
from dataeval.utils.data import Metadata
|
17
19
|
|
18
20
|
with contextlib.suppress(ImportError):
|
19
21
|
from matplotlib.figure import Figure
|
@@ -37,7 +39,7 @@ def _plot(labels: NDArray[Any], bar_heights: NDArray[Any]) -> Figure:
|
|
37
39
|
"""
|
38
40
|
import matplotlib.pyplot as plt
|
39
41
|
|
40
|
-
fig, ax = plt.subplots(figsize=(
|
42
|
+
fig, ax = plt.subplots(figsize=(8, 8))
|
41
43
|
|
42
44
|
ax.bar(labels, bar_heights)
|
43
45
|
ax.set_xlabel("Factors")
|
@@ -51,7 +53,7 @@ def _plot(labels: NDArray[Any], bar_heights: NDArray[Any]) -> Figure:
|
|
51
53
|
@dataclass(frozen=True)
|
52
54
|
class DiversityOutput(Output):
|
53
55
|
"""
|
54
|
-
Output class for :func
|
56
|
+
Output class for :func:`.diversity` :term:`bias<Bias>` metric.
|
55
57
|
|
56
58
|
Attributes
|
57
59
|
----------
|
@@ -61,14 +63,14 @@ class DiversityOutput(Output):
|
|
61
63
|
Classwise diversity index [n_class x n_factor]
|
62
64
|
factor_names : list[str]
|
63
65
|
Names of each metadata factor
|
64
|
-
|
66
|
+
class_names : list[str]
|
65
67
|
Class labels for each value in the dataset
|
66
68
|
"""
|
67
69
|
|
68
70
|
diversity_index: NDArray[np.double]
|
69
71
|
classwise: NDArray[np.double]
|
70
72
|
factor_names: list[str]
|
71
|
-
|
73
|
+
class_names: list[str]
|
72
74
|
|
73
75
|
def plot(
|
74
76
|
self,
|
@@ -90,7 +92,7 @@ class DiversityOutput(Output):
|
|
90
92
|
"""
|
91
93
|
if plot_classwise:
|
92
94
|
if row_labels is None:
|
93
|
-
row_labels = self.
|
95
|
+
row_labels = self.class_names
|
94
96
|
if col_labels is None:
|
95
97
|
col_labels = self.factor_names
|
96
98
|
|
@@ -191,6 +193,9 @@ def diversity_simpson(
|
|
191
193
|
return ev_index
|
192
194
|
|
193
195
|
|
196
|
+
_DIVERSITY_FN_MAP = {"simpson": diversity_simpson, "shannon": diversity_shannon}
|
197
|
+
|
198
|
+
|
194
199
|
@set_metadata
|
195
200
|
def diversity(
|
196
201
|
metadata: Metadata,
|
@@ -210,7 +215,7 @@ def diversity(
|
|
210
215
|
Parameters
|
211
216
|
----------
|
212
217
|
metadata : Metadata
|
213
|
-
Preprocessed metadata
|
218
|
+
Preprocessed metadata
|
214
219
|
method : "simpson" or "shannon", default "simpson"
|
215
220
|
The methodology used for defining diversity
|
216
221
|
|
@@ -251,7 +256,7 @@ def diversity(
|
|
251
256
|
--------
|
252
257
|
scipy.stats.entropy
|
253
258
|
"""
|
254
|
-
diversity_fn = get_method(
|
259
|
+
diversity_fn = get_method(_DIVERSITY_FN_MAP, method)
|
255
260
|
discretized_data = np.hstack((metadata.class_labels[:, np.newaxis], metadata.discrete_data))
|
256
261
|
cnts = get_counts(discretized_data)
|
257
262
|
num_bins = np.bincount(np.nonzero(cnts)[1])
|
@@ -2,40 +2,86 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
+
import contextlib
|
5
6
|
import warnings
|
6
7
|
from dataclasses import dataclass
|
7
8
|
from typing import Any, Generic, TypeVar
|
8
9
|
|
9
10
|
import numpy as np
|
10
|
-
from numpy.typing import
|
11
|
+
from numpy.typing import NDArray
|
11
12
|
from scipy.stats import chisquare
|
12
13
|
from scipy.stats.contingency import chi2_contingency, crosstab
|
13
14
|
|
14
|
-
from dataeval.
|
15
|
-
from dataeval.
|
16
|
-
from dataeval.utils.
|
15
|
+
from dataeval._output import Output, set_metadata
|
16
|
+
from dataeval.typing import ArrayLike
|
17
|
+
from dataeval.utils._array import as_numpy
|
18
|
+
from dataeval.utils.data import Metadata
|
19
|
+
|
20
|
+
with contextlib.suppress(ImportError):
|
21
|
+
import pandas as pd
|
17
22
|
|
18
23
|
TData = TypeVar("TData", np.float64, NDArray[np.float64])
|
19
24
|
|
20
25
|
|
21
26
|
@dataclass(frozen=True)
|
22
|
-
class
|
27
|
+
class BaseParityOutput(Generic[TData], Output):
|
28
|
+
score: TData
|
29
|
+
p_value: TData
|
30
|
+
|
31
|
+
def to_dataframe(self) -> pd.DataFrame:
|
32
|
+
"""
|
33
|
+
Exports the parity output results to a pandas DataFrame.
|
34
|
+
|
35
|
+
Returns
|
36
|
+
-------
|
37
|
+
pd.DataFrame
|
38
|
+
"""
|
39
|
+
import pandas as pd
|
40
|
+
|
41
|
+
return pd.DataFrame(
|
42
|
+
index=self.factor_names, # type: ignore - list[str] is documented as acceptable index type
|
43
|
+
data={
|
44
|
+
"score": self.score.round(2),
|
45
|
+
"p-value": self.p_value.round(2),
|
46
|
+
},
|
47
|
+
)
|
48
|
+
|
49
|
+
|
50
|
+
@dataclass(frozen=True)
|
51
|
+
class LabelParityOutput(BaseParityOutput[np.float64]):
|
52
|
+
"""
|
53
|
+
Output class for :func:`.label_parity` :term:`bias<Bias>` metrics.
|
54
|
+
|
55
|
+
Attributes
|
56
|
+
----------
|
57
|
+
score : np.float64
|
58
|
+
chi-squared score(s) of the test
|
59
|
+
p_value : np.float64
|
60
|
+
p-value(s) of the test
|
61
|
+
"""
|
62
|
+
|
63
|
+
|
64
|
+
@dataclass(frozen=True)
|
65
|
+
class ParityOutput(BaseParityOutput[NDArray[np.float64]]):
|
23
66
|
"""
|
24
|
-
Output class for :func
|
67
|
+
Output class for :func:`.parity` :term:`bias<Bias>` metrics.
|
25
68
|
|
26
69
|
Attributes
|
27
70
|
----------
|
28
|
-
score :
|
71
|
+
score : NDArray[np.float64]
|
29
72
|
chi-squared score(s) of the test
|
30
|
-
p_value :
|
73
|
+
p_value : NDArray[np.float64]
|
31
74
|
p-value(s) of the test
|
32
|
-
|
75
|
+
factor_names : list[str]
|
33
76
|
Names of each metadata factor
|
77
|
+
insufficient_data: dict
|
78
|
+
Dictionary of metadata factors with less than 5 class occurrences per value
|
34
79
|
"""
|
35
80
|
|
36
|
-
score:
|
37
|
-
p_value:
|
38
|
-
|
81
|
+
# score: NDArray[np.float64]
|
82
|
+
# p_value: NDArray[np.float64]
|
83
|
+
factor_names: list[str]
|
84
|
+
insufficient_data: dict[str, dict[int, dict[str, int]]]
|
39
85
|
|
40
86
|
|
41
87
|
def normalize_expected_dist(expected_dist: NDArray[Any], observed_dist: NDArray[Any]) -> NDArray[Any]:
|
@@ -109,7 +155,7 @@ def validate_dist(label_dist: NDArray[Any], label_name: str) -> None:
|
|
109
155
|
raise ValueError(f"No labels found in the {label_name} dataset")
|
110
156
|
if np.any(label_dist < 5):
|
111
157
|
warnings.warn(
|
112
|
-
f"Labels {np.where(label_dist<5)[0]} in {label_name}"
|
158
|
+
f"Labels {np.where(label_dist < 5)[0]} in {label_name}"
|
113
159
|
" dataset have frequencies less than 5. This may lead"
|
114
160
|
" to invalid chi-squared evaluation.",
|
115
161
|
UserWarning,
|
@@ -121,7 +167,7 @@ def label_parity(
|
|
121
167
|
expected_labels: ArrayLike,
|
122
168
|
observed_labels: ArrayLike,
|
123
169
|
num_classes: int | None = None,
|
124
|
-
) ->
|
170
|
+
) -> LabelParityOutput:
|
125
171
|
"""
|
126
172
|
Calculate the chi-square statistic to assess the :term:`parity<Parity>` \
|
127
173
|
between expected and observed label distributions.
|
@@ -142,7 +188,7 @@ def label_parity(
|
|
142
188
|
|
143
189
|
Returns
|
144
190
|
-------
|
145
|
-
|
191
|
+
LabelParityOutput
|
146
192
|
chi-squared score and :term`P-Value` of the test
|
147
193
|
|
148
194
|
Raises
|
@@ -171,7 +217,7 @@ def label_parity(
|
|
171
217
|
>>> expected_labels = rng.choice([0, 1, 2, 3, 4], (100))
|
172
218
|
>>> observed_labels = rng.choice([2, 3, 0, 4, 1], (100))
|
173
219
|
>>> label_parity(expected_labels, observed_labels)
|
174
|
-
|
220
|
+
LabelParityOutput(score=14.007374204742625, p_value=0.0072715574616218)
|
175
221
|
"""
|
176
222
|
|
177
223
|
# Calculate
|
@@ -179,8 +225,8 @@ def label_parity(
|
|
179
225
|
num_classes = 0
|
180
226
|
|
181
227
|
# Calculate the class frequencies associated with the datasets
|
182
|
-
observed_dist = np.bincount(
|
183
|
-
expected_dist = np.bincount(
|
228
|
+
observed_dist = np.bincount(as_numpy(observed_labels), minlength=num_classes)
|
229
|
+
expected_dist = np.bincount(as_numpy(expected_labels), minlength=num_classes)
|
184
230
|
|
185
231
|
# Validate
|
186
232
|
validate_dist(observed_dist, "observed")
|
@@ -202,11 +248,11 @@ def label_parity(
|
|
202
248
|
)
|
203
249
|
|
204
250
|
cs, p = chisquare(f_obs=observed_dist, f_exp=expected_dist)
|
205
|
-
return
|
251
|
+
return LabelParityOutput(cs, p)
|
206
252
|
|
207
253
|
|
208
254
|
@set_metadata
|
209
|
-
def parity(metadata: Metadata) -> ParityOutput
|
255
|
+
def parity(metadata: Metadata) -> ParityOutput:
|
210
256
|
"""
|
211
257
|
Calculate chi-square statistics to assess the linear relationship \
|
212
258
|
between multiple factors and class labels.
|
@@ -218,7 +264,7 @@ def parity(metadata: Metadata) -> ParityOutput[NDArray[np.float64]]:
|
|
218
264
|
Parameters
|
219
265
|
----------
|
220
266
|
metadata : Metadata
|
221
|
-
Preprocessed metadata
|
267
|
+
Preprocessed metadata
|
222
268
|
|
223
269
|
Returns
|
224
270
|
-------
|
@@ -250,22 +296,21 @@ def parity(metadata: Metadata) -> ParityOutput[NDArray[np.float64]]:
|
|
250
296
|
--------
|
251
297
|
Randomly creating some "continuous" and categorical variables using ``np.random.default_rng``
|
252
298
|
|
253
|
-
>>>
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
... "
|
258
|
-
... "
|
259
|
-
...
|
260
|
-
...
|
261
|
-
>>>
|
262
|
-
>>> metadata = preprocess(metadata_dict, labels, continuous_factor_bincounts)
|
299
|
+
>>> metadata = generate_random_metadata(
|
300
|
+
... labels=["doctor", "artist", "teacher"],
|
301
|
+
... factors={
|
302
|
+
... "age": [25, 30, 35, 45],
|
303
|
+
... "income": [50000, 65000, 80000],
|
304
|
+
... "gender": ["M", "F"]},
|
305
|
+
... length=100,
|
306
|
+
... random_seed=175)
|
307
|
+
>>> metadata.continuous_factor_bins = {"age": 4, "income": 3}
|
263
308
|
>>> parity(metadata)
|
264
|
-
ParityOutput(score=array([7.35731943, 5.46711299, 0.51506212]), p_value=array([0.28906231, 0.24263543, 0.77295762]),
|
309
|
+
ParityOutput(score=array([7.35731943, 5.46711299, 0.51506212]), p_value=array([0.28906231, 0.24263543, 0.77295762]), factor_names=['age', 'income', 'gender'], insufficient_data={'age': {3: {'artist': 4}, 4: {'artist': 4, 'teacher': 3}}, 'income': {1: {'artist': 3}}})
|
265
310
|
""" # noqa: E501
|
266
311
|
chi_scores = np.zeros(metadata.discrete_data.shape[1])
|
267
312
|
p_values = np.zeros_like(chi_scores)
|
268
|
-
|
313
|
+
insufficient_data = {}
|
269
314
|
for i, col_data in enumerate(metadata.discrete_data.T):
|
270
315
|
# Builds a contingency matrix where entry at index (r,c) represents
|
271
316
|
# the frequency of current_factor_name achieving value unique_factor_values[r]
|
@@ -279,14 +324,14 @@ def parity(metadata: Metadata) -> ParityOutput[NDArray[np.float64]]:
|
|
279
324
|
current_factor_name = metadata.discrete_factor_names[i]
|
280
325
|
for int_factor, int_class in zip(counts[0], counts[1]):
|
281
326
|
if contingency_matrix[int_factor, int_class] > 0:
|
282
|
-
factor_category = unique_factor_values[int_factor]
|
283
|
-
if current_factor_name not in
|
284
|
-
|
285
|
-
if factor_category not in
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
327
|
+
factor_category = unique_factor_values[int_factor].item()
|
328
|
+
if current_factor_name not in insufficient_data:
|
329
|
+
insufficient_data[current_factor_name] = {}
|
330
|
+
if factor_category not in insufficient_data[current_factor_name]:
|
331
|
+
insufficient_data[current_factor_name][factor_category] = {}
|
332
|
+
class_name = metadata.class_names[int_class]
|
333
|
+
class_count = contingency_matrix[int_factor, int_class].item()
|
334
|
+
insufficient_data[current_factor_name][factor_category][class_name] = class_count
|
290
335
|
|
291
336
|
# This deletes rows containing only zeros,
|
292
337
|
# because scipy.stats.chi2_contingency fails when there are rows containing only zeros.
|
@@ -299,24 +344,7 @@ def parity(metadata: Metadata) -> ParityOutput[NDArray[np.float64]]:
|
|
299
344
|
chi_scores[i] = chi2
|
300
345
|
p_values[i] = p
|
301
346
|
|
302
|
-
if
|
303
|
-
|
304
|
-
for factor, fact_dict in not_enough_data.items():
|
305
|
-
stacked_msg = []
|
306
|
-
for key, value in fact_dict.items():
|
307
|
-
msg = []
|
308
|
-
for item in value:
|
309
|
-
msg.append(f"label {item[0]}: {item[1]} occurrences")
|
310
|
-
flat_msg = "\n\t\t".join(msg)
|
311
|
-
stacked_msg.append(f"value {key} - {flat_msg}\n\t")
|
312
|
-
factor_msg.append(factor + " - " + "".join(stacked_msg))
|
313
|
-
|
314
|
-
message = "\n".join(factor_msg)
|
315
|
-
|
316
|
-
warnings.warn(
|
317
|
-
f"The following factors did not meet the recommended 5 occurrences for each value-label combination. \n\
|
318
|
-
Recommend rerunning parity after adjusting the following factor-value-label combinations: \n{message}",
|
319
|
-
UserWarning,
|
320
|
-
)
|
347
|
+
if insufficient_data:
|
348
|
+
warnings.warn("Some factors did not meet the recommended 5 occurrences for each value-label combination.")
|
321
349
|
|
322
|
-
return ParityOutput(chi_scores, p_values, metadata.discrete_factor_names)
|
350
|
+
return ParityOutput(chi_scores, p_values, metadata.discrete_factor_names, insufficient_data)
|