dataeval 0.81.0__py3-none-any.whl → 0.82.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. dataeval/__init__.py +1 -1
  2. dataeval/config.py +68 -11
  3. dataeval/detectors/drift/__init__.py +2 -2
  4. dataeval/detectors/drift/_base.py +8 -64
  5. dataeval/detectors/drift/_mmd.py +12 -38
  6. dataeval/detectors/drift/_torch.py +7 -7
  7. dataeval/detectors/drift/_uncertainty.py +6 -5
  8. dataeval/detectors/drift/updates.py +20 -3
  9. dataeval/detectors/linters/__init__.py +3 -2
  10. dataeval/detectors/linters/duplicates.py +14 -46
  11. dataeval/detectors/linters/outliers.py +25 -159
  12. dataeval/detectors/ood/__init__.py +1 -1
  13. dataeval/detectors/ood/ae.py +6 -5
  14. dataeval/detectors/ood/base.py +2 -2
  15. dataeval/detectors/ood/metadata_ood_mi.py +4 -6
  16. dataeval/detectors/ood/mixin.py +3 -4
  17. dataeval/detectors/ood/vae.py +3 -2
  18. dataeval/metadata/__init__.py +2 -1
  19. dataeval/metadata/_distance.py +134 -0
  20. dataeval/metadata/_ood.py +30 -49
  21. dataeval/metadata/_utils.py +44 -0
  22. dataeval/metrics/bias/__init__.py +5 -4
  23. dataeval/metrics/bias/_balance.py +17 -149
  24. dataeval/metrics/bias/_coverage.py +4 -106
  25. dataeval/metrics/bias/_diversity.py +12 -107
  26. dataeval/metrics/bias/_parity.py +7 -71
  27. dataeval/metrics/estimators/__init__.py +5 -4
  28. dataeval/metrics/estimators/_ber.py +2 -20
  29. dataeval/metrics/estimators/_clusterer.py +1 -61
  30. dataeval/metrics/estimators/_divergence.py +2 -19
  31. dataeval/metrics/estimators/_uap.py +2 -16
  32. dataeval/metrics/stats/__init__.py +15 -12
  33. dataeval/metrics/stats/_base.py +41 -128
  34. dataeval/metrics/stats/_boxratiostats.py +13 -13
  35. dataeval/metrics/stats/_dimensionstats.py +17 -58
  36. dataeval/metrics/stats/_hashstats.py +19 -35
  37. dataeval/metrics/stats/_imagestats.py +94 -0
  38. dataeval/metrics/stats/_labelstats.py +42 -121
  39. dataeval/metrics/stats/_pixelstats.py +19 -51
  40. dataeval/metrics/stats/_visualstats.py +19 -51
  41. dataeval/outputs/__init__.py +57 -0
  42. dataeval/outputs/_base.py +182 -0
  43. dataeval/outputs/_bias.py +381 -0
  44. dataeval/outputs/_drift.py +83 -0
  45. dataeval/outputs/_estimators.py +114 -0
  46. dataeval/outputs/_linters.py +186 -0
  47. dataeval/outputs/_metadata.py +54 -0
  48. dataeval/{detectors/ood/output.py → outputs/_ood.py} +22 -22
  49. dataeval/outputs/_stats.py +393 -0
  50. dataeval/outputs/_utils.py +44 -0
  51. dataeval/outputs/_workflows.py +364 -0
  52. dataeval/typing.py +187 -7
  53. dataeval/utils/_method.py +1 -5
  54. dataeval/utils/_plot.py +2 -2
  55. dataeval/utils/data/__init__.py +5 -1
  56. dataeval/utils/data/_dataset.py +217 -0
  57. dataeval/utils/data/_embeddings.py +12 -14
  58. dataeval/utils/data/_images.py +30 -27
  59. dataeval/utils/data/_metadata.py +28 -11
  60. dataeval/utils/data/_selection.py +25 -22
  61. dataeval/utils/data/_split.py +5 -29
  62. dataeval/utils/data/_targets.py +14 -2
  63. dataeval/utils/data/datasets/_base.py +5 -5
  64. dataeval/utils/data/datasets/_cifar10.py +1 -1
  65. dataeval/utils/data/datasets/_milco.py +1 -1
  66. dataeval/utils/data/datasets/_mnist.py +1 -1
  67. dataeval/utils/data/datasets/_ships.py +1 -1
  68. dataeval/utils/data/{_types.py → datasets/_types.py} +10 -16
  69. dataeval/utils/data/datasets/_voc.py +1 -1
  70. dataeval/utils/data/selections/_classfilter.py +4 -5
  71. dataeval/utils/data/selections/_indices.py +2 -2
  72. dataeval/utils/data/selections/_limit.py +2 -2
  73. dataeval/utils/data/selections/_reverse.py +2 -2
  74. dataeval/utils/data/selections/_shuffle.py +2 -2
  75. dataeval/utils/torch/_internal.py +5 -5
  76. dataeval/utils/torch/trainer.py +8 -8
  77. dataeval/workflows/__init__.py +2 -1
  78. dataeval/workflows/sufficiency.py +6 -342
  79. {dataeval-0.81.0.dist-info → dataeval-0.82.1.dist-info}/METADATA +2 -2
  80. dataeval-0.82.1.dist-info/RECORD +105 -0
  81. dataeval/_output.py +0 -137
  82. dataeval/detectors/ood/metadata_ks_compare.py +0 -129
  83. dataeval/metrics/stats/_datasetstats.py +0 -198
  84. dataeval-0.81.0.dist-info/RECORD +0 -94
  85. {dataeval-0.81.0.dist-info → dataeval-0.82.1.dist-info}/LICENSE.txt +0 -0
  86. {dataeval-0.81.0.dist-info → dataeval-0.82.1.dist-info}/WHEEL +0 -0
@@ -0,0 +1,217 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ from typing import Any, Generic, Iterable, Literal, Sequence, TypeVar
6
+
7
+ from dataeval.typing import (
8
+ Array,
9
+ ArrayLike,
10
+ DatasetMetadata,
11
+ ImageClassificationDataset,
12
+ ObjectDetectionDataset,
13
+ )
14
+ from dataeval.utils._array import as_numpy
15
+
16
+
17
+ def _validate_data(
18
+ datum_type: Literal["ic", "od"],
19
+ images: Array | Sequence[Array],
20
+ labels: Sequence[int] | Sequence[Sequence[int]],
21
+ bboxes: Sequence[Sequence[Sequence[float]]] | None,
22
+ metadata: Sequence[dict[str, Any]] | None,
23
+ ) -> None:
24
+ # Validate inputs
25
+ dataset_len = len(images)
26
+
27
+ if not isinstance(images, (Sequence, Array)) or len(images[0].shape) != 3:
28
+ raise ValueError("Images must be a sequence or array of 3 dimensional arrays (H, W, C).")
29
+ if len(labels) != dataset_len:
30
+ raise ValueError(f"Number of labels ({len(labels)}) does not match number of images ({dataset_len}).")
31
+ if bboxes is not None and len(bboxes) != dataset_len:
32
+ raise ValueError(f"Number of bboxes ({len(bboxes)}) does not match number of images ({dataset_len}).")
33
+ if metadata is not None and len(metadata) != dataset_len:
34
+ raise ValueError(f"Number of metadata ({len(metadata)}) does not match number of images ({dataset_len}).")
35
+
36
+ if datum_type == "ic":
37
+ if not isinstance(labels, Sequence) or not isinstance(labels[0], int):
38
+ raise TypeError("Labels must be a sequence of integers for image classification.")
39
+ elif datum_type == "od":
40
+ if not isinstance(labels, Sequence) or not isinstance(labels[0], Sequence) or not isinstance(labels[0][0], int):
41
+ raise TypeError("Labels must be a sequence of sequences of integers for object detection.")
42
+ if (
43
+ bboxes is None
44
+ or not isinstance(bboxes, (Sequence, Array))
45
+ or not isinstance(bboxes[0], (Sequence, Array))
46
+ or not isinstance(bboxes[0][0], (Sequence, Array))
47
+ or not len(bboxes[0][0]) == 4
48
+ ):
49
+ raise TypeError("Boxes must be a sequence of sequences of (x0, y0, x1, y1) for object detection.")
50
+
51
+
52
+ def _find_max(arr: ArrayLike) -> Any:
53
+ if isinstance(arr[0], (Iterable, Sequence, Array)):
54
+ return max([_find_max(x) for x in arr]) # type: ignore
55
+ else:
56
+ return max(arr)
57
+
58
+
59
+ _TLabels = TypeVar("_TLabels", Sequence[int], Sequence[Sequence[int]])
60
+
61
+
62
+ class BaseAnnotatedDataset(Generic[_TLabels]):
63
+ def __init__(
64
+ self,
65
+ datum_type: Literal["ic", "od"],
66
+ images: Array | Sequence[Array],
67
+ labels: _TLabels,
68
+ metadata: Sequence[dict[str, Any]] | None,
69
+ classes: Sequence[str] | None,
70
+ name: str | None = None,
71
+ ) -> None:
72
+ self._classes = classes if classes is not None else [str(i) for i in range(_find_max(labels) + 1)]
73
+ self._index2label = dict(enumerate(self._classes))
74
+ self._images = images
75
+ self._labels = labels
76
+ self._metadata = metadata
77
+ self._id = name or f"{len(self._images)}_image_{len(self._index2label)}_class_{datum_type}_dataset"
78
+
79
+ @property
80
+ def metadata(self) -> DatasetMetadata:
81
+ return DatasetMetadata(id=self._id, index2label=self._index2label)
82
+
83
+ def __len__(self) -> int:
84
+ return len(self._images)
85
+
86
+
87
+ class CustomImageClassificationDataset(BaseAnnotatedDataset[Sequence[int]], ImageClassificationDataset):
88
+ def __init__(
89
+ self,
90
+ images: Array | Sequence[Array],
91
+ labels: Sequence[int],
92
+ metadata: Sequence[dict[str, Any]] | None,
93
+ classes: Sequence[str] | None,
94
+ name: str | None = None,
95
+ ) -> None:
96
+ super().__init__("ic", images, labels, metadata, classes)
97
+ if name is not None:
98
+ self.__name__ = name
99
+ self.__class__.__name__ = name
100
+ self.__class__.__qualname__ = name
101
+
102
+ def __getitem__(self, idx: int, /) -> tuple[Array, Array, dict[str, Any]]:
103
+ one_hot = [0.0] * len(self._index2label)
104
+ one_hot[self._labels[idx]] = 1.0
105
+ return (
106
+ self._images[idx],
107
+ as_numpy(one_hot),
108
+ self._metadata[idx] if self._metadata is not None else {},
109
+ )
110
+
111
+
112
+ class CustomObjectDetectionDataset(BaseAnnotatedDataset[Sequence[Sequence[int]]], ObjectDetectionDataset):
113
+ class ObjectDetectionTarget:
114
+ def __init__(self, labels: Sequence[int], bboxes: Sequence[Sequence[float]]) -> None:
115
+ self._labels = labels
116
+ self._bboxes = bboxes
117
+ self._scores = [1.0] * len(labels)
118
+
119
+ @property
120
+ def labels(self) -> Sequence[int]:
121
+ return self._labels
122
+
123
+ @property
124
+ def boxes(self) -> Sequence[Sequence[float]]:
125
+ return self._bboxes
126
+
127
+ @property
128
+ def scores(self) -> Sequence[float]:
129
+ return self._scores
130
+
131
+ def __init__(
132
+ self,
133
+ images: Array | Sequence[Array],
134
+ labels: Sequence[Sequence[int]],
135
+ bboxes: Sequence[Sequence[Sequence[float]]],
136
+ metadata: Sequence[dict[str, Any]] | None,
137
+ classes: Sequence[str] | None,
138
+ name: str | None = None,
139
+ ) -> None:
140
+ super().__init__("od", images, labels, metadata, classes)
141
+ if name is not None:
142
+ self.__name__ = name
143
+ self.__class__.__name__ = name
144
+ self.__class__.__qualname__ = name
145
+ self._bboxes = bboxes
146
+
147
+ @property
148
+ def metadata(self) -> DatasetMetadata:
149
+ return DatasetMetadata(id=self._id, index2label=self._index2label)
150
+
151
+ def __getitem__(self, idx: int, /) -> tuple[Array, ObjectDetectionTarget, dict[str, Any]]:
152
+ return (
153
+ self._images[idx],
154
+ self.ObjectDetectionTarget(self._labels[idx], self._bboxes[idx]),
155
+ self._metadata[idx] if self._metadata is not None else {},
156
+ )
157
+
158
+
159
+ def to_image_classification_dataset(
160
+ images: Array | Sequence[Array],
161
+ labels: Sequence[int],
162
+ metadata: Sequence[dict[str, Any]] | None,
163
+ classes: Sequence[str] | None,
164
+ name: str | None = None,
165
+ ) -> ImageClassificationDataset:
166
+ """
167
+ Helper function to create custom ImageClassificationDataset classes.
168
+
169
+ Parameters
170
+ ----------
171
+ images : Array | Sequence[Array]
172
+ The images to use in the dataset.
173
+ labels : Sequence[int]
174
+ The labels to use in the dataset.
175
+ metadata : Sequence[dict[str, Any]] | None
176
+ The metadata to use in the dataset.
177
+ classes : Sequence[str] | None
178
+ The classes to use in the dataset.
179
+
180
+ Returns
181
+ -------
182
+ ImageClassificationDataset
183
+ """
184
+ _validate_data("ic", images, labels, None, metadata)
185
+ return CustomImageClassificationDataset(images, labels, metadata, classes, name)
186
+
187
+
188
+ def to_object_detection_dataset(
189
+ images: Array | Sequence[Array],
190
+ labels: Sequence[Sequence[int]],
191
+ bboxes: Sequence[Sequence[Sequence[float]]],
192
+ metadata: Sequence[dict[str, Any]] | None,
193
+ classes: Sequence[str] | None,
194
+ name: str | None = None,
195
+ ) -> ObjectDetectionDataset:
196
+ """
197
+ Helper function to create custom ObjectDetectionDataset classes.
198
+
199
+ Parameters
200
+ ----------
201
+ images : Array | Sequence[Array]
202
+ The images to use in the dataset.
203
+ labels : Sequence[Sequence[int]]
204
+ The labels to use in the dataset.
205
+ bboxes : Sequence[Sequence[Sequence[float]]]
206
+ The bounding boxes (x0,y0,x1,y0) to use in the dataset.
207
+ metadata : Sequence[dict[str, Any]] | None
208
+ The metadata to use in the dataset.
209
+ classes : Sequence[str] | None
210
+ The classes to use in the dataset.
211
+
212
+ Returns
213
+ -------
214
+ ObjectDetectionDataset
215
+ """
216
+ _validate_data("od", images, labels, bboxes, metadata)
217
+ return CustomObjectDetectionDataset(images, labels, bboxes, metadata, classes, name)
@@ -9,9 +9,8 @@ import torch
9
9
  from torch.utils.data import DataLoader, Subset
10
10
  from tqdm import tqdm
11
11
 
12
- from dataeval.config import get_device
13
- from dataeval.typing import TArray
14
- from dataeval.utils.data._types import Dataset
12
+ from dataeval.config import DeviceLike, get_device
13
+ from dataeval.typing import Array, Dataset
15
14
  from dataeval.utils.torch.models import SupportsEncode
16
15
 
17
16
 
@@ -25,13 +24,14 @@ class Embeddings:
25
24
  ----------
26
25
  dataset : ImageClassificationDataset or ObjectDetectionDataset
27
26
  Dataset to access original images from.
28
- batch_size : int, optional
27
+ batch_size : int
29
28
  Batch size to use when encoding images.
30
- model : torch.nn.Module, optional
29
+ model : torch.nn.Module or None, default None
31
30
  Model to use for encoding images.
32
- device : torch.device, optional
33
- Device to use for encoding images.
34
- verbose : bool, optional
31
+ device : DeviceLike or None, default None
32
+ The hardware device to use if specified, otherwise uses the DataEval
33
+ default or torch default.
34
+ verbose : bool, default False
35
35
  Whether to print progress bar when encoding images.
36
36
  """
37
37
 
@@ -41,11 +41,10 @@ class Embeddings:
41
41
 
42
42
  def __init__(
43
43
  self,
44
- dataset: Dataset[TArray, Any],
44
+ dataset: Dataset[tuple[Array, Any, Any]],
45
45
  batch_size: int,
46
- indices: Sequence[int] | None = None,
47
46
  model: torch.nn.Module | None = None,
48
- device: torch.device | str | None = None,
47
+ device: DeviceLike | None = None,
49
48
  verbose: bool = False,
50
49
  ) -> None:
51
50
  self.device = get_device(device)
@@ -53,7 +52,6 @@ class Embeddings:
53
52
  self.verbose = verbose
54
53
 
55
54
  self._dataset = dataset
56
- self._indices = indices if indices is not None else range(len(dataset))
57
55
  model = torch.nn.Flatten() if model is None else model
58
56
  self._model = model.to(self.device).eval()
59
57
  self._encoder = model.encode if isinstance(model, SupportsEncode) else model
@@ -78,7 +76,7 @@ class Embeddings:
78
76
  @torch.no_grad
79
77
  def _batch(self, indices: Sequence[int]) -> Iterator[torch.Tensor]:
80
78
  # manual batching
81
- dataloader = DataLoader(Subset(self._dataset, indices), batch_size=self.batch_size, collate_fn=self._collate_fn)
79
+ dataloader = DataLoader(Subset(self._dataset, indices), batch_size=self.batch_size, collate_fn=self._collate_fn) # type: ignore
82
80
  for i, images in (
83
81
  tqdm(enumerate(dataloader), total=math.ceil(len(indices) / self.batch_size), desc="Batch processing")
84
82
  if self.verbose
@@ -87,7 +85,7 @@ class Embeddings:
87
85
  embeddings = self._encoder(torch.stack(images).to(self.device))
88
86
  yield embeddings
89
87
 
90
- def __getitem__(self, key: int | slice | list[int]) -> torch.Tensor:
88
+ def __getitem__(self, key: int | slice | list[int], /) -> torch.Tensor:
91
89
  if isinstance(key, list):
92
90
  return torch.vstack(list(self._batch(key))).to(self.device)
93
91
  if isinstance(key, slice):
@@ -2,13 +2,14 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- from typing import Any, Generic, Iterator, Sequence, overload
5
+ from typing import Any, Generic, Iterator, Sequence, TypeVar, cast, overload
6
6
 
7
- from dataeval.typing import TArray
8
- from dataeval.utils.data._types import Dataset
7
+ from dataeval.typing import Dataset
9
8
 
9
+ T = TypeVar("T")
10
10
 
11
- class Images(Generic[TArray]):
11
+
12
+ class Images(Generic[T]):
12
13
  """
13
14
  Collection of image data from a dataset.
14
15
 
@@ -16,17 +17,15 @@ class Images(Generic[TArray]):
16
17
 
17
18
  Parameters
18
19
  ----------
19
- dataset : ImageClassificationDataset or ObjectDetectionDataset
20
+ dataset : Dataset[tuple[T, ...]] or Dataset[T]
20
21
  Dataset to access images from.
21
22
  """
22
23
 
23
- def __init__(
24
- self,
25
- dataset: Dataset[TArray, Any],
26
- ) -> None:
24
+ def __init__(self, dataset: Dataset[tuple[T, Any, Any] | T]) -> None:
25
+ self._is_tuple_datum = isinstance(dataset[0], tuple)
27
26
  self._dataset = dataset
28
27
 
29
- def to_list(self) -> Sequence[TArray]:
28
+ def to_list(self) -> Sequence[T]:
30
29
  """
31
30
  Converts entire dataset to a sequence of images.
32
31
 
@@ -37,29 +36,33 @@ class Images(Generic[TArray]):
37
36
 
38
37
  Returns
39
38
  -------
40
- list[TArray]
39
+ list[T]
41
40
  """
42
41
  return self[:]
43
42
 
44
43
  @overload
45
- def __getitem__(self, key: slice | list[int]) -> Sequence[TArray]: ...
46
-
44
+ def __getitem__(self, key: int, /) -> T: ...
47
45
  @overload
48
- def __getitem__(self, key: int) -> TArray: ...
49
-
50
- def __getitem__(self, key: int | slice | list[int]) -> Sequence[TArray] | TArray:
51
- if isinstance(key, list):
52
- return [self._dataset[i][0] for i in key]
53
- if isinstance(key, slice):
54
- indices = list(range(len(self._dataset))[key])
55
- return [self._dataset[i][0] for i in indices]
56
- elif isinstance(key, int):
57
- return self._dataset[key][0]
58
- raise TypeError("Invalid argument type.")
59
-
60
- def __iter__(self) -> Iterator[TArray]:
46
+ def __getitem__(self, key: slice, /) -> Sequence[T]: ...
47
+
48
+ def __getitem__(self, key: int | slice, /) -> Sequence[T] | T:
49
+ if self._is_tuple_datum:
50
+ dataset = cast(Dataset[tuple[T, Any, Any]], self._dataset)
51
+ if isinstance(key, slice):
52
+ return [dataset[k][0] for k in range(len(self._dataset))[key]]
53
+ elif isinstance(key, int):
54
+ return dataset[key][0]
55
+ else:
56
+ dataset = cast(Dataset[T], self._dataset)
57
+ if isinstance(key, slice):
58
+ return [dataset[k] for k in range(len(self._dataset))[key]]
59
+ elif isinstance(key, int):
60
+ return dataset[key]
61
+ raise TypeError(f"Key must be integers or slices, not {type(key)}")
62
+
63
+ def __iter__(self) -> Iterator[T]:
61
64
  for i in range(len(self._dataset)):
62
- yield self._dataset[i][0]
65
+ yield self[i]
63
66
 
64
67
  def __len__(self) -> int:
65
68
  return len(self._dataset)
@@ -3,18 +3,19 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  import warnings
6
- from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence
6
+ from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence, cast
7
7
 
8
8
  import numpy as np
9
9
  from numpy.typing import NDArray
10
10
 
11
- from dataeval.typing import Array
12
- from dataeval.utils._array import as_numpy, to_numpy
13
- from dataeval.utils._bin import bin_data, digitize_data, is_continuous
14
- from dataeval.utils.data._types import (
15
- Dataset,
11
+ from dataeval.typing import (
12
+ AnnotatedDataset,
13
+ Array,
14
+ ArrayLike,
16
15
  ObjectDetectionTarget,
17
16
  )
17
+ from dataeval.utils._array import as_numpy, to_numpy
18
+ from dataeval.utils._bin import bin_data, digitize_data, is_continuous
18
19
  from dataeval.utils.metadata import merge
19
20
 
20
21
  if TYPE_CHECKING:
@@ -65,7 +66,7 @@ class Metadata:
65
66
 
66
67
  def __init__(
67
68
  self,
68
- dataset: Dataset[Any, Any],
69
+ dataset: AnnotatedDataset[tuple[Any, Any, dict[str, Any]]],
69
70
  *,
70
71
  continuous_factor_bins: Mapping[str, int | Sequence[float]] | None = None,
71
72
  auto_bin_method: Literal["uniform_width", "uniform_count", "clusters"] = "uniform_width",
@@ -276,12 +277,12 @@ class Metadata:
276
277
  if self._processed and not force:
277
278
  return
278
279
 
279
- # Validate the metadata dimensions
280
- self._validate()
281
-
282
280
  # Create image indices from targets
283
281
  self._image_indices = np.arange(len(self.raw)) if self.targets.source is None else self.targets.source
284
282
 
283
+ # Validate the metadata dimensions
284
+ self._validate()
285
+
285
286
  # Include specified metadata keys
286
287
  if self.include:
287
288
  metadata = {i: self.merged[i] for i in self.include if i in self.merged}
@@ -341,7 +342,11 @@ class Metadata:
341
342
 
342
343
  # Split out the dictionaries into the keys and values
343
344
  self._discrete_factor_names = list(discrete_metadata.keys())
344
- self._discrete_data = np.stack(list(discrete_metadata.values()), axis=-1, dtype=np.int64)
345
+ self._discrete_data = (
346
+ np.stack(list(discrete_metadata.values()), axis=-1, dtype=np.int64)
347
+ if discrete_metadata
348
+ else np.array([], dtype=np.int64)
349
+ )
345
350
  self._continuous_factor_names = list(continuous_metadata.keys())
346
351
  self._continuous_data = (
347
352
  np.stack(list(continuous_metadata.values()), axis=-1, dtype=np.float64)
@@ -350,3 +355,15 @@ class Metadata:
350
355
  )
351
356
  self._total_num_factors = len(self._discrete_factor_names + self._continuous_factor_names) + 1
352
357
  self._processed = True
358
+
359
+ def add_factors(self, factors: Mapping[str, ArrayLike]) -> None:
360
+ self._merge()
361
+ self._processed = False
362
+ target_len = len(self.targets.source) if self.targets.source is not None else len(self.targets)
363
+ if any(len(v) != target_len for v in factors.values()):
364
+ raise ValueError(
365
+ "The lists/arrays in the provided factors have a different length than the current metadata factors."
366
+ )
367
+ merged = cast(tuple[dict[str, ArrayLike], dict[str, list[str]]], self._merged)[0]
368
+ for k, v in factors.items():
369
+ merged[k] = v
@@ -3,12 +3,11 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  from enum import IntEnum
6
- from typing import Any, Generic, Iterator, Sequence, TypeVar
6
+ from typing import Generic, Iterator, Sequence, TypeVar
7
7
 
8
- from dataeval.utils.data._types import Dataset
8
+ from dataeval.typing import AnnotatedDataset, DatasetMetadata
9
9
 
10
- _TData = TypeVar("_TData")
11
- _TTarget = TypeVar("_TTarget")
10
+ _TDatum = TypeVar("_TDatum", covariant=True)
12
11
 
13
12
 
14
13
  class SelectionStage(IntEnum):
@@ -17,16 +16,16 @@ class SelectionStage(IntEnum):
17
16
  ORDER = 2
18
17
 
19
18
 
20
- class Selection(Generic[_TData, _TTarget]):
19
+ class Selection(Generic[_TDatum]):
21
20
  stage: SelectionStage
22
21
 
23
- def __call__(self, dataset: Select[_TData, _TTarget]) -> None: ...
22
+ def __call__(self, dataset: Select[_TDatum]) -> None: ...
24
23
 
25
24
  def __str__(self) -> str:
26
25
  return f"{self.__class__.__name__}({', '.join([f'{k}={v}' for k, v in self.__dict__.items()])})"
27
26
 
28
27
 
29
- class Select(Generic[_TData, _TTarget], Dataset[_TData, _TTarget]):
28
+ class Select(AnnotatedDataset[_TDatum]):
30
29
  """
31
30
  Wraps a dataset and applies selection criteria to it.
32
31
 
@@ -60,35 +59,43 @@ class Select(Generic[_TData, _TTarget], Dataset[_TData, _TTarget]):
60
59
  (data_20, 0, {'id': 20})
61
60
  """
62
61
 
63
- _dataset: Dataset[_TData, _TTarget]
62
+ _dataset: AnnotatedDataset[_TDatum]
64
63
  _selection: list[int]
65
- _selections: Sequence[Selection[_TData, _TTarget]]
64
+ _selections: Sequence[Selection[_TDatum]]
66
65
  _size_limit: int
67
66
 
68
67
  def __init__(
69
68
  self,
70
- dataset: Dataset[_TData, _TTarget],
71
- selections: Selection[_TData, _TTarget] | list[Selection[_TData, _TTarget]] | None = None,
69
+ dataset: AnnotatedDataset[_TDatum],
70
+ selections: Selection[_TDatum] | list[Selection[_TDatum]] | None = None,
72
71
  ) -> None:
72
+ self.__dict__.update(dataset.__dict__)
73
73
  self._dataset = dataset
74
74
  self._size_limit = len(dataset)
75
75
  self._selection = list(range(self._size_limit))
76
76
  self._selections = self._sort_selections(selections)
77
- self.__dict__.update(dataset.__dict__)
77
+
78
+ # Ensure metadata is populated correctly as DatasetMetadata TypedDict
79
+ _metadata = getattr(dataset, "metadata", {})
80
+ if "id" not in _metadata:
81
+ _metadata["id"] = dataset.__class__.__name__
82
+ self._metadata = DatasetMetadata(**_metadata)
78
83
 
79
84
  if self._selections:
80
85
  self._apply_selections()
81
86
 
87
+ @property
88
+ def metadata(self) -> DatasetMetadata:
89
+ return self._metadata
90
+
82
91
  def __str__(self) -> str:
83
92
  nt = "\n "
84
93
  title = f"{self.__class__.__name__} Dataset"
85
94
  sep = "-" * len(title)
86
95
  selections = f"Selections: [{', '.join([str(s) for s in self._sort_selections(self._selections)])}]"
87
- return f"{title}\n{sep}{nt}{selections}\n\n{self._dataset}"
96
+ return f"{title}\n{sep}{nt}{selections}{nt}Selected Size: {len(self)}\n\n{self._dataset}"
88
97
 
89
- def _sort_selections(
90
- self, selections: Selection[_TData, _TTarget] | Sequence[Selection[_TData, _TTarget]] | None
91
- ) -> list[Selection]:
98
+ def _sort_selections(self, selections: Selection[_TDatum] | Sequence[Selection[_TDatum]] | None) -> list[Selection]:
92
99
  if not selections:
93
100
  return []
94
101
 
@@ -104,14 +111,10 @@ class Select(Generic[_TData, _TTarget], Dataset[_TData, _TTarget]):
104
111
  selection(self)
105
112
  self._selection = self._selection[: self._size_limit]
106
113
 
107
- def __getattr__(self, name: str, /) -> Any:
108
- selfattr = getattr(self._dataset, name, None)
109
- return selfattr if selfattr is not None else getattr(self._dataset, name)
110
-
111
- def __getitem__(self, index: int) -> tuple[_TData, _TTarget, dict[str, Any]]:
114
+ def __getitem__(self, index: int) -> _TDatum:
112
115
  return self._dataset[self._selection[index]]
113
116
 
114
- def __iter__(self) -> Iterator[tuple[_TData, _TTarget, dict[str, Any]]]:
117
+ def __iter__(self) -> Iterator[_TDatum]:
115
118
  for i in range(len(self)):
116
119
  yield self[i]
117
120
 
@@ -3,7 +3,6 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  import warnings
6
- from dataclasses import dataclass
7
6
  from typing import Any, Iterator, Protocol
8
7
 
9
8
  import numpy as np
@@ -13,32 +12,9 @@ from sklearn.metrics import silhouette_score
13
12
  from sklearn.model_selection import GroupKFold, KFold, StratifiedGroupKFold, StratifiedKFold
14
13
  from sklearn.utils.multiclass import type_of_target
15
14
 
16
- from dataeval._output import Output, set_metadata
17
-
18
-
19
- @dataclass
20
- class TrainValSplit:
21
- """Tuple containing train and validation indices"""
22
-
23
- train: NDArray[np.intp]
24
- val: NDArray[np.intp]
25
-
26
-
27
- @dataclass(frozen=True)
28
- class SplitDatasetOutput(Output):
29
- """
30
- Output class containing test indices and a list of TrainValSplits.
31
-
32
- Attributes
33
- ----------
34
- test: NDArray[np.intp]
35
- Indices for the test set
36
- folds: list[TrainValSplit]
37
- List where each index contains the indices for the train and validation splits
38
- """
39
-
40
- test: NDArray[np.intp]
41
- folds: list[TrainValSplit]
15
+ from dataeval.config import get_seed
16
+ from dataeval.outputs._base import set_metadata
17
+ from dataeval.outputs._utils import SplitDatasetOutput, TrainValSplit
42
18
 
43
19
 
44
20
  class KFoldSplitter(Protocol):
@@ -237,9 +213,9 @@ def bin_kmeans(array: NDArray[Any]) -> NDArray[np.intp]:
237
213
  best_score = 0.50
238
214
  bin_index = np.zeros(len(array), dtype=np.intp)
239
215
  for k in range(2, 20):
240
- clusterer = KMeans(n_clusters=k)
216
+ clusterer = KMeans(n_clusters=k, random_state=get_seed())
241
217
  cluster_labels = clusterer.fit_predict(array)
242
- score = silhouette_score(array, cluster_labels, sample_size=25_000)
218
+ score = silhouette_score(array, cluster_labels, sample_size=25_000, random_state=get_seed())
243
219
  if score > best_score:
244
220
  best_score = score
245
221
  bin_index = cluster_labels.astype(np.intp)
@@ -1,5 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from typing import Iterator
4
+
3
5
  __all__ = []
4
6
 
5
7
  from dataclasses import dataclass
@@ -52,10 +54,16 @@ class Targets:
52
54
  + f" source: {None if self.source is None else self.source.shape}\n"
53
55
  )
54
56
 
57
+ if self.bboxes is not None and len(self.bboxes) > 0 and self.bboxes.shape[-1] != 4:
58
+ raise ValueError("Bounding boxes must be in (x0,y0,x1,y1) format.")
59
+
55
60
  def __len__(self) -> int:
56
- return len(self.labels)
61
+ if self.source is None:
62
+ return len(self.labels)
63
+ else:
64
+ return len(np.unique(self.source))
57
65
 
58
- def at(self, idx: int) -> Targets:
66
+ def __getitem__(self, idx: int, /) -> Targets:
59
67
  if self.source is None or self.bboxes is None:
60
68
  return Targets(
61
69
  np.atleast_1d(self.labels[idx]),
@@ -71,3 +79,7 @@ class Targets:
71
79
  np.atleast_2d(self.bboxes[mask]),
72
80
  np.atleast_1d(self.source[mask]),
73
81
  )
82
+
83
+ def __iter__(self) -> Iterator[Targets]:
84
+ for i in range(len(self.labels)) if self.source is None else np.unique(self.source):
85
+ yield self[i]
@@ -6,8 +6,10 @@ from abc import abstractmethod
6
6
  from pathlib import Path
7
7
  from typing import Any, Generic, Iterator, Literal, NamedTuple, Sequence, TypeVar
8
8
 
9
- from dataeval.utils.data._types import (
10
- Dataset,
9
+ from dataeval.utils.data.datasets._fileio import _ensure_exists
10
+ from dataeval.utils.data.datasets._mixin import BaseDatasetMixin
11
+ from dataeval.utils.data.datasets._types import (
12
+ AnnotatedDataset,
11
13
  DatasetMetadata,
12
14
  ImageClassificationDataset,
13
15
  ObjectDetectionDataset,
@@ -16,8 +18,6 @@ from dataeval.utils.data._types import (
16
18
  SegmentationTarget,
17
19
  Transform,
18
20
  )
19
- from dataeval.utils.data.datasets._fileio import _ensure_exists
20
- from dataeval.utils.data.datasets._mixin import BaseDatasetMixin
21
21
 
22
22
  _TArray = TypeVar("_TArray")
23
23
  _TTarget = TypeVar("_TTarget")
@@ -31,7 +31,7 @@ class DataLocation(NamedTuple):
31
31
  checksum: str
32
32
 
33
33
 
34
- class BaseDataset(Dataset[_TArray, _TTarget], Generic[_TArray, _TTarget, _TRawTarget]):
34
+ class BaseDataset(AnnotatedDataset[tuple[_TArray, _TTarget, dict[str, Any]]], Generic[_TArray, _TTarget, _TRawTarget]):
35
35
  """
36
36
  Base class for internet downloaded datasets.
37
37
  """