dataeval 0.73.0__py3-none-any.whl → 0.74.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +3 -3
- dataeval/detectors/__init__.py +1 -1
- dataeval/detectors/drift/__init__.py +1 -1
- dataeval/detectors/drift/base.py +2 -2
- dataeval/detectors/drift/torch.py +1 -101
- dataeval/detectors/linters/clusterer.py +1 -1
- dataeval/detectors/ood/__init__.py +11 -4
- dataeval/detectors/ood/ae.py +2 -1
- dataeval/detectors/ood/ae_torch.py +70 -0
- dataeval/detectors/ood/aegmm.py +4 -3
- dataeval/detectors/ood/base.py +58 -108
- dataeval/detectors/ood/base_tf.py +109 -0
- dataeval/detectors/ood/base_torch.py +109 -0
- dataeval/detectors/ood/llr.py +2 -2
- dataeval/detectors/ood/metadata_ks_compare.py +53 -14
- dataeval/detectors/ood/vae.py +3 -2
- dataeval/detectors/ood/vaegmm.py +5 -4
- dataeval/metrics/bias/__init__.py +3 -0
- dataeval/metrics/bias/balance.py +77 -64
- dataeval/metrics/bias/coverage.py +12 -12
- dataeval/metrics/bias/diversity.py +74 -114
- dataeval/metrics/bias/metadata_preprocessing.py +285 -0
- dataeval/metrics/bias/metadata_utils.py +229 -0
- dataeval/metrics/bias/parity.py +54 -158
- dataeval/utils/__init__.py +2 -2
- dataeval/utils/gmm.py +26 -0
- dataeval/utils/metadata.py +29 -9
- dataeval/utils/shared.py +1 -1
- dataeval/utils/split_dataset.py +12 -6
- dataeval/utils/tensorflow/_internal/gmm.py +4 -24
- dataeval/utils/torch/datasets.py +2 -2
- dataeval/utils/torch/gmm.py +98 -0
- dataeval/utils/torch/models.py +192 -0
- dataeval/utils/torch/trainer.py +84 -5
- dataeval/utils/torch/utils.py +107 -1
- dataeval/workflows/__init__.py +1 -1
- {dataeval-0.73.0.dist-info → dataeval-0.74.0.dist-info}/METADATA +1 -2
- {dataeval-0.73.0.dist-info → dataeval-0.74.0.dist-info}/RECORD +40 -34
- dataeval/metrics/bias/metadata.py +0 -358
- {dataeval-0.73.0.dist-info → dataeval-0.74.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.73.0.dist-info → dataeval-0.74.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,229 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
import contextlib
|
6
|
+
from typing import Any
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
from numpy.typing import ArrayLike, NDArray
|
10
|
+
|
11
|
+
from dataeval.interop import to_numpy
|
12
|
+
|
13
|
+
with contextlib.suppress(ImportError):
|
14
|
+
from matplotlib.figure import Figure
|
15
|
+
|
16
|
+
|
17
|
+
def get_counts(data: NDArray[np.int_], min_num_bins: int | None = None) -> NDArray[np.int_]:
|
18
|
+
"""
|
19
|
+
Returns columnwise unique counts for discrete data.
|
20
|
+
|
21
|
+
Parameters
|
22
|
+
----------
|
23
|
+
data : NDArray
|
24
|
+
Array containing integer values for metadata factors
|
25
|
+
min_num_bins : int | None, default None
|
26
|
+
Minimum number of bins for bincount, helps force consistency across runs
|
27
|
+
|
28
|
+
Returns
|
29
|
+
-------
|
30
|
+
NDArray[np.int_]
|
31
|
+
Bin counts per column of data.
|
32
|
+
"""
|
33
|
+
max_value = data.max() + 1 if min_num_bins is None else min_num_bins
|
34
|
+
cnt_array = np.zeros((max_value, data.shape[1]), dtype=np.int_)
|
35
|
+
for idx in range(data.shape[1]):
|
36
|
+
cnt_array[:, idx] = np.bincount(data[:, idx], minlength=max_value)
|
37
|
+
|
38
|
+
return cnt_array
|
39
|
+
|
40
|
+
|
41
|
+
def heatmap(
|
42
|
+
data: ArrayLike,
|
43
|
+
row_labels: list[str] | ArrayLike,
|
44
|
+
col_labels: list[str] | ArrayLike,
|
45
|
+
xlabel: str = "",
|
46
|
+
ylabel: str = "",
|
47
|
+
cbarlabel: str = "",
|
48
|
+
) -> Figure:
|
49
|
+
"""
|
50
|
+
Plots a formatted heatmap
|
51
|
+
|
52
|
+
Parameters
|
53
|
+
----------
|
54
|
+
data : NDArray
|
55
|
+
Array containing numerical values for factors to plot
|
56
|
+
row_labels : ArrayLike
|
57
|
+
List/Array containing the labels for rows in the histogram
|
58
|
+
col_labels : ArrayLike
|
59
|
+
List/Array containing the labels for columns in the histogram
|
60
|
+
xlabel : str, default ""
|
61
|
+
X-axis label
|
62
|
+
ylabel : str, default ""
|
63
|
+
Y-axis label
|
64
|
+
cbarlabel : str, default ""
|
65
|
+
Label for the colorbar
|
66
|
+
|
67
|
+
Returns
|
68
|
+
-------
|
69
|
+
matplotlib.figure.Figure
|
70
|
+
Formatted heatmap
|
71
|
+
"""
|
72
|
+
import matplotlib.pyplot as plt
|
73
|
+
from matplotlib.ticker import FuncFormatter
|
74
|
+
|
75
|
+
np_data = to_numpy(data)
|
76
|
+
rows = row_labels if isinstance(row_labels, list) else to_numpy(row_labels)
|
77
|
+
cols = col_labels if isinstance(col_labels, list) else to_numpy(col_labels)
|
78
|
+
|
79
|
+
fig, ax = plt.subplots(figsize=(10, 10))
|
80
|
+
|
81
|
+
# Plot the heatmap
|
82
|
+
im = ax.imshow(np_data, vmin=0, vmax=1.0)
|
83
|
+
|
84
|
+
# Create colorbar
|
85
|
+
cbar = fig.colorbar(im, shrink=0.5)
|
86
|
+
cbar.set_ticks([0.0, 0.25, 0.5, 0.75, 1.0])
|
87
|
+
cbar.set_ticklabels(["0.0", "0.25", "0.5", "0.75", "1.0"])
|
88
|
+
cbar.set_label(cbarlabel, loc="center")
|
89
|
+
|
90
|
+
# Show all ticks and label them with the respective list entries.
|
91
|
+
ax.set_xticks(np.arange(np_data.shape[1]), labels=cols)
|
92
|
+
ax.set_yticks(np.arange(np_data.shape[0]), labels=rows)
|
93
|
+
|
94
|
+
ax.tick_params(top=False, bottom=True, labeltop=False, labelbottom=True)
|
95
|
+
# Rotate the tick labels and set their alignment.
|
96
|
+
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
|
97
|
+
|
98
|
+
# Turn spines off and create white grid.
|
99
|
+
ax.spines[:].set_visible(False)
|
100
|
+
|
101
|
+
ax.set_xticks(np.arange(np_data.shape[1] + 1) - 0.5, minor=True)
|
102
|
+
ax.set_yticks(np.arange(np_data.shape[0] + 1) - 0.5, minor=True)
|
103
|
+
ax.grid(which="minor", color="w", linestyle="-", linewidth=3)
|
104
|
+
ax.tick_params(which="minor", bottom=False, left=False)
|
105
|
+
|
106
|
+
if xlabel:
|
107
|
+
ax.set_xlabel(xlabel)
|
108
|
+
if ylabel:
|
109
|
+
ax.set_ylabel(ylabel)
|
110
|
+
|
111
|
+
valfmt = FuncFormatter(format_text)
|
112
|
+
|
113
|
+
# Normalize the threshold to the images color range.
|
114
|
+
threshold = im.norm(1.0) / 2.0
|
115
|
+
|
116
|
+
# Set default alignment to center, but allow it to be
|
117
|
+
# overwritten by textkw.
|
118
|
+
kw = {"horizontalalignment": "center", "verticalalignment": "center"}
|
119
|
+
|
120
|
+
# Loop over the data and create a `Text` for each "pixel".
|
121
|
+
# Change the text's color depending on the data.
|
122
|
+
textcolors = ("white", "black")
|
123
|
+
texts = []
|
124
|
+
for i in range(np_data.shape[0]):
|
125
|
+
for j in range(np_data.shape[1]):
|
126
|
+
kw.update(color=textcolors[int(im.norm(np_data[i, j]) > threshold)])
|
127
|
+
text = im.axes.text(j, i, valfmt(np_data[i, j], None), **kw) # type: ignore
|
128
|
+
texts.append(text)
|
129
|
+
|
130
|
+
fig.tight_layout()
|
131
|
+
return fig
|
132
|
+
|
133
|
+
|
134
|
+
# Function to define how the text is displayed in the heatmap
|
135
|
+
def format_text(*args: str) -> str:
|
136
|
+
"""
|
137
|
+
Helper function to format text for heatmap()
|
138
|
+
|
139
|
+
Parameters
|
140
|
+
----------
|
141
|
+
*args : tuple[str, str]
|
142
|
+
Text to be formatted. Second element is ignored, but is a
|
143
|
+
mandatory pass-through argument as per matplotlib.ticket.FuncFormatter
|
144
|
+
|
145
|
+
Returns
|
146
|
+
-------
|
147
|
+
str
|
148
|
+
Formatted text
|
149
|
+
"""
|
150
|
+
x = args[0]
|
151
|
+
return f"{x:.2f}".replace("0.00", "0").replace("0.", ".").replace("nan", "")
|
152
|
+
|
153
|
+
|
154
|
+
def diversity_bar_plot(labels: NDArray[Any], bar_heights: NDArray[Any]) -> Figure:
|
155
|
+
"""
|
156
|
+
Plots a formatted bar plot
|
157
|
+
|
158
|
+
Parameters
|
159
|
+
----------
|
160
|
+
labels : NDArray
|
161
|
+
Array containing the labels for each bar
|
162
|
+
bar_heights : NDArray
|
163
|
+
Array containing the values for each bar
|
164
|
+
|
165
|
+
Returns
|
166
|
+
-------
|
167
|
+
matplotlib.figure.Figure
|
168
|
+
Bar plot figure
|
169
|
+
"""
|
170
|
+
import matplotlib.pyplot as plt
|
171
|
+
|
172
|
+
fig, ax = plt.subplots(figsize=(10, 10))
|
173
|
+
|
174
|
+
ax.bar(labels, bar_heights)
|
175
|
+
ax.set_xlabel("Factors")
|
176
|
+
|
177
|
+
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
|
178
|
+
|
179
|
+
fig.tight_layout()
|
180
|
+
return fig
|
181
|
+
|
182
|
+
|
183
|
+
def coverage_plot(images: NDArray[Any], num_images: int) -> Figure:
|
184
|
+
"""
|
185
|
+
Creates a single plot of all of the provided images
|
186
|
+
|
187
|
+
Parameters
|
188
|
+
----------
|
189
|
+
images : NDArray
|
190
|
+
Array containing only the desired images to plot
|
191
|
+
|
192
|
+
Returns
|
193
|
+
-------
|
194
|
+
matplotlib.figure.Figure
|
195
|
+
Plot of all provided images
|
196
|
+
"""
|
197
|
+
import matplotlib.pyplot as plt
|
198
|
+
|
199
|
+
num_images = min(num_images, len(images))
|
200
|
+
|
201
|
+
if images.ndim == 4:
|
202
|
+
images = np.moveaxis(images, 1, -1)
|
203
|
+
elif images.ndim == 3:
|
204
|
+
images = np.repeat(images[:, :, :, np.newaxis], 3, axis=-1)
|
205
|
+
else:
|
206
|
+
raise ValueError(
|
207
|
+
f"Expected a (N,C,H,W) or a (N, H, W) set of images, but got a {images.ndim}-dimensional set of images."
|
208
|
+
)
|
209
|
+
|
210
|
+
rows = int(np.ceil(num_images / 3))
|
211
|
+
fig, axs = plt.subplots(rows, 3, figsize=(9, 3 * rows))
|
212
|
+
|
213
|
+
if rows == 1:
|
214
|
+
for j in range(3):
|
215
|
+
if j >= len(images):
|
216
|
+
continue
|
217
|
+
axs[j].imshow(images[j])
|
218
|
+
axs[j].axis("off")
|
219
|
+
else:
|
220
|
+
for i in range(rows):
|
221
|
+
for j in range(3):
|
222
|
+
i_j = i * 3 + j
|
223
|
+
if i_j >= len(images):
|
224
|
+
continue
|
225
|
+
axs[i, j].imshow(images[i_j])
|
226
|
+
axs[i, j].axis("off")
|
227
|
+
|
228
|
+
fig.tight_layout()
|
229
|
+
return fig
|
dataeval/metrics/bias/parity.py
CHANGED
@@ -4,14 +4,15 @@ __all__ = ["ParityOutput", "parity", "label_parity"]
|
|
4
4
|
|
5
5
|
import warnings
|
6
6
|
from dataclasses import dataclass
|
7
|
-
from typing import Any, Generic,
|
7
|
+
from typing import Any, Generic, TypeVar
|
8
8
|
|
9
9
|
import numpy as np
|
10
10
|
from numpy.typing import ArrayLike, NDArray
|
11
|
-
from scipy.stats import
|
11
|
+
from scipy.stats import chisquare
|
12
|
+
from scipy.stats.contingency import chi2_contingency, crosstab
|
12
13
|
|
13
|
-
from dataeval.interop import to_numpy
|
14
|
-
from dataeval.metrics.bias.
|
14
|
+
from dataeval.interop import as_numpy, to_numpy
|
15
|
+
from dataeval.metrics.bias.metadata_preprocessing import MetadataOutput
|
15
16
|
from dataeval.output import OutputMetadata, set_metadata
|
16
17
|
|
17
18
|
TData = TypeVar("TData", np.float64, NDArray[np.float64])
|
@@ -28,7 +29,7 @@ class ParityOutput(Generic[TData], OutputMetadata):
|
|
28
29
|
chi-squared score(s) of the test
|
29
30
|
p_value : np.float64 | NDArray[np.float64]
|
30
31
|
p-value(s) of the test
|
31
|
-
metadata_names: list[str] | None
|
32
|
+
metadata_names : list[str] | None
|
32
33
|
Names of each metadata factor
|
33
34
|
"""
|
34
35
|
|
@@ -37,92 +38,6 @@ class ParityOutput(Generic[TData], OutputMetadata):
|
|
37
38
|
metadata_names: list[str] | None
|
38
39
|
|
39
40
|
|
40
|
-
def digitize_factor_bins(continuous_values: NDArray[Any], bins: int, factor_name: str) -> NDArray[np.intp]:
|
41
|
-
"""
|
42
|
-
Digitizes a list of values into a given number of bins.
|
43
|
-
|
44
|
-
Parameters
|
45
|
-
----------
|
46
|
-
continuous_values: NDArray
|
47
|
-
The values to be digitized.
|
48
|
-
bins: int
|
49
|
-
The number of bins for the discrete values that continuous_values will be digitized into.
|
50
|
-
factor_name: str
|
51
|
-
The name of the factor to be digitized.
|
52
|
-
|
53
|
-
Returns
|
54
|
-
-------
|
55
|
-
NDArray
|
56
|
-
The digitized values
|
57
|
-
"""
|
58
|
-
|
59
|
-
if not np.all([np.issubdtype(type(n), np.number) for n in continuous_values]):
|
60
|
-
raise TypeError(
|
61
|
-
f"Encountered a non-numeric value for factor {factor_name}, but the factor"
|
62
|
-
" was specified to be continuous. Ensure all occurrences of this factor are numeric types,"
|
63
|
-
f" or do not specify {factor_name} as a continuous factor."
|
64
|
-
)
|
65
|
-
|
66
|
-
_, bin_edges = np.histogram(continuous_values, bins=bins)
|
67
|
-
bin_edges[-1] = np.inf
|
68
|
-
bin_edges[0] = -np.inf
|
69
|
-
return np.digitize(continuous_values, bin_edges)
|
70
|
-
|
71
|
-
|
72
|
-
def format_discretize_factors(
|
73
|
-
data: NDArray[Any], names: list[str], is_categorical: list[bool], continuous_factor_bincounts: Mapping[str, int]
|
74
|
-
) -> dict[str, NDArray[Any]]:
|
75
|
-
"""
|
76
|
-
Sets up the internal list of metadata factors.
|
77
|
-
|
78
|
-
Parameters
|
79
|
-
----------
|
80
|
-
data_factors: Dict[str, NDArray]
|
81
|
-
The dataset factors, which are per-image attributes including class label and metadata.
|
82
|
-
Each key of dataset_factors is a factor, whose value is the per-image factor values.
|
83
|
-
continuous_factor_bincounts : Dict[str, int]
|
84
|
-
The factors in data_factors that have continuous values and the array of bin counts to
|
85
|
-
discretize values into. All factors are treated as having discrete values unless they
|
86
|
-
are specified as keys in this dictionary. Each element of this array must occur as a key
|
87
|
-
in data_factors.
|
88
|
-
|
89
|
-
Returns
|
90
|
-
-------
|
91
|
-
Dict[str, NDArray]
|
92
|
-
- Intrinsic per-image metadata information with the formatting that input data_factors uses.
|
93
|
-
Each key is a metadata factor, whose value is the discrete per-image factor values.
|
94
|
-
"""
|
95
|
-
|
96
|
-
invalid_keys = set(continuous_factor_bincounts.keys()) - set(names)
|
97
|
-
if invalid_keys:
|
98
|
-
raise KeyError(
|
99
|
-
f"The continuous factor(s) {invalid_keys} do not exist in data_factors. Delete these "
|
100
|
-
"keys from `continuous_factor_names` or add corresponding entries to `data_factors`."
|
101
|
-
)
|
102
|
-
|
103
|
-
warn = []
|
104
|
-
metadata_factors = {}
|
105
|
-
for i, name in enumerate(names):
|
106
|
-
if name == CLASS_LABEL:
|
107
|
-
continue
|
108
|
-
if name in continuous_factor_bincounts:
|
109
|
-
metadata_factors[name] = digitize_factor_bins(data[:, i], continuous_factor_bincounts[name], name)
|
110
|
-
elif not is_categorical[i]:
|
111
|
-
warn.append(name)
|
112
|
-
metadata_factors[name] = data[:, i]
|
113
|
-
else:
|
114
|
-
metadata_factors[name] = data[:, i]
|
115
|
-
|
116
|
-
if warn:
|
117
|
-
warnings.warn(
|
118
|
-
f"The following factors appear to be continuous but did not have the desired number of bins specified: \n\
|
119
|
-
{warn}",
|
120
|
-
UserWarning,
|
121
|
-
)
|
122
|
-
|
123
|
-
return metadata_factors
|
124
|
-
|
125
|
-
|
126
41
|
def normalize_expected_dist(expected_dist: NDArray[Any], observed_dist: NDArray[Any]) -> NDArray[Any]:
|
127
42
|
"""
|
128
43
|
Normalize the expected label distribution to match the total number of labels in the observed distribution.
|
@@ -132,14 +47,14 @@ def normalize_expected_dist(expected_dist: NDArray[Any], observed_dist: NDArray[
|
|
132
47
|
|
133
48
|
Parameters
|
134
49
|
----------
|
135
|
-
expected_dist :
|
50
|
+
expected_dist : NDArray
|
136
51
|
The expected label distribution. This array represents the anticipated distribution of labels.
|
137
|
-
observed_dist :
|
52
|
+
observed_dist : NDArray
|
138
53
|
The observed label distribution. This array represents the actual distribution of labels in the dataset.
|
139
54
|
|
140
55
|
Returns
|
141
56
|
-------
|
142
|
-
|
57
|
+
NDArray
|
143
58
|
The normalized expected distribution, scaled to have the same sum as the observed distribution.
|
144
59
|
|
145
60
|
Raises
|
@@ -179,6 +94,8 @@ def validate_dist(label_dist: NDArray[Any], label_name: str) -> None:
|
|
179
94
|
----------
|
180
95
|
label_dist : NDArray
|
181
96
|
Array representing label distributions
|
97
|
+
label_name : str
|
98
|
+
String representing label name
|
182
99
|
|
183
100
|
Raises
|
184
101
|
------
|
@@ -219,7 +136,7 @@ def label_parity(
|
|
219
136
|
List of class labels in the expected dataset
|
220
137
|
observed_labels : ArrayLike
|
221
138
|
List of class labels in the observed dataset
|
222
|
-
num_classes : int
|
139
|
+
num_classes : int or None, default None
|
223
140
|
The number of unique classes in the datasets. If not provided, the function will infer it
|
224
141
|
from the set of unique labels in expected_labels and observed_labels
|
225
142
|
|
@@ -288,31 +205,19 @@ def label_parity(
|
|
288
205
|
|
289
206
|
|
290
207
|
@set_metadata()
|
291
|
-
def parity(
|
292
|
-
class_labels: ArrayLike,
|
293
|
-
metadata: Mapping[str, ArrayLike],
|
294
|
-
continuous_factor_bincounts: Mapping[str, int] | None = None,
|
295
|
-
) -> ParityOutput[NDArray[np.float64]]:
|
208
|
+
def parity(metadata: MetadataOutput) -> ParityOutput[NDArray[np.float64]]:
|
296
209
|
"""
|
297
|
-
Calculate chi-square statistics to assess the relationship between multiple factors
|
210
|
+
Calculate chi-square statistics to assess the linear relationship between multiple factors
|
298
211
|
and class labels.
|
299
212
|
|
300
213
|
This function computes the chi-square statistic for each metadata factor to determine if there is
|
301
|
-
a significant relationship between the factor values and class labels. The
|
302
|
-
|
214
|
+
a significant relationship between the factor values and class labels. The chi-square statistic is
|
215
|
+
only valid for linear relationships. If non-linear relationships exist, use `balance`.
|
303
216
|
|
304
217
|
Parameters
|
305
218
|
----------
|
306
|
-
|
307
|
-
|
308
|
-
metadata: Mapping[str, ArrayLike]
|
309
|
-
The dataset factors, which are per-image metadata attributes.
|
310
|
-
Each key of dataset_factors is a factor, whose value is the per-image factor values.
|
311
|
-
continuous_factor_bincounts : Mapping[str, int] | None, default None
|
312
|
-
A dictionary specifying the number of bins for discretizing the continuous factors.
|
313
|
-
The keys should correspond to the names of continuous factors in `metadata`,
|
314
|
-
and the values should be the number of bins to use for discretization.
|
315
|
-
If not provided, no discretization is applied.
|
219
|
+
metadata : MetadataOutput
|
220
|
+
Output after running `metadata_preprocessing`
|
316
221
|
|
317
222
|
Returns
|
318
223
|
-------
|
@@ -326,75 +231,66 @@ def parity(
|
|
326
231
|
Warning
|
327
232
|
If any cell in the contingency matrix has a value between 0 and 5, a warning is issued because this can
|
328
233
|
lead to inaccurate chi-square calculations. It is recommended to ensure that each label co-occurs with
|
329
|
-
factor values either 0 times or at least 5 times.
|
330
|
-
into fewer bins.
|
234
|
+
factor values either 0 times or at least 5 times.
|
331
235
|
|
332
236
|
Note
|
333
237
|
----
|
334
|
-
- Each key of the ``continuous_factor_bincounts`` dictionary must occur as a key in data_factors.
|
335
238
|
- A high score with a low p-value suggests that a metadata factor is strongly correlated with a class label.
|
336
239
|
- The function creates a contingency matrix for each factor, where each entry represents the frequency of a
|
337
240
|
specific factor value co-occurring with a particular class label.
|
338
241
|
- Rows containing only zeros in the contingency matrix are removed before performing the chi-square test
|
339
242
|
to prevent errors in the calculation.
|
340
243
|
|
244
|
+
See Also
|
245
|
+
--------
|
246
|
+
balance
|
247
|
+
|
341
248
|
Examples
|
342
249
|
--------
|
343
250
|
Randomly creating some "continuous" and categorical variables using ``np.random.default_rng``
|
344
251
|
|
345
252
|
>>> labels = np_random_gen.choice([0, 1, 2], (100))
|
346
|
-
>>>
|
347
|
-
...
|
348
|
-
...
|
349
|
-
...
|
350
|
-
...
|
253
|
+
>>> metadata_dict = [
|
254
|
+
... {
|
255
|
+
... "age": list(np_random_gen.choice([25, 30, 35, 45], (100))),
|
256
|
+
... "income": list(np_random_gen.choice([50000, 65000, 80000], (100))),
|
257
|
+
... "gender": list(np_random_gen.choice(["M", "F"], (100))),
|
258
|
+
... }
|
259
|
+
... ]
|
351
260
|
>>> continuous_factor_bincounts = {"age": 4, "income": 3}
|
352
|
-
>>>
|
261
|
+
>>> metadata = metadata_preprocessing(metadata_dict, labels, continuous_factor_bincounts)
|
262
|
+
>>> parity(metadata)
|
353
263
|
ParityOutput(score=array([7.35731943, 5.46711299, 0.51506212]), p_value=array([0.28906231, 0.24263543, 0.77295762]), metadata_names=['age', 'income', 'gender'])
|
354
264
|
""" # noqa: E501
|
355
|
-
|
356
|
-
|
357
|
-
f"Got class labels with {len(np.shape(class_labels))}-dimensional",
|
358
|
-
f" shape {np.shape(class_labels)}, but expected a 1-dimensional array.",
|
359
|
-
)
|
360
|
-
|
361
|
-
data, names, is_categorical, _ = preprocess_metadata(class_labels, metadata)
|
362
|
-
continuous_factor_bincounts = continuous_factor_bincounts if continuous_factor_bincounts else {}
|
363
|
-
|
364
|
-
factors = format_discretize_factors(data, names, is_categorical, continuous_factor_bincounts)
|
365
|
-
|
366
|
-
# unique class labels
|
367
|
-
class_idx = names.index(CLASS_LABEL)
|
368
|
-
u_cls = np.unique(data[:, class_idx])
|
369
|
-
|
370
|
-
chi_scores = np.zeros(len(factors))
|
371
|
-
p_values = np.zeros(len(factors))
|
265
|
+
chi_scores = np.zeros(metadata.discrete_data.shape[1])
|
266
|
+
p_values = np.zeros_like(chi_scores)
|
372
267
|
not_enough_data = {}
|
373
|
-
for i,
|
374
|
-
unique_factor_values = np.unique(factor_values)
|
375
|
-
contingency_matrix = np.zeros((len(unique_factor_values), u_cls.size))
|
268
|
+
for i, col_data in enumerate(metadata.discrete_data.T):
|
376
269
|
# Builds a contingency matrix where entry at index (r,c) represents
|
377
270
|
# the frequency of current_factor_name achieving value unique_factor_values[r]
|
378
271
|
# at a data point with class c.
|
379
|
-
|
380
|
-
#
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
not_enough_data[current_factor_name]
|
391
|
-
|
392
|
-
|
272
|
+
results = crosstab(col_data, metadata.class_labels)
|
273
|
+
contingency_matrix = as_numpy(results.count) # type: ignore
|
274
|
+
|
275
|
+
# Determines if any frequencies are too low
|
276
|
+
counts = np.nonzero(contingency_matrix < 5)
|
277
|
+
unique_factor_values = np.unique(col_data)
|
278
|
+
current_factor_name = metadata.discrete_factor_names[i]
|
279
|
+
for int_factor, int_class in zip(counts[0], counts[1]):
|
280
|
+
if contingency_matrix[int_factor, int_class] > 0:
|
281
|
+
factor_category = unique_factor_values[int_factor]
|
282
|
+
if current_factor_name not in not_enough_data:
|
283
|
+
not_enough_data[current_factor_name] = {}
|
284
|
+
if factor_category not in not_enough_data[current_factor_name]:
|
285
|
+
not_enough_data[current_factor_name][factor_category] = []
|
286
|
+
not_enough_data[current_factor_name][factor_category].append(
|
287
|
+
(metadata.class_names[int_class], int(contingency_matrix[int_factor, int_class]))
|
288
|
+
)
|
393
289
|
|
394
290
|
# This deletes rows containing only zeros,
|
395
291
|
# because scipy.stats.chi2_contingency fails when there are rows containing only zeros.
|
396
292
|
rowsums = np.sum(contingency_matrix, axis=1)
|
397
|
-
rowmask = np.
|
293
|
+
rowmask = np.nonzero(rowsums)[0]
|
398
294
|
contingency_matrix = contingency_matrix[rowmask]
|
399
295
|
|
400
296
|
chi2, p, _, _ = chi2_contingency(contingency_matrix)
|
@@ -422,4 +318,4 @@ def parity(
|
|
422
318
|
UserWarning,
|
423
319
|
)
|
424
320
|
|
425
|
-
return ParityOutput(chi_scores, p_values,
|
321
|
+
return ParityOutput(chi_scores, p_values, metadata.discrete_factor_names)
|
dataeval/utils/__init__.py
CHANGED
@@ -10,12 +10,12 @@ from dataeval.utils.split_dataset import split_dataset
|
|
10
10
|
|
11
11
|
__all__ = ["split_dataset", "merge_metadata"]
|
12
12
|
|
13
|
-
if _IS_TORCH_AVAILABLE:
|
13
|
+
if _IS_TORCH_AVAILABLE:
|
14
14
|
from dataeval.utils import torch
|
15
15
|
|
16
16
|
__all__ += ["torch"]
|
17
17
|
|
18
|
-
if _IS_TENSORFLOW_AVAILABLE:
|
18
|
+
if _IS_TENSORFLOW_AVAILABLE:
|
19
19
|
from dataeval.utils import tensorflow
|
20
20
|
|
21
21
|
__all__ += ["tensorflow"]
|
dataeval/utils/gmm.py
ADDED
@@ -0,0 +1,26 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from typing import Generic, TypeVar
|
3
|
+
|
4
|
+
TGMMData = TypeVar("TGMMData")
|
5
|
+
|
6
|
+
|
7
|
+
@dataclass
|
8
|
+
class GaussianMixtureModelParams(Generic[TGMMData]):
|
9
|
+
"""
|
10
|
+
phi : TGMMData
|
11
|
+
Mixture component distribution weights.
|
12
|
+
mu : TGMMData
|
13
|
+
Mixture means.
|
14
|
+
cov : TGMMData
|
15
|
+
Mixture covariance.
|
16
|
+
L : TGMMData
|
17
|
+
Cholesky decomposition of `cov`.
|
18
|
+
log_det_cov : TGMMData
|
19
|
+
Log of the determinant of `cov`.
|
20
|
+
"""
|
21
|
+
|
22
|
+
phi: TGMMData
|
23
|
+
mu: TGMMData
|
24
|
+
cov: TGMMData
|
25
|
+
L: TGMMData
|
26
|
+
log_det_cov: TGMMData
|
dataeval/utils/metadata.py
CHANGED
@@ -131,7 +131,9 @@ def _flatten_dict_inner(
|
|
131
131
|
return items, size
|
132
132
|
|
133
133
|
|
134
|
-
def _flatten_dict(
|
134
|
+
def _flatten_dict(
|
135
|
+
d: Mapping[str, Any], sep: str, ignore_lists: bool, fully_qualified: bool
|
136
|
+
) -> tuple[dict[str, Any], int]:
|
135
137
|
"""
|
136
138
|
Flattens a dictionary and converts values to numeric values when possible.
|
137
139
|
|
@@ -165,7 +167,7 @@ def _flatten_dict(d: Mapping[str, Any], sep: str, ignore_lists: bool, fully_qual
|
|
165
167
|
output[k] = cv
|
166
168
|
elif not isinstance(cv, list):
|
167
169
|
output[k] = cv if not size else [cv] * size
|
168
|
-
return output
|
170
|
+
return output, size if size is not None else 1
|
169
171
|
|
170
172
|
|
171
173
|
def _is_metadata_dict_of_dicts(metadata: Mapping) -> bool:
|
@@ -188,7 +190,7 @@ def merge_metadata(
|
|
188
190
|
ignore_lists: bool = False,
|
189
191
|
fully_qualified: bool = False,
|
190
192
|
as_numpy: bool = False,
|
191
|
-
) -> dict[str, list[Any]] | dict[str, NDArray[Any]]:
|
193
|
+
) -> tuple[dict[str, list[Any]] | dict[str, NDArray[Any]], NDArray[np.int_]]:
|
192
194
|
"""
|
193
195
|
Merges a collection of metadata dictionaries into a single flattened dictionary of keys and values.
|
194
196
|
|
@@ -208,8 +210,10 @@ def merge_metadata(
|
|
208
210
|
|
209
211
|
Returns
|
210
212
|
-------
|
211
|
-
dict[str, list[Any]]
|
213
|
+
dict[str, list[Any]] or dict[str, NDArray[Any]]
|
212
214
|
A single dictionary containing the flattened data as lists or NumPy arrays
|
215
|
+
NDArray[np.int_]
|
216
|
+
Array defining where individual images start, helpful when working with object detection metadata
|
213
217
|
|
214
218
|
Note
|
215
219
|
----
|
@@ -217,9 +221,12 @@ def merge_metadata(
|
|
217
221
|
|
218
222
|
Example
|
219
223
|
-------
|
220
|
-
>>> list_metadata = [{"common": 1, "target": [{"a": 1, "b": 3}, {"a": 2, "b": 4}], "source": "example"}]
|
221
|
-
>>> merge_metadata(list_metadata)
|
224
|
+
>>> list_metadata = [{"common": 1, "target": [{"a": 1, "b": 3, "c": 5}, {"a": 2, "b": 4}], "source": "example"}]
|
225
|
+
>>> reorganized_metadata, image_indicies = merge_metadata(list_metadata)
|
226
|
+
>>> reorganized_metadata
|
222
227
|
{'common': [1, 1], 'a': [1, 2], 'b': [3, 4], 'source': ['example', 'example']}
|
228
|
+
>>> image_indicies
|
229
|
+
array([0])
|
223
230
|
"""
|
224
231
|
merged: dict[str, list[Any]] = {}
|
225
232
|
isect: set[str] = set()
|
@@ -236,8 +243,11 @@ def merge_metadata(
|
|
236
243
|
else:
|
237
244
|
dicts = list(metadata)
|
238
245
|
|
239
|
-
|
240
|
-
|
246
|
+
image_repeats = np.zeros(len(dicts))
|
247
|
+
for i, d in enumerate(dicts):
|
248
|
+
flattened, image_repeats[i] = _flatten_dict(
|
249
|
+
d, sep="_", ignore_lists=ignore_lists, fully_qualified=fully_qualified
|
250
|
+
)
|
241
251
|
isect = isect.intersection(flattened.keys()) if isect else set(flattened.keys())
|
242
252
|
union = union.union(flattened.keys())
|
243
253
|
for k, v in flattened.items():
|
@@ -248,6 +258,16 @@ def merge_metadata(
|
|
248
258
|
|
249
259
|
output: dict[str, Any] = {}
|
250
260
|
|
261
|
+
if image_repeats.sum() == image_repeats.size:
|
262
|
+
image_indicies = np.arange(image_repeats.size)
|
263
|
+
else:
|
264
|
+
image_ids = np.arange(image_repeats.size)
|
265
|
+
image_data = np.concatenate(
|
266
|
+
[np.repeat(image_ids[i], image_repeats[i]) for i in range(image_ids.size)], dtype=np.int_
|
267
|
+
)
|
268
|
+
_, image_unsorted = np.unique(image_data, return_index=True)
|
269
|
+
image_indicies = np.sort(image_unsorted)
|
270
|
+
|
251
271
|
if keys:
|
252
272
|
output["keys"] = np.array(keys) if as_numpy else keys
|
253
273
|
|
@@ -255,4 +275,4 @@ def merge_metadata(
|
|
255
275
|
cv = _convert_type(merged[k])
|
256
276
|
output[k] = np.array(cv) if as_numpy else cv
|
257
277
|
|
258
|
-
return output
|
278
|
+
return output, image_indicies
|
dataeval/utils/shared.py
CHANGED