dataeval 0.73.1__py3-none-any.whl → 0.74.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +3 -9
- dataeval/detectors/__init__.py +2 -10
- dataeval/detectors/drift/base.py +3 -3
- dataeval/detectors/drift/mmd.py +1 -1
- dataeval/detectors/drift/torch.py +1 -101
- dataeval/detectors/linters/clusterer.py +3 -3
- dataeval/detectors/linters/duplicates.py +4 -4
- dataeval/detectors/linters/outliers.py +4 -4
- dataeval/detectors/ood/__init__.py +9 -9
- dataeval/detectors/ood/{ae.py → ae_torch.py} +22 -27
- dataeval/detectors/ood/base.py +63 -113
- dataeval/detectors/ood/base_torch.py +109 -0
- dataeval/detectors/ood/metadata_ks_compare.py +52 -14
- dataeval/interop.py +1 -1
- dataeval/metrics/bias/__init__.py +3 -0
- dataeval/metrics/bias/balance.py +73 -70
- dataeval/metrics/bias/coverage.py +4 -4
- dataeval/metrics/bias/diversity.py +67 -136
- dataeval/metrics/bias/metadata_preprocessing.py +285 -0
- dataeval/metrics/bias/metadata_utils.py +229 -0
- dataeval/metrics/bias/parity.py +51 -161
- dataeval/metrics/estimators/ber.py +3 -3
- dataeval/metrics/estimators/divergence.py +3 -3
- dataeval/metrics/estimators/uap.py +3 -3
- dataeval/metrics/stats/base.py +2 -2
- dataeval/metrics/stats/boxratiostats.py +1 -1
- dataeval/metrics/stats/datasetstats.py +6 -6
- dataeval/metrics/stats/dimensionstats.py +1 -1
- dataeval/metrics/stats/hashstats.py +1 -1
- dataeval/metrics/stats/labelstats.py +3 -3
- dataeval/metrics/stats/pixelstats.py +1 -1
- dataeval/metrics/stats/visualstats.py +1 -1
- dataeval/output.py +77 -53
- dataeval/utils/__init__.py +1 -7
- dataeval/utils/gmm.py +26 -0
- dataeval/utils/metadata.py +29 -9
- dataeval/utils/torch/gmm.py +98 -0
- dataeval/utils/torch/models.py +192 -0
- dataeval/utils/torch/trainer.py +84 -5
- dataeval/utils/torch/utils.py +107 -1
- dataeval/workflows/sufficiency.py +4 -4
- {dataeval-0.73.1.dist-info → dataeval-0.74.1.dist-info}/METADATA +3 -9
- dataeval-0.74.1.dist-info/RECORD +65 -0
- dataeval/detectors/ood/aegmm.py +0 -66
- dataeval/detectors/ood/llr.py +0 -302
- dataeval/detectors/ood/vae.py +0 -97
- dataeval/detectors/ood/vaegmm.py +0 -75
- dataeval/metrics/bias/metadata.py +0 -440
- dataeval/utils/lazy.py +0 -26
- dataeval/utils/tensorflow/__init__.py +0 -19
- dataeval/utils/tensorflow/_internal/gmm.py +0 -123
- dataeval/utils/tensorflow/_internal/loss.py +0 -121
- dataeval/utils/tensorflow/_internal/models.py +0 -1394
- dataeval/utils/tensorflow/_internal/trainer.py +0 -114
- dataeval/utils/tensorflow/_internal/utils.py +0 -256
- dataeval/utils/tensorflow/loss/__init__.py +0 -11
- dataeval-0.73.1.dist-info/RECORD +0 -73
- {dataeval-0.73.1.dist-info → dataeval-0.74.1.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.73.1.dist-info → dataeval-0.74.1.dist-info}/WHEEL +0 -0
@@ -1,440 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
__all__ = []
|
4
|
-
|
5
|
-
import contextlib
|
6
|
-
from typing import Any, Mapping
|
7
|
-
|
8
|
-
import numpy as np
|
9
|
-
from numpy.typing import ArrayLike, NDArray
|
10
|
-
from scipy.stats import entropy as sp_entropy
|
11
|
-
|
12
|
-
from dataeval.interop import to_numpy
|
13
|
-
|
14
|
-
with contextlib.suppress(ImportError):
|
15
|
-
from matplotlib.figure import Figure
|
16
|
-
|
17
|
-
CLASS_LABEL = "class_label"
|
18
|
-
|
19
|
-
|
20
|
-
def get_counts(
|
21
|
-
data: NDArray[Any],
|
22
|
-
names: list[str],
|
23
|
-
continuous_factor_bincounts: Mapping[str, int] | None = None,
|
24
|
-
subset_mask: NDArray[np.bool_] | None = None,
|
25
|
-
hist_cache: dict[str, NDArray[np.intp]] | None = None,
|
26
|
-
) -> dict[str, NDArray[np.intp]]:
|
27
|
-
"""
|
28
|
-
Initialize dictionary of histogram counts --- treat categorical values
|
29
|
-
as histogram bins.
|
30
|
-
|
31
|
-
Parameters
|
32
|
-
----------
|
33
|
-
data : NDArray
|
34
|
-
Array containing numerical values for metadata factors
|
35
|
-
names : list[str]
|
36
|
-
Names of metadata factors -- keys of the metadata dictionary
|
37
|
-
continuous_factor_bincounts : Mapping[str, int] or None, default None
|
38
|
-
The factors in names that have continuous values and the array of bin counts to
|
39
|
-
discretize values into. All factors are treated as having discrete values unless they
|
40
|
-
are specified as keys in this dictionary. Each element of this array must occur as a key
|
41
|
-
in names.
|
42
|
-
Names of metadata factors -- keys of the metadata dictionary
|
43
|
-
subset_mask : NDArray[np.bool_] or None, default None
|
44
|
-
Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
|
45
|
-
hist_cache : dict[str, NDArray[np.intp]] or None, default None
|
46
|
-
Optional cache to store histogram counts
|
47
|
-
|
48
|
-
Returns
|
49
|
-
-------
|
50
|
-
dict[str, NDArray[np.intp]]
|
51
|
-
histogram counts per metadata factor in `factors`. Each
|
52
|
-
factor will have a different number of bins. Counts get reused
|
53
|
-
across metrics, so hist_counts are cached but only if computed
|
54
|
-
globally, i.e. without masked samples.
|
55
|
-
"""
|
56
|
-
|
57
|
-
hist_counts = {}
|
58
|
-
|
59
|
-
mask = subset_mask if subset_mask is not None else np.ones(data.shape[0], dtype=np.bool_)
|
60
|
-
|
61
|
-
for cdx, fn in enumerate(names):
|
62
|
-
if hist_cache is not None and fn in hist_cache:
|
63
|
-
cnts = hist_cache[fn]
|
64
|
-
else:
|
65
|
-
hist_edges = np.array([-np.inf, np.inf])
|
66
|
-
cnts = np.array([len(data[:, cdx].squeeze())])
|
67
|
-
# linter doesn't like double indexing
|
68
|
-
col_data = np.array(data[mask, cdx].squeeze(), dtype=np.float64)
|
69
|
-
|
70
|
-
if continuous_factor_bincounts and fn in continuous_factor_bincounts:
|
71
|
-
num_bins = continuous_factor_bincounts[fn]
|
72
|
-
_, hist_edges = np.histogram(data[:, cdx].squeeze(), bins=num_bins, density=True)
|
73
|
-
hist_edges[-1] = np.inf
|
74
|
-
hist_edges[0] = -np.inf
|
75
|
-
disc_col_data = np.digitize(col_data, np.array(hist_edges))
|
76
|
-
_, cnts = np.unique(disc_col_data, return_counts=True)
|
77
|
-
else:
|
78
|
-
_, cnts = np.unique(col_data, return_counts=True)
|
79
|
-
|
80
|
-
if hist_cache is not None:
|
81
|
-
hist_cache[fn] = cnts
|
82
|
-
|
83
|
-
hist_counts[fn] = cnts
|
84
|
-
|
85
|
-
return hist_counts
|
86
|
-
|
87
|
-
|
88
|
-
def entropy(
|
89
|
-
data: NDArray[Any],
|
90
|
-
names: list[str],
|
91
|
-
continuous_factor_bincounts: Mapping[str, int] | None = None,
|
92
|
-
normalized: bool = False,
|
93
|
-
subset_mask: NDArray[np.bool_] | None = None,
|
94
|
-
hist_cache: dict[str, NDArray[np.intp]] | None = None,
|
95
|
-
) -> NDArray[np.float64]:
|
96
|
-
"""
|
97
|
-
Meant for use with :term:`bias<Bias>` metrics, :term:`balance<Balance>`, :term:`diversity<Diversity>`,
|
98
|
-
ClasswiseBalance, and Classwise Diversity.
|
99
|
-
|
100
|
-
Compute entropy for discrete/categorical variables and for continuous variables through standard
|
101
|
-
histogram binning.
|
102
|
-
|
103
|
-
Parameters
|
104
|
-
----------
|
105
|
-
data : NDArray
|
106
|
-
Array containing numerical values for metadata factors
|
107
|
-
names : list[str]
|
108
|
-
Names of metadata factors -- keys of the metadata dictionary
|
109
|
-
continuous_factor_bincounts : Mapping[str, int] or None, default None
|
110
|
-
The factors in names that have continuous values and the array of bin counts to
|
111
|
-
discretize values into. All factors are treated as having discrete values unless they
|
112
|
-
are specified as keys in this dictionary. Each element of this array must occur as a key
|
113
|
-
in names.
|
114
|
-
normalized : bool, default False
|
115
|
-
Flag that determines whether or not to normalize entropy by log(num_bins)
|
116
|
-
subset_mask : NDArray[np.bool_] or None, default None
|
117
|
-
Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
|
118
|
-
hist_cache : dict[str, NDArray[np.intp]] or None, default None
|
119
|
-
Optional cache to store histogram counts
|
120
|
-
|
121
|
-
Notes
|
122
|
-
-----
|
123
|
-
For continuous variables, histogram bins are chosen automatically. See
|
124
|
-
numpy.histogram for details.
|
125
|
-
|
126
|
-
Returns
|
127
|
-
-------
|
128
|
-
NDArray[np.float64]
|
129
|
-
Entropy estimate per column of X
|
130
|
-
|
131
|
-
See Also
|
132
|
-
--------
|
133
|
-
numpy.histogram
|
134
|
-
scipy.stats.entropy
|
135
|
-
"""
|
136
|
-
|
137
|
-
num_factors = len(names)
|
138
|
-
hist_counts = get_counts(data, names, continuous_factor_bincounts, subset_mask, hist_cache)
|
139
|
-
|
140
|
-
ev_index = np.empty(num_factors)
|
141
|
-
for col, cnts in enumerate(hist_counts.values()):
|
142
|
-
# entropy in nats, normalizes counts
|
143
|
-
ev_index[col] = sp_entropy(cnts)
|
144
|
-
if normalized:
|
145
|
-
cnt_len = np.size(cnts, 0)
|
146
|
-
if cnt_len == 1:
|
147
|
-
# log(0)
|
148
|
-
ev_index[col] = 0
|
149
|
-
else:
|
150
|
-
ev_index[col] /= np.log(cnt_len)
|
151
|
-
return ev_index
|
152
|
-
|
153
|
-
|
154
|
-
def get_num_bins(
|
155
|
-
data: NDArray[Any],
|
156
|
-
names: list[str],
|
157
|
-
continuous_factor_bincounts: Mapping[str, int] | None = None,
|
158
|
-
subset_mask: NDArray[np.bool_] | None = None,
|
159
|
-
hist_cache: dict[str, NDArray[np.intp]] | None = None,
|
160
|
-
) -> NDArray[np.float64]:
|
161
|
-
"""
|
162
|
-
Number of bins or unique values for each metadata factor, used to
|
163
|
-
normalize entropy/diversity.
|
164
|
-
|
165
|
-
Parameters
|
166
|
-
----------
|
167
|
-
data : NDArray
|
168
|
-
Array containing numerical values for metadata factors
|
169
|
-
names : list[str]
|
170
|
-
Names of metadata factors -- keys of the metadata dictionary
|
171
|
-
continuous_factor_bincounts : Mapping[str, int] or None, default None
|
172
|
-
The factors in names that have continuous values and the array of bin counts to
|
173
|
-
discretize values into. All factors are treated as having discrete values unless they
|
174
|
-
are specified as keys in this dictionary. Each element of this array must occur as a key
|
175
|
-
in names.
|
176
|
-
subset_mask : NDArray[np.bool_] or None, default None
|
177
|
-
Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
|
178
|
-
hist_cache : dict[str, NDArray[np.intp]] or None, default None
|
179
|
-
Optional cache to store histogram counts
|
180
|
-
|
181
|
-
Returns
|
182
|
-
-------
|
183
|
-
NDArray[np.float64]
|
184
|
-
Number of bins used in the discretization for each value in names.
|
185
|
-
"""
|
186
|
-
# likely cached
|
187
|
-
hist_counts = get_counts(data, names, continuous_factor_bincounts, subset_mask, hist_cache)
|
188
|
-
num_bins = np.empty(len(hist_counts))
|
189
|
-
for idx, cnts in enumerate(hist_counts.values()):
|
190
|
-
num_bins[idx] = np.size(cnts, 0)
|
191
|
-
|
192
|
-
return num_bins
|
193
|
-
|
194
|
-
|
195
|
-
def infer_categorical(arr: NDArray[np.float64], threshold: float = 0.2) -> NDArray[np.bool_]:
|
196
|
-
"""
|
197
|
-
Compute fraction of feature values that are unique --- intended to be used
|
198
|
-
for inferring whether variables are categorical.
|
199
|
-
"""
|
200
|
-
if arr.ndim == 1:
|
201
|
-
arr = np.expand_dims(arr, axis=1)
|
202
|
-
num_samples = arr.shape[0]
|
203
|
-
pct_unique = np.empty(arr.shape[1])
|
204
|
-
for col in range(arr.shape[1]): # type: ignore
|
205
|
-
uvals = np.unique(arr[:, col], axis=0)
|
206
|
-
pct_unique[col] = len(uvals) / num_samples
|
207
|
-
return pct_unique < threshold
|
208
|
-
|
209
|
-
|
210
|
-
def preprocess_metadata(
|
211
|
-
class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], cat_thresh: float = 0.2
|
212
|
-
) -> tuple[NDArray[Any], list[str], list[bool], NDArray[np.str_]]:
|
213
|
-
"""
|
214
|
-
Formats metadata by organizing factor names, converting labels to numeric values,
|
215
|
-
adds class labels to the dataset structure, and marks which factors are categorical.
|
216
|
-
"""
|
217
|
-
# if class_labels is not numeric
|
218
|
-
class_array = to_numpy(class_labels)
|
219
|
-
if not np.issubdtype(class_array.dtype, np.integer):
|
220
|
-
unique_classes, numerical_labels = np.unique(class_array, return_inverse=True)
|
221
|
-
else:
|
222
|
-
numerical_labels = np.asarray(class_array, dtype=np.intp)
|
223
|
-
unique_classes = np.unique(class_array)
|
224
|
-
|
225
|
-
# convert class_labels and dict of lists to matrix of metadata values
|
226
|
-
preprocessed_metadata = {CLASS_LABEL: numerical_labels}
|
227
|
-
|
228
|
-
# map columns of dict that are not numeric (e.g. string) to numeric values
|
229
|
-
# that mutual information and diversity functions can accommodate. Each
|
230
|
-
# unique string receives a unique integer value.
|
231
|
-
for k, v in metadata.items():
|
232
|
-
if k == CLASS_LABEL:
|
233
|
-
continue
|
234
|
-
# if not numeric
|
235
|
-
v = to_numpy(v)
|
236
|
-
if not np.issubdtype(v.dtype, np.number):
|
237
|
-
_, mapped_vals = np.unique(v, return_inverse=True)
|
238
|
-
preprocessed_metadata[k] = mapped_vals
|
239
|
-
else:
|
240
|
-
preprocessed_metadata[k] = v
|
241
|
-
|
242
|
-
data = np.stack(list(preprocessed_metadata.values()), axis=-1)
|
243
|
-
names = list(preprocessed_metadata.keys())
|
244
|
-
is_categorical = [
|
245
|
-
var == CLASS_LABEL or infer_categorical(preprocessed_metadata[var].astype(np.float64), cat_thresh)[0]
|
246
|
-
for var in names
|
247
|
-
]
|
248
|
-
|
249
|
-
return data, names, is_categorical, unique_classes
|
250
|
-
|
251
|
-
|
252
|
-
def heatmap(
|
253
|
-
data: ArrayLike,
|
254
|
-
row_labels: list[str] | ArrayLike,
|
255
|
-
col_labels: list[str] | ArrayLike,
|
256
|
-
xlabel: str = "",
|
257
|
-
ylabel: str = "",
|
258
|
-
cbarlabel: str = "",
|
259
|
-
) -> Figure:
|
260
|
-
"""
|
261
|
-
Plots a formatted heatmap
|
262
|
-
|
263
|
-
Parameters
|
264
|
-
----------
|
265
|
-
data : NDArray
|
266
|
-
Array containing numerical values for factors to plot
|
267
|
-
row_labels : ArrayLike
|
268
|
-
List/Array containing the labels for rows in the histogram
|
269
|
-
col_labels : ArrayLike
|
270
|
-
List/Array containing the labels for columns in the histogram
|
271
|
-
xlabel : str, default ""
|
272
|
-
X-axis label
|
273
|
-
ylabel : str, default ""
|
274
|
-
Y-axis label
|
275
|
-
cbarlabel : str, default ""
|
276
|
-
Label for the colorbar
|
277
|
-
|
278
|
-
Returns
|
279
|
-
-------
|
280
|
-
matplotlib.figure.Figure
|
281
|
-
Formatted heatmap
|
282
|
-
"""
|
283
|
-
import matplotlib.pyplot as plt
|
284
|
-
from matplotlib.ticker import FuncFormatter
|
285
|
-
|
286
|
-
np_data = to_numpy(data)
|
287
|
-
rows = row_labels if isinstance(row_labels, list) else to_numpy(row_labels)
|
288
|
-
cols = col_labels if isinstance(col_labels, list) else to_numpy(col_labels)
|
289
|
-
|
290
|
-
fig, ax = plt.subplots(figsize=(10, 10))
|
291
|
-
|
292
|
-
# Plot the heatmap
|
293
|
-
im = ax.imshow(np_data, vmin=0, vmax=1.0)
|
294
|
-
|
295
|
-
# Create colorbar
|
296
|
-
cbar = fig.colorbar(im, shrink=0.5)
|
297
|
-
cbar.set_ticks([0.0, 0.25, 0.5, 0.75, 1.0])
|
298
|
-
cbar.set_ticklabels(["0.0", "0.25", "0.5", "0.75", "1.0"])
|
299
|
-
cbar.set_label(cbarlabel, loc="center")
|
300
|
-
|
301
|
-
# Show all ticks and label them with the respective list entries.
|
302
|
-
ax.set_xticks(np.arange(np_data.shape[1]), labels=cols)
|
303
|
-
ax.set_yticks(np.arange(np_data.shape[0]), labels=rows)
|
304
|
-
|
305
|
-
ax.tick_params(top=False, bottom=True, labeltop=False, labelbottom=True)
|
306
|
-
# Rotate the tick labels and set their alignment.
|
307
|
-
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
|
308
|
-
|
309
|
-
# Turn spines off and create white grid.
|
310
|
-
ax.spines[:].set_visible(False)
|
311
|
-
|
312
|
-
ax.set_xticks(np.arange(np_data.shape[1] + 1) - 0.5, minor=True)
|
313
|
-
ax.set_yticks(np.arange(np_data.shape[0] + 1) - 0.5, minor=True)
|
314
|
-
ax.grid(which="minor", color="w", linestyle="-", linewidth=3)
|
315
|
-
ax.tick_params(which="minor", bottom=False, left=False)
|
316
|
-
|
317
|
-
if xlabel:
|
318
|
-
ax.set_xlabel(xlabel)
|
319
|
-
if ylabel:
|
320
|
-
ax.set_ylabel(ylabel)
|
321
|
-
|
322
|
-
valfmt = FuncFormatter(format_text)
|
323
|
-
|
324
|
-
# Normalize the threshold to the images color range.
|
325
|
-
threshold = im.norm(1.0) / 2.0
|
326
|
-
|
327
|
-
# Set default alignment to center, but allow it to be
|
328
|
-
# overwritten by textkw.
|
329
|
-
kw = {"horizontalalignment": "center", "verticalalignment": "center"}
|
330
|
-
|
331
|
-
# Loop over the data and create a `Text` for each "pixel".
|
332
|
-
# Change the text's color depending on the data.
|
333
|
-
textcolors = ("white", "black")
|
334
|
-
texts = []
|
335
|
-
for i in range(np_data.shape[0]):
|
336
|
-
for j in range(np_data.shape[1]):
|
337
|
-
kw.update(color=textcolors[int(im.norm(np_data[i, j]) > threshold)])
|
338
|
-
text = im.axes.text(j, i, valfmt(np_data[i, j], None), **kw) # type: ignore
|
339
|
-
texts.append(text)
|
340
|
-
|
341
|
-
fig.tight_layout()
|
342
|
-
return fig
|
343
|
-
|
344
|
-
|
345
|
-
# Function to define how the text is displayed in the heatmap
|
346
|
-
def format_text(*args: str) -> str:
|
347
|
-
"""
|
348
|
-
Helper function to format text for heatmap()
|
349
|
-
|
350
|
-
Parameters
|
351
|
-
----------
|
352
|
-
*args : tuple[str, str]
|
353
|
-
Text to be formatted. Second element is ignored, but is a
|
354
|
-
mandatory pass-through argument as per matplotlib.ticket.FuncFormatter
|
355
|
-
|
356
|
-
Returns
|
357
|
-
-------
|
358
|
-
str
|
359
|
-
Formatted text
|
360
|
-
"""
|
361
|
-
x = args[0]
|
362
|
-
return f"{x:.2f}".replace("0.00", "0").replace("0.", ".").replace("nan", "")
|
363
|
-
|
364
|
-
|
365
|
-
def diversity_bar_plot(labels: NDArray[Any], bar_heights: NDArray[Any]) -> Figure:
|
366
|
-
"""
|
367
|
-
Plots a formatted bar plot
|
368
|
-
|
369
|
-
Parameters
|
370
|
-
----------
|
371
|
-
labels : NDArray
|
372
|
-
Array containing the labels for each bar
|
373
|
-
bar_heights : NDArray
|
374
|
-
Array containing the values for each bar
|
375
|
-
|
376
|
-
Returns
|
377
|
-
-------
|
378
|
-
matplotlib.figure.Figure
|
379
|
-
Bar plot figure
|
380
|
-
"""
|
381
|
-
import matplotlib.pyplot as plt
|
382
|
-
|
383
|
-
fig, ax = plt.subplots(figsize=(10, 10))
|
384
|
-
|
385
|
-
ax.bar(labels, bar_heights)
|
386
|
-
ax.set_xlabel("Factors")
|
387
|
-
|
388
|
-
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
|
389
|
-
|
390
|
-
fig.tight_layout()
|
391
|
-
return fig
|
392
|
-
|
393
|
-
|
394
|
-
def coverage_plot(images: NDArray[Any], num_images: int) -> Figure:
|
395
|
-
"""
|
396
|
-
Creates a single plot of all of the provided images
|
397
|
-
|
398
|
-
Parameters
|
399
|
-
----------
|
400
|
-
images : NDArray
|
401
|
-
Array containing only the desired images to plot
|
402
|
-
|
403
|
-
Returns
|
404
|
-
-------
|
405
|
-
matplotlib.figure.Figure
|
406
|
-
Plot of all provided images
|
407
|
-
"""
|
408
|
-
import matplotlib.pyplot as plt
|
409
|
-
|
410
|
-
num_images = min(num_images, len(images))
|
411
|
-
|
412
|
-
if images.ndim == 4:
|
413
|
-
images = np.moveaxis(images, 1, -1)
|
414
|
-
elif images.ndim == 3:
|
415
|
-
images = np.repeat(images[:, :, :, np.newaxis], 3, axis=-1)
|
416
|
-
else:
|
417
|
-
raise ValueError(
|
418
|
-
f"Expected a (N,C,H,W) or a (N, H, W) set of images, but got a {images.ndim}-dimensional set of images."
|
419
|
-
)
|
420
|
-
|
421
|
-
rows = int(np.ceil(num_images / 3))
|
422
|
-
fig, axs = plt.subplots(rows, 3, figsize=(9, 3 * rows))
|
423
|
-
|
424
|
-
if rows == 1:
|
425
|
-
for j in range(3):
|
426
|
-
if j >= len(images):
|
427
|
-
continue
|
428
|
-
axs[j].imshow(images[j])
|
429
|
-
axs[j].axis("off")
|
430
|
-
else:
|
431
|
-
for i in range(rows):
|
432
|
-
for j in range(3):
|
433
|
-
i_j = i * 3 + j
|
434
|
-
if i_j >= len(images):
|
435
|
-
continue
|
436
|
-
axs[i, j].imshow(images[i_j])
|
437
|
-
axs[i, j].axis("off")
|
438
|
-
|
439
|
-
fig.tight_layout()
|
440
|
-
return fig
|
dataeval/utils/lazy.py
DELETED
@@ -1,26 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
from functools import cached_property
|
4
|
-
from importlib import import_module
|
5
|
-
from typing import Any
|
6
|
-
|
7
|
-
|
8
|
-
class LazyModule:
|
9
|
-
def __init__(self, name: str) -> None:
|
10
|
-
self._name = name
|
11
|
-
|
12
|
-
def __getattr__(self, key: str) -> Any:
|
13
|
-
return getattr(self._module, key)
|
14
|
-
|
15
|
-
@cached_property
|
16
|
-
def _module(self):
|
17
|
-
return import_module(self._name)
|
18
|
-
|
19
|
-
|
20
|
-
LAZY_MODULES: dict[str, LazyModule] = {}
|
21
|
-
|
22
|
-
|
23
|
-
def lazyload(name: str) -> LazyModule:
|
24
|
-
if name not in LAZY_MODULES:
|
25
|
-
LAZY_MODULES[name] = LazyModule(name)
|
26
|
-
return LAZY_MODULES[name]
|
@@ -1,19 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
TensorFlow models are used in :term:`out of distribution<Out-of-distribution (OOD)>` detectors in the
|
3
|
-
:mod:`dataeval.detectors.ood` module.
|
4
|
-
|
5
|
-
DataEval provides basic default models through the utility :func:`dataeval.utils.tensorflow.create_model`.
|
6
|
-
"""
|
7
|
-
|
8
|
-
from dataeval import _IS_TENSORFLOW_AVAILABLE
|
9
|
-
|
10
|
-
__all__ = []
|
11
|
-
|
12
|
-
|
13
|
-
if _IS_TENSORFLOW_AVAILABLE:
|
14
|
-
import dataeval.utils.tensorflow.loss as loss
|
15
|
-
from dataeval.utils.tensorflow._internal.utils import create_model
|
16
|
-
|
17
|
-
__all__ = ["create_model", "loss"]
|
18
|
-
|
19
|
-
del _IS_TENSORFLOW_AVAILABLE
|
@@ -1,123 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
Source code derived from Alibi-Detect 0.11.4
|
3
|
-
https://github.com/SeldonIO/alibi-detect/tree/v0.11.4
|
4
|
-
|
5
|
-
Original code Copyright (c) 2023 Seldon Technologies Ltd
|
6
|
-
Licensed under Apache Software License (Apache 2.0)
|
7
|
-
"""
|
8
|
-
|
9
|
-
from __future__ import annotations
|
10
|
-
|
11
|
-
from typing import TYPE_CHECKING, NamedTuple
|
12
|
-
|
13
|
-
import numpy as np
|
14
|
-
|
15
|
-
from dataeval.utils.lazy import lazyload
|
16
|
-
|
17
|
-
if TYPE_CHECKING:
|
18
|
-
import tensorflow as tf
|
19
|
-
else:
|
20
|
-
tf = lazyload("tensorflow")
|
21
|
-
|
22
|
-
|
23
|
-
class GaussianMixtureModelParams(NamedTuple):
|
24
|
-
"""
|
25
|
-
phi : tf.Tensor
|
26
|
-
Mixture component distribution weights.
|
27
|
-
mu : tf.Tensor
|
28
|
-
Mixture means.
|
29
|
-
cov : tf.Tensor
|
30
|
-
Mixture covariance.
|
31
|
-
L : tf.Tensor
|
32
|
-
Cholesky decomposition of `cov`.
|
33
|
-
log_det_cov : tf.Tensor
|
34
|
-
Log of the determinant of `cov`.
|
35
|
-
"""
|
36
|
-
|
37
|
-
phi: tf.Tensor
|
38
|
-
mu: tf.Tensor
|
39
|
-
cov: tf.Tensor
|
40
|
-
L: tf.Tensor
|
41
|
-
log_det_cov: tf.Tensor
|
42
|
-
|
43
|
-
|
44
|
-
def gmm_params(z: tf.Tensor, gamma: tf.Tensor) -> GaussianMixtureModelParams:
|
45
|
-
"""
|
46
|
-
Compute parameters of Gaussian Mixture Model.
|
47
|
-
|
48
|
-
Parameters
|
49
|
-
----------
|
50
|
-
z : tf.Tensor
|
51
|
-
Observations.
|
52
|
-
gamma : tf.Tensor
|
53
|
-
Mixture probabilities to derive mixture distribution weights from.
|
54
|
-
|
55
|
-
Returns
|
56
|
-
-------
|
57
|
-
GaussianMixtureModelParams(phi, mu, cov, L, log_det_cov)
|
58
|
-
The parameters used to calculate energy.
|
59
|
-
"""
|
60
|
-
# compute gmm parameters phi, mu and cov
|
61
|
-
N = gamma.shape[0] # nb of samples in batch
|
62
|
-
sum_gamma = tf.reduce_sum(gamma, 0) # K
|
63
|
-
phi = sum_gamma / N # K
|
64
|
-
mu = tf.reduce_sum(tf.expand_dims(gamma, -1) * tf.expand_dims(z, 1), 0) / tf.expand_dims(
|
65
|
-
sum_gamma, -1
|
66
|
-
) # K x D (D = latent_dim)
|
67
|
-
z_mu = tf.expand_dims(z, 1) - tf.expand_dims(mu, 0) # N x K x D
|
68
|
-
z_mu_outer = tf.expand_dims(z_mu, -1) * tf.expand_dims(z_mu, -2) # N x K x D x D
|
69
|
-
cov = tf.reduce_sum(tf.expand_dims(tf.expand_dims(gamma, -1), -1) * z_mu_outer, 0) / tf.expand_dims(
|
70
|
-
tf.expand_dims(sum_gamma, -1), -1
|
71
|
-
) # K x D x D
|
72
|
-
|
73
|
-
# cholesky decomposition of covariance and determinant derivation
|
74
|
-
D = tf.shape(cov)[1] # type: ignore
|
75
|
-
eps = 1e-6
|
76
|
-
L = tf.linalg.cholesky(cov + tf.eye(D) * eps) # K x D x D
|
77
|
-
log_det_cov = 2.0 * tf.reduce_sum(tf.math.log(tf.linalg.diag_part(L)), 1) # K
|
78
|
-
|
79
|
-
return GaussianMixtureModelParams(phi, mu, cov, L, log_det_cov)
|
80
|
-
|
81
|
-
|
82
|
-
def gmm_energy(
|
83
|
-
z: tf.Tensor,
|
84
|
-
params: GaussianMixtureModelParams,
|
85
|
-
return_mean: bool = True,
|
86
|
-
) -> tuple[tf.Tensor, tf.Tensor]:
|
87
|
-
"""
|
88
|
-
Compute sample energy from Gaussian Mixture Model.
|
89
|
-
|
90
|
-
Parameters
|
91
|
-
----------
|
92
|
-
params : GaussianMixtureModelParams
|
93
|
-
The gaussian mixture model parameters.
|
94
|
-
return_mean : bool, default True
|
95
|
-
Take mean across all sample energies in a batch.
|
96
|
-
|
97
|
-
Returns
|
98
|
-
-------
|
99
|
-
sample_energy
|
100
|
-
The sample energy of the GMM.
|
101
|
-
cov_diag
|
102
|
-
The inverse sum of the diagonal components of the covariance matrix.
|
103
|
-
"""
|
104
|
-
D = tf.shape(params.cov)[1] # type: ignore
|
105
|
-
z_mu = tf.expand_dims(z, 1) - tf.expand_dims(params.mu, 0) # N x K x D
|
106
|
-
z_mu_T = tf.transpose(z_mu, perm=[1, 2, 0]) # K x D x N
|
107
|
-
v = tf.linalg.triangular_solve(params.L, z_mu_T, lower=True) # K x D x D
|
108
|
-
|
109
|
-
# rewrite sample energy in logsumexp format for numerical stability
|
110
|
-
logits = tf.math.log(tf.expand_dims(params.phi, -1)) - 0.5 * (
|
111
|
-
tf.reduce_sum(tf.square(v), 1)
|
112
|
-
+ tf.cast(D, tf.float32) * tf.math.log(2.0 * np.pi) # type: ignore py38
|
113
|
-
+ tf.expand_dims(params.log_det_cov, -1)
|
114
|
-
) # K x N
|
115
|
-
sample_energy = -tf.reduce_logsumexp(logits, axis=0) # N
|
116
|
-
|
117
|
-
if return_mean:
|
118
|
-
sample_energy = tf.reduce_mean(sample_energy)
|
119
|
-
|
120
|
-
# inverse sum of variances
|
121
|
-
cov_diag = tf.reduce_sum(tf.divide(1, tf.linalg.diag_part(params.cov)))
|
122
|
-
|
123
|
-
return sample_energy, cov_diag
|