dataeval 0.82.1__py3-none-any.whl → 0.84.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. dataeval/__init__.py +7 -2
  2. dataeval/config.py +13 -3
  3. dataeval/metadata/__init__.py +2 -2
  4. dataeval/metadata/_ood.py +144 -27
  5. dataeval/metrics/bias/__init__.py +11 -1
  6. dataeval/metrics/bias/_balance.py +3 -3
  7. dataeval/metrics/bias/_completeness.py +130 -0
  8. dataeval/metrics/estimators/_ber.py +2 -1
  9. dataeval/metrics/stats/_base.py +31 -36
  10. dataeval/metrics/stats/_dimensionstats.py +2 -2
  11. dataeval/metrics/stats/_hashstats.py +2 -2
  12. dataeval/metrics/stats/_imagestats.py +4 -4
  13. dataeval/metrics/stats/_labelstats.py +4 -45
  14. dataeval/metrics/stats/_pixelstats.py +2 -2
  15. dataeval/metrics/stats/_visualstats.py +2 -2
  16. dataeval/outputs/__init__.py +4 -2
  17. dataeval/outputs/_bias.py +31 -22
  18. dataeval/outputs/_metadata.py +7 -0
  19. dataeval/outputs/_stats.py +2 -3
  20. dataeval/typing.py +43 -12
  21. dataeval/utils/_array.py +26 -1
  22. dataeval/utils/_mst.py +1 -2
  23. dataeval/utils/data/_dataset.py +2 -0
  24. dataeval/utils/data/_embeddings.py +115 -32
  25. dataeval/utils/data/_images.py +38 -15
  26. dataeval/utils/data/_selection.py +7 -8
  27. dataeval/utils/data/_split.py +76 -129
  28. dataeval/utils/data/datasets/_base.py +4 -2
  29. dataeval/utils/data/datasets/_cifar10.py +17 -9
  30. dataeval/utils/data/datasets/_milco.py +18 -12
  31. dataeval/utils/data/datasets/_mnist.py +24 -8
  32. dataeval/utils/data/datasets/_ships.py +18 -8
  33. dataeval/utils/data/datasets/_types.py +1 -5
  34. dataeval/utils/data/datasets/_voc.py +47 -24
  35. dataeval/utils/data/selections/__init__.py +2 -0
  36. dataeval/utils/data/selections/_classfilter.py +1 -1
  37. dataeval/utils/data/selections/_prioritize.py +296 -0
  38. dataeval/utils/data/selections/_shuffle.py +13 -4
  39. dataeval/utils/metadata.py +1 -1
  40. dataeval/utils/torch/_gmm.py +3 -2
  41. {dataeval-0.82.1.dist-info → dataeval-0.84.0.dist-info}/METADATA +4 -4
  42. {dataeval-0.82.1.dist-info → dataeval-0.84.0.dist-info}/RECORD +44 -43
  43. dataeval/detectors/ood/metadata_ood_mi.py +0 -91
  44. {dataeval-0.82.1.dist-info → dataeval-0.84.0.dist-info}/LICENSE.txt +0 -0
  45. {dataeval-0.82.1.dist-info → dataeval-0.84.0.dist-info}/WHEEL +0 -0
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  from pathlib import Path
6
- from typing import Any, Literal, Sequence, TypeVar
6
+ from typing import TYPE_CHECKING, Any, Literal, Sequence, TypeVar
7
7
 
8
8
  import torch
9
9
  from defusedxml.ElementTree import parse
@@ -16,7 +16,10 @@ from dataeval.utils.data.datasets._base import (
16
16
  DataLocation,
17
17
  )
18
18
  from dataeval.utils.data.datasets._mixin import BaseDatasetNumpyMixin, BaseDatasetTorchMixin
19
- from dataeval.utils.data.datasets._types import ObjectDetectionTarget, SegmentationTarget, Transform
19
+ from dataeval.utils.data.datasets._types import ObjectDetectionTarget, SegmentationTarget
20
+
21
+ if TYPE_CHECKING:
22
+ from dataeval.typing import Transform
20
23
 
21
24
  _TArray = TypeVar("_TArray")
22
25
  _TTarget = TypeVar("_TTarget")
@@ -201,6 +204,8 @@ class BaseVOCDataset(BaseDataset[_TArray, _TTarget, list[str]]):
201
204
  boxes: list[list[float]] = []
202
205
  label_str = []
203
206
  root = parse(annotation).getroot()
207
+ if root is None:
208
+ raise ValueError(f"Unable to parse {annotation}")
204
209
  num_objects = len(root.findall("object"))
205
210
  additional_meta: dict[str, Any] = {
206
211
  "folder": [root.findtext("folder", default="") for _ in range(num_objects)],
@@ -253,21 +258,27 @@ class VOCDetection(
253
258
  If "base", then the combined dataset of "train" and "val" is returned.
254
259
  year : "2007", "2008", "2009", "2010", "2011" or "2012", default "2012"
255
260
  The dataset year.
256
- transforms : Transform | Sequence[Transform] | None, default None
261
+ transforms : Transform, Sequence[Transform] or None, default None
257
262
  Transform(s) to apply to the data.
258
263
  verbose : bool, default False
259
264
  If True, outputs print statements.
260
265
 
261
266
  Attributes
262
267
  ----------
263
- index2label : dict
268
+ path : pathlib.Path
269
+ Location of the folder containing the data.
270
+ image_set : "train", "val", "test" or "base"
271
+ The selected image set from the dataset.
272
+ index2label : dict[int, str]
264
273
  Dictionary which translates from class integers to the associated class strings.
265
- label2index : dict
274
+ label2index : dict[str, int]
266
275
  Dictionary which translates from class strings to the associated class integers.
267
- path : Path
268
- Location of the folder containing the data.
269
- metadata : dict
270
- Dictionary containing Dataset metadata, such as `id` which returns the dataset class name.
276
+ metadata : DatasetMetadata
277
+ Typed dictionary containing dataset metadata, such as `id` which returns the dataset class name.
278
+ transforms : Sequence[Transform]
279
+ The transforms to be applied to the data.
280
+ size : int
281
+ The size of the dataset.
271
282
  """
272
283
 
273
284
 
@@ -277,7 +288,7 @@ class VOCDetectionTorch(
277
288
  BaseDatasetTorchMixin,
278
289
  ):
279
290
  """
280
- `Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Detection Dataset.
291
+ `Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Detection Dataset as PyTorch tensors.
281
292
 
282
293
  Parameters
283
294
  ----------
@@ -291,21 +302,27 @@ class VOCDetectionTorch(
291
302
  If "base", then the combined dataset of "train" and "val" is returned.
292
303
  year : "2007", "2008", "2009", "2010", "2011" or "2012", default "2012"
293
304
  The dataset year.
294
- transforms : Transform | Sequence[Transform] | None, default None
305
+ transforms : Transform, Sequence[Transform] or None, default None
295
306
  Transform(s) to apply to the data.
296
307
  verbose : bool, default False
297
308
  If True, outputs print statements.
298
309
 
299
310
  Attributes
300
311
  ----------
301
- index2label : dict
312
+ path : pathlib.Path
313
+ Location of the folder containing the data.
314
+ image_set : "train", "val", "test" or "base"
315
+ The selected image set from the dataset.
316
+ index2label : dict[int, str]
302
317
  Dictionary which translates from class integers to the associated class strings.
303
- label2index : dict
318
+ label2index : dict[str, int]
304
319
  Dictionary which translates from class strings to the associated class integers.
305
- path : Path
306
- Location of the folder containing the data.
307
- metadata : dict
308
- Dictionary containing Dataset metadata, such as `id` which returns the dataset class name.
320
+ metadata : DatasetMetadata
321
+ Typed dictionary containing dataset metadata, such as `id` which returns the dataset class name.
322
+ transforms : Sequence[Transform]
323
+ The transforms to be applied to the data.
324
+ size : int
325
+ The size of the dataset.
309
326
  """
310
327
 
311
328
 
@@ -329,21 +346,27 @@ class VOCSegmentation(
329
346
  If "base", then the combined dataset of "train" and "val" is returned.
330
347
  year : "2007", "2008", "2009", "2010", "2011" or "2012", default "2012"
331
348
  The dataset year.
332
- transforms : Transform | Sequence[Transform] | None, default None
349
+ transforms : Transform, Sequence[Transform] or None, default None
333
350
  Transform(s) to apply to the data.
334
351
  verbose : bool, default False
335
352
  If True, outputs print statements.
336
353
 
337
354
  Attributes
338
355
  ----------
339
- index2label : dict
356
+ path : pathlib.Path
357
+ Location of the folder containing the data.
358
+ image_set : "train", "val", "test" or "base"
359
+ The selected image set from the dataset.
360
+ index2label : dict[int, str]
340
361
  Dictionary which translates from class integers to the associated class strings.
341
- label2index : dict
362
+ label2index : dict[str, int]
342
363
  Dictionary which translates from class strings to the associated class integers.
343
- path : Path
344
- Location of the folder containing the data.
345
- metadata : dict
346
- Dictionary containing Dataset metadata, such as `id` which returns the dataset class name.
364
+ metadata : DatasetMetadata
365
+ Typed dictionary containing dataset metadata, such as `id` which returns the dataset class name.
366
+ transforms : Sequence[Transform]
367
+ The transforms to be applied to the data.
368
+ size : int
369
+ The size of the dataset.
347
370
  """
348
371
 
349
372
  def _load_data(self) -> tuple[list[str], list[str], dict[str, list[Any]]]:
@@ -4,6 +4,7 @@ __all__ = [
4
4
  "ClassFilter",
5
5
  "Indices",
6
6
  "Limit",
7
+ "Prioritize",
7
8
  "Reverse",
8
9
  "Shuffle",
9
10
  ]
@@ -11,5 +12,6 @@ __all__ = [
11
12
  from dataeval.utils.data.selections._classfilter import ClassFilter
12
13
  from dataeval.utils.data.selections._indices import Indices
13
14
  from dataeval.utils.data.selections._limit import Limit
15
+ from dataeval.utils.data.selections._prioritize import Prioritize
14
16
  from dataeval.utils.data.selections._reverse import Reverse
15
17
  from dataeval.utils.data.selections._shuffle import Shuffle
@@ -10,7 +10,7 @@ from dataeval.typing import Array, ImageClassificationDatum
10
10
  from dataeval.utils._array import as_numpy
11
11
  from dataeval.utils.data._selection import Select, Selection, SelectionStage
12
12
 
13
- TImageClassificationDatum = TypeVar("TImageClassificationDatum", bound=ImageClassificationDatum, covariant=True)
13
+ TImageClassificationDatum = TypeVar("TImageClassificationDatum", bound=ImageClassificationDatum)
14
14
 
15
15
 
16
16
  class ClassFilter(Selection[TImageClassificationDatum]):
@@ -0,0 +1,296 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ import logging
6
+ import warnings
7
+ from abc import ABC, abstractmethod
8
+ from typing import Any, Literal, overload
9
+
10
+ import numpy as np
11
+ import torch
12
+ from numpy.typing import NDArray
13
+ from sklearn.cluster import KMeans
14
+ from sklearn.metrics import pairwise_distances
15
+
16
+ from dataeval.config import EPSILON, DeviceLike, get_seed
17
+ from dataeval.utils.data import Embeddings, Select
18
+ from dataeval.utils.data._selection import Selection, SelectionStage
19
+
20
+ _logger = logging.getLogger(__name__)
21
+
22
+
23
+ class _Clusters:
24
+ __slots__ = ["labels", "cluster_centers", "unique_labels"]
25
+
26
+ labels: NDArray[np.intp]
27
+ cluster_centers: NDArray[np.float64]
28
+ unique_labels: NDArray[np.intp]
29
+
30
+ def __init__(self, labels: NDArray[np.intp], cluster_centers: NDArray[np.float64]) -> None:
31
+ self.labels = labels
32
+ self.cluster_centers = cluster_centers
33
+ self.unique_labels = np.unique(labels)
34
+
35
+ def _dist2center(self, X: NDArray[np.float64]) -> NDArray[np.float64]:
36
+ dist = np.zeros(self.labels.shape)
37
+ for lab in self.unique_labels:
38
+ dist[self.labels == lab] = np.linalg.norm(X[self.labels == lab, :] - self.cluster_centers[lab, :], axis=1)
39
+ return dist
40
+
41
+ def _complexity(self, X: NDArray[np.float64]) -> NDArray[np.float64]:
42
+ num_clst_intra = int(np.maximum(np.minimum(int(self.unique_labels.shape[0] / 5), 20), 1))
43
+ d_intra = np.zeros(self.unique_labels.shape)
44
+ d_inter = np.zeros(self.unique_labels.shape)
45
+ for cdx, lab in enumerate(self.unique_labels):
46
+ d_intra[cdx] = np.mean(np.linalg.norm(X[self.labels == lab, :] - self.cluster_centers[cdx, :], axis=1))
47
+ d_inter[cdx] = np.mean(
48
+ np.linalg.norm(self.cluster_centers - self.cluster_centers[cdx, :], axis=1)[:num_clst_intra]
49
+ )
50
+ cj = d_intra * d_inter
51
+ tau = 0.1
52
+ exp = np.exp(cj / tau)
53
+ prob: NDArray[np.float64] = exp / np.sum(exp)
54
+ return prob
55
+
56
+ def _sort_by_weights(self, X: NDArray[np.float64]) -> NDArray[np.intp]:
57
+ pr = self._complexity(X)
58
+ d2c = self._dist2center(X)
59
+ inds_per_clst: list[NDArray[np.intp]] = []
60
+ for lab in zip(self.unique_labels):
61
+ inds = np.nonzero(self.labels == lab)[0]
62
+ # 'hardest' first
63
+ srt_inds = np.argsort(d2c[inds])[::-1]
64
+ inds_per_clst.append(inds[srt_inds])
65
+ glob_inds: list[NDArray[np.intp]] = []
66
+ while not bool(np.all([arr.size == 0 for arr in inds_per_clst])):
67
+ clst_ind = np.random.choice(self.unique_labels, 1, p=pr)[0]
68
+ if inds_per_clst[clst_ind].size > 0:
69
+ glob_inds.append(inds_per_clst[clst_ind][0])
70
+ else:
71
+ continue
72
+ inds_per_clst[clst_ind] = inds_per_clst[clst_ind][1:]
73
+ # sorted hardest first; reverse for consistency
74
+ return np.array(glob_inds[::-1])
75
+
76
+
77
+ class _Sorter(ABC):
78
+ @abstractmethod
79
+ def _sort(self, embeddings: NDArray[Any], reference: NDArray[Any] | None = None) -> NDArray[np.intp]: ...
80
+
81
+
82
+ class _KNNSorter(_Sorter):
83
+ def __init__(self, samples: int, k: int | None) -> None:
84
+ if k is None or k <= 0:
85
+ k = int(np.sqrt(samples))
86
+ _logger._log(logging.INFO, f"Setting k to default value of {k}", {"k": k, "samples": samples})
87
+ elif k >= samples:
88
+ raise ValueError(f"k={k} should be less than dataset size ({samples})")
89
+ elif k >= samples / 10 and k > np.sqrt(samples):
90
+ warnings.warn(
91
+ f"Variable k={k} is large with respect to dataset size but valid; "
92
+ + f"a nominal recommendation is k={int(np.sqrt(samples))}"
93
+ )
94
+ self._k = k
95
+
96
+ def _sort(self, embeddings: NDArray[Any], reference: NDArray[Any] | None = None) -> NDArray[np.intp]:
97
+ if reference is None:
98
+ dists = pairwise_distances(embeddings, embeddings)
99
+ np.fill_diagonal(dists, np.inf)
100
+ else:
101
+ dists = pairwise_distances(embeddings, reference)
102
+ inds = np.argsort(np.sort(dists, axis=1)[:, self._k])
103
+ return inds
104
+
105
+
106
+ class _KMeansSorter(_Sorter):
107
+ def __init__(self, samples: int, c: int | None, n_init: int | Literal["auto", "warn"] = "auto") -> None:
108
+ if c is None or c <= 0:
109
+ c = int(np.sqrt(samples))
110
+ _logger._log(logging.INFO, f"Setting the value of num_clusters to a default value of {c}", {})
111
+ if c >= samples:
112
+ raise ValueError(f"c={c} should be less than dataset size ({samples})")
113
+ self._c = c
114
+ self._n_init = n_init
115
+
116
+ def _get_clusters(self, embeddings: NDArray[Any]) -> _Clusters:
117
+ clst = KMeans(n_clusters=self._c, init="k-means++", n_init=self._n_init, random_state=get_seed()) # type: ignore - n_init allows int but is typed as str
118
+ clst.fit(embeddings)
119
+ if clst.labels_ is None or clst.cluster_centers_ is None:
120
+ raise ValueError("Clustering failed to produce labels or cluster centers")
121
+ return _Clusters(clst.labels_, clst.cluster_centers_)
122
+
123
+
124
+ class _KMeansDistanceSorter(_KMeansSorter):
125
+ def _sort(self, embeddings: NDArray[Any], reference: NDArray[Any] | None = None) -> NDArray[np.intp]:
126
+ clst = self._get_clusters(embeddings if reference is None else reference)
127
+ inds = np.argsort(clst._dist2center(embeddings))
128
+ return inds
129
+
130
+
131
+ class _KMeansComplexitySorter(_KMeansSorter):
132
+ def _sort(self, embeddings: NDArray[Any], reference: NDArray[Any] | None = None) -> NDArray[np.intp]:
133
+ clst = self._get_clusters(embeddings if reference is None else reference)
134
+ inds = clst._sort_by_weights(embeddings)
135
+ return inds
136
+
137
+
138
+ class Prioritize(Selection[Any]):
139
+ """
140
+ Prioritizes the dataset by sort order in the embedding space.
141
+
142
+ Parameters
143
+ ----------
144
+ model : torch.nn.Module
145
+ Model to use for encoding images
146
+ batch_size : int
147
+ Batch size to use when encoding images
148
+ device : DeviceLike or None
149
+ Device to use for encoding images
150
+ method : Literal["knn", "kmeans_distance", "kmeans_complexity"]
151
+ Method to use for prioritization
152
+ k : int | None, default None
153
+ Number of nearest neighbors to use for prioritization (knn only)
154
+ c : int | None, default None
155
+ Number of clusters to use for prioritization (kmeans only)
156
+ """
157
+
158
+ stage = SelectionStage.ORDER
159
+
160
+ @overload
161
+ def __init__(
162
+ self,
163
+ model: torch.nn.Module,
164
+ batch_size: int,
165
+ device: DeviceLike | None,
166
+ method: Literal["knn"],
167
+ *,
168
+ k: int | None = None,
169
+ ) -> None: ...
170
+
171
+ @overload
172
+ def __init__(
173
+ self,
174
+ model: torch.nn.Module,
175
+ batch_size: int,
176
+ device: DeviceLike | None,
177
+ method: Literal["kmeans_distance", "kmeans_complexity"],
178
+ *,
179
+ c: int | None = None,
180
+ ) -> None: ...
181
+
182
+ def __init__(
183
+ self,
184
+ model: torch.nn.Module,
185
+ batch_size: int,
186
+ device: DeviceLike | None,
187
+ method: Literal["knn", "kmeans_distance", "kmeans_complexity"],
188
+ *,
189
+ k: int | None = None,
190
+ c: int | None = None,
191
+ ) -> None:
192
+ if method not in ("knn", "kmeans_distance", "kmeans_complexity"):
193
+ raise ValueError(f"Invalid prioritization method: {method}")
194
+ self._model = model
195
+ self._batch_size = batch_size
196
+ self._device = device
197
+ self._method = method
198
+ self._embeddings: Embeddings | None = None
199
+ self._reference: Embeddings | None = None
200
+ self._k = k
201
+ self._c = c
202
+
203
+ @overload
204
+ @classmethod
205
+ def using(
206
+ cls,
207
+ method: Literal["knn"],
208
+ *,
209
+ k: int | None = None,
210
+ embeddings: Embeddings | None = None,
211
+ reference: Embeddings | None = None,
212
+ ) -> Prioritize: ...
213
+
214
+ @overload
215
+ @classmethod
216
+ def using(
217
+ cls,
218
+ method: Literal["kmeans_distance", "kmeans_complexity"],
219
+ *,
220
+ c: int | None = None,
221
+ embeddings: Embeddings | None = None,
222
+ reference: Embeddings | None = None,
223
+ ) -> Prioritize: ...
224
+
225
+ @classmethod
226
+ def using(
227
+ cls,
228
+ method: Literal["knn", "kmeans_distance", "kmeans_complexity"],
229
+ *,
230
+ k: int | None = None,
231
+ c: int | None = None,
232
+ embeddings: Embeddings | None = None,
233
+ reference: Embeddings | None = None,
234
+ ) -> Prioritize:
235
+ """
236
+ Prioritizes the dataset by sort order in the embedding space using existing
237
+ embeddings and/or reference dataset embeddings.
238
+
239
+ Parameters
240
+ ----------
241
+ method : Literal["knn", "kmeans_distance", "kmeans_complexity"]
242
+ Method to use for prioritization
243
+ embeddings : Embeddings or None, default None
244
+ Embeddings to use for prioritization
245
+ reference : Embeddings or None, default None
246
+ Reference embeddings to prioritize relative to
247
+ k : int or None, default None
248
+ Number of nearest neighbors to use for prioritization (knn only)
249
+ c : int or None, default None
250
+ Number of clusters to use for prioritization (kmeans, cluster only)
251
+
252
+ Notes
253
+ -----
254
+ At least one of `embeddings` or `reference` must be provided.
255
+ """
256
+ emb_params: Embeddings | None = embeddings if embeddings is not None else reference
257
+ if emb_params is None:
258
+ raise ValueError("Must provide at least embeddings or reference embeddings.")
259
+ prioritize = Prioritize(emb_params._model, emb_params.batch_size, emb_params.device, method)
260
+ prioritize._k = k
261
+ prioritize._c = c
262
+ prioritize._embeddings = embeddings
263
+ prioritize._reference = reference
264
+ return prioritize
265
+
266
+ def _get_sorter(self, samples: int) -> _Sorter:
267
+ if self._method == "knn":
268
+ return _KNNSorter(samples, self._k)
269
+ elif self._method == "kmeans_distance":
270
+ return _KMeansDistanceSorter(samples, self._c)
271
+ else: # self._method == "kmeans_complexity"
272
+ return _KMeansComplexitySorter(samples, self._c)
273
+
274
+ def _to_normalized_ndarray(self, embeddings: Embeddings, selection: list[int] | None = None) -> NDArray[Any]:
275
+ emb: NDArray[Any] = embeddings.to_tensor(selection).cpu().numpy()
276
+ emb /= max(np.max(np.linalg.norm(emb, axis=1)), EPSILON)
277
+ return emb
278
+
279
+ def __call__(self, dataset: Select[Any]) -> None:
280
+ # Initialize sorter
281
+ self._sorter = self._get_sorter(len(dataset._selection))
282
+ # Extract and normalize embeddings
283
+ embeddings = (
284
+ Embeddings(dataset, batch_size=self._batch_size, model=self._model, device=self._device)
285
+ if self._embeddings is None
286
+ else self._embeddings
287
+ )
288
+ if len(dataset._selection) != len(embeddings):
289
+ raise ValueError(
290
+ "Size of embeddings do not match the size of the selection: "
291
+ + f"embeddings={len(embeddings)}, selection={len(dataset._selection)}"
292
+ )
293
+ emb = self._to_normalized_ndarray(embeddings, dataset._selection)
294
+ ref = None if self._reference is None else self._to_normalized_ndarray(self._reference)
295
+ # Sort indices
296
+ dataset._selection = self._sorter._sort(emb, ref).tolist()
@@ -2,10 +2,14 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- from typing import Any
5
+ from typing import Any, Sequence
6
6
 
7
7
  import numpy as np
8
+ from numpy.random import BitGenerator, Generator, SeedSequence
9
+ from numpy.typing import NDArray
8
10
 
11
+ from dataeval.typing import Array, ArrayLike
12
+ from dataeval.utils._array import as_numpy
9
13
  from dataeval.utils.data._selection import Select, Selection, SelectionStage
10
14
 
11
15
 
@@ -15,14 +19,19 @@ class Shuffle(Selection[Any]):
15
19
 
16
20
  Parameters
17
21
  ----------
18
- seed
22
+ seed : int, ArrayLike, SeedSequence, BitGenerator, Generator or None, default None
19
23
  Seed for the random number generator.
24
+
25
+ See Also
26
+ --------
27
+ `NumPy Random Generator <https://numpy.org/doc/stable/reference/random/generator.html>`_
20
28
  """
21
29
 
30
+ seed: int | NDArray[Any] | SeedSequence | BitGenerator | Generator | None
22
31
  stage = SelectionStage.ORDER
23
32
 
24
- def __init__(self, seed: int):
25
- self.seed = seed
33
+ def __init__(self, seed: int | ArrayLike | SeedSequence | BitGenerator | Generator | None = None):
34
+ self.seed = as_numpy(seed) if isinstance(seed, (Sequence, Array)) else seed
26
35
 
27
36
  def __call__(self, dataset: Select[Any]) -> None:
28
37
  rng = np.random.default_rng(self.seed)
@@ -214,7 +214,7 @@ def flatten(
214
214
  output[k] = cv
215
215
  else:
216
216
  dropped_inner.setdefault(k, set()).add(DropReason.INCONSISTENT_KEY)
217
- elif not isinstance(cv, list):
217
+ else:
218
218
  output[k] = cv if not size else [cv] * size
219
219
 
220
220
  if fully_qualified:
@@ -16,6 +16,8 @@ from typing import TypeVar
16
16
  import numpy as np
17
17
  import torch
18
18
 
19
+ from dataeval.config import EPSILON
20
+
19
21
  TGMMData = TypeVar("TGMMData")
20
22
 
21
23
 
@@ -74,8 +76,7 @@ def gmm_params(z: torch.Tensor, gamma: torch.Tensor) -> GaussianMixtureModelPara
74
76
 
75
77
  # cholesky decomposition of covariance and determinant derivation
76
78
  D = cov.shape[1]
77
- eps = 1e-6
78
- L = torch.linalg.cholesky(cov + torch.eye(D) * eps) # K x D x D
79
+ L = torch.linalg.cholesky(cov + torch.eye(D) * EPSILON) # K x D x D
79
80
  log_det_cov = 2.0 * torch.sum(torch.log(torch.diagonal(L, dim1=-2, dim2=-1)), 1) # K
80
81
 
81
82
  return GaussianMixtureModelParams(phi, mu, cov, L, log_det_cov)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: dataeval
3
- Version: 0.82.1
3
+ Version: 0.84.0
4
4
  Summary: DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks
5
5
  Home-page: https://dataeval.ai/
6
6
  License: MIT
@@ -50,9 +50,9 @@ and reference material, please visit our documentation on
50
50
 
51
51
  <!-- start tagline -->
52
52
 
53
- DataEval curates datasets to train and test performant, robust, unbiased and
54
- reliable AI models and monitors for data shifts that impact performance of
55
- deployed models.
53
+ DataEval analyzes datasets and models to give users the ability to train and
54
+ test performant, unbiased, and reliable AI models and monitor data for
55
+ impactful shifts to deployed models.
56
56
 
57
57
  <!-- end tagline -->
58
58