dataeval 0.72.2__py3-none-any.whl → 0.73.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +1 -1
- dataeval/detectors/ood/ae.py +14 -6
- dataeval/detectors/ood/aegmm.py +14 -6
- dataeval/detectors/ood/base.py +9 -3
- dataeval/detectors/ood/llr.py +22 -16
- dataeval/detectors/ood/vae.py +14 -6
- dataeval/detectors/ood/vaegmm.py +14 -6
- dataeval/interop.py +9 -7
- dataeval/metrics/bias/balance.py +25 -29
- dataeval/metrics/bias/coverage.py +35 -3
- dataeval/metrics/bias/diversity.py +50 -27
- dataeval/metrics/bias/metadata.py +99 -16
- dataeval/metrics/bias/parity.py +43 -35
- dataeval/utils/__init__.py +2 -1
- dataeval/utils/lazy.py +26 -0
- dataeval/utils/metadata.py +258 -0
- dataeval/utils/tensorflow/_internal/gmm.py +8 -2
- dataeval/utils/tensorflow/_internal/loss.py +20 -11
- dataeval/utils/tensorflow/_internal/{pixelcnn.py → models.py} +371 -77
- dataeval/utils/tensorflow/_internal/trainer.py +12 -5
- dataeval/utils/tensorflow/_internal/utils.py +70 -71
- {dataeval-0.72.2.dist-info → dataeval-0.73.0.dist-info}/METADATA +3 -3
- {dataeval-0.72.2.dist-info → dataeval-0.73.0.dist-info}/RECORD +25 -24
- dataeval/utils/tensorflow/_internal/autoencoder.py +0 -316
- {dataeval-0.72.2.dist-info → dataeval-0.73.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.72.2.dist-info → dataeval-0.73.0.dist-info}/WHEEL +0 -0
@@ -2,16 +2,27 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = ["DiversityOutput", "diversity"]
|
4
4
|
|
5
|
+
import contextlib
|
5
6
|
from dataclasses import dataclass
|
6
7
|
from typing import Any, Literal, Mapping
|
7
8
|
|
8
9
|
import numpy as np
|
9
10
|
from numpy.typing import ArrayLike, NDArray
|
10
11
|
|
11
|
-
from dataeval.metrics.bias.metadata import
|
12
|
+
from dataeval.metrics.bias.metadata import (
|
13
|
+
diversity_bar_plot,
|
14
|
+
entropy,
|
15
|
+
get_counts,
|
16
|
+
get_num_bins,
|
17
|
+
heatmap,
|
18
|
+
preprocess_metadata,
|
19
|
+
)
|
12
20
|
from dataeval.output import OutputMetadata, set_metadata
|
13
21
|
from dataeval.utils.shared import get_method
|
14
22
|
|
23
|
+
with contextlib.suppress(ImportError):
|
24
|
+
from matplotlib.figure import Figure
|
25
|
+
|
15
26
|
|
16
27
|
@dataclass(frozen=True)
|
17
28
|
class DiversityOutput(OutputMetadata):
|
@@ -32,36 +43,50 @@ class DiversityOutput(OutputMetadata):
|
|
32
43
|
|
33
44
|
diversity_index: NDArray[np.float64]
|
34
45
|
classwise: NDArray[np.float64]
|
35
|
-
|
36
|
-
class_list: NDArray[np.int64]
|
46
|
+
class_list: NDArray[Any]
|
37
47
|
metadata_names: list[str]
|
38
|
-
|
39
48
|
method: Literal["shannon", "simpson"]
|
40
49
|
|
41
|
-
def plot(
|
50
|
+
def plot(
|
51
|
+
self,
|
52
|
+
row_labels: list[Any] | NDArray[Any] | None = None,
|
53
|
+
col_labels: list[Any] | NDArray[Any] | None = None,
|
54
|
+
plot_classwise: bool = False,
|
55
|
+
) -> Figure:
|
42
56
|
"""
|
43
57
|
Plot a heatmap of diversity information
|
44
58
|
|
45
59
|
Parameters
|
46
60
|
----------
|
47
|
-
row_labels:
|
48
|
-
Array containing the labels for rows in the histogram
|
49
|
-
col_labels:
|
50
|
-
Array containing the labels for columns in the histogram
|
61
|
+
row_labels : ArrayLike | None, default None
|
62
|
+
List/Array containing the labels for rows in the histogram
|
63
|
+
col_labels : ArrayLike | None, default None
|
64
|
+
List/Array containing the labels for columns in the histogram
|
65
|
+
plot_classwise : bool, default False
|
66
|
+
Whether to plot per-class balance instead of global balance
|
51
67
|
"""
|
52
|
-
if
|
53
|
-
row_labels
|
54
|
-
|
55
|
-
col_labels
|
68
|
+
if plot_classwise:
|
69
|
+
if row_labels is None:
|
70
|
+
row_labels = self.class_list
|
71
|
+
if col_labels is None:
|
72
|
+
col_labels = self.metadata_names
|
73
|
+
|
74
|
+
fig = heatmap(
|
75
|
+
self.classwise,
|
76
|
+
row_labels,
|
77
|
+
col_labels,
|
78
|
+
xlabel="Factors",
|
79
|
+
ylabel="Class",
|
80
|
+
cbarlabel=f"Normalized {self.method.title()} Index",
|
81
|
+
)
|
56
82
|
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
)
|
83
|
+
else:
|
84
|
+
# Creating label array for heat map axes
|
85
|
+
heat_labels = np.concatenate((["class"], self.metadata_names))
|
86
|
+
|
87
|
+
fig = diversity_bar_plot(heat_labels, self.diversity_index)
|
88
|
+
|
89
|
+
return fig
|
65
90
|
|
66
91
|
|
67
92
|
def diversity_shannon(
|
@@ -237,19 +262,17 @@ def diversity(
|
|
237
262
|
numpy.histogram
|
238
263
|
"""
|
239
264
|
diversity_fn = get_method({"simpson": diversity_simpson, "shannon": diversity_shannon}, method)
|
240
|
-
data, names, is_categorical = preprocess_metadata(class_labels, metadata)
|
265
|
+
data, names, is_categorical, unique_labels = preprocess_metadata(class_labels, metadata)
|
241
266
|
diversity_index = diversity_fn(data, names, is_categorical, None).astype(np.float64)
|
242
267
|
|
243
268
|
class_idx = names.index("class_label")
|
244
|
-
|
245
|
-
|
246
|
-
u_classes = np.unique(class_lbl)
|
269
|
+
u_classes = np.unique(data[:, class_idx])
|
247
270
|
num_factors = len(names)
|
248
271
|
diversity = np.empty((len(u_classes), num_factors))
|
249
272
|
diversity[:] = np.nan
|
250
273
|
for idx, cls in enumerate(u_classes):
|
251
|
-
subset_mask =
|
274
|
+
subset_mask = data[:, class_idx] == cls
|
252
275
|
diversity[idx, :] = diversity_fn(data, names, is_categorical, subset_mask)
|
253
276
|
div_no_class = np.concatenate((diversity[:, :class_idx], diversity[:, (class_idx + 1) :]), axis=1)
|
254
277
|
|
255
|
-
return DiversityOutput(diversity_index, div_no_class,
|
278
|
+
return DiversityOutput(diversity_index, div_no_class, unique_labels, list(metadata.keys()), method)
|
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
+
import contextlib
|
5
6
|
from typing import Any, Mapping
|
6
7
|
|
7
8
|
import numpy as np
|
@@ -10,6 +11,11 @@ from scipy.stats import entropy as sp_entropy
|
|
10
11
|
|
11
12
|
from dataeval.interop import to_numpy
|
12
13
|
|
14
|
+
with contextlib.suppress(ImportError):
|
15
|
+
from matplotlib.figure import Figure
|
16
|
+
|
17
|
+
CLASS_LABEL = "class_label"
|
18
|
+
|
13
19
|
|
14
20
|
def get_counts(
|
15
21
|
data: NDArray[np.int_], names: list[str], is_categorical: list[bool], subset_mask: NDArray[np.bool_] | None = None
|
@@ -147,14 +153,24 @@ def infer_categorical(arr: NDArray[Any], threshold: float = 0.2) -> NDArray[Any]
|
|
147
153
|
|
148
154
|
def preprocess_metadata(
|
149
155
|
class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], cat_thresh: float = 0.2
|
150
|
-
) -> tuple[NDArray[Any], list[str], list[bool]]:
|
156
|
+
) -> tuple[NDArray[Any], list[str], list[bool], NDArray[np.str_]]:
|
157
|
+
# if class_labels is not numeric
|
158
|
+
class_array = to_numpy(class_labels)
|
159
|
+
if not np.issubdtype(class_array.dtype, np.number):
|
160
|
+
unique_classes, numerical_labels = np.unique(class_array, return_inverse=True)
|
161
|
+
else:
|
162
|
+
numerical_labels = np.asarray(class_array, dtype=int)
|
163
|
+
unique_classes = np.unique(class_array)
|
164
|
+
|
151
165
|
# convert class_labels and dict of lists to matrix of metadata values
|
152
|
-
preprocessed_metadata = {
|
166
|
+
preprocessed_metadata = {CLASS_LABEL: numerical_labels}
|
153
167
|
|
154
168
|
# map columns of dict that are not numeric (e.g. string) to numeric values
|
155
169
|
# that mutual information and diversity functions can accommodate. Each
|
156
170
|
# unique string receives a unique integer value.
|
157
171
|
for k, v in metadata.items():
|
172
|
+
if k == CLASS_LABEL:
|
173
|
+
k = "label_class"
|
158
174
|
# if not numeric
|
159
175
|
v = to_numpy(v)
|
160
176
|
if not np.issubdtype(v.dtype, np.number):
|
@@ -167,35 +183,34 @@ def preprocess_metadata(
|
|
167
183
|
names = list(preprocessed_metadata.keys())
|
168
184
|
is_categorical = [infer_categorical(preprocessed_metadata[var], cat_thresh)[0] for var in names]
|
169
185
|
|
170
|
-
return data, names, is_categorical
|
186
|
+
return data, names, is_categorical, unique_classes
|
171
187
|
|
172
188
|
|
173
189
|
def heatmap(
|
174
190
|
data: NDArray[Any],
|
175
|
-
row_labels: NDArray[Any],
|
176
|
-
col_labels: NDArray[Any],
|
191
|
+
row_labels: list[str] | NDArray[Any],
|
192
|
+
col_labels: list[str] | NDArray[Any],
|
177
193
|
xlabel: str = "",
|
178
194
|
ylabel: str = "",
|
179
195
|
cbarlabel: str = "",
|
180
|
-
) ->
|
196
|
+
) -> Figure:
|
181
197
|
"""
|
182
198
|
Plots a formatted heatmap
|
183
199
|
|
184
200
|
Parameters
|
185
201
|
----------
|
186
|
-
data: NDArray
|
202
|
+
data : NDArray
|
187
203
|
Array containing numerical values for factors to plot
|
188
|
-
row_labels:
|
189
|
-
Array containing the labels for rows in the histogram
|
190
|
-
col_labels:
|
191
|
-
Array containing the labels for columns in the histogram
|
192
|
-
xlabel: str, default ""
|
204
|
+
row_labels : ArrayLike
|
205
|
+
List/Array containing the labels for rows in the histogram
|
206
|
+
col_labels : ArrayLike
|
207
|
+
List/Array containing the labels for columns in the histogram
|
208
|
+
xlabel : str, default ""
|
193
209
|
X-axis label
|
194
|
-
ylabel: str, default ""
|
210
|
+
ylabel : str, default ""
|
195
211
|
Y-axis label
|
196
|
-
cbarlabel: str, default ""
|
212
|
+
cbarlabel : str, default ""
|
197
213
|
Label for the colorbar
|
198
|
-
|
199
214
|
"""
|
200
215
|
import matplotlib
|
201
216
|
import matplotlib.pyplot as plt
|
@@ -252,7 +267,7 @@ def heatmap(
|
|
252
267
|
texts.append(text)
|
253
268
|
|
254
269
|
fig.tight_layout()
|
255
|
-
|
270
|
+
return fig
|
256
271
|
|
257
272
|
|
258
273
|
# Function to define how the text is displayed in the heatmap
|
@@ -273,3 +288,71 @@ def format_text(*args: str) -> str:
|
|
273
288
|
"""
|
274
289
|
x = args[0]
|
275
290
|
return f"{x:.2f}".replace("0.00", "0").replace("0.", ".").replace("nan", "")
|
291
|
+
|
292
|
+
|
293
|
+
def diversity_bar_plot(labels: NDArray[Any], bar_heights: NDArray[Any]) -> Figure:
|
294
|
+
"""
|
295
|
+
Plots a formatted bar plot
|
296
|
+
|
297
|
+
Parameters
|
298
|
+
----------
|
299
|
+
labels : NDArray
|
300
|
+
Array containing the labels for each bar
|
301
|
+
bar_heights : NDArray
|
302
|
+
Array containing the values for each bar
|
303
|
+
"""
|
304
|
+
import matplotlib.pyplot as plt
|
305
|
+
|
306
|
+
fig, ax = plt.subplots(figsize=(10, 10))
|
307
|
+
|
308
|
+
ax.bar(labels, bar_heights)
|
309
|
+
ax.set_xlabel("Factors")
|
310
|
+
|
311
|
+
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
|
312
|
+
|
313
|
+
fig.tight_layout()
|
314
|
+
return fig
|
315
|
+
|
316
|
+
|
317
|
+
def coverage_plot(images: NDArray[Any], num_images: int) -> Figure:
|
318
|
+
"""
|
319
|
+
Creates a single plot of all of the provided images
|
320
|
+
|
321
|
+
Parameters
|
322
|
+
----------
|
323
|
+
images : NDArray
|
324
|
+
Array containing only the desired images to plot
|
325
|
+
"""
|
326
|
+
import matplotlib.pyplot as plt
|
327
|
+
|
328
|
+
num_images = min(num_images, len(images))
|
329
|
+
|
330
|
+
if images.ndim == 4:
|
331
|
+
images = np.moveaxis(images, 1, -1)
|
332
|
+
elif images.ndim == 3:
|
333
|
+
images = np.repeat(images[:, :, :, np.newaxis], 3, axis=-1)
|
334
|
+
else:
|
335
|
+
raise ValueError(
|
336
|
+
f"Expected a (N,C,H,W) or a (N, H, W) set of images, but got a {images.ndim}-dimensional set of images."
|
337
|
+
)
|
338
|
+
|
339
|
+
rows = np.ceil(num_images / 3).astype(int)
|
340
|
+
fig, axs = plt.subplots(rows, 3, figsize=(9, 3 * rows))
|
341
|
+
|
342
|
+
if rows == 1:
|
343
|
+
for j in range(3):
|
344
|
+
if j >= len(images):
|
345
|
+
continue
|
346
|
+
axs[j].imshow(images[j])
|
347
|
+
axs[j].axis("off")
|
348
|
+
else:
|
349
|
+
for i in range(rows):
|
350
|
+
for j in range(3):
|
351
|
+
i_j = i * 3 + j
|
352
|
+
if i_j >= len(images):
|
353
|
+
continue
|
354
|
+
axs[i, j].imshow(images[i_j])
|
355
|
+
axs[i, j].axis("off")
|
356
|
+
|
357
|
+
fig.tight_layout()
|
358
|
+
return fig
|
dataeval/metrics/bias/parity.py
CHANGED
@@ -11,6 +11,7 @@ from numpy.typing import ArrayLike, NDArray
|
|
11
11
|
from scipy.stats import chi2_contingency, chisquare
|
12
12
|
|
13
13
|
from dataeval.interop import to_numpy
|
14
|
+
from dataeval.metrics.bias.metadata import CLASS_LABEL, preprocess_metadata
|
14
15
|
from dataeval.output import OutputMetadata, set_metadata
|
15
16
|
|
16
17
|
TData = TypeVar("TData", np.float64, NDArray[np.float64])
|
@@ -27,10 +28,13 @@ class ParityOutput(Generic[TData], OutputMetadata):
|
|
27
28
|
chi-squared score(s) of the test
|
28
29
|
p_value : np.float64 | NDArray[np.float64]
|
29
30
|
p-value(s) of the test
|
31
|
+
metadata_names: list[str] | None
|
32
|
+
Names of each metadata factor
|
30
33
|
"""
|
31
34
|
|
32
35
|
score: TData
|
33
36
|
p_value: TData
|
37
|
+
metadata_names: list[str] | None
|
34
38
|
|
35
39
|
|
36
40
|
def digitize_factor_bins(continuous_values: NDArray[Any], bins: int, factor_name: str) -> NDArray[np.intp]:
|
@@ -66,7 +70,7 @@ def digitize_factor_bins(continuous_values: NDArray[Any], bins: int, factor_name
|
|
66
70
|
|
67
71
|
|
68
72
|
def format_discretize_factors(
|
69
|
-
|
73
|
+
data: NDArray[Any], names: list[str], is_categorical: list[bool], continuous_factor_bincounts: Mapping[str, int]
|
70
74
|
) -> dict[str, NDArray[Any]]:
|
71
75
|
"""
|
72
76
|
Sets up the internal list of metadata factors.
|
@@ -89,30 +93,32 @@ def format_discretize_factors(
|
|
89
93
|
Each key is a metadata factor, whose value is the discrete per-image factor values.
|
90
94
|
"""
|
91
95
|
|
92
|
-
invalid_keys = set(continuous_factor_bincounts.keys()) - set(
|
96
|
+
invalid_keys = set(continuous_factor_bincounts.keys()) - set(names)
|
93
97
|
if invalid_keys:
|
94
98
|
raise KeyError(
|
95
99
|
f"The continuous factor(s) {invalid_keys} do not exist in data_factors. Delete these "
|
96
100
|
"keys from `continuous_factor_names` or add corresponding entries to `data_factors`."
|
97
101
|
)
|
98
102
|
|
103
|
+
warn = []
|
99
104
|
metadata_factors = {}
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
105
|
+
for i, name in enumerate(names):
|
106
|
+
if name == CLASS_LABEL:
|
107
|
+
continue
|
108
|
+
if name in continuous_factor_bincounts:
|
109
|
+
metadata_factors[name] = digitize_factor_bins(data[:, i], continuous_factor_bincounts[name], name)
|
110
|
+
elif not is_categorical[i]:
|
111
|
+
warn.append(name)
|
112
|
+
metadata_factors[name] = data[:, i]
|
113
|
+
else:
|
114
|
+
metadata_factors[name] = data[:, i]
|
115
|
+
|
116
|
+
if warn:
|
117
|
+
warnings.warn(
|
118
|
+
f"The following factors appear to be continuous but did not have the desired number of bins specified: \n\
|
119
|
+
{warn}",
|
120
|
+
UserWarning,
|
121
|
+
)
|
116
122
|
|
117
123
|
return metadata_factors
|
118
124
|
|
@@ -247,7 +253,7 @@ def label_parity(
|
|
247
253
|
>>> expected_labels = np_random_gen.choice([0, 1, 2, 3, 4], (100))
|
248
254
|
>>> observed_labels = np_random_gen.choice([2, 3, 0, 4, 1], (100))
|
249
255
|
>>> label_parity(expected_labels, observed_labels)
|
250
|
-
ParityOutput(score=14.007374204742625, p_value=0.0072715574616218)
|
256
|
+
ParityOutput(score=14.007374204742625, p_value=0.0072715574616218, metadata_names=None)
|
251
257
|
"""
|
252
258
|
|
253
259
|
# Calculate
|
@@ -278,13 +284,13 @@ def label_parity(
|
|
278
284
|
)
|
279
285
|
|
280
286
|
cs, p = chisquare(f_obs=observed_dist, f_exp=expected_dist)
|
281
|
-
return ParityOutput(cs, p)
|
287
|
+
return ParityOutput(cs, p, None)
|
282
288
|
|
283
289
|
|
284
290
|
@set_metadata()
|
285
291
|
def parity(
|
286
292
|
class_labels: ArrayLike,
|
287
|
-
|
293
|
+
metadata: Mapping[str, ArrayLike],
|
288
294
|
continuous_factor_bincounts: Mapping[str, int] | None = None,
|
289
295
|
) -> ParityOutput[NDArray[np.float64]]:
|
290
296
|
"""
|
@@ -299,12 +305,12 @@ def parity(
|
|
299
305
|
----------
|
300
306
|
class_labels: ArrayLike
|
301
307
|
List of class labels for each image
|
302
|
-
|
308
|
+
metadata: Mapping[str, ArrayLike]
|
303
309
|
The dataset factors, which are per-image metadata attributes.
|
304
310
|
Each key of dataset_factors is a factor, whose value is the per-image factor values.
|
305
311
|
continuous_factor_bincounts : Mapping[str, int] | None, default None
|
306
312
|
A dictionary specifying the number of bins for discretizing the continuous factors.
|
307
|
-
The keys should correspond to the names of continuous factors in `
|
313
|
+
The keys should correspond to the names of continuous factors in `metadata`,
|
308
314
|
and the values should be the number of bins to use for discretization.
|
309
315
|
If not provided, no discretization is applied.
|
310
316
|
|
@@ -337,42 +343,44 @@ def parity(
|
|
337
343
|
Randomly creating some "continuous" and categorical variables using ``np.random.default_rng``
|
338
344
|
|
339
345
|
>>> labels = np_random_gen.choice([0, 1, 2], (100))
|
340
|
-
>>>
|
346
|
+
>>> metadata = {
|
341
347
|
... "age": np_random_gen.choice([25, 30, 35, 45], (100)),
|
342
348
|
... "income": np_random_gen.choice([50000, 65000, 80000], (100)),
|
343
349
|
... "gender": np_random_gen.choice(["M", "F"], (100)),
|
344
350
|
... }
|
345
351
|
>>> continuous_factor_bincounts = {"age": 4, "income": 3}
|
346
|
-
>>> parity(labels,
|
347
|
-
ParityOutput(score=array([7.35731943, 5.46711299, 0.51506212]), p_value=array([0.28906231, 0.24263543, 0.77295762]))
|
348
|
-
"""
|
352
|
+
>>> parity(labels, metadata, continuous_factor_bincounts)
|
353
|
+
ParityOutput(score=array([7.35731943, 5.46711299, 0.51506212]), p_value=array([0.28906231, 0.24263543, 0.77295762]), metadata_names=['age', 'income', 'gender'])
|
354
|
+
""" # noqa: E501
|
349
355
|
if len(np.shape(class_labels)) > 1:
|
350
356
|
raise ValueError(
|
351
357
|
f"Got class labels with {len(np.shape(class_labels))}-dimensional",
|
352
358
|
f" shape {np.shape(class_labels)}, but expected a 1-dimensional array.",
|
353
359
|
)
|
354
360
|
|
355
|
-
|
361
|
+
data, names, is_categorical, _ = preprocess_metadata(class_labels, metadata)
|
356
362
|
continuous_factor_bincounts = continuous_factor_bincounts if continuous_factor_bincounts else {}
|
357
363
|
|
358
|
-
|
359
|
-
|
364
|
+
factors = format_discretize_factors(data, names, is_categorical, continuous_factor_bincounts)
|
365
|
+
|
366
|
+
# unique class labels
|
367
|
+
class_idx = names.index(CLASS_LABEL)
|
368
|
+
u_cls = np.unique(data[:, class_idx])
|
360
369
|
|
361
370
|
chi_scores = np.zeros(len(factors))
|
362
371
|
p_values = np.zeros(len(factors))
|
363
|
-
n_cls = len(np.unique(labels))
|
364
372
|
not_enough_data = {}
|
365
373
|
for i, (current_factor_name, factor_values) in enumerate(factors.items()):
|
366
374
|
unique_factor_values = np.unique(factor_values)
|
367
|
-
contingency_matrix = np.zeros((len(unique_factor_values),
|
375
|
+
contingency_matrix = np.zeros((len(unique_factor_values), u_cls.size))
|
368
376
|
# Builds a contingency matrix where entry at index (r,c) represents
|
369
377
|
# the frequency of current_factor_name achieving value unique_factor_values[r]
|
370
378
|
# at a data point with class c.
|
371
379
|
|
372
380
|
# TODO: Vectorize this nested for loop
|
373
381
|
for fi, factor_value in enumerate(unique_factor_values):
|
374
|
-
for label in
|
375
|
-
with_both = np.bitwise_and((
|
382
|
+
for label in u_cls:
|
383
|
+
with_both = np.bitwise_and((data[:, class_idx] == label), factor_values == factor_value)
|
376
384
|
contingency_matrix[fi, label] = np.sum(with_both)
|
377
385
|
if 0 < contingency_matrix[fi, label] < 5:
|
378
386
|
if current_factor_name not in not_enough_data:
|
@@ -414,4 +422,4 @@ def parity(
|
|
414
422
|
UserWarning,
|
415
423
|
)
|
416
424
|
|
417
|
-
return ParityOutput(chi_scores, p_values)
|
425
|
+
return ParityOutput(chi_scores, p_values, list(metadata.keys()))
|
dataeval/utils/__init__.py
CHANGED
@@ -5,9 +5,10 @@ metrics. Currently DataEval supports both :term:`TensorFlow` and PyTorch backend
|
|
5
5
|
"""
|
6
6
|
|
7
7
|
from dataeval import _IS_TENSORFLOW_AVAILABLE, _IS_TORCH_AVAILABLE
|
8
|
+
from dataeval.utils.metadata import merge_metadata
|
8
9
|
from dataeval.utils.split_dataset import split_dataset
|
9
10
|
|
10
|
-
__all__ = ["split_dataset"]
|
11
|
+
__all__ = ["split_dataset", "merge_metadata"]
|
11
12
|
|
12
13
|
if _IS_TORCH_AVAILABLE: # pragma: no cover
|
13
14
|
from dataeval.utils import torch
|
dataeval/utils/lazy.py
ADDED
@@ -0,0 +1,26 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from functools import cached_property
|
4
|
+
from importlib import import_module
|
5
|
+
from typing import Any
|
6
|
+
|
7
|
+
|
8
|
+
class LazyModule:
|
9
|
+
def __init__(self, name: str) -> None:
|
10
|
+
self._name = name
|
11
|
+
|
12
|
+
def __getattr__(self, key: str) -> Any:
|
13
|
+
return getattr(self._module, key)
|
14
|
+
|
15
|
+
@cached_property
|
16
|
+
def _module(self):
|
17
|
+
return import_module(self._name)
|
18
|
+
|
19
|
+
|
20
|
+
LAZY_MODULES: dict[str, LazyModule] = {}
|
21
|
+
|
22
|
+
|
23
|
+
def lazyload(name: str) -> LazyModule:
|
24
|
+
if name not in LAZY_MODULES:
|
25
|
+
LAZY_MODULES[name] = LazyModule(name)
|
26
|
+
return LAZY_MODULES[name]
|