dataeval 0.63.0__py3-none-any.whl → 0.65.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. dataeval/__init__.py +4 -4
  2. dataeval/_internal/detectors/clusterer.py +47 -34
  3. dataeval/_internal/detectors/drift/base.py +53 -35
  4. dataeval/_internal/detectors/drift/cvm.py +5 -4
  5. dataeval/_internal/detectors/drift/ks.py +7 -6
  6. dataeval/_internal/detectors/drift/mmd.py +39 -19
  7. dataeval/_internal/detectors/drift/torch.py +6 -5
  8. dataeval/_internal/detectors/drift/uncertainty.py +7 -8
  9. dataeval/_internal/detectors/duplicates.py +57 -30
  10. dataeval/_internal/detectors/linter.py +40 -24
  11. dataeval/_internal/detectors/ood/ae.py +2 -1
  12. dataeval/_internal/detectors/ood/aegmm.py +2 -1
  13. dataeval/_internal/detectors/ood/base.py +37 -15
  14. dataeval/_internal/detectors/ood/llr.py +9 -8
  15. dataeval/_internal/detectors/ood/vae.py +2 -1
  16. dataeval/_internal/detectors/ood/vaegmm.py +2 -1
  17. dataeval/_internal/flags.py +42 -21
  18. dataeval/_internal/interop.py +3 -12
  19. dataeval/_internal/metrics/balance.py +188 -0
  20. dataeval/_internal/metrics/ber.py +123 -48
  21. dataeval/_internal/metrics/coverage.py +90 -74
  22. dataeval/_internal/metrics/divergence.py +101 -67
  23. dataeval/_internal/metrics/diversity.py +211 -0
  24. dataeval/_internal/metrics/parity.py +287 -155
  25. dataeval/_internal/metrics/stats.py +198 -317
  26. dataeval/_internal/metrics/uap.py +40 -29
  27. dataeval/_internal/metrics/utils.py +430 -0
  28. dataeval/_internal/models/tensorflow/losses.py +3 -3
  29. dataeval/_internal/models/tensorflow/trainer.py +3 -2
  30. dataeval/_internal/models/tensorflow/utils.py +4 -3
  31. dataeval/_internal/output.py +82 -0
  32. dataeval/_internal/utils.py +64 -0
  33. dataeval/_internal/workflows/sufficiency.py +96 -107
  34. dataeval/flags/__init__.py +2 -2
  35. dataeval/metrics/__init__.py +26 -7
  36. dataeval/utils/__init__.py +9 -0
  37. {dataeval-0.63.0.dist-info → dataeval-0.65.0.dist-info}/METADATA +1 -1
  38. dataeval-0.65.0.dist-info/RECORD +60 -0
  39. dataeval/_internal/functional/__init__.py +0 -0
  40. dataeval/_internal/functional/ber.py +0 -63
  41. dataeval/_internal/functional/coverage.py +0 -75
  42. dataeval/_internal/functional/divergence.py +0 -16
  43. dataeval/_internal/functional/hash.py +0 -79
  44. dataeval/_internal/functional/metadata.py +0 -136
  45. dataeval/_internal/functional/metadataparity.py +0 -190
  46. dataeval/_internal/functional/uap.py +0 -6
  47. dataeval/_internal/functional/utils.py +0 -158
  48. dataeval/_internal/maite/__init__.py +0 -0
  49. dataeval/_internal/maite/utils.py +0 -30
  50. dataeval/_internal/metrics/base.py +0 -92
  51. dataeval/_internal/metrics/metadata.py +0 -610
  52. dataeval/_internal/metrics/metadataparity.py +0 -67
  53. dataeval-0.63.0.dist-info/RECORD +0 -68
  54. {dataeval-0.63.0.dist-info → dataeval-0.65.0.dist-info}/LICENSE.txt +0 -0
  55. {dataeval-0.63.0.dist-info → dataeval-0.65.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,430 @@
1
+ from typing import Any, Callable, Dict, List, Literal, NamedTuple, Optional, Sequence, Tuple, Union
2
+
3
+ import numpy as np
4
+ import xxhash as xxh
5
+ from numpy.typing import NDArray
6
+ from PIL import Image
7
+ from scipy.fftpack import dct
8
+ from scipy.signal import convolve2d
9
+ from scipy.sparse import csr_matrix
10
+ from scipy.sparse.csgraph import minimum_spanning_tree as mst
11
+ from scipy.spatial.distance import pdist, squareform
12
+ from scipy.stats import entropy as sp_entropy
13
+ from sklearn.neighbors import NearestNeighbors
14
+
15
+ EPSILON = 1e-5
16
+ EDGE_KERNEL = np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=np.int8)
17
+ BIT_DEPTH = (1, 8, 12, 16, 32)
18
+ HASH_SIZE = 8
19
+ MAX_FACTOR = 4
20
+
21
+
22
+ def get_method(method_map: Dict[str, Callable], method: str) -> Callable:
23
+ if method not in method_map:
24
+ raise ValueError(f"Specified method {method} is not a valid method: {method_map}.")
25
+ return method_map[method]
26
+
27
+
28
+ def get_counts(
29
+ data: NDArray, names: List[str], is_categorical: List[bool], subset_mask: Optional[NDArray[np.bool_]] = None
30
+ ) -> tuple[Dict, Dict]:
31
+ """
32
+ Initialize dictionary of histogram counts --- treat categorical values
33
+ as histogram bins.
34
+
35
+ Parameters
36
+ ----------
37
+ subset_mask: Optional[NDArray[np.bool_]]
38
+ Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
39
+
40
+ Returns
41
+ -------
42
+ counts: Dict
43
+ histogram counts per metadata factor in `factors`. Each
44
+ factor will have a different number of bins. Counts get reused
45
+ across metrics, so hist_counts are cached but only if computed
46
+ globally, i.e. without masked samples.
47
+ """
48
+
49
+ hist_counts, hist_bins = {}, {}
50
+ # np.where needed to satisfy linter
51
+ mask = np.where(subset_mask if subset_mask is not None else np.ones(data.shape[0], dtype=bool))
52
+
53
+ for cdx, fn in enumerate(names):
54
+ # linter doesn't like double indexing
55
+ col_data = data[mask, cdx].squeeze()
56
+ if is_categorical[cdx]:
57
+ # if discrete, use unique values as bins
58
+ bins, cnts = np.unique(col_data, return_counts=True)
59
+ else:
60
+ bins = hist_bins.get(fn, "auto")
61
+ cnts, bins = np.histogram(col_data, bins=bins, density=True)
62
+
63
+ hist_counts[fn] = cnts
64
+ hist_bins[fn] = bins
65
+
66
+ return hist_counts, hist_bins
67
+
68
+
69
+ def entropy(
70
+ data: NDArray,
71
+ names: List[str],
72
+ is_categorical: List[bool],
73
+ normalized: bool = False,
74
+ subset_mask: Optional[NDArray[np.bool_]] = None,
75
+ ) -> NDArray[np.float64]:
76
+ """
77
+ Meant for use with Bias metrics, Balance, Diversity, ClasswiseBalance,
78
+ and Classwise Diversity.
79
+
80
+ Compute entropy for discrete/categorical variables and for continuous variables through standard
81
+ histogram binning.
82
+
83
+ Parameters
84
+ ----------
85
+ normalized: bool
86
+ Flag that determines whether or not to normalize entropy by log(num_bins)
87
+ subset_mask: Optional[NDArray[np.bool_]]
88
+ Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
89
+
90
+ Notes
91
+ -----
92
+ For continuous variables, histogram bins are chosen automatically. See
93
+ numpy.histogram for details.
94
+
95
+ Returns
96
+ -------
97
+ ent: NDArray[np.float64]
98
+ Entropy estimate per column of X
99
+
100
+ See Also
101
+ --------
102
+ numpy.histogram
103
+ scipy.stats.entropy
104
+ """
105
+
106
+ num_factors = len(names)
107
+ hist_counts, _ = get_counts(data, names, is_categorical, subset_mask)
108
+
109
+ ev_index = np.empty(num_factors)
110
+ for col, cnts in enumerate(hist_counts.values()):
111
+ # entropy in nats, normalizes counts
112
+ ev_index[col] = sp_entropy(cnts)
113
+ if normalized:
114
+ if len(cnts) == 1:
115
+ # log(0)
116
+ ev_index[col] = 0
117
+ else:
118
+ ev_index[col] /= np.log(len(cnts))
119
+ return ev_index
120
+
121
+
122
+ def get_num_bins(
123
+ data: NDArray, names: List[str], is_categorical: List[bool], subset_mask: Optional[NDArray[np.bool_]] = None
124
+ ) -> NDArray[np.float64]:
125
+ """
126
+ Number of bins or unique values for each metadata factor, used to
127
+ normalize entropy/diversity.
128
+
129
+ Parameters
130
+ ----------
131
+ subset_mask: Optional[NDArray[np.bool_]]
132
+ Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
133
+
134
+ Returns
135
+ -------
136
+ NDArray[np.float64]
137
+ """
138
+ # likely cached
139
+ hist_counts, _ = get_counts(data, names, is_categorical, subset_mask)
140
+ num_bins = np.empty(len(hist_counts))
141
+ for idx, cnts in enumerate(hist_counts.values()):
142
+ num_bins[idx] = len(cnts)
143
+
144
+ return num_bins
145
+
146
+
147
+ def infer_categorical(X: NDArray, threshold: float = 0.5) -> NDArray:
148
+ """
149
+ Compute fraction of feature values that are unique --- intended to be used
150
+ for inferring whether variables are categorical.
151
+ """
152
+ if X.ndim == 1:
153
+ X = np.expand_dims(X, axis=1)
154
+ num_samples = X.shape[0]
155
+ pct_unique = np.empty(X.shape[1])
156
+ for col in range(X.shape[1]): # type: ignore
157
+ uvals = np.unique(X[:, col], axis=0)
158
+ pct_unique[col] = len(uvals) / num_samples
159
+ return pct_unique < threshold
160
+
161
+
162
+ def preprocess_metadata(
163
+ class_labels: Sequence[int], metadata: List[Dict], cat_thresh: float = 0.2
164
+ ) -> Tuple[NDArray, List[str], List[bool]]:
165
+ # convert class_labels and list of metadata dicts to dict of ndarrays
166
+ metadata_dict: Dict[str, NDArray] = {
167
+ "class_label": np.asarray(class_labels, dtype=int),
168
+ **{k: np.array([d[k] for d in metadata]) for k in metadata[0]},
169
+ }
170
+
171
+ # map columns of dict that are not numeric (e.g. string) to numeric values
172
+ # that mutual information and diversity functions can accommodate. Each
173
+ # unique string receives a unique integer value.
174
+ for k, v in metadata_dict.items():
175
+ # if not numeric
176
+ if not np.issubdtype(v.dtype, np.number):
177
+ _, mapped_vals = np.unique(v, return_inverse=True)
178
+ metadata_dict[k] = mapped_vals
179
+
180
+ data = np.stack(list(metadata_dict.values()), axis=-1)
181
+ names = list(metadata_dict.keys())
182
+ is_categorical = [infer_categorical(metadata_dict[var], cat_thresh)[0] for var in names]
183
+
184
+ return data, names, is_categorical
185
+
186
+
187
+ def flatten(X: NDArray):
188
+ """
189
+ Flattens input array from (N, ... ) to (N, -1) where all samples N have all data in their last dimension
190
+
191
+ Parameters
192
+ ----------
193
+ X : NDArray, shape - (N, ... )
194
+ Input array
195
+
196
+ Returns
197
+ -------
198
+ NDArray, shape - (N, -1)
199
+ """
200
+
201
+ return X.reshape((X.shape[0], -1))
202
+
203
+
204
+ def minimum_spanning_tree(X: NDArray) -> Any:
205
+ """
206
+ Returns the minimum spanning tree from a NumPy image array.
207
+
208
+ Parameters
209
+ ----------
210
+ X : NDArray
211
+ Numpy image array
212
+
213
+ Returns
214
+ -------
215
+ Data representing the minimum spanning tree
216
+ """
217
+ # All features belong on second dimension
218
+ X = flatten(X)
219
+ # We add a small constant to the distance matrix to ensure scipy interprets
220
+ # the input graph as fully-connected.
221
+ dense_eudist = squareform(pdist(X)) + EPSILON
222
+ eudist_csr = csr_matrix(dense_eudist)
223
+ return mst(eudist_csr)
224
+
225
+
226
+ def get_classes_counts(labels: NDArray) -> Tuple[int, int]:
227
+ """
228
+ Returns the classes and counts of from an array of labels
229
+
230
+ Parameters
231
+ ----------
232
+ label : NDArray
233
+ Numpy labels array
234
+
235
+ Returns
236
+ -------
237
+ Classes and counts
238
+
239
+ Raises
240
+ ------
241
+ ValueError
242
+ If the number of unique classes is less than 2
243
+ """
244
+ classes, counts = np.unique(labels, return_counts=True)
245
+ M = len(classes)
246
+ if M < 2:
247
+ raise ValueError("Label vector contains less than 2 classes!")
248
+ N = np.sum(counts).astype(int)
249
+ return M, N
250
+
251
+
252
+ def compute_neighbors(
253
+ A: NDArray,
254
+ B: NDArray,
255
+ k: int = 1,
256
+ algorithm: Literal["auto", "ball_tree", "kd_tree"] = "auto",
257
+ ) -> NDArray:
258
+ """
259
+ For each sample in A, compute the nearest neighbor in B
260
+
261
+ Parameters
262
+ ----------
263
+ A, B : NDArray
264
+ The n_samples and n_features respectively
265
+ k : int
266
+ The number of neighbors to find
267
+ algorithm : Literal
268
+ Tree method for nearest neighbor (auto, ball_tree or kd_tree)
269
+
270
+ Note
271
+ ----
272
+ Do not use kd_tree if n_features > 20
273
+
274
+ Returns
275
+ -------
276
+ List:
277
+ Closest points to each point in A and B
278
+
279
+ Raises
280
+ ------
281
+ ValueError
282
+ If algorithm is not "auto", "ball_tree", or "kd_tree"
283
+
284
+ See Also
285
+ --------
286
+ sklearn.neighbors.NearestNeighbors
287
+ """
288
+
289
+ if k < 1:
290
+ raise ValueError("k must be >= 1")
291
+ if algorithm not in ["auto", "ball_tree", "kd_tree"]:
292
+ raise ValueError("Algorithm must be 'auto', 'ball_tree', or 'kd_tree'")
293
+
294
+ A = flatten(A)
295
+ B = flatten(B)
296
+
297
+ nbrs = NearestNeighbors(n_neighbors=k + 1, algorithm=algorithm).fit(B)
298
+ nns = nbrs.kneighbors(A)[1]
299
+ nns = nns[:, 1:].squeeze()
300
+
301
+ return nns
302
+
303
+
304
+ class BitDepth(NamedTuple):
305
+ depth: int
306
+ pmin: Union[float, int]
307
+ pmax: Union[float, int]
308
+
309
+
310
+ def get_bitdepth(image: NDArray) -> BitDepth:
311
+ """
312
+ Approximates the bit depth of the image using the
313
+ min and max pixel values.
314
+ """
315
+ pmin, pmax = np.min(image), np.max(image)
316
+ if pmin < 0:
317
+ return BitDepth(0, pmin, pmax)
318
+ else:
319
+ depth = ([x for x in BIT_DEPTH if 2**x > pmax] or [max(BIT_DEPTH)])[0]
320
+ return BitDepth(depth, 0, 2**depth - 1)
321
+
322
+
323
+ def rescale(image: NDArray, depth: int = 1) -> NDArray:
324
+ """
325
+ Rescales the image using the bit depth provided.
326
+ """
327
+ bitdepth = get_bitdepth(image)
328
+ if bitdepth.depth == depth:
329
+ return image
330
+ else:
331
+ normalized = (image + bitdepth.pmin) / (bitdepth.pmax - bitdepth.pmin)
332
+ return normalized * (2**depth - 1)
333
+
334
+
335
+ def normalize_image_shape(image: NDArray) -> NDArray:
336
+ """
337
+ Normalizes the image shape into (C,H,W).
338
+ """
339
+ ndim = image.ndim
340
+ if ndim == 2:
341
+ return np.expand_dims(image, axis=0)
342
+ elif ndim == 3:
343
+ return image
344
+ elif ndim > 3:
345
+ # Slice all but the last 3 dimensions
346
+ return image[(0,) * (ndim - 3)]
347
+ else:
348
+ raise ValueError("Images must have 2 or more dimensions.")
349
+
350
+
351
+ def edge_filter(image: NDArray, offset: float = 0.5) -> NDArray:
352
+ """
353
+ Returns the image filtered using a 3x3 edge detection kernel:
354
+ [[ -1, -1, -1 ],
355
+ [ -1, 8, -1 ],
356
+ [ -1, -1, -1 ]]
357
+ """
358
+ edges = convolve2d(image, EDGE_KERNEL, mode="same", boundary="symm") + offset
359
+ np.clip(edges, 0, 255, edges)
360
+ return edges
361
+
362
+
363
+ def pchash(image: NDArray) -> str:
364
+ """
365
+ Performs a perceptual hash on an image by resizing to a square NxN image
366
+ using the Lanczos algorithm where N is 32x32 or the largest multiple of
367
+ 8 that is smaller than the input image dimensions. The resampled image
368
+ is compressed using a discrete cosine transform and the lowest frequency
369
+ component is encoded as a bit array of greater or less than median value
370
+ and returned as a hex string.
371
+
372
+ Parameters
373
+ ----------
374
+ image : NDArray
375
+ An image as a numpy array in CxHxW format
376
+
377
+ Returns
378
+ -------
379
+ str
380
+ The hex string hash of the image using perceptual hashing
381
+ """
382
+ # Verify that the image is at least larger than an 8x8 image
383
+ min_dim = min(image.shape[-2:])
384
+ if min_dim < HASH_SIZE + 1:
385
+ raise ValueError(f"Image must be larger than {HASH_SIZE}x{HASH_SIZE} for fuzzy hashing.")
386
+
387
+ # Calculates the dimensions of the resized square image
388
+ resize_dim = HASH_SIZE * min((min_dim - 1) // HASH_SIZE, MAX_FACTOR)
389
+
390
+ # Normalizes the image to CxHxW and takes the mean over all the channels
391
+ normalized = np.mean(normalize_image_shape(image), axis=0).squeeze()
392
+
393
+ # Rescales the pixel values to an 8-bit 0-255 image
394
+ rescaled = rescale(normalized, 8).astype(np.uint8)
395
+
396
+ # Resizes the image using the Lanczos algorithm to a square image
397
+ im = np.array(Image.fromarray(rescaled).resize((resize_dim, resize_dim), Image.Resampling.LANCZOS))
398
+
399
+ # Performs discrete cosine transforms to compress the image information and takes the lowest frequency component
400
+ transform = dct(dct(im.T).T)[:HASH_SIZE, :HASH_SIZE]
401
+
402
+ # Encodes the transform as a bit array over the median value
403
+ diff = transform > np.median(transform)
404
+
405
+ # Pads the front of the bit array to a multiple of 8 with False
406
+ padded = np.full(int(np.ceil(diff.size / 8) * 8), False)
407
+ padded[-diff.size :] = diff.ravel()
408
+
409
+ # Converts the bit array to a hex string and strips leading 0s
410
+ hash_hex = np.packbits(padded).tobytes().hex().lstrip("0")
411
+ return hash_hex if hash_hex else "0"
412
+
413
+
414
+ def xxhash(image: NDArray) -> str:
415
+ """
416
+ Performs a fast non-cryptographic hash using the xxhash algorithm
417
+ (xxhash.com) against the image as a flattened bytearray. The hash
418
+ is returned as a hex string.
419
+
420
+ Parameters
421
+ ----------
422
+ image : NDArray
423
+ An image as a numpy array
424
+
425
+ Returns
426
+ -------
427
+ str
428
+ The hex string hash of the image using the xxHash algorithm
429
+ """
430
+ return xxh.xxh3_64_hexdigest(image.ravel().tobytes())
@@ -8,9 +8,9 @@ Licensed under Apache Software License (Apache 2.0)
8
8
 
9
9
  from typing import Literal, Optional, Union, cast
10
10
 
11
- import numpy as np
12
11
  import tensorflow as tf
13
12
  from keras.layers import Flatten
13
+ from numpy.typing import NDArray
14
14
  from tensorflow_probability.python.distributions.mvn_diag import MultivariateNormalDiag
15
15
  from tensorflow_probability.python.distributions.mvn_tril import MultivariateNormalTriL
16
16
  from tensorflow_probability.python.stats import covariance
@@ -35,12 +35,12 @@ class Elbo:
35
35
  def __init__(
36
36
  self,
37
37
  cov_type: Union[Literal["cov_full", "cov_diag"], float] = 1.0,
38
- x: Optional[Union[tf.Tensor, np.ndarray]] = None,
38
+ x: Optional[Union[tf.Tensor, NDArray]] = None,
39
39
  ):
40
40
  if isinstance(cov_type, float):
41
41
  self.cov = ("sim", cov_type)
42
42
  elif cov_type in ["cov_full", "cov_diag"]:
43
- x_np: np.ndarray = x.numpy() if tf.is_tensor(x) else x # type: ignore
43
+ x_np: NDArray = x.numpy() if tf.is_tensor(x) else x # type: ignore
44
44
  cov = covariance(x_np.reshape(x_np.shape[0], -1)) # type: ignore py38
45
45
  if cov_type == "cov_diag": # infer standard deviation from covariance matrix
46
46
  cov = tf.math.sqrt(tf.linalg.diag_part(cov))
@@ -11,12 +11,13 @@ from typing import Callable, Iterable, Optional, Tuple, cast
11
11
  import keras
12
12
  import numpy as np
13
13
  import tensorflow as tf
14
+ from numpy.typing import NDArray
14
15
 
15
16
 
16
17
  def trainer(
17
18
  model: keras.Model,
18
- x_train: np.ndarray,
19
- y_train: Optional[np.ndarray] = None,
19
+ x_train: NDArray,
20
+ y_train: Optional[NDArray] = None,
20
21
  loss_fn: Optional[Callable[..., tf.Tensor]] = None,
21
22
  optimizer: keras.optimizers.Optimizer = keras.optimizers.Adam,
22
23
  preprocess_fn: Optional[Callable[[tf.Tensor], tf.Tensor]] = None,
@@ -21,6 +21,7 @@ from keras.layers import (
21
21
  InputLayer,
22
22
  Reshape,
23
23
  )
24
+ from numpy.typing import NDArray
24
25
  from tensorflow._api.v2.nn import relu, softmax, tanh
25
26
 
26
27
  from dataeval._internal.models.tensorflow.autoencoder import AE, AEGMM, VAE, VAEGMM
@@ -28,12 +29,12 @@ from dataeval._internal.models.tensorflow.pixelcnn import PixelCNN
28
29
 
29
30
 
30
31
  def predict_batch(
31
- x: Union[list, np.ndarray, tf.Tensor],
32
+ x: Union[list, NDArray, tf.Tensor],
32
33
  model: Union[Callable, keras.Model],
33
34
  batch_size: int = int(1e10),
34
35
  preprocess_fn: Optional[Callable] = None,
35
36
  dtype: Union[Type[np.generic], tf.DType] = np.float32,
36
- ) -> Union[np.ndarray, tf.Tensor, tuple, list]:
37
+ ) -> Union[NDArray, tf.Tensor, tuple, list]:
37
38
  """
38
39
  Make batch predictions on a model.
39
40
 
@@ -80,7 +81,7 @@ def predict_batch(
80
81
  else:
81
82
  raise TypeError(
82
83
  f"Model output type {type(preds_tmp)} not supported. The model output "
83
- f"type needs to be one of list, tuple, np.ndarray or tf.Tensor."
84
+ f"type needs to be one of list, tuple, NDArray or tf.Tensor."
84
85
  )
85
86
  concat = np.concatenate if return_np else tf.concat
86
87
  out = cast(
@@ -0,0 +1,82 @@
1
+ import inspect
2
+ from datetime import datetime, timezone
3
+ from functools import wraps
4
+ from typing import Dict, List, Optional
5
+
6
+ import numpy as np
7
+
8
+ from dataeval import __version__
9
+
10
+
11
+ class OutputMetadata:
12
+ _name: str
13
+ _execution_time: str
14
+ _execution_duration: float
15
+ _arguments: Dict[str, str]
16
+ _state: Dict[str, str]
17
+ _version: str
18
+
19
+ def dict(self) -> Dict:
20
+ return {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
21
+
22
+ def meta(self) -> Dict:
23
+ return {k.removeprefix("_"): v for k, v in self.__dict__.items() if k.startswith("_")}
24
+
25
+
26
+ def set_metadata(module_name: str = "", state_attr: Optional[List[str]] = None):
27
+ def decorator(fn):
28
+ @wraps(fn)
29
+ def wrapper(*args, **kwargs):
30
+ def fmt(v):
31
+ if np.isscalar(v):
32
+ return v
33
+ if hasattr(v, "shape"):
34
+ return f"{v.__class__.__name__}: shape={getattr(v, 'shape')}"
35
+ if hasattr(v, "__len__"):
36
+ return f"{v.__class__.__name__}: len={len(v)}"
37
+ return f"{v.__class__.__name__}"
38
+
39
+ time = datetime.now(timezone.utc)
40
+ result = fn(*args, **kwargs)
41
+ duration = (datetime.now(timezone.utc) - time).total_seconds()
42
+ fn_params = inspect.signature(fn).parameters
43
+ # set all params with defaults then update params with mapped arguments and explicit keyword args
44
+ arguments = {k: None if v.default is inspect.Parameter.empty else v.default for k, v in fn_params.items()}
45
+ arguments.update(zip(fn_params, args))
46
+ arguments.update(kwargs)
47
+ arguments = {k: fmt(v) for k, v in arguments.items()}
48
+ state = (
49
+ {k: fmt(getattr(args[0], k)) for k in state_attr if "self" in arguments}
50
+ if "self" in arguments and state_attr
51
+ else {}
52
+ )
53
+ name = args[0].__class__.__name__ if "self" in arguments else fn.__name__
54
+ metadata = {
55
+ "_name": f"{module_name}.{name}",
56
+ "_execution_time": time,
57
+ "_execution_duration": duration,
58
+ "_arguments": {k: v for k, v in arguments.items() if k != "self"},
59
+ "_state": state,
60
+ "_version": __version__,
61
+ }
62
+ for k, v in metadata.items():
63
+ object.__setattr__(result, k, v)
64
+ return result
65
+
66
+ return wrapper
67
+
68
+ return decorator
69
+
70
+
71
+ def populate_defaults(d: dict, c: type) -> dict:
72
+ def default(t):
73
+ name = t._name if hasattr(t, "_name") else t.__name__ # py3.9 : _name, py3.10 : __name__
74
+ if name == "Dict":
75
+ return {}
76
+ if name == "List":
77
+ return []
78
+ if name == "ndarray":
79
+ return np.array([])
80
+ raise TypeError("Unrecognized annotation type")
81
+
82
+ return {k: d[k] if k in d else default(t) for k, t in c.__annotations__.items()}
@@ -0,0 +1,64 @@
1
+ from collections import defaultdict
2
+ from typing import Any, Dict, List
3
+
4
+ from torch.utils.data import Dataset
5
+
6
+
7
+ def read_dataset(dataset: Dataset) -> List[List[Any]]:
8
+ """
9
+ Extract information from a dataset at each index into a individual lists of each information position
10
+
11
+ Parameters
12
+ ----------
13
+ dataset : torch.utils.data.Dataset
14
+ Input dataset
15
+
16
+ Returns
17
+ -------
18
+ List[List[Any]]
19
+ All objects in individual lists based on return position from dataset
20
+
21
+ Warning
22
+ -------
23
+ No type checking is done between lists or data inside lists
24
+
25
+ See Also
26
+ --------
27
+ torch.utils.data.Dataset
28
+
29
+ Examples
30
+ --------
31
+ >>> import numpy as np
32
+
33
+ >>> data = np.ones((10, 3, 3))
34
+ >>> labels = np.ones((10,))
35
+ >>> class ICDataset:
36
+ ... def __init__(self, data, labels):
37
+ ... self.data = data
38
+ ... self.labels = labels
39
+
40
+ ... def __getitem__(self, idx):
41
+ ... return self.data[idx], self.labels[idx]
42
+
43
+ >>> ds = ICDataset(data, labels)
44
+
45
+ >>> result = read_dataset(ds)
46
+ >>> assert len(result) == 2
47
+ True
48
+ >>> assert result[0].shape == (10, 3, 3) # 10 3x3 images
49
+ True
50
+ >>> assert result[1].shape == (10,) # 10 labels
51
+ True
52
+ """
53
+
54
+ ddict: Dict[int, List] = defaultdict(list)
55
+
56
+ for data in dataset:
57
+ # Convert to tuple if single return (e.g. images only)
58
+ if not isinstance(data, tuple):
59
+ data = (data,)
60
+
61
+ for i, d in enumerate(data):
62
+ ddict[i].append(d)
63
+
64
+ return list(ddict.values())