dataeval 0.63.0__py3-none-any.whl → 0.65.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +4 -4
- dataeval/_internal/detectors/clusterer.py +47 -34
- dataeval/_internal/detectors/drift/base.py +53 -35
- dataeval/_internal/detectors/drift/cvm.py +5 -4
- dataeval/_internal/detectors/drift/ks.py +7 -6
- dataeval/_internal/detectors/drift/mmd.py +39 -19
- dataeval/_internal/detectors/drift/torch.py +6 -5
- dataeval/_internal/detectors/drift/uncertainty.py +7 -8
- dataeval/_internal/detectors/duplicates.py +57 -30
- dataeval/_internal/detectors/linter.py +40 -24
- dataeval/_internal/detectors/ood/ae.py +2 -1
- dataeval/_internal/detectors/ood/aegmm.py +2 -1
- dataeval/_internal/detectors/ood/base.py +37 -15
- dataeval/_internal/detectors/ood/llr.py +9 -8
- dataeval/_internal/detectors/ood/vae.py +2 -1
- dataeval/_internal/detectors/ood/vaegmm.py +2 -1
- dataeval/_internal/flags.py +42 -21
- dataeval/_internal/interop.py +3 -12
- dataeval/_internal/metrics/balance.py +188 -0
- dataeval/_internal/metrics/ber.py +123 -48
- dataeval/_internal/metrics/coverage.py +90 -74
- dataeval/_internal/metrics/divergence.py +101 -67
- dataeval/_internal/metrics/diversity.py +211 -0
- dataeval/_internal/metrics/parity.py +287 -155
- dataeval/_internal/metrics/stats.py +198 -317
- dataeval/_internal/metrics/uap.py +40 -29
- dataeval/_internal/metrics/utils.py +430 -0
- dataeval/_internal/models/tensorflow/losses.py +3 -3
- dataeval/_internal/models/tensorflow/trainer.py +3 -2
- dataeval/_internal/models/tensorflow/utils.py +4 -3
- dataeval/_internal/output.py +82 -0
- dataeval/_internal/utils.py +64 -0
- dataeval/_internal/workflows/sufficiency.py +96 -107
- dataeval/flags/__init__.py +2 -2
- dataeval/metrics/__init__.py +26 -7
- dataeval/utils/__init__.py +9 -0
- {dataeval-0.63.0.dist-info → dataeval-0.65.0.dist-info}/METADATA +1 -1
- dataeval-0.65.0.dist-info/RECORD +60 -0
- dataeval/_internal/functional/__init__.py +0 -0
- dataeval/_internal/functional/ber.py +0 -63
- dataeval/_internal/functional/coverage.py +0 -75
- dataeval/_internal/functional/divergence.py +0 -16
- dataeval/_internal/functional/hash.py +0 -79
- dataeval/_internal/functional/metadata.py +0 -136
- dataeval/_internal/functional/metadataparity.py +0 -190
- dataeval/_internal/functional/uap.py +0 -6
- dataeval/_internal/functional/utils.py +0 -158
- dataeval/_internal/maite/__init__.py +0 -0
- dataeval/_internal/maite/utils.py +0 -30
- dataeval/_internal/metrics/base.py +0 -92
- dataeval/_internal/metrics/metadata.py +0 -610
- dataeval/_internal/metrics/metadataparity.py +0 -67
- dataeval-0.63.0.dist-info/RECORD +0 -68
- {dataeval-0.63.0.dist-info → dataeval-0.65.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.63.0.dist-info → dataeval-0.65.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,430 @@
|
|
1
|
+
from typing import Any, Callable, Dict, List, Literal, NamedTuple, Optional, Sequence, Tuple, Union
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
import xxhash as xxh
|
5
|
+
from numpy.typing import NDArray
|
6
|
+
from PIL import Image
|
7
|
+
from scipy.fftpack import dct
|
8
|
+
from scipy.signal import convolve2d
|
9
|
+
from scipy.sparse import csr_matrix
|
10
|
+
from scipy.sparse.csgraph import minimum_spanning_tree as mst
|
11
|
+
from scipy.spatial.distance import pdist, squareform
|
12
|
+
from scipy.stats import entropy as sp_entropy
|
13
|
+
from sklearn.neighbors import NearestNeighbors
|
14
|
+
|
15
|
+
EPSILON = 1e-5
|
16
|
+
EDGE_KERNEL = np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=np.int8)
|
17
|
+
BIT_DEPTH = (1, 8, 12, 16, 32)
|
18
|
+
HASH_SIZE = 8
|
19
|
+
MAX_FACTOR = 4
|
20
|
+
|
21
|
+
|
22
|
+
def get_method(method_map: Dict[str, Callable], method: str) -> Callable:
|
23
|
+
if method not in method_map:
|
24
|
+
raise ValueError(f"Specified method {method} is not a valid method: {method_map}.")
|
25
|
+
return method_map[method]
|
26
|
+
|
27
|
+
|
28
|
+
def get_counts(
|
29
|
+
data: NDArray, names: List[str], is_categorical: List[bool], subset_mask: Optional[NDArray[np.bool_]] = None
|
30
|
+
) -> tuple[Dict, Dict]:
|
31
|
+
"""
|
32
|
+
Initialize dictionary of histogram counts --- treat categorical values
|
33
|
+
as histogram bins.
|
34
|
+
|
35
|
+
Parameters
|
36
|
+
----------
|
37
|
+
subset_mask: Optional[NDArray[np.bool_]]
|
38
|
+
Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
|
39
|
+
|
40
|
+
Returns
|
41
|
+
-------
|
42
|
+
counts: Dict
|
43
|
+
histogram counts per metadata factor in `factors`. Each
|
44
|
+
factor will have a different number of bins. Counts get reused
|
45
|
+
across metrics, so hist_counts are cached but only if computed
|
46
|
+
globally, i.e. without masked samples.
|
47
|
+
"""
|
48
|
+
|
49
|
+
hist_counts, hist_bins = {}, {}
|
50
|
+
# np.where needed to satisfy linter
|
51
|
+
mask = np.where(subset_mask if subset_mask is not None else np.ones(data.shape[0], dtype=bool))
|
52
|
+
|
53
|
+
for cdx, fn in enumerate(names):
|
54
|
+
# linter doesn't like double indexing
|
55
|
+
col_data = data[mask, cdx].squeeze()
|
56
|
+
if is_categorical[cdx]:
|
57
|
+
# if discrete, use unique values as bins
|
58
|
+
bins, cnts = np.unique(col_data, return_counts=True)
|
59
|
+
else:
|
60
|
+
bins = hist_bins.get(fn, "auto")
|
61
|
+
cnts, bins = np.histogram(col_data, bins=bins, density=True)
|
62
|
+
|
63
|
+
hist_counts[fn] = cnts
|
64
|
+
hist_bins[fn] = bins
|
65
|
+
|
66
|
+
return hist_counts, hist_bins
|
67
|
+
|
68
|
+
|
69
|
+
def entropy(
|
70
|
+
data: NDArray,
|
71
|
+
names: List[str],
|
72
|
+
is_categorical: List[bool],
|
73
|
+
normalized: bool = False,
|
74
|
+
subset_mask: Optional[NDArray[np.bool_]] = None,
|
75
|
+
) -> NDArray[np.float64]:
|
76
|
+
"""
|
77
|
+
Meant for use with Bias metrics, Balance, Diversity, ClasswiseBalance,
|
78
|
+
and Classwise Diversity.
|
79
|
+
|
80
|
+
Compute entropy for discrete/categorical variables and for continuous variables through standard
|
81
|
+
histogram binning.
|
82
|
+
|
83
|
+
Parameters
|
84
|
+
----------
|
85
|
+
normalized: bool
|
86
|
+
Flag that determines whether or not to normalize entropy by log(num_bins)
|
87
|
+
subset_mask: Optional[NDArray[np.bool_]]
|
88
|
+
Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
|
89
|
+
|
90
|
+
Notes
|
91
|
+
-----
|
92
|
+
For continuous variables, histogram bins are chosen automatically. See
|
93
|
+
numpy.histogram for details.
|
94
|
+
|
95
|
+
Returns
|
96
|
+
-------
|
97
|
+
ent: NDArray[np.float64]
|
98
|
+
Entropy estimate per column of X
|
99
|
+
|
100
|
+
See Also
|
101
|
+
--------
|
102
|
+
numpy.histogram
|
103
|
+
scipy.stats.entropy
|
104
|
+
"""
|
105
|
+
|
106
|
+
num_factors = len(names)
|
107
|
+
hist_counts, _ = get_counts(data, names, is_categorical, subset_mask)
|
108
|
+
|
109
|
+
ev_index = np.empty(num_factors)
|
110
|
+
for col, cnts in enumerate(hist_counts.values()):
|
111
|
+
# entropy in nats, normalizes counts
|
112
|
+
ev_index[col] = sp_entropy(cnts)
|
113
|
+
if normalized:
|
114
|
+
if len(cnts) == 1:
|
115
|
+
# log(0)
|
116
|
+
ev_index[col] = 0
|
117
|
+
else:
|
118
|
+
ev_index[col] /= np.log(len(cnts))
|
119
|
+
return ev_index
|
120
|
+
|
121
|
+
|
122
|
+
def get_num_bins(
|
123
|
+
data: NDArray, names: List[str], is_categorical: List[bool], subset_mask: Optional[NDArray[np.bool_]] = None
|
124
|
+
) -> NDArray[np.float64]:
|
125
|
+
"""
|
126
|
+
Number of bins or unique values for each metadata factor, used to
|
127
|
+
normalize entropy/diversity.
|
128
|
+
|
129
|
+
Parameters
|
130
|
+
----------
|
131
|
+
subset_mask: Optional[NDArray[np.bool_]]
|
132
|
+
Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
|
133
|
+
|
134
|
+
Returns
|
135
|
+
-------
|
136
|
+
NDArray[np.float64]
|
137
|
+
"""
|
138
|
+
# likely cached
|
139
|
+
hist_counts, _ = get_counts(data, names, is_categorical, subset_mask)
|
140
|
+
num_bins = np.empty(len(hist_counts))
|
141
|
+
for idx, cnts in enumerate(hist_counts.values()):
|
142
|
+
num_bins[idx] = len(cnts)
|
143
|
+
|
144
|
+
return num_bins
|
145
|
+
|
146
|
+
|
147
|
+
def infer_categorical(X: NDArray, threshold: float = 0.5) -> NDArray:
|
148
|
+
"""
|
149
|
+
Compute fraction of feature values that are unique --- intended to be used
|
150
|
+
for inferring whether variables are categorical.
|
151
|
+
"""
|
152
|
+
if X.ndim == 1:
|
153
|
+
X = np.expand_dims(X, axis=1)
|
154
|
+
num_samples = X.shape[0]
|
155
|
+
pct_unique = np.empty(X.shape[1])
|
156
|
+
for col in range(X.shape[1]): # type: ignore
|
157
|
+
uvals = np.unique(X[:, col], axis=0)
|
158
|
+
pct_unique[col] = len(uvals) / num_samples
|
159
|
+
return pct_unique < threshold
|
160
|
+
|
161
|
+
|
162
|
+
def preprocess_metadata(
|
163
|
+
class_labels: Sequence[int], metadata: List[Dict], cat_thresh: float = 0.2
|
164
|
+
) -> Tuple[NDArray, List[str], List[bool]]:
|
165
|
+
# convert class_labels and list of metadata dicts to dict of ndarrays
|
166
|
+
metadata_dict: Dict[str, NDArray] = {
|
167
|
+
"class_label": np.asarray(class_labels, dtype=int),
|
168
|
+
**{k: np.array([d[k] for d in metadata]) for k in metadata[0]},
|
169
|
+
}
|
170
|
+
|
171
|
+
# map columns of dict that are not numeric (e.g. string) to numeric values
|
172
|
+
# that mutual information and diversity functions can accommodate. Each
|
173
|
+
# unique string receives a unique integer value.
|
174
|
+
for k, v in metadata_dict.items():
|
175
|
+
# if not numeric
|
176
|
+
if not np.issubdtype(v.dtype, np.number):
|
177
|
+
_, mapped_vals = np.unique(v, return_inverse=True)
|
178
|
+
metadata_dict[k] = mapped_vals
|
179
|
+
|
180
|
+
data = np.stack(list(metadata_dict.values()), axis=-1)
|
181
|
+
names = list(metadata_dict.keys())
|
182
|
+
is_categorical = [infer_categorical(metadata_dict[var], cat_thresh)[0] for var in names]
|
183
|
+
|
184
|
+
return data, names, is_categorical
|
185
|
+
|
186
|
+
|
187
|
+
def flatten(X: NDArray):
|
188
|
+
"""
|
189
|
+
Flattens input array from (N, ... ) to (N, -1) where all samples N have all data in their last dimension
|
190
|
+
|
191
|
+
Parameters
|
192
|
+
----------
|
193
|
+
X : NDArray, shape - (N, ... )
|
194
|
+
Input array
|
195
|
+
|
196
|
+
Returns
|
197
|
+
-------
|
198
|
+
NDArray, shape - (N, -1)
|
199
|
+
"""
|
200
|
+
|
201
|
+
return X.reshape((X.shape[0], -1))
|
202
|
+
|
203
|
+
|
204
|
+
def minimum_spanning_tree(X: NDArray) -> Any:
|
205
|
+
"""
|
206
|
+
Returns the minimum spanning tree from a NumPy image array.
|
207
|
+
|
208
|
+
Parameters
|
209
|
+
----------
|
210
|
+
X : NDArray
|
211
|
+
Numpy image array
|
212
|
+
|
213
|
+
Returns
|
214
|
+
-------
|
215
|
+
Data representing the minimum spanning tree
|
216
|
+
"""
|
217
|
+
# All features belong on second dimension
|
218
|
+
X = flatten(X)
|
219
|
+
# We add a small constant to the distance matrix to ensure scipy interprets
|
220
|
+
# the input graph as fully-connected.
|
221
|
+
dense_eudist = squareform(pdist(X)) + EPSILON
|
222
|
+
eudist_csr = csr_matrix(dense_eudist)
|
223
|
+
return mst(eudist_csr)
|
224
|
+
|
225
|
+
|
226
|
+
def get_classes_counts(labels: NDArray) -> Tuple[int, int]:
|
227
|
+
"""
|
228
|
+
Returns the classes and counts of from an array of labels
|
229
|
+
|
230
|
+
Parameters
|
231
|
+
----------
|
232
|
+
label : NDArray
|
233
|
+
Numpy labels array
|
234
|
+
|
235
|
+
Returns
|
236
|
+
-------
|
237
|
+
Classes and counts
|
238
|
+
|
239
|
+
Raises
|
240
|
+
------
|
241
|
+
ValueError
|
242
|
+
If the number of unique classes is less than 2
|
243
|
+
"""
|
244
|
+
classes, counts = np.unique(labels, return_counts=True)
|
245
|
+
M = len(classes)
|
246
|
+
if M < 2:
|
247
|
+
raise ValueError("Label vector contains less than 2 classes!")
|
248
|
+
N = np.sum(counts).astype(int)
|
249
|
+
return M, N
|
250
|
+
|
251
|
+
|
252
|
+
def compute_neighbors(
|
253
|
+
A: NDArray,
|
254
|
+
B: NDArray,
|
255
|
+
k: int = 1,
|
256
|
+
algorithm: Literal["auto", "ball_tree", "kd_tree"] = "auto",
|
257
|
+
) -> NDArray:
|
258
|
+
"""
|
259
|
+
For each sample in A, compute the nearest neighbor in B
|
260
|
+
|
261
|
+
Parameters
|
262
|
+
----------
|
263
|
+
A, B : NDArray
|
264
|
+
The n_samples and n_features respectively
|
265
|
+
k : int
|
266
|
+
The number of neighbors to find
|
267
|
+
algorithm : Literal
|
268
|
+
Tree method for nearest neighbor (auto, ball_tree or kd_tree)
|
269
|
+
|
270
|
+
Note
|
271
|
+
----
|
272
|
+
Do not use kd_tree if n_features > 20
|
273
|
+
|
274
|
+
Returns
|
275
|
+
-------
|
276
|
+
List:
|
277
|
+
Closest points to each point in A and B
|
278
|
+
|
279
|
+
Raises
|
280
|
+
------
|
281
|
+
ValueError
|
282
|
+
If algorithm is not "auto", "ball_tree", or "kd_tree"
|
283
|
+
|
284
|
+
See Also
|
285
|
+
--------
|
286
|
+
sklearn.neighbors.NearestNeighbors
|
287
|
+
"""
|
288
|
+
|
289
|
+
if k < 1:
|
290
|
+
raise ValueError("k must be >= 1")
|
291
|
+
if algorithm not in ["auto", "ball_tree", "kd_tree"]:
|
292
|
+
raise ValueError("Algorithm must be 'auto', 'ball_tree', or 'kd_tree'")
|
293
|
+
|
294
|
+
A = flatten(A)
|
295
|
+
B = flatten(B)
|
296
|
+
|
297
|
+
nbrs = NearestNeighbors(n_neighbors=k + 1, algorithm=algorithm).fit(B)
|
298
|
+
nns = nbrs.kneighbors(A)[1]
|
299
|
+
nns = nns[:, 1:].squeeze()
|
300
|
+
|
301
|
+
return nns
|
302
|
+
|
303
|
+
|
304
|
+
class BitDepth(NamedTuple):
|
305
|
+
depth: int
|
306
|
+
pmin: Union[float, int]
|
307
|
+
pmax: Union[float, int]
|
308
|
+
|
309
|
+
|
310
|
+
def get_bitdepth(image: NDArray) -> BitDepth:
|
311
|
+
"""
|
312
|
+
Approximates the bit depth of the image using the
|
313
|
+
min and max pixel values.
|
314
|
+
"""
|
315
|
+
pmin, pmax = np.min(image), np.max(image)
|
316
|
+
if pmin < 0:
|
317
|
+
return BitDepth(0, pmin, pmax)
|
318
|
+
else:
|
319
|
+
depth = ([x for x in BIT_DEPTH if 2**x > pmax] or [max(BIT_DEPTH)])[0]
|
320
|
+
return BitDepth(depth, 0, 2**depth - 1)
|
321
|
+
|
322
|
+
|
323
|
+
def rescale(image: NDArray, depth: int = 1) -> NDArray:
|
324
|
+
"""
|
325
|
+
Rescales the image using the bit depth provided.
|
326
|
+
"""
|
327
|
+
bitdepth = get_bitdepth(image)
|
328
|
+
if bitdepth.depth == depth:
|
329
|
+
return image
|
330
|
+
else:
|
331
|
+
normalized = (image + bitdepth.pmin) / (bitdepth.pmax - bitdepth.pmin)
|
332
|
+
return normalized * (2**depth - 1)
|
333
|
+
|
334
|
+
|
335
|
+
def normalize_image_shape(image: NDArray) -> NDArray:
|
336
|
+
"""
|
337
|
+
Normalizes the image shape into (C,H,W).
|
338
|
+
"""
|
339
|
+
ndim = image.ndim
|
340
|
+
if ndim == 2:
|
341
|
+
return np.expand_dims(image, axis=0)
|
342
|
+
elif ndim == 3:
|
343
|
+
return image
|
344
|
+
elif ndim > 3:
|
345
|
+
# Slice all but the last 3 dimensions
|
346
|
+
return image[(0,) * (ndim - 3)]
|
347
|
+
else:
|
348
|
+
raise ValueError("Images must have 2 or more dimensions.")
|
349
|
+
|
350
|
+
|
351
|
+
def edge_filter(image: NDArray, offset: float = 0.5) -> NDArray:
|
352
|
+
"""
|
353
|
+
Returns the image filtered using a 3x3 edge detection kernel:
|
354
|
+
[[ -1, -1, -1 ],
|
355
|
+
[ -1, 8, -1 ],
|
356
|
+
[ -1, -1, -1 ]]
|
357
|
+
"""
|
358
|
+
edges = convolve2d(image, EDGE_KERNEL, mode="same", boundary="symm") + offset
|
359
|
+
np.clip(edges, 0, 255, edges)
|
360
|
+
return edges
|
361
|
+
|
362
|
+
|
363
|
+
def pchash(image: NDArray) -> str:
|
364
|
+
"""
|
365
|
+
Performs a perceptual hash on an image by resizing to a square NxN image
|
366
|
+
using the Lanczos algorithm where N is 32x32 or the largest multiple of
|
367
|
+
8 that is smaller than the input image dimensions. The resampled image
|
368
|
+
is compressed using a discrete cosine transform and the lowest frequency
|
369
|
+
component is encoded as a bit array of greater or less than median value
|
370
|
+
and returned as a hex string.
|
371
|
+
|
372
|
+
Parameters
|
373
|
+
----------
|
374
|
+
image : NDArray
|
375
|
+
An image as a numpy array in CxHxW format
|
376
|
+
|
377
|
+
Returns
|
378
|
+
-------
|
379
|
+
str
|
380
|
+
The hex string hash of the image using perceptual hashing
|
381
|
+
"""
|
382
|
+
# Verify that the image is at least larger than an 8x8 image
|
383
|
+
min_dim = min(image.shape[-2:])
|
384
|
+
if min_dim < HASH_SIZE + 1:
|
385
|
+
raise ValueError(f"Image must be larger than {HASH_SIZE}x{HASH_SIZE} for fuzzy hashing.")
|
386
|
+
|
387
|
+
# Calculates the dimensions of the resized square image
|
388
|
+
resize_dim = HASH_SIZE * min((min_dim - 1) // HASH_SIZE, MAX_FACTOR)
|
389
|
+
|
390
|
+
# Normalizes the image to CxHxW and takes the mean over all the channels
|
391
|
+
normalized = np.mean(normalize_image_shape(image), axis=0).squeeze()
|
392
|
+
|
393
|
+
# Rescales the pixel values to an 8-bit 0-255 image
|
394
|
+
rescaled = rescale(normalized, 8).astype(np.uint8)
|
395
|
+
|
396
|
+
# Resizes the image using the Lanczos algorithm to a square image
|
397
|
+
im = np.array(Image.fromarray(rescaled).resize((resize_dim, resize_dim), Image.Resampling.LANCZOS))
|
398
|
+
|
399
|
+
# Performs discrete cosine transforms to compress the image information and takes the lowest frequency component
|
400
|
+
transform = dct(dct(im.T).T)[:HASH_SIZE, :HASH_SIZE]
|
401
|
+
|
402
|
+
# Encodes the transform as a bit array over the median value
|
403
|
+
diff = transform > np.median(transform)
|
404
|
+
|
405
|
+
# Pads the front of the bit array to a multiple of 8 with False
|
406
|
+
padded = np.full(int(np.ceil(diff.size / 8) * 8), False)
|
407
|
+
padded[-diff.size :] = diff.ravel()
|
408
|
+
|
409
|
+
# Converts the bit array to a hex string and strips leading 0s
|
410
|
+
hash_hex = np.packbits(padded).tobytes().hex().lstrip("0")
|
411
|
+
return hash_hex if hash_hex else "0"
|
412
|
+
|
413
|
+
|
414
|
+
def xxhash(image: NDArray) -> str:
|
415
|
+
"""
|
416
|
+
Performs a fast non-cryptographic hash using the xxhash algorithm
|
417
|
+
(xxhash.com) against the image as a flattened bytearray. The hash
|
418
|
+
is returned as a hex string.
|
419
|
+
|
420
|
+
Parameters
|
421
|
+
----------
|
422
|
+
image : NDArray
|
423
|
+
An image as a numpy array
|
424
|
+
|
425
|
+
Returns
|
426
|
+
-------
|
427
|
+
str
|
428
|
+
The hex string hash of the image using the xxHash algorithm
|
429
|
+
"""
|
430
|
+
return xxh.xxh3_64_hexdigest(image.ravel().tobytes())
|
@@ -8,9 +8,9 @@ Licensed under Apache Software License (Apache 2.0)
|
|
8
8
|
|
9
9
|
from typing import Literal, Optional, Union, cast
|
10
10
|
|
11
|
-
import numpy as np
|
12
11
|
import tensorflow as tf
|
13
12
|
from keras.layers import Flatten
|
13
|
+
from numpy.typing import NDArray
|
14
14
|
from tensorflow_probability.python.distributions.mvn_diag import MultivariateNormalDiag
|
15
15
|
from tensorflow_probability.python.distributions.mvn_tril import MultivariateNormalTriL
|
16
16
|
from tensorflow_probability.python.stats import covariance
|
@@ -35,12 +35,12 @@ class Elbo:
|
|
35
35
|
def __init__(
|
36
36
|
self,
|
37
37
|
cov_type: Union[Literal["cov_full", "cov_diag"], float] = 1.0,
|
38
|
-
x: Optional[Union[tf.Tensor,
|
38
|
+
x: Optional[Union[tf.Tensor, NDArray]] = None,
|
39
39
|
):
|
40
40
|
if isinstance(cov_type, float):
|
41
41
|
self.cov = ("sim", cov_type)
|
42
42
|
elif cov_type in ["cov_full", "cov_diag"]:
|
43
|
-
x_np:
|
43
|
+
x_np: NDArray = x.numpy() if tf.is_tensor(x) else x # type: ignore
|
44
44
|
cov = covariance(x_np.reshape(x_np.shape[0], -1)) # type: ignore py38
|
45
45
|
if cov_type == "cov_diag": # infer standard deviation from covariance matrix
|
46
46
|
cov = tf.math.sqrt(tf.linalg.diag_part(cov))
|
@@ -11,12 +11,13 @@ from typing import Callable, Iterable, Optional, Tuple, cast
|
|
11
11
|
import keras
|
12
12
|
import numpy as np
|
13
13
|
import tensorflow as tf
|
14
|
+
from numpy.typing import NDArray
|
14
15
|
|
15
16
|
|
16
17
|
def trainer(
|
17
18
|
model: keras.Model,
|
18
|
-
x_train:
|
19
|
-
y_train: Optional[
|
19
|
+
x_train: NDArray,
|
20
|
+
y_train: Optional[NDArray] = None,
|
20
21
|
loss_fn: Optional[Callable[..., tf.Tensor]] = None,
|
21
22
|
optimizer: keras.optimizers.Optimizer = keras.optimizers.Adam,
|
22
23
|
preprocess_fn: Optional[Callable[[tf.Tensor], tf.Tensor]] = None,
|
@@ -21,6 +21,7 @@ from keras.layers import (
|
|
21
21
|
InputLayer,
|
22
22
|
Reshape,
|
23
23
|
)
|
24
|
+
from numpy.typing import NDArray
|
24
25
|
from tensorflow._api.v2.nn import relu, softmax, tanh
|
25
26
|
|
26
27
|
from dataeval._internal.models.tensorflow.autoencoder import AE, AEGMM, VAE, VAEGMM
|
@@ -28,12 +29,12 @@ from dataeval._internal.models.tensorflow.pixelcnn import PixelCNN
|
|
28
29
|
|
29
30
|
|
30
31
|
def predict_batch(
|
31
|
-
x: Union[list,
|
32
|
+
x: Union[list, NDArray, tf.Tensor],
|
32
33
|
model: Union[Callable, keras.Model],
|
33
34
|
batch_size: int = int(1e10),
|
34
35
|
preprocess_fn: Optional[Callable] = None,
|
35
36
|
dtype: Union[Type[np.generic], tf.DType] = np.float32,
|
36
|
-
) -> Union[
|
37
|
+
) -> Union[NDArray, tf.Tensor, tuple, list]:
|
37
38
|
"""
|
38
39
|
Make batch predictions on a model.
|
39
40
|
|
@@ -80,7 +81,7 @@ def predict_batch(
|
|
80
81
|
else:
|
81
82
|
raise TypeError(
|
82
83
|
f"Model output type {type(preds_tmp)} not supported. The model output "
|
83
|
-
f"type needs to be one of list, tuple,
|
84
|
+
f"type needs to be one of list, tuple, NDArray or tf.Tensor."
|
84
85
|
)
|
85
86
|
concat = np.concatenate if return_np else tf.concat
|
86
87
|
out = cast(
|
@@ -0,0 +1,82 @@
|
|
1
|
+
import inspect
|
2
|
+
from datetime import datetime, timezone
|
3
|
+
from functools import wraps
|
4
|
+
from typing import Dict, List, Optional
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
|
8
|
+
from dataeval import __version__
|
9
|
+
|
10
|
+
|
11
|
+
class OutputMetadata:
|
12
|
+
_name: str
|
13
|
+
_execution_time: str
|
14
|
+
_execution_duration: float
|
15
|
+
_arguments: Dict[str, str]
|
16
|
+
_state: Dict[str, str]
|
17
|
+
_version: str
|
18
|
+
|
19
|
+
def dict(self) -> Dict:
|
20
|
+
return {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
|
21
|
+
|
22
|
+
def meta(self) -> Dict:
|
23
|
+
return {k.removeprefix("_"): v for k, v in self.__dict__.items() if k.startswith("_")}
|
24
|
+
|
25
|
+
|
26
|
+
def set_metadata(module_name: str = "", state_attr: Optional[List[str]] = None):
|
27
|
+
def decorator(fn):
|
28
|
+
@wraps(fn)
|
29
|
+
def wrapper(*args, **kwargs):
|
30
|
+
def fmt(v):
|
31
|
+
if np.isscalar(v):
|
32
|
+
return v
|
33
|
+
if hasattr(v, "shape"):
|
34
|
+
return f"{v.__class__.__name__}: shape={getattr(v, 'shape')}"
|
35
|
+
if hasattr(v, "__len__"):
|
36
|
+
return f"{v.__class__.__name__}: len={len(v)}"
|
37
|
+
return f"{v.__class__.__name__}"
|
38
|
+
|
39
|
+
time = datetime.now(timezone.utc)
|
40
|
+
result = fn(*args, **kwargs)
|
41
|
+
duration = (datetime.now(timezone.utc) - time).total_seconds()
|
42
|
+
fn_params = inspect.signature(fn).parameters
|
43
|
+
# set all params with defaults then update params with mapped arguments and explicit keyword args
|
44
|
+
arguments = {k: None if v.default is inspect.Parameter.empty else v.default for k, v in fn_params.items()}
|
45
|
+
arguments.update(zip(fn_params, args))
|
46
|
+
arguments.update(kwargs)
|
47
|
+
arguments = {k: fmt(v) for k, v in arguments.items()}
|
48
|
+
state = (
|
49
|
+
{k: fmt(getattr(args[0], k)) for k in state_attr if "self" in arguments}
|
50
|
+
if "self" in arguments and state_attr
|
51
|
+
else {}
|
52
|
+
)
|
53
|
+
name = args[0].__class__.__name__ if "self" in arguments else fn.__name__
|
54
|
+
metadata = {
|
55
|
+
"_name": f"{module_name}.{name}",
|
56
|
+
"_execution_time": time,
|
57
|
+
"_execution_duration": duration,
|
58
|
+
"_arguments": {k: v for k, v in arguments.items() if k != "self"},
|
59
|
+
"_state": state,
|
60
|
+
"_version": __version__,
|
61
|
+
}
|
62
|
+
for k, v in metadata.items():
|
63
|
+
object.__setattr__(result, k, v)
|
64
|
+
return result
|
65
|
+
|
66
|
+
return wrapper
|
67
|
+
|
68
|
+
return decorator
|
69
|
+
|
70
|
+
|
71
|
+
def populate_defaults(d: dict, c: type) -> dict:
|
72
|
+
def default(t):
|
73
|
+
name = t._name if hasattr(t, "_name") else t.__name__ # py3.9 : _name, py3.10 : __name__
|
74
|
+
if name == "Dict":
|
75
|
+
return {}
|
76
|
+
if name == "List":
|
77
|
+
return []
|
78
|
+
if name == "ndarray":
|
79
|
+
return np.array([])
|
80
|
+
raise TypeError("Unrecognized annotation type")
|
81
|
+
|
82
|
+
return {k: d[k] if k in d else default(t) for k, t in c.__annotations__.items()}
|
@@ -0,0 +1,64 @@
|
|
1
|
+
from collections import defaultdict
|
2
|
+
from typing import Any, Dict, List
|
3
|
+
|
4
|
+
from torch.utils.data import Dataset
|
5
|
+
|
6
|
+
|
7
|
+
def read_dataset(dataset: Dataset) -> List[List[Any]]:
|
8
|
+
"""
|
9
|
+
Extract information from a dataset at each index into a individual lists of each information position
|
10
|
+
|
11
|
+
Parameters
|
12
|
+
----------
|
13
|
+
dataset : torch.utils.data.Dataset
|
14
|
+
Input dataset
|
15
|
+
|
16
|
+
Returns
|
17
|
+
-------
|
18
|
+
List[List[Any]]
|
19
|
+
All objects in individual lists based on return position from dataset
|
20
|
+
|
21
|
+
Warning
|
22
|
+
-------
|
23
|
+
No type checking is done between lists or data inside lists
|
24
|
+
|
25
|
+
See Also
|
26
|
+
--------
|
27
|
+
torch.utils.data.Dataset
|
28
|
+
|
29
|
+
Examples
|
30
|
+
--------
|
31
|
+
>>> import numpy as np
|
32
|
+
|
33
|
+
>>> data = np.ones((10, 3, 3))
|
34
|
+
>>> labels = np.ones((10,))
|
35
|
+
>>> class ICDataset:
|
36
|
+
... def __init__(self, data, labels):
|
37
|
+
... self.data = data
|
38
|
+
... self.labels = labels
|
39
|
+
|
40
|
+
... def __getitem__(self, idx):
|
41
|
+
... return self.data[idx], self.labels[idx]
|
42
|
+
|
43
|
+
>>> ds = ICDataset(data, labels)
|
44
|
+
|
45
|
+
>>> result = read_dataset(ds)
|
46
|
+
>>> assert len(result) == 2
|
47
|
+
True
|
48
|
+
>>> assert result[0].shape == (10, 3, 3) # 10 3x3 images
|
49
|
+
True
|
50
|
+
>>> assert result[1].shape == (10,) # 10 labels
|
51
|
+
True
|
52
|
+
"""
|
53
|
+
|
54
|
+
ddict: Dict[int, List] = defaultdict(list)
|
55
|
+
|
56
|
+
for data in dataset:
|
57
|
+
# Convert to tuple if single return (e.g. images only)
|
58
|
+
if not isinstance(data, tuple):
|
59
|
+
data = (data,)
|
60
|
+
|
61
|
+
for i, d in enumerate(data):
|
62
|
+
ddict[i].append(d)
|
63
|
+
|
64
|
+
return list(ddict.values())
|