dataeval 0.73.1__py3-none-any.whl → 0.74.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +3 -9
- dataeval/detectors/__init__.py +2 -10
- dataeval/detectors/drift/base.py +3 -3
- dataeval/detectors/drift/mmd.py +1 -1
- dataeval/detectors/drift/torch.py +1 -101
- dataeval/detectors/linters/clusterer.py +3 -3
- dataeval/detectors/linters/duplicates.py +4 -4
- dataeval/detectors/linters/outliers.py +4 -4
- dataeval/detectors/ood/__init__.py +9 -9
- dataeval/detectors/ood/{ae.py → ae_torch.py} +22 -27
- dataeval/detectors/ood/base.py +63 -113
- dataeval/detectors/ood/base_torch.py +109 -0
- dataeval/detectors/ood/metadata_ks_compare.py +52 -14
- dataeval/interop.py +1 -1
- dataeval/metrics/bias/__init__.py +3 -0
- dataeval/metrics/bias/balance.py +73 -70
- dataeval/metrics/bias/coverage.py +4 -4
- dataeval/metrics/bias/diversity.py +67 -136
- dataeval/metrics/bias/metadata_preprocessing.py +285 -0
- dataeval/metrics/bias/metadata_utils.py +229 -0
- dataeval/metrics/bias/parity.py +51 -161
- dataeval/metrics/estimators/ber.py +3 -3
- dataeval/metrics/estimators/divergence.py +3 -3
- dataeval/metrics/estimators/uap.py +3 -3
- dataeval/metrics/stats/base.py +2 -2
- dataeval/metrics/stats/boxratiostats.py +1 -1
- dataeval/metrics/stats/datasetstats.py +6 -6
- dataeval/metrics/stats/dimensionstats.py +1 -1
- dataeval/metrics/stats/hashstats.py +1 -1
- dataeval/metrics/stats/labelstats.py +3 -3
- dataeval/metrics/stats/pixelstats.py +1 -1
- dataeval/metrics/stats/visualstats.py +1 -1
- dataeval/output.py +77 -53
- dataeval/utils/__init__.py +1 -7
- dataeval/utils/gmm.py +26 -0
- dataeval/utils/metadata.py +29 -9
- dataeval/utils/torch/gmm.py +98 -0
- dataeval/utils/torch/models.py +192 -0
- dataeval/utils/torch/trainer.py +84 -5
- dataeval/utils/torch/utils.py +107 -1
- dataeval/workflows/sufficiency.py +4 -4
- {dataeval-0.73.1.dist-info → dataeval-0.74.1.dist-info}/METADATA +3 -9
- dataeval-0.74.1.dist-info/RECORD +65 -0
- dataeval/detectors/ood/aegmm.py +0 -66
- dataeval/detectors/ood/llr.py +0 -302
- dataeval/detectors/ood/vae.py +0 -97
- dataeval/detectors/ood/vaegmm.py +0 -75
- dataeval/metrics/bias/metadata.py +0 -440
- dataeval/utils/lazy.py +0 -26
- dataeval/utils/tensorflow/__init__.py +0 -19
- dataeval/utils/tensorflow/_internal/gmm.py +0 -123
- dataeval/utils/tensorflow/_internal/loss.py +0 -121
- dataeval/utils/tensorflow/_internal/models.py +0 -1394
- dataeval/utils/tensorflow/_internal/trainer.py +0 -114
- dataeval/utils/tensorflow/_internal/utils.py +0 -256
- dataeval/utils/tensorflow/loss/__init__.py +0 -11
- dataeval-0.73.1.dist-info/RECORD +0 -73
- {dataeval-0.73.1.dist-info → dataeval-0.74.1.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.73.1.dist-info → dataeval-0.74.1.dist-info}/WHEEL +0 -0
@@ -0,0 +1,229 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
import contextlib
|
6
|
+
from typing import Any
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
from numpy.typing import ArrayLike, NDArray
|
10
|
+
|
11
|
+
from dataeval.interop import to_numpy
|
12
|
+
|
13
|
+
with contextlib.suppress(ImportError):
|
14
|
+
from matplotlib.figure import Figure
|
15
|
+
|
16
|
+
|
17
|
+
def get_counts(data: NDArray[np.int_], min_num_bins: int | None = None) -> NDArray[np.int_]:
|
18
|
+
"""
|
19
|
+
Returns columnwise unique counts for discrete data.
|
20
|
+
|
21
|
+
Parameters
|
22
|
+
----------
|
23
|
+
data : NDArray
|
24
|
+
Array containing integer values for metadata factors
|
25
|
+
min_num_bins : int | None, default None
|
26
|
+
Minimum number of bins for bincount, helps force consistency across runs
|
27
|
+
|
28
|
+
Returns
|
29
|
+
-------
|
30
|
+
NDArray[np.int_]
|
31
|
+
Bin counts per column of data.
|
32
|
+
"""
|
33
|
+
max_value = data.max() + 1 if min_num_bins is None else min_num_bins
|
34
|
+
cnt_array = np.zeros((max_value, data.shape[1]), dtype=np.int_)
|
35
|
+
for idx in range(data.shape[1]):
|
36
|
+
cnt_array[:, idx] = np.bincount(data[:, idx], minlength=max_value)
|
37
|
+
|
38
|
+
return cnt_array
|
39
|
+
|
40
|
+
|
41
|
+
def heatmap(
|
42
|
+
data: ArrayLike,
|
43
|
+
row_labels: list[str] | ArrayLike,
|
44
|
+
col_labels: list[str] | ArrayLike,
|
45
|
+
xlabel: str = "",
|
46
|
+
ylabel: str = "",
|
47
|
+
cbarlabel: str = "",
|
48
|
+
) -> Figure:
|
49
|
+
"""
|
50
|
+
Plots a formatted heatmap
|
51
|
+
|
52
|
+
Parameters
|
53
|
+
----------
|
54
|
+
data : NDArray
|
55
|
+
Array containing numerical values for factors to plot
|
56
|
+
row_labels : ArrayLike
|
57
|
+
List/Array containing the labels for rows in the histogram
|
58
|
+
col_labels : ArrayLike
|
59
|
+
List/Array containing the labels for columns in the histogram
|
60
|
+
xlabel : str, default ""
|
61
|
+
X-axis label
|
62
|
+
ylabel : str, default ""
|
63
|
+
Y-axis label
|
64
|
+
cbarlabel : str, default ""
|
65
|
+
Label for the colorbar
|
66
|
+
|
67
|
+
Returns
|
68
|
+
-------
|
69
|
+
matplotlib.figure.Figure
|
70
|
+
Formatted heatmap
|
71
|
+
"""
|
72
|
+
import matplotlib.pyplot as plt
|
73
|
+
from matplotlib.ticker import FuncFormatter
|
74
|
+
|
75
|
+
np_data = to_numpy(data)
|
76
|
+
rows = row_labels if isinstance(row_labels, list) else to_numpy(row_labels)
|
77
|
+
cols = col_labels if isinstance(col_labels, list) else to_numpy(col_labels)
|
78
|
+
|
79
|
+
fig, ax = plt.subplots(figsize=(10, 10))
|
80
|
+
|
81
|
+
# Plot the heatmap
|
82
|
+
im = ax.imshow(np_data, vmin=0, vmax=1.0)
|
83
|
+
|
84
|
+
# Create colorbar
|
85
|
+
cbar = fig.colorbar(im, shrink=0.5)
|
86
|
+
cbar.set_ticks([0.0, 0.25, 0.5, 0.75, 1.0])
|
87
|
+
cbar.set_ticklabels(["0.0", "0.25", "0.5", "0.75", "1.0"])
|
88
|
+
cbar.set_label(cbarlabel, loc="center")
|
89
|
+
|
90
|
+
# Show all ticks and label them with the respective list entries.
|
91
|
+
ax.set_xticks(np.arange(np_data.shape[1]), labels=cols)
|
92
|
+
ax.set_yticks(np.arange(np_data.shape[0]), labels=rows)
|
93
|
+
|
94
|
+
ax.tick_params(top=False, bottom=True, labeltop=False, labelbottom=True)
|
95
|
+
# Rotate the tick labels and set their alignment.
|
96
|
+
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
|
97
|
+
|
98
|
+
# Turn spines off and create white grid.
|
99
|
+
ax.spines[:].set_visible(False)
|
100
|
+
|
101
|
+
ax.set_xticks(np.arange(np_data.shape[1] + 1) - 0.5, minor=True)
|
102
|
+
ax.set_yticks(np.arange(np_data.shape[0] + 1) - 0.5, minor=True)
|
103
|
+
ax.grid(which="minor", color="w", linestyle="-", linewidth=3)
|
104
|
+
ax.tick_params(which="minor", bottom=False, left=False)
|
105
|
+
|
106
|
+
if xlabel:
|
107
|
+
ax.set_xlabel(xlabel)
|
108
|
+
if ylabel:
|
109
|
+
ax.set_ylabel(ylabel)
|
110
|
+
|
111
|
+
valfmt = FuncFormatter(format_text)
|
112
|
+
|
113
|
+
# Normalize the threshold to the images color range.
|
114
|
+
threshold = im.norm(1.0) / 2.0
|
115
|
+
|
116
|
+
# Set default alignment to center, but allow it to be
|
117
|
+
# overwritten by textkw.
|
118
|
+
kw = {"horizontalalignment": "center", "verticalalignment": "center"}
|
119
|
+
|
120
|
+
# Loop over the data and create a `Text` for each "pixel".
|
121
|
+
# Change the text's color depending on the data.
|
122
|
+
textcolors = ("white", "black")
|
123
|
+
texts = []
|
124
|
+
for i in range(np_data.shape[0]):
|
125
|
+
for j in range(np_data.shape[1]):
|
126
|
+
kw.update(color=textcolors[int(im.norm(np_data[i, j]) > threshold)])
|
127
|
+
text = im.axes.text(j, i, valfmt(np_data[i, j], None), **kw) # type: ignore
|
128
|
+
texts.append(text)
|
129
|
+
|
130
|
+
fig.tight_layout()
|
131
|
+
return fig
|
132
|
+
|
133
|
+
|
134
|
+
# Function to define how the text is displayed in the heatmap
|
135
|
+
def format_text(*args: str) -> str:
|
136
|
+
"""
|
137
|
+
Helper function to format text for heatmap()
|
138
|
+
|
139
|
+
Parameters
|
140
|
+
----------
|
141
|
+
*args : tuple[str, str]
|
142
|
+
Text to be formatted. Second element is ignored, but is a
|
143
|
+
mandatory pass-through argument as per matplotlib.ticket.FuncFormatter
|
144
|
+
|
145
|
+
Returns
|
146
|
+
-------
|
147
|
+
str
|
148
|
+
Formatted text
|
149
|
+
"""
|
150
|
+
x = args[0]
|
151
|
+
return f"{x:.2f}".replace("0.00", "0").replace("0.", ".").replace("nan", "")
|
152
|
+
|
153
|
+
|
154
|
+
def diversity_bar_plot(labels: NDArray[Any], bar_heights: NDArray[Any]) -> Figure:
|
155
|
+
"""
|
156
|
+
Plots a formatted bar plot
|
157
|
+
|
158
|
+
Parameters
|
159
|
+
----------
|
160
|
+
labels : NDArray
|
161
|
+
Array containing the labels for each bar
|
162
|
+
bar_heights : NDArray
|
163
|
+
Array containing the values for each bar
|
164
|
+
|
165
|
+
Returns
|
166
|
+
-------
|
167
|
+
matplotlib.figure.Figure
|
168
|
+
Bar plot figure
|
169
|
+
"""
|
170
|
+
import matplotlib.pyplot as plt
|
171
|
+
|
172
|
+
fig, ax = plt.subplots(figsize=(10, 10))
|
173
|
+
|
174
|
+
ax.bar(labels, bar_heights)
|
175
|
+
ax.set_xlabel("Factors")
|
176
|
+
|
177
|
+
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
|
178
|
+
|
179
|
+
fig.tight_layout()
|
180
|
+
return fig
|
181
|
+
|
182
|
+
|
183
|
+
def coverage_plot(images: NDArray[Any], num_images: int) -> Figure:
|
184
|
+
"""
|
185
|
+
Creates a single plot of all of the provided images
|
186
|
+
|
187
|
+
Parameters
|
188
|
+
----------
|
189
|
+
images : NDArray
|
190
|
+
Array containing only the desired images to plot
|
191
|
+
|
192
|
+
Returns
|
193
|
+
-------
|
194
|
+
matplotlib.figure.Figure
|
195
|
+
Plot of all provided images
|
196
|
+
"""
|
197
|
+
import matplotlib.pyplot as plt
|
198
|
+
|
199
|
+
num_images = min(num_images, len(images))
|
200
|
+
|
201
|
+
if images.ndim == 4:
|
202
|
+
images = np.moveaxis(images, 1, -1)
|
203
|
+
elif images.ndim == 3:
|
204
|
+
images = np.repeat(images[:, :, :, np.newaxis], 3, axis=-1)
|
205
|
+
else:
|
206
|
+
raise ValueError(
|
207
|
+
f"Expected a (N,C,H,W) or a (N, H, W) set of images, but got a {images.ndim}-dimensional set of images."
|
208
|
+
)
|
209
|
+
|
210
|
+
rows = int(np.ceil(num_images / 3))
|
211
|
+
fig, axs = plt.subplots(rows, 3, figsize=(9, 3 * rows))
|
212
|
+
|
213
|
+
if rows == 1:
|
214
|
+
for j in range(3):
|
215
|
+
if j >= len(images):
|
216
|
+
continue
|
217
|
+
axs[j].imshow(images[j])
|
218
|
+
axs[j].axis("off")
|
219
|
+
else:
|
220
|
+
for i in range(rows):
|
221
|
+
for j in range(3):
|
222
|
+
i_j = i * 3 + j
|
223
|
+
if i_j >= len(images):
|
224
|
+
continue
|
225
|
+
axs[i, j].imshow(images[i_j])
|
226
|
+
axs[i, j].axis("off")
|
227
|
+
|
228
|
+
fig.tight_layout()
|
229
|
+
return fig
|
dataeval/metrics/bias/parity.py
CHANGED
@@ -4,21 +4,22 @@ __all__ = ["ParityOutput", "parity", "label_parity"]
|
|
4
4
|
|
5
5
|
import warnings
|
6
6
|
from dataclasses import dataclass
|
7
|
-
from typing import Any, Generic,
|
7
|
+
from typing import Any, Generic, TypeVar
|
8
8
|
|
9
9
|
import numpy as np
|
10
10
|
from numpy.typing import ArrayLike, NDArray
|
11
|
-
from scipy.stats import
|
11
|
+
from scipy.stats import chisquare
|
12
|
+
from scipy.stats.contingency import chi2_contingency, crosstab
|
12
13
|
|
13
|
-
from dataeval.interop import to_numpy
|
14
|
-
from dataeval.metrics.bias.
|
15
|
-
from dataeval.output import
|
14
|
+
from dataeval.interop import as_numpy, to_numpy
|
15
|
+
from dataeval.metrics.bias.metadata_preprocessing import MetadataOutput
|
16
|
+
from dataeval.output import Output, set_metadata
|
16
17
|
|
17
18
|
TData = TypeVar("TData", np.float64, NDArray[np.float64])
|
18
19
|
|
19
20
|
|
20
21
|
@dataclass(frozen=True)
|
21
|
-
class ParityOutput(Generic[TData],
|
22
|
+
class ParityOutput(Generic[TData], Output):
|
22
23
|
"""
|
23
24
|
Output class for :func:`parity` and :func:`label_parity` :term:`bias<Bias>` metrics
|
24
25
|
|
@@ -37,97 +38,6 @@ class ParityOutput(Generic[TData], OutputMetadata):
|
|
37
38
|
metadata_names: list[str] | None
|
38
39
|
|
39
40
|
|
40
|
-
def digitize_factor_bins(continuous_values: NDArray[Any], bins: int, factor_name: str) -> NDArray[np.intp]:
|
41
|
-
"""
|
42
|
-
Digitizes a list of values into a given number of bins.
|
43
|
-
|
44
|
-
Parameters
|
45
|
-
----------
|
46
|
-
continuous_values : NDArray
|
47
|
-
The values to be digitized.
|
48
|
-
bins : int
|
49
|
-
The number of bins for the discrete values that continuous_values will be digitized into.
|
50
|
-
factor_name : str
|
51
|
-
The name of the factor to be digitized.
|
52
|
-
|
53
|
-
Returns
|
54
|
-
-------
|
55
|
-
NDArray[np.intp]
|
56
|
-
The digitized values
|
57
|
-
"""
|
58
|
-
|
59
|
-
if not np.all([np.issubdtype(type(n), np.number) for n in continuous_values]):
|
60
|
-
raise TypeError(
|
61
|
-
f"Encountered a non-numeric value for factor {factor_name}, but the factor"
|
62
|
-
" was specified to be continuous. Ensure all occurrences of this factor are numeric types,"
|
63
|
-
f" or do not specify {factor_name} as a continuous factor."
|
64
|
-
)
|
65
|
-
|
66
|
-
_, bin_edges = np.histogram(continuous_values, bins=bins)
|
67
|
-
bin_edges[-1] = np.inf
|
68
|
-
bin_edges[0] = -np.inf
|
69
|
-
return np.digitize(continuous_values, bin_edges)
|
70
|
-
|
71
|
-
|
72
|
-
def format_discretize_factors(
|
73
|
-
data: NDArray[Any],
|
74
|
-
names: list[str],
|
75
|
-
is_categorical: list[bool],
|
76
|
-
continuous_factor_bincounts: Mapping[str, int] | None,
|
77
|
-
) -> dict[str, NDArray[Any]]:
|
78
|
-
"""
|
79
|
-
Sets up the internal list of metadata factors.
|
80
|
-
|
81
|
-
Parameters
|
82
|
-
----------
|
83
|
-
data : NDArray
|
84
|
-
The dataset factors, which are per-image attributes including class label and metadata.
|
85
|
-
names : list[str]
|
86
|
-
The class label
|
87
|
-
continuous_factor_bincounts : Mapping[str, int] or None
|
88
|
-
The factors in data_factors that have continuous values and the array of bin counts to
|
89
|
-
discretize values into. All factors are treated as having discrete values unless they
|
90
|
-
are specified as keys in this dictionary. Each element of this array must occur as a key
|
91
|
-
in data_factors.
|
92
|
-
|
93
|
-
Returns
|
94
|
-
-------
|
95
|
-
Dict[str, NDArray]
|
96
|
-
- Intrinsic per-image metadata information with the formatting that input data_factors uses.
|
97
|
-
Each key is a metadata factor, whose value is the discrete per-image factor values.
|
98
|
-
"""
|
99
|
-
|
100
|
-
if continuous_factor_bincounts:
|
101
|
-
invalid_keys = set(continuous_factor_bincounts.keys()) - set(names)
|
102
|
-
if invalid_keys:
|
103
|
-
raise KeyError(
|
104
|
-
f"The continuous factor(s) {invalid_keys} do not exist in data_factors. Delete these "
|
105
|
-
"keys from `continuous_factor_names` or add corresponding entries to `data_factors`."
|
106
|
-
)
|
107
|
-
|
108
|
-
warn = []
|
109
|
-
metadata_factors = {}
|
110
|
-
for i, name in enumerate(names):
|
111
|
-
if name == CLASS_LABEL:
|
112
|
-
continue
|
113
|
-
if continuous_factor_bincounts and name in continuous_factor_bincounts:
|
114
|
-
metadata_factors[name] = digitize_factor_bins(data[:, i], continuous_factor_bincounts[name], name)
|
115
|
-
elif not is_categorical[i]:
|
116
|
-
warn.append(name)
|
117
|
-
metadata_factors[name] = data[:, i]
|
118
|
-
else:
|
119
|
-
metadata_factors[name] = data[:, i]
|
120
|
-
|
121
|
-
if warn:
|
122
|
-
warnings.warn(
|
123
|
-
f"The following factors appear to be continuous but did not have the desired number of bins specified: \n\
|
124
|
-
{warn}",
|
125
|
-
UserWarning,
|
126
|
-
)
|
127
|
-
|
128
|
-
return metadata_factors
|
129
|
-
|
130
|
-
|
131
41
|
def normalize_expected_dist(expected_dist: NDArray[Any], observed_dist: NDArray[Any]) -> NDArray[Any]:
|
132
42
|
"""
|
133
43
|
Normalize the expected label distribution to match the total number of labels in the observed distribution.
|
@@ -206,7 +116,7 @@ def validate_dist(label_dist: NDArray[Any], label_name: str) -> None:
|
|
206
116
|
)
|
207
117
|
|
208
118
|
|
209
|
-
@set_metadata
|
119
|
+
@set_metadata
|
210
120
|
def label_parity(
|
211
121
|
expected_labels: ArrayLike,
|
212
122
|
observed_labels: ArrayLike,
|
@@ -294,32 +204,20 @@ def label_parity(
|
|
294
204
|
return ParityOutput(cs, p, None)
|
295
205
|
|
296
206
|
|
297
|
-
@set_metadata
|
298
|
-
def parity(
|
299
|
-
class_labels: ArrayLike,
|
300
|
-
metadata: Mapping[str, ArrayLike],
|
301
|
-
continuous_factor_bincounts: Mapping[str, int] | None = None,
|
302
|
-
) -> ParityOutput[NDArray[np.float64]]:
|
207
|
+
@set_metadata
|
208
|
+
def parity(metadata: MetadataOutput) -> ParityOutput[NDArray[np.float64]]:
|
303
209
|
"""
|
304
|
-
Calculate chi-square statistics to assess the relationship between multiple factors
|
210
|
+
Calculate chi-square statistics to assess the linear relationship between multiple factors
|
305
211
|
and class labels.
|
306
212
|
|
307
213
|
This function computes the chi-square statistic for each metadata factor to determine if there is
|
308
|
-
a significant relationship between the factor values and class labels. The
|
309
|
-
|
214
|
+
a significant relationship between the factor values and class labels. The chi-square statistic is
|
215
|
+
only valid for linear relationships. If non-linear relationships exist, use `balance`.
|
310
216
|
|
311
217
|
Parameters
|
312
218
|
----------
|
313
|
-
|
314
|
-
|
315
|
-
metadata : Mapping[str, ArrayLike]
|
316
|
-
The dataset factors, which are per-image metadata attributes.
|
317
|
-
Each key of dataset_factors is a factor, whose value is the per-image factor values.
|
318
|
-
continuous_factor_bincounts : Mapping[str, int] or None, default None
|
319
|
-
A dictionary specifying the number of bins for discretizing the continuous factors.
|
320
|
-
The keys should correspond to the names of continuous factors in `metadata`,
|
321
|
-
and the values should be the number of bins to use for discretization.
|
322
|
-
If not provided, no discretization is applied.
|
219
|
+
metadata : MetadataOutput
|
220
|
+
Output after running `metadata_preprocessing`
|
323
221
|
|
324
222
|
Returns
|
325
223
|
-------
|
@@ -333,74 +231,66 @@ def parity(
|
|
333
231
|
Warning
|
334
232
|
If any cell in the contingency matrix has a value between 0 and 5, a warning is issued because this can
|
335
233
|
lead to inaccurate chi-square calculations. It is recommended to ensure that each label co-occurs with
|
336
|
-
factor values either 0 times or at least 5 times.
|
337
|
-
into fewer bins.
|
234
|
+
factor values either 0 times or at least 5 times.
|
338
235
|
|
339
236
|
Note
|
340
237
|
----
|
341
|
-
- Each key of the ``continuous_factor_bincounts`` dictionary must occur as a key in data_factors.
|
342
238
|
- A high score with a low p-value suggests that a metadata factor is strongly correlated with a class label.
|
343
239
|
- The function creates a contingency matrix for each factor, where each entry represents the frequency of a
|
344
240
|
specific factor value co-occurring with a particular class label.
|
345
241
|
- Rows containing only zeros in the contingency matrix are removed before performing the chi-square test
|
346
242
|
to prevent errors in the calculation.
|
347
243
|
|
244
|
+
See Also
|
245
|
+
--------
|
246
|
+
balance
|
247
|
+
|
348
248
|
Examples
|
349
249
|
--------
|
350
250
|
Randomly creating some "continuous" and categorical variables using ``np.random.default_rng``
|
351
251
|
|
352
252
|
>>> labels = np_random_gen.choice([0, 1, 2], (100))
|
353
|
-
>>>
|
354
|
-
...
|
355
|
-
...
|
356
|
-
...
|
357
|
-
...
|
253
|
+
>>> metadata_dict = [
|
254
|
+
... {
|
255
|
+
... "age": list(np_random_gen.choice([25, 30, 35, 45], (100))),
|
256
|
+
... "income": list(np_random_gen.choice([50000, 65000, 80000], (100))),
|
257
|
+
... "gender": list(np_random_gen.choice(["M", "F"], (100))),
|
258
|
+
... }
|
259
|
+
... ]
|
358
260
|
>>> continuous_factor_bincounts = {"age": 4, "income": 3}
|
359
|
-
>>>
|
261
|
+
>>> metadata = metadata_preprocessing(metadata_dict, labels, continuous_factor_bincounts)
|
262
|
+
>>> parity(metadata)
|
360
263
|
ParityOutput(score=array([7.35731943, 5.46711299, 0.51506212]), p_value=array([0.28906231, 0.24263543, 0.77295762]), metadata_names=['age', 'income', 'gender'])
|
361
264
|
""" # noqa: E501
|
362
|
-
|
363
|
-
|
364
|
-
f"Got class labels with {len(np.shape(class_labels))}-dimensional",
|
365
|
-
f" shape {np.shape(class_labels)}, but expected a 1-dimensional array.",
|
366
|
-
)
|
367
|
-
|
368
|
-
data, names, is_categorical, _ = preprocess_metadata(class_labels, metadata)
|
369
|
-
|
370
|
-
factors = format_discretize_factors(data, names, is_categorical, continuous_factor_bincounts)
|
371
|
-
|
372
|
-
# unique class labels
|
373
|
-
class_idx = names.index(CLASS_LABEL)
|
374
|
-
u_cls = np.unique(data[:, class_idx])
|
375
|
-
|
376
|
-
chi_scores = np.zeros(len(factors))
|
377
|
-
p_values = np.zeros(len(factors))
|
265
|
+
chi_scores = np.zeros(metadata.discrete_data.shape[1])
|
266
|
+
p_values = np.zeros_like(chi_scores)
|
378
267
|
not_enough_data = {}
|
379
|
-
for i,
|
380
|
-
unique_factor_values = np.unique(factor_values)
|
381
|
-
contingency_matrix = np.zeros((len(unique_factor_values), u_cls.size))
|
268
|
+
for i, col_data in enumerate(metadata.discrete_data.T):
|
382
269
|
# Builds a contingency matrix where entry at index (r,c) represents
|
383
270
|
# the frequency of current_factor_name achieving value unique_factor_values[r]
|
384
271
|
# at a data point with class c.
|
385
|
-
|
386
|
-
#
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
not_enough_data[current_factor_name]
|
397
|
-
|
398
|
-
|
272
|
+
results = crosstab(col_data, metadata.class_labels)
|
273
|
+
contingency_matrix = as_numpy(results.count) # type: ignore
|
274
|
+
|
275
|
+
# Determines if any frequencies are too low
|
276
|
+
counts = np.nonzero(contingency_matrix < 5)
|
277
|
+
unique_factor_values = np.unique(col_data)
|
278
|
+
current_factor_name = metadata.discrete_factor_names[i]
|
279
|
+
for int_factor, int_class in zip(counts[0], counts[1]):
|
280
|
+
if contingency_matrix[int_factor, int_class] > 0:
|
281
|
+
factor_category = unique_factor_values[int_factor]
|
282
|
+
if current_factor_name not in not_enough_data:
|
283
|
+
not_enough_data[current_factor_name] = {}
|
284
|
+
if factor_category not in not_enough_data[current_factor_name]:
|
285
|
+
not_enough_data[current_factor_name][factor_category] = []
|
286
|
+
not_enough_data[current_factor_name][factor_category].append(
|
287
|
+
(metadata.class_names[int_class], int(contingency_matrix[int_factor, int_class]))
|
288
|
+
)
|
399
289
|
|
400
290
|
# This deletes rows containing only zeros,
|
401
291
|
# because scipy.stats.chi2_contingency fails when there are rows containing only zeros.
|
402
292
|
rowsums = np.sum(contingency_matrix, axis=1)
|
403
|
-
rowmask = np.
|
293
|
+
rowmask = np.nonzero(rowsums)[0]
|
404
294
|
contingency_matrix = contingency_matrix[rowmask]
|
405
295
|
|
406
296
|
chi2, p, _, _ = chi2_contingency(contingency_matrix)
|
@@ -428,4 +318,4 @@ def parity(
|
|
428
318
|
UserWarning,
|
429
319
|
)
|
430
320
|
|
431
|
-
return ParityOutput(chi_scores, p_values,
|
321
|
+
return ParityOutput(chi_scores, p_values, metadata.discrete_factor_names)
|
@@ -20,12 +20,12 @@ from scipy.sparse import coo_matrix
|
|
20
20
|
from scipy.stats import mode
|
21
21
|
|
22
22
|
from dataeval.interop import as_numpy
|
23
|
-
from dataeval.output import
|
23
|
+
from dataeval.output import Output, set_metadata
|
24
24
|
from dataeval.utils.shared import compute_neighbors, get_classes_counts, get_method, minimum_spanning_tree
|
25
25
|
|
26
26
|
|
27
27
|
@dataclass(frozen=True)
|
28
|
-
class BEROutput(
|
28
|
+
class BEROutput(Output):
|
29
29
|
"""
|
30
30
|
Output class for :func:`ber` estimator metric
|
31
31
|
|
@@ -114,7 +114,7 @@ def knn_lowerbound(value: float, classes: int, k: int) -> float:
|
|
114
114
|
return ((classes - 1) / classes) * (1 - np.sqrt(max(0, 1 - ((classes / (classes - 1)) * value))))
|
115
115
|
|
116
116
|
|
117
|
-
@set_metadata
|
117
|
+
@set_metadata
|
118
118
|
def ber(images: ArrayLike, labels: ArrayLike, k: int = 1, method: Literal["KNN", "MST"] = "KNN") -> BEROutput:
|
119
119
|
"""
|
120
120
|
An estimator for Multi-class :term:`Bayes error rate<Bayes Error Rate (BER)>` using FR or KNN test statistic basis
|
@@ -14,12 +14,12 @@ import numpy as np
|
|
14
14
|
from numpy.typing import ArrayLike, NDArray
|
15
15
|
|
16
16
|
from dataeval.interop import as_numpy
|
17
|
-
from dataeval.output import
|
17
|
+
from dataeval.output import Output, set_metadata
|
18
18
|
from dataeval.utils.shared import compute_neighbors, get_method, minimum_spanning_tree
|
19
19
|
|
20
20
|
|
21
21
|
@dataclass(frozen=True)
|
22
|
-
class DivergenceOutput(
|
22
|
+
class DivergenceOutput(Output):
|
23
23
|
"""
|
24
24
|
Output class for :func:`divergence` estimator metric
|
25
25
|
|
@@ -78,7 +78,7 @@ def divergence_fnn(data: NDArray[np.float64], labels: NDArray[np.int_]) -> int:
|
|
78
78
|
return errors
|
79
79
|
|
80
80
|
|
81
|
-
@set_metadata
|
81
|
+
@set_metadata
|
82
82
|
def divergence(data_a: ArrayLike, data_b: ArrayLike, method: Literal["FNN", "MST"] = "FNN") -> DivergenceOutput:
|
83
83
|
"""
|
84
84
|
Calculates the :term`divergence` and any errors between the datasets
|
@@ -14,11 +14,11 @@ from numpy.typing import ArrayLike
|
|
14
14
|
from sklearn.metrics import average_precision_score
|
15
15
|
|
16
16
|
from dataeval.interop import as_numpy
|
17
|
-
from dataeval.output import
|
17
|
+
from dataeval.output import Output, set_metadata
|
18
18
|
|
19
19
|
|
20
20
|
@dataclass(frozen=True)
|
21
|
-
class UAPOutput(
|
21
|
+
class UAPOutput(Output):
|
22
22
|
"""
|
23
23
|
Output class for :func:`uap` estimator metric
|
24
24
|
|
@@ -31,7 +31,7 @@ class UAPOutput(OutputMetadata):
|
|
31
31
|
uap: float
|
32
32
|
|
33
33
|
|
34
|
-
@set_metadata
|
34
|
+
@set_metadata
|
35
35
|
def uap(labels: ArrayLike, scores: ArrayLike) -> UAPOutput:
|
36
36
|
"""
|
37
37
|
FR Test Statistic based estimate of the empirical mean precision for
|
dataeval/metrics/stats/base.py
CHANGED
@@ -15,7 +15,7 @@ import tqdm
|
|
15
15
|
from numpy.typing import ArrayLike, NDArray
|
16
16
|
|
17
17
|
from dataeval.interop import to_numpy_iter
|
18
|
-
from dataeval.output import
|
18
|
+
from dataeval.output import Output
|
19
19
|
from dataeval.utils.image import normalize_image_shape, rescale
|
20
20
|
|
21
21
|
DTYPE_REGEX = re.compile(r"NDArray\[np\.(.*?)\]")
|
@@ -65,7 +65,7 @@ class SourceIndex(NamedTuple):
|
|
65
65
|
|
66
66
|
|
67
67
|
@dataclass(frozen=True)
|
68
|
-
class BaseStatsOutput(
|
68
|
+
class BaseStatsOutput(Output):
|
69
69
|
"""
|
70
70
|
Attributes
|
71
71
|
----------
|
@@ -15,11 +15,11 @@ from dataeval.metrics.stats.dimensionstats import (
|
|
15
15
|
from dataeval.metrics.stats.labelstats import LabelStatsOutput, labelstats
|
16
16
|
from dataeval.metrics.stats.pixelstats import PixelStatsOutput, PixelStatsProcessor
|
17
17
|
from dataeval.metrics.stats.visualstats import VisualStatsOutput, VisualStatsProcessor
|
18
|
-
from dataeval.output import
|
18
|
+
from dataeval.output import Output, set_metadata
|
19
19
|
|
20
20
|
|
21
21
|
@dataclass(frozen=True)
|
22
|
-
class DatasetStatsOutput(
|
22
|
+
class DatasetStatsOutput(Output):
|
23
23
|
"""
|
24
24
|
Output class for :func:`datasetstats` stats metric
|
25
25
|
|
@@ -41,7 +41,7 @@ class DatasetStatsOutput(OutputMetadata):
|
|
41
41
|
visualstats: VisualStatsOutput
|
42
42
|
labelstats: LabelStatsOutput | None = None
|
43
43
|
|
44
|
-
def _outputs(self) -> list[
|
44
|
+
def _outputs(self) -> list[Output]:
|
45
45
|
return [s for s in (self.dimensionstats, self.pixelstats, self.visualstats, self.labelstats) if s is not None]
|
46
46
|
|
47
47
|
def dict(self) -> dict[str, Any]:
|
@@ -54,7 +54,7 @@ class DatasetStatsOutput(OutputMetadata):
|
|
54
54
|
|
55
55
|
|
56
56
|
@dataclass(frozen=True)
|
57
|
-
class ChannelStatsOutput(
|
57
|
+
class ChannelStatsOutput(Output):
|
58
58
|
"""
|
59
59
|
Output class for :func:`channelstats` stats metric
|
60
60
|
|
@@ -84,7 +84,7 @@ class ChannelStatsOutput(OutputMetadata):
|
|
84
84
|
raise ValueError("All StatsOutput classes must contain the same number of image sources.")
|
85
85
|
|
86
86
|
|
87
|
-
@set_metadata
|
87
|
+
@set_metadata
|
88
88
|
def datasetstats(
|
89
89
|
images: Iterable[ArrayLike],
|
90
90
|
bboxes: Iterable[ArrayLike] | None = None,
|
@@ -131,7 +131,7 @@ def datasetstats(
|
|
131
131
|
return DatasetStatsOutput(*outputs, labelstats=labelstats(labels) if labels else None) # type: ignore
|
132
132
|
|
133
133
|
|
134
|
-
@set_metadata
|
134
|
+
@set_metadata
|
135
135
|
def channelstats(
|
136
136
|
images: Iterable[ArrayLike],
|
137
137
|
bboxes: Iterable[ArrayLike] | None = None,
|