dataeval 0.82.1__py3-none-any.whl → 0.84.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. dataeval/__init__.py +7 -2
  2. dataeval/config.py +13 -3
  3. dataeval/metadata/__init__.py +2 -2
  4. dataeval/metadata/_ood.py +144 -27
  5. dataeval/metrics/bias/__init__.py +11 -1
  6. dataeval/metrics/bias/_balance.py +3 -3
  7. dataeval/metrics/bias/_completeness.py +130 -0
  8. dataeval/metrics/estimators/_ber.py +2 -1
  9. dataeval/metrics/stats/_base.py +31 -36
  10. dataeval/metrics/stats/_dimensionstats.py +2 -2
  11. dataeval/metrics/stats/_hashstats.py +2 -2
  12. dataeval/metrics/stats/_imagestats.py +4 -4
  13. dataeval/metrics/stats/_labelstats.py +4 -45
  14. dataeval/metrics/stats/_pixelstats.py +2 -2
  15. dataeval/metrics/stats/_visualstats.py +2 -2
  16. dataeval/outputs/__init__.py +4 -2
  17. dataeval/outputs/_bias.py +31 -22
  18. dataeval/outputs/_metadata.py +7 -0
  19. dataeval/outputs/_stats.py +2 -3
  20. dataeval/typing.py +43 -12
  21. dataeval/utils/_array.py +26 -1
  22. dataeval/utils/_mst.py +1 -2
  23. dataeval/utils/data/_dataset.py +2 -0
  24. dataeval/utils/data/_embeddings.py +115 -32
  25. dataeval/utils/data/_images.py +38 -15
  26. dataeval/utils/data/_selection.py +7 -8
  27. dataeval/utils/data/_split.py +76 -129
  28. dataeval/utils/data/datasets/_base.py +4 -2
  29. dataeval/utils/data/datasets/_cifar10.py +17 -9
  30. dataeval/utils/data/datasets/_milco.py +18 -12
  31. dataeval/utils/data/datasets/_mnist.py +24 -8
  32. dataeval/utils/data/datasets/_ships.py +18 -8
  33. dataeval/utils/data/datasets/_types.py +1 -5
  34. dataeval/utils/data/datasets/_voc.py +47 -24
  35. dataeval/utils/data/selections/__init__.py +2 -0
  36. dataeval/utils/data/selections/_classfilter.py +1 -1
  37. dataeval/utils/data/selections/_prioritize.py +296 -0
  38. dataeval/utils/data/selections/_shuffle.py +13 -4
  39. dataeval/utils/metadata.py +1 -1
  40. dataeval/utils/torch/_gmm.py +3 -2
  41. {dataeval-0.82.1.dist-info → dataeval-0.84.0.dist-info}/METADATA +4 -4
  42. {dataeval-0.82.1.dist-info → dataeval-0.84.0.dist-info}/RECORD +44 -43
  43. dataeval/detectors/ood/metadata_ood_mi.py +0 -91
  44. {dataeval-0.82.1.dist-info → dataeval-0.84.0.dist-info}/LICENSE.txt +0 -0
  45. {dataeval-0.82.1.dist-info → dataeval-0.84.0.dist-info}/WHEEL +0 -0
@@ -47,6 +47,8 @@ def _validate_data(
47
47
  or not len(bboxes[0][0]) == 4
48
48
  ):
49
49
  raise TypeError("Boxes must be a sequence of sequences of (x0, y0, x1, y1) for object detection.")
50
+ else:
51
+ raise ValueError(f"Unknown datum type '{datum_type}'. Must be 'ic' or 'od'.")
50
52
 
51
53
 
52
54
  def _find_max(arr: ArrayLike) -> Any:
@@ -3,14 +3,14 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  import math
6
- from typing import Any, Iterator, Sequence
6
+ from typing import Any, Iterator, Sequence, cast
7
7
 
8
8
  import torch
9
9
  from torch.utils.data import DataLoader, Subset
10
10
  from tqdm import tqdm
11
11
 
12
12
  from dataeval.config import DeviceLike, get_device
13
- from dataeval.typing import Array, Dataset
13
+ from dataeval.typing import Array, Dataset, Transform
14
14
  from dataeval.utils.torch.models import SupportsEncode
15
15
 
16
16
 
@@ -26,11 +26,15 @@ class Embeddings:
26
26
  Dataset to access original images from.
27
27
  batch_size : int
28
28
  Batch size to use when encoding images.
29
+ transforms : Transform or Sequence[Transform] or None, default None
30
+ Transforms to apply to images before encoding.
29
31
  model : torch.nn.Module or None, default None
30
32
  Model to use for encoding images.
31
33
  device : DeviceLike or None, default None
32
34
  The hardware device to use if specified, otherwise uses the DataEval
33
35
  default or torch default.
36
+ cache : bool, default False
37
+ Whether to cache the embeddings in memory.
34
38
  verbose : bool, default False
35
39
  Whether to print progress bar when encoding images.
36
40
  """
@@ -43,61 +47,140 @@ class Embeddings:
43
47
  self,
44
48
  dataset: Dataset[tuple[Array, Any, Any]],
45
49
  batch_size: int,
50
+ transforms: Transform[torch.Tensor] | Sequence[Transform[torch.Tensor]] | None = None,
46
51
  model: torch.nn.Module | None = None,
47
52
  device: DeviceLike | None = None,
53
+ cache: bool = False,
48
54
  verbose: bool = False,
49
55
  ) -> None:
50
56
  self.device = get_device(device)
51
- self.batch_size = batch_size
57
+ self.cache = cache
58
+ self.batch_size = batch_size if batch_size > 0 else 1
52
59
  self.verbose = verbose
53
60
 
54
61
  self._dataset = dataset
62
+ self._length = len(dataset)
55
63
  model = torch.nn.Flatten() if model is None else model
64
+ self._transforms = [transforms] if isinstance(transforms, Transform) else transforms
56
65
  self._model = model.to(self.device).eval()
57
66
  self._encoder = model.encode if isinstance(model, SupportsEncode) else model
58
67
  self._collate_fn = lambda datum: [torch.as_tensor(i) for i, _, _ in datum]
68
+ self._cached_idx = set()
69
+ self._embeddings: torch.Tensor = torch.empty(())
70
+ self._shallow: bool = False
59
71
 
60
- def to_tensor(self) -> torch.Tensor:
72
+ def to_tensor(self, indices: Sequence[int] | None = None) -> torch.Tensor:
61
73
  """
62
- Converts entire dataset to embeddings.
74
+ Converts dataset to embeddings.
63
75
 
64
- Warning
65
- -------
66
- Will process the entire dataset in batches and return
67
- embeddings as a single Tensor in memory.
76
+ Parameters
77
+ ----------
78
+ indices : Sequence[int] or None, default None
79
+ The indices to convert to embeddings
68
80
 
69
81
  Returns
70
82
  -------
71
83
  torch.Tensor
84
+
85
+ Warning
86
+ -------
87
+ Processing large quantities of data can be resource intensive.
88
+ """
89
+ if indices is not None:
90
+ return torch.vstack(list(self._batch(indices))).to(self.device)
91
+ else:
92
+ return self[:]
93
+
94
+ @classmethod
95
+ def from_array(cls, array: Array, device: DeviceLike | None = None) -> Embeddings:
72
96
  """
73
- return self[:]
97
+ Instantiates a shallow Embeddings object using an array.
74
98
 
75
- # Reduce overhead cost by not tracking tensor gradients
76
- @torch.no_grad
99
+ Parameters
100
+ ----------
101
+ array : Array
102
+ The array to convert to embeddings.
103
+ device : DeviceLike or None, default None
104
+ The hardware device to use if specified, otherwise uses the DataEval
105
+ default or torch default.
106
+
107
+ Returns
108
+ -------
109
+ Embeddings
110
+
111
+ Example
112
+ -------
113
+ >>> import numpy as np
114
+ >>> from dataeval.utils.data._embeddings import Embeddings
115
+ >>> array = np.random.randn(100, 3, 224, 224)
116
+ >>> embeddings = Embeddings.from_array(array)
117
+ >>> print(embeddings.to_tensor().shape)
118
+ torch.Size([100, 3, 224, 224])
119
+ """
120
+ embeddings = Embeddings([], 0, None, None, device, True, False)
121
+ embeddings._length = len(array)
122
+ embeddings._cached_idx = set(range(len(array)))
123
+ embeddings._embeddings = torch.as_tensor(array).to(get_device(device))
124
+ embeddings._shallow = True
125
+ return embeddings
126
+
127
+ def _encode(self, images: list[torch.Tensor]) -> torch.Tensor:
128
+ if self._transforms:
129
+ images = [transform(image) for transform in self._transforms for image in images]
130
+ return self._encoder(torch.stack(images).to(self.device))
131
+
132
+ @torch.no_grad() # Reduce overhead cost by not tracking tensor gradients
77
133
  def _batch(self, indices: Sequence[int]) -> Iterator[torch.Tensor]:
78
- # manual batching
79
- dataloader = DataLoader(Subset(self._dataset, indices), batch_size=self.batch_size, collate_fn=self._collate_fn) # type: ignore
80
- for i, images in (
81
- tqdm(enumerate(dataloader), total=math.ceil(len(indices) / self.batch_size), desc="Batch processing")
82
- if self.verbose
83
- else enumerate(dataloader)
84
- ):
85
- embeddings = self._encoder(torch.stack(images).to(self.device))
86
- yield embeddings
87
-
88
- def __getitem__(self, key: int | slice | list[int], /) -> torch.Tensor:
89
- if isinstance(key, list):
90
- return torch.vstack(list(self._batch(key))).to(self.device)
91
- if isinstance(key, slice):
92
- return torch.vstack(list(self._batch(range(len(self._dataset))[key]))).to(self.device)
93
- elif isinstance(key, int):
94
- return self._encoder(torch.as_tensor(self._dataset[key][0]).to(self.device))
95
- raise TypeError("Invalid argument type.")
134
+ dataset = cast(torch.utils.data.Dataset[tuple[Array, Any, Any]], self._dataset)
135
+ total_batches = math.ceil(len(indices) / self.batch_size)
136
+
137
+ # If not caching, process all indices normally
138
+ if not self.cache:
139
+ for images in tqdm(
140
+ DataLoader(Subset(dataset, indices), self.batch_size, collate_fn=self._collate_fn),
141
+ total=total_batches,
142
+ desc="Batch embedding",
143
+ disable=not self.verbose,
144
+ ):
145
+ yield self._encode(images)
146
+ return
147
+
148
+ # If caching, process each batch of indices at a time, preserving original order
149
+ for i in tqdm(range(0, len(indices), self.batch_size), desc="Batch embedding", disable=not self.verbose):
150
+ batch = indices[i : i + self.batch_size]
151
+ uncached = [idx for idx in batch if idx not in self._cached_idx]
152
+
153
+ if uncached:
154
+ # Process uncached indices as as single batch
155
+ for images in DataLoader(Subset(dataset, uncached), len(uncached), collate_fn=self._collate_fn):
156
+ embeddings = self._encode(images)
157
+
158
+ if not self._embeddings.shape:
159
+ full_shape = (len(self._dataset), *embeddings.shape[1:])
160
+ self._embeddings = torch.empty(full_shape, dtype=embeddings.dtype, device=self.device)
161
+
162
+ self._embeddings[uncached] = embeddings
163
+ self._cached_idx.update(uncached)
164
+
165
+ yield self._embeddings[batch]
166
+
167
+ def __getitem__(self, key: int | slice, /) -> torch.Tensor:
168
+ if not isinstance(key, slice) and not hasattr(key, "__int__"):
169
+ raise TypeError("Invalid argument type.")
170
+
171
+ if self._shallow:
172
+ if not self._embeddings.shape:
173
+ raise ValueError("Embeddings not initialized.")
174
+ return self._embeddings[key]
175
+
176
+ indices = list(range(len(self._dataset))[key]) if isinstance(key, slice) else [int(key)]
177
+ result = torch.vstack(list(self._batch(indices))).to(self.device)
178
+ return result.squeeze(0) if len(indices) == 1 else result
96
179
 
97
180
  def __iter__(self) -> Iterator[torch.Tensor]:
98
181
  # process in batches while yielding individual embeddings
99
- for batch in self._batch(range(len(self._dataset))):
182
+ for batch in self._batch(range(self._length)):
100
183
  yield from batch
101
184
 
102
185
  def __len__(self) -> int:
103
- return len(self._dataset)
186
+ return self._length
@@ -2,11 +2,15 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- from typing import Any, Generic, Iterator, Sequence, TypeVar, cast, overload
5
+ from typing import TYPE_CHECKING, Any, Generic, Iterator, Sequence, TypeVar, cast, overload
6
6
 
7
- from dataeval.typing import Dataset
7
+ from dataeval.typing import Array, Dataset
8
+ from dataeval.utils._array import as_numpy, channels_first_to_last
8
9
 
9
- T = TypeVar("T")
10
+ if TYPE_CHECKING:
11
+ from matplotlib.figure import Figure
12
+
13
+ T = TypeVar("T", bound=Array)
10
14
 
11
15
 
12
16
  class Images(Generic[T]):
@@ -21,7 +25,10 @@ class Images(Generic[T]):
21
25
  Dataset to access images from.
22
26
  """
23
27
 
24
- def __init__(self, dataset: Dataset[tuple[T, Any, Any] | T]) -> None:
28
+ def __init__(
29
+ self,
30
+ dataset: Dataset[tuple[T, Any, Any] | T],
31
+ ) -> None:
25
32
  self._is_tuple_datum = isinstance(dataset[0], tuple)
26
33
  self._dataset = dataset
27
34
 
@@ -40,25 +47,41 @@ class Images(Generic[T]):
40
47
  """
41
48
  return self[:]
42
49
 
50
+ def plot(
51
+ self,
52
+ indices: Sequence[int],
53
+ images_per_row: int = 3,
54
+ figsize: tuple[int, int] = (10, 10),
55
+ ) -> Figure:
56
+ import matplotlib.pyplot as plt
57
+
58
+ num_images = len(indices)
59
+ num_rows = (num_images + images_per_row - 1) // images_per_row
60
+ fig, axes = plt.subplots(num_rows, images_per_row, figsize=figsize)
61
+ for i, ax in enumerate(axes.flatten()):
62
+ image = channels_first_to_last(as_numpy(self[i]))
63
+ ax.imshow(image)
64
+ ax.axis("off")
65
+ plt.tight_layout()
66
+ return fig
67
+
43
68
  @overload
44
69
  def __getitem__(self, key: int, /) -> T: ...
45
70
  @overload
46
71
  def __getitem__(self, key: slice, /) -> Sequence[T]: ...
47
72
 
48
73
  def __getitem__(self, key: int | slice, /) -> Sequence[T] | T:
74
+ if isinstance(key, slice):
75
+ return [self._get_image(k) for k in range(len(self._dataset))[key]]
76
+ elif hasattr(key, "__int__"):
77
+ return self._get_image(int(key))
78
+ raise TypeError(f"Key must be integers or slices, not {type(key)}")
79
+
80
+ def _get_image(self, index: int) -> T:
49
81
  if self._is_tuple_datum:
50
- dataset = cast(Dataset[tuple[T, Any, Any]], self._dataset)
51
- if isinstance(key, slice):
52
- return [dataset[k][0] for k in range(len(self._dataset))[key]]
53
- elif isinstance(key, int):
54
- return dataset[key][0]
82
+ return cast(Dataset[tuple[T, Any, Any]], self._dataset)[index][0]
55
83
  else:
56
- dataset = cast(Dataset[T], self._dataset)
57
- if isinstance(key, slice):
58
- return [dataset[k] for k in range(len(self._dataset))[key]]
59
- elif isinstance(key, int):
60
- return dataset[key]
61
- raise TypeError(f"Key must be integers or slices, not {type(key)}")
84
+ return cast(Dataset[T], self._dataset)[index]
62
85
 
63
86
  def __iter__(self) -> Iterator[T]:
64
87
  for i in range(len(self._dataset)):
@@ -7,7 +7,7 @@ from typing import Generic, Iterator, Sequence, TypeVar
7
7
 
8
8
  from dataeval.typing import AnnotatedDataset, DatasetMetadata
9
9
 
10
- _TDatum = TypeVar("_TDatum", covariant=True)
10
+ _TDatum = TypeVar("_TDatum")
11
11
 
12
12
 
13
13
  class SelectionStage(IntEnum):
@@ -67,13 +67,13 @@ class Select(AnnotatedDataset[_TDatum]):
67
67
  def __init__(
68
68
  self,
69
69
  dataset: AnnotatedDataset[_TDatum],
70
- selections: Selection[_TDatum] | list[Selection[_TDatum]] | None = None,
70
+ selections: Selection[_TDatum] | Sequence[Selection[_TDatum]] | None = None,
71
71
  ) -> None:
72
72
  self.__dict__.update(dataset.__dict__)
73
73
  self._dataset = dataset
74
74
  self._size_limit = len(dataset)
75
75
  self._selection = list(range(self._size_limit))
76
- self._selections = self._sort_selections(selections)
76
+ self._selections = self._sort(selections)
77
77
 
78
78
  # Ensure metadata is populated correctly as DatasetMetadata TypedDict
79
79
  _metadata = getattr(dataset, "metadata", {})
@@ -81,8 +81,7 @@ class Select(AnnotatedDataset[_TDatum]):
81
81
  _metadata["id"] = dataset.__class__.__name__
82
82
  self._metadata = DatasetMetadata(**_metadata)
83
83
 
84
- if self._selections:
85
- self._apply_selections()
84
+ self._select()
86
85
 
87
86
  @property
88
87
  def metadata(self) -> DatasetMetadata:
@@ -92,10 +91,10 @@ class Select(AnnotatedDataset[_TDatum]):
92
91
  nt = "\n "
93
92
  title = f"{self.__class__.__name__} Dataset"
94
93
  sep = "-" * len(title)
95
- selections = f"Selections: [{', '.join([str(s) for s in self._sort_selections(self._selections)])}]"
94
+ selections = f"Selections: [{', '.join([str(s) for s in self._selections])}]"
96
95
  return f"{title}\n{sep}{nt}{selections}{nt}Selected Size: {len(self)}\n\n{self._dataset}"
97
96
 
98
- def _sort_selections(self, selections: Selection[_TDatum] | Sequence[Selection[_TDatum]] | None) -> list[Selection]:
97
+ def _sort(self, selections: Selection[_TDatum] | Sequence[Selection[_TDatum]] | None) -> list[Selection]:
99
98
  if not selections:
100
99
  return []
101
100
 
@@ -106,7 +105,7 @@ class Select(AnnotatedDataset[_TDatum]):
106
105
  selection_list = [selection for category in sorted(grouped) for selection in grouped[category]]
107
106
  return selection_list
108
107
 
109
- def _apply_selections(self) -> None:
108
+ def _select(self) -> None:
110
109
  for selection in self._selections:
111
110
  selection(self)
112
111
  self._selection = self._selection[: self._size_limit]