dataeval 0.76.1__py3-none-any.whl → 0.82.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. dataeval/__init__.py +3 -3
  2. dataeval/config.py +77 -0
  3. dataeval/detectors/__init__.py +1 -1
  4. dataeval/detectors/drift/__init__.py +6 -6
  5. dataeval/detectors/drift/{base.py → _base.py} +40 -85
  6. dataeval/detectors/drift/{cvm.py → _cvm.py} +21 -28
  7. dataeval/detectors/drift/{ks.py → _ks.py} +20 -26
  8. dataeval/detectors/drift/{mmd.py → _mmd.py} +31 -43
  9. dataeval/detectors/drift/{torch.py → _torch.py} +2 -1
  10. dataeval/detectors/drift/{uncertainty.py → _uncertainty.py} +24 -7
  11. dataeval/detectors/drift/updates.py +20 -3
  12. dataeval/detectors/linters/__init__.py +3 -5
  13. dataeval/detectors/linters/duplicates.py +13 -36
  14. dataeval/detectors/linters/outliers.py +23 -148
  15. dataeval/detectors/ood/__init__.py +1 -1
  16. dataeval/detectors/ood/ae.py +30 -9
  17. dataeval/detectors/ood/base.py +5 -4
  18. dataeval/detectors/ood/mixin.py +21 -7
  19. dataeval/detectors/ood/vae.py +73 -0
  20. dataeval/metadata/__init__.py +6 -0
  21. dataeval/metadata/_distance.py +167 -0
  22. dataeval/metadata/_ood.py +217 -0
  23. dataeval/metadata/_utils.py +44 -0
  24. dataeval/metrics/__init__.py +1 -1
  25. dataeval/metrics/bias/__init__.py +6 -4
  26. dataeval/metrics/bias/{balance.py → _balance.py} +15 -101
  27. dataeval/metrics/bias/_coverage.py +98 -0
  28. dataeval/metrics/bias/{diversity.py → _diversity.py} +18 -111
  29. dataeval/metrics/bias/{parity.py → _parity.py} +39 -77
  30. dataeval/metrics/estimators/__init__.py +15 -4
  31. dataeval/metrics/estimators/{ber.py → _ber.py} +42 -29
  32. dataeval/metrics/estimators/_clusterer.py +44 -0
  33. dataeval/metrics/estimators/{divergence.py → _divergence.py} +18 -30
  34. dataeval/metrics/estimators/{uap.py → _uap.py} +4 -18
  35. dataeval/metrics/stats/__init__.py +16 -13
  36. dataeval/metrics/stats/{base.py → _base.py} +82 -133
  37. dataeval/metrics/stats/{boxratiostats.py → _boxratiostats.py} +15 -18
  38. dataeval/metrics/stats/_dimensionstats.py +75 -0
  39. dataeval/metrics/stats/{hashstats.py → _hashstats.py} +21 -37
  40. dataeval/metrics/stats/_imagestats.py +94 -0
  41. dataeval/metrics/stats/_labelstats.py +131 -0
  42. dataeval/metrics/stats/{pixelstats.py → _pixelstats.py} +19 -50
  43. dataeval/metrics/stats/{visualstats.py → _visualstats.py} +23 -54
  44. dataeval/outputs/__init__.py +53 -0
  45. dataeval/{output.py → outputs/_base.py} +55 -25
  46. dataeval/outputs/_bias.py +381 -0
  47. dataeval/outputs/_drift.py +83 -0
  48. dataeval/outputs/_estimators.py +114 -0
  49. dataeval/outputs/_linters.py +184 -0
  50. dataeval/{detectors/ood/output.py → outputs/_ood.py} +22 -22
  51. dataeval/outputs/_stats.py +387 -0
  52. dataeval/outputs/_utils.py +44 -0
  53. dataeval/outputs/_workflows.py +364 -0
  54. dataeval/typing.py +234 -0
  55. dataeval/utils/__init__.py +2 -2
  56. dataeval/utils/_array.py +169 -0
  57. dataeval/utils/_bin.py +199 -0
  58. dataeval/utils/_clusterer.py +144 -0
  59. dataeval/utils/_fast_mst.py +189 -0
  60. dataeval/utils/{image.py → _image.py} +6 -4
  61. dataeval/utils/_method.py +14 -0
  62. dataeval/utils/{shared.py → _mst.py} +3 -65
  63. dataeval/utils/{plot.py → _plot.py} +6 -6
  64. dataeval/utils/data/__init__.py +26 -0
  65. dataeval/utils/data/_dataset.py +217 -0
  66. dataeval/utils/data/_embeddings.py +104 -0
  67. dataeval/utils/data/_images.py +68 -0
  68. dataeval/utils/data/_metadata.py +360 -0
  69. dataeval/utils/data/_selection.py +126 -0
  70. dataeval/utils/{dataset/split.py → data/_split.py} +12 -38
  71. dataeval/utils/data/_targets.py +85 -0
  72. dataeval/utils/data/collate.py +103 -0
  73. dataeval/utils/data/datasets/__init__.py +17 -0
  74. dataeval/utils/data/datasets/_base.py +254 -0
  75. dataeval/utils/data/datasets/_cifar10.py +134 -0
  76. dataeval/utils/data/datasets/_fileio.py +168 -0
  77. dataeval/utils/data/datasets/_milco.py +153 -0
  78. dataeval/utils/data/datasets/_mixin.py +56 -0
  79. dataeval/utils/data/datasets/_mnist.py +183 -0
  80. dataeval/utils/data/datasets/_ships.py +123 -0
  81. dataeval/utils/data/datasets/_types.py +52 -0
  82. dataeval/utils/data/datasets/_voc.py +352 -0
  83. dataeval/utils/data/selections/__init__.py +15 -0
  84. dataeval/utils/data/selections/_classfilter.py +57 -0
  85. dataeval/utils/data/selections/_indices.py +26 -0
  86. dataeval/utils/data/selections/_limit.py +26 -0
  87. dataeval/utils/data/selections/_reverse.py +18 -0
  88. dataeval/utils/data/selections/_shuffle.py +29 -0
  89. dataeval/utils/metadata.py +51 -376
  90. dataeval/utils/torch/{gmm.py → _gmm.py} +4 -2
  91. dataeval/utils/torch/{internal.py → _internal.py} +21 -51
  92. dataeval/utils/torch/models.py +43 -2
  93. dataeval/workflows/__init__.py +2 -1
  94. dataeval/workflows/sufficiency.py +11 -346
  95. {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/METADATA +5 -2
  96. dataeval-0.82.0.dist-info/RECORD +104 -0
  97. dataeval/detectors/linters/clusterer.py +0 -512
  98. dataeval/detectors/linters/merged_stats.py +0 -49
  99. dataeval/detectors/ood/metadata_ks_compare.py +0 -129
  100. dataeval/detectors/ood/metadata_least_likely.py +0 -119
  101. dataeval/interop.py +0 -69
  102. dataeval/metrics/bias/coverage.py +0 -194
  103. dataeval/metrics/stats/datasetstats.py +0 -202
  104. dataeval/metrics/stats/dimensionstats.py +0 -115
  105. dataeval/metrics/stats/labelstats.py +0 -210
  106. dataeval/utils/dataset/__init__.py +0 -7
  107. dataeval/utils/dataset/datasets.py +0 -412
  108. dataeval/utils/dataset/read.py +0 -63
  109. dataeval-0.76.1.dist-info/RECORD +0 -67
  110. /dataeval/{log.py → _log.py} +0 -0
  111. /dataeval/utils/torch/{blocks.py → _blocks.py} +0 -0
  112. {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/LICENSE.txt +0 -0
  113. {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,360 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ import warnings
6
+ from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence
7
+
8
+ import numpy as np
9
+ from numpy.typing import NDArray
10
+
11
+ from dataeval.typing import (
12
+ AnnotatedDataset,
13
+ Array,
14
+ ObjectDetectionTarget,
15
+ )
16
+ from dataeval.utils._array import as_numpy, to_numpy
17
+ from dataeval.utils._bin import bin_data, digitize_data, is_continuous
18
+ from dataeval.utils.metadata import merge
19
+
20
+ if TYPE_CHECKING:
21
+ from dataeval.utils.data import Targets
22
+ else:
23
+ from dataeval.utils.data._targets import Targets
24
+
25
+
26
+ class Metadata:
27
+ """
28
+ Class containing binned metadata.
29
+
30
+ Attributes
31
+ ----------
32
+ discrete_factor_names : list[str]
33
+ List containing factor names for the original data that was discrete and
34
+ the binned continuous data
35
+ discrete_data : NDArray[np.int64]
36
+ Array containing values for the original data that was discrete and the
37
+ binned continuous data
38
+ continuous_factor_names : list[str]
39
+ List containing factor names for the original continuous data
40
+ continuous_data : NDArray[np.float64] | None
41
+ Array containing values for the original continuous data or None if there
42
+ was no continuous data
43
+ class_labels : NDArray[np.int]
44
+ Numerical class labels for the images/objects
45
+ class_names : list[str]
46
+ List of unique class names
47
+ total_num_factors : int
48
+ Sum of discrete_factor_names and continuous_factor_names plus 1 for class
49
+ image_indices : NDArray[np.intp]
50
+ Array of the image index that is mapped by the index of the factor
51
+
52
+ Parameters
53
+ ----------
54
+ dataset : ImageClassificationDataset or ObjectDetectionDataset
55
+ Dataset to access original targets and metadata from.
56
+ continuous_factor_bins : Mapping[str, int | Sequence[float]] | None, default None
57
+ Mapping from continuous factor name to the number of bins or bin edges
58
+ auto_bin_method : Literal["uniform_width", "uniform_count", "clusters"], default "uniform_width"
59
+ Method for automatically determining the number of bins for continuous factors
60
+ exclude : Sequence[str] | None, default None
61
+ Filter metadata factors to exclude the specified factors, cannot be set with `include`
62
+ include : Sequence[str] | None, default None
63
+ Filter metadata factors to include the specified factors, cannot be set with `exclude`
64
+ """
65
+
66
+ def __init__(
67
+ self,
68
+ dataset: AnnotatedDataset[tuple[Any, Any, dict[str, Any]]],
69
+ *,
70
+ continuous_factor_bins: Mapping[str, int | Sequence[float]] | None = None,
71
+ auto_bin_method: Literal["uniform_width", "uniform_count", "clusters"] = "uniform_width",
72
+ exclude: Sequence[str] | None = None,
73
+ include: Sequence[str] | None = None,
74
+ ) -> None:
75
+ self._collated = False
76
+ self._merged = None
77
+ self._processed = False
78
+
79
+ self._dataset = dataset
80
+ self._continuous_factor_bins = dict(continuous_factor_bins) if continuous_factor_bins else {}
81
+ self._auto_bin_method = auto_bin_method
82
+
83
+ if exclude is not None and include is not None:
84
+ raise ValueError("Filters for `exclude` and `include` are mutually exclusive.")
85
+
86
+ self._exclude = set(exclude or ())
87
+ self._include = set(include or ())
88
+
89
+ @property
90
+ def targets(self) -> Targets:
91
+ self._collate()
92
+ return self._targets
93
+
94
+ @property
95
+ def raw(self) -> list[dict[str, Any]]:
96
+ self._collate()
97
+ return self._raw
98
+
99
+ @property
100
+ def exclude(self) -> set[str]:
101
+ return self._exclude
102
+
103
+ @exclude.setter
104
+ def exclude(self, value: Sequence[str]) -> None:
105
+ exclude = set(value)
106
+ if self._exclude != exclude:
107
+ self._exclude = exclude
108
+ self._include = set()
109
+ self._processed = False
110
+
111
+ @property
112
+ def include(self) -> set[str]:
113
+ return self._include
114
+
115
+ @include.setter
116
+ def include(self, value: Sequence[str]) -> None:
117
+ include = set(value)
118
+ if self._include != include:
119
+ self._include = include
120
+ self._exclude = set()
121
+ self._processed = False
122
+
123
+ @property
124
+ def continuous_factor_bins(self) -> Mapping[str, int | Sequence[float]]:
125
+ return self._continuous_factor_bins
126
+
127
+ @continuous_factor_bins.setter
128
+ def continuous_factor_bins(self, bins: Mapping[str, int | Sequence[float]]) -> None:
129
+ if self._continuous_factor_bins != bins:
130
+ self._continuous_factor_bins = dict(bins)
131
+ self._processed = False
132
+
133
+ @property
134
+ def auto_bin_method(self) -> str:
135
+ return self._auto_bin_method
136
+
137
+ @auto_bin_method.setter
138
+ def auto_bin_method(self, method: Literal["uniform_width", "uniform_count", "clusters"]) -> None:
139
+ if self._auto_bin_method != method:
140
+ self._auto_bin_method = method
141
+ self._processed = False
142
+
143
+ @property
144
+ def merged(self) -> dict[str, Any]:
145
+ self._merge()
146
+ return {} if self._merged is None else self._merged[0]
147
+
148
+ @property
149
+ def dropped_factors(self) -> dict[str, list[str]]:
150
+ self._merge()
151
+ return {} if self._merged is None else self._merged[1]
152
+
153
+ @property
154
+ def discrete_factor_names(self) -> list[str]:
155
+ self._process()
156
+ return self._discrete_factor_names
157
+
158
+ @property
159
+ def discrete_data(self) -> NDArray[np.int64]:
160
+ self._process()
161
+ return self._discrete_data
162
+
163
+ @property
164
+ def continuous_factor_names(self) -> list[str]:
165
+ self._process()
166
+ return self._continuous_factor_names
167
+
168
+ @property
169
+ def continuous_data(self) -> NDArray[np.float64]:
170
+ self._process()
171
+ return self._continuous_data
172
+
173
+ @property
174
+ def class_labels(self) -> NDArray[np.intp]:
175
+ self._collate()
176
+ return self._class_labels
177
+
178
+ @property
179
+ def class_names(self) -> list[str]:
180
+ self._collate()
181
+ return self._class_names
182
+
183
+ @property
184
+ def total_num_factors(self) -> int:
185
+ self._process()
186
+ return self._total_num_factors
187
+
188
+ @property
189
+ def image_indices(self) -> NDArray[np.intp]:
190
+ self._process()
191
+ return self._image_indices
192
+
193
+ def _collate(self, force: bool = False):
194
+ if self._collated and not force:
195
+ return
196
+
197
+ raw: list[dict[str, Any]] = []
198
+
199
+ labels = []
200
+ bboxes = []
201
+ scores = []
202
+ srcidx = []
203
+ is_od = None
204
+ for i in range(len(self._dataset)):
205
+ _, target, metadata = self._dataset[i]
206
+
207
+ raw.append(metadata)
208
+
209
+ if is_od_target := isinstance(target, ObjectDetectionTarget):
210
+ target_len = len(target.labels)
211
+ labels.extend(as_numpy(target.labels).tolist())
212
+ bboxes.extend(as_numpy(target.boxes).tolist())
213
+ scores.extend(as_numpy(target.scores).tolist())
214
+ srcidx.extend([i] * target_len)
215
+ elif isinstance(target, Array):
216
+ target_len = 1
217
+ labels.append(int(np.argmax(as_numpy(target))))
218
+ scores.append(target)
219
+ else:
220
+ raise TypeError("Encountered unsupported target type in dataset")
221
+
222
+ is_od = is_od_target if is_od is None else is_od
223
+ if is_od != is_od_target:
224
+ raise ValueError("Encountered unexpected target type in dataset")
225
+
226
+ labels = as_numpy(labels).astype(np.intp)
227
+ scores = as_numpy(scores).astype(np.float32)
228
+ bboxes = as_numpy(bboxes).astype(np.float32) if is_od else None
229
+ srcidx = as_numpy(srcidx).astype(np.intp) if is_od else None
230
+
231
+ self._targets = Targets(labels, scores, bboxes, srcidx)
232
+ self._raw = raw
233
+
234
+ index2label = self._dataset.metadata.get("index2label", {})
235
+ self._class_labels = self._targets.labels
236
+ self._class_names = [index2label.get(i, str(i)) for i in np.unique(self._class_labels)]
237
+ self._collated = True
238
+
239
+ def _merge(self, force: bool = False):
240
+ if self._merged is not None and not force:
241
+ return
242
+
243
+ targets_per_image = (
244
+ None if self.targets.source is None else np.unique(self.targets.source, return_counts=True)[1].tolist()
245
+ )
246
+ self._merged = merge(self.raw, return_dropped=True, ignore_lists=False, targets_per_image=targets_per_image)
247
+
248
+ def _validate(self) -> None:
249
+ # Check that metadata is a single, flattened dictionary with uniform array lengths
250
+ check_length = None
251
+ if self._targets.labels.ndim > 1:
252
+ raise ValueError(
253
+ f"Got class labels with {self._targets.labels.ndim}-dimensional "
254
+ f"shape {self._targets.labels.shape}, but expected a 1-dimensional array."
255
+ )
256
+ for v in self.merged.values():
257
+ if not isinstance(v, (list, tuple, np.ndarray)):
258
+ raise TypeError(
259
+ "Metadata dictionary needs to be a single dictionary whose values "
260
+ "are arraylike containing the metadata on a per image or per object basis."
261
+ )
262
+ else:
263
+ check_length = len(v) if check_length is None else check_length
264
+ if check_length != len(v):
265
+ raise ValueError(
266
+ "The lists/arrays in the metadata dict have varying lengths. "
267
+ "Metadata requires them to be uniform in length."
268
+ )
269
+ if len(self._class_labels) != check_length:
270
+ raise ValueError(
271
+ f"The length of the label array {len(self._class_labels)} is not the same as "
272
+ f"the length of the metadata arrays {check_length}."
273
+ )
274
+
275
+ def _process(self, force: bool = False) -> None:
276
+ if self._processed and not force:
277
+ return
278
+
279
+ # Trigger collate and merge if not yet done
280
+ self._collate()
281
+ self._merge()
282
+
283
+ # Validate the metadata dimensions
284
+ self._validate()
285
+
286
+ # Create image indices from targets
287
+ self._image_indices = np.arange(len(self.raw)) if self.targets.source is None else self.targets.source
288
+
289
+ # Include specified metadata keys
290
+ if self.include:
291
+ metadata = {i: self.merged[i] for i in self.include if i in self.merged}
292
+ continuous_factor_bins = (
293
+ {i: self.continuous_factor_bins[i] for i in self.include if i in self.continuous_factor_bins}
294
+ if self.continuous_factor_bins
295
+ else {}
296
+ )
297
+ else:
298
+ metadata = self.merged
299
+ continuous_factor_bins = dict(self.continuous_factor_bins) if self.continuous_factor_bins else {}
300
+ for k in self.exclude:
301
+ metadata.pop(k, None)
302
+ continuous_factor_bins.pop(k, None)
303
+
304
+ # Remove generated "_image_index" if present
305
+ if "_image_index" in metadata:
306
+ metadata.pop("_image_index", None)
307
+
308
+ # Bin according to user supplied bins
309
+ continuous_metadata = {}
310
+ discrete_metadata = {}
311
+ if continuous_factor_bins:
312
+ invalid_keys = set(continuous_factor_bins.keys()) - set(metadata.keys())
313
+ if invalid_keys:
314
+ raise KeyError(
315
+ f"The keys - {invalid_keys} - are present in the `continuous_factor_bins` dictionary "
316
+ "but are not keys in the `metadata` dictionary. Delete these keys from `continuous_factor_bins` "
317
+ "or add corresponding entries to the `metadata` dictionary."
318
+ )
319
+ for factor, bins in continuous_factor_bins.items():
320
+ discrete_metadata[factor] = digitize_data(metadata[factor], bins)
321
+ continuous_metadata[factor] = metadata[factor]
322
+
323
+ # Determine category of the rest of the keys
324
+ remaining_keys = set(metadata.keys()) - set(continuous_metadata.keys())
325
+ for key in remaining_keys:
326
+ data = to_numpy(metadata[key])
327
+ if np.issubdtype(data.dtype, np.number):
328
+ result = is_continuous(data, self._image_indices)
329
+ if result:
330
+ continuous_metadata[key] = data
331
+ unique_samples, ordinal_data = np.unique(data, return_inverse=True)
332
+ if unique_samples.size <= np.max([20, data.size * 0.01]):
333
+ discrete_metadata[key] = ordinal_data
334
+ else:
335
+ warnings.warn(
336
+ f"A user defined binning was not provided for {key}. "
337
+ f"Using the {self.auto_bin_method} method to discretize the data. "
338
+ "It is recommended that the user rerun and supply the desired "
339
+ "bins using the continuous_factor_bins parameter.",
340
+ UserWarning,
341
+ )
342
+ discrete_metadata[key] = bin_data(data, self.auto_bin_method)
343
+ else:
344
+ _, discrete_metadata[key] = np.unique(data, return_inverse=True)
345
+
346
+ # Split out the dictionaries into the keys and values
347
+ self._discrete_factor_names = list(discrete_metadata.keys())
348
+ self._discrete_data = (
349
+ np.stack(list(discrete_metadata.values()), axis=-1, dtype=np.int64)
350
+ if discrete_metadata
351
+ else np.array([], dtype=np.int64)
352
+ )
353
+ self._continuous_factor_names = list(continuous_metadata.keys())
354
+ self._continuous_data = (
355
+ np.stack(list(continuous_metadata.values()), axis=-1, dtype=np.float64)
356
+ if continuous_metadata
357
+ else np.array([], dtype=np.float64)
358
+ )
359
+ self._total_num_factors = len(self._discrete_factor_names + self._continuous_factor_names) + 1
360
+ self._processed = True
@@ -0,0 +1,126 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ from enum import IntEnum
6
+ from typing import Any, Generic, Iterator, Sequence, TypeVar
7
+
8
+ from dataeval.typing import AnnotatedDataset, DatasetMetadata
9
+
10
+ _TDatum = TypeVar("_TDatum")
11
+
12
+
13
+ class SelectionStage(IntEnum):
14
+ STATE = 0
15
+ FILTER = 1
16
+ ORDER = 2
17
+
18
+
19
+ class Selection(Generic[_TDatum]):
20
+ stage: SelectionStage
21
+
22
+ def __call__(self, dataset: Select[_TDatum]) -> None: ...
23
+
24
+ def __str__(self) -> str:
25
+ return f"{self.__class__.__name__}({', '.join([f'{k}={v}' for k, v in self.__dict__.items()])})"
26
+
27
+
28
+ class Select(AnnotatedDataset[_TDatum]):
29
+ """
30
+ Wraps a dataset and applies selection criteria to it.
31
+
32
+ Parameters
33
+ ----------
34
+ dataset : Dataset
35
+ The dataset to wrap.
36
+ selections : Selection or list[Selection], optional
37
+ The selection criteria to apply to the dataset.
38
+
39
+ Examples
40
+ --------
41
+ >>> from dataeval.utils.data.selections import ClassFilter, Limit
42
+
43
+ >>> # Construct a sample dataset with size of 100 and class count of 10
44
+ >>> # Elements at index `idx` are returned as tuples:
45
+ >>> # - f"data_{idx}", one_hot_encoded(idx % class_count), {"id": idx}
46
+ >>> dataset = SampleDataset(size=100, class_count=10)
47
+
48
+ >>> # Apply a selection criteria to the dataset
49
+ >>> selections = [Limit(size=5), ClassFilter(classes=[0, 2])]
50
+ >>> selected_dataset = Select(dataset, selections=selections)
51
+
52
+ >>> # Iterate over the selected dataset
53
+ >>> for data, target, meta in selected_dataset:
54
+ ... print(f"({data}, {np.argmax(target)}, {meta})")
55
+ (data_0, 0, {'id': 0})
56
+ (data_2, 2, {'id': 2})
57
+ (data_10, 0, {'id': 10})
58
+ (data_12, 2, {'id': 12})
59
+ (data_20, 0, {'id': 20})
60
+ """
61
+
62
+ _dataset: AnnotatedDataset[_TDatum]
63
+ _selection: list[int]
64
+ _selections: Sequence[Selection[_TDatum]]
65
+ _size_limit: int
66
+
67
+ def __init__(
68
+ self,
69
+ dataset: AnnotatedDataset[_TDatum],
70
+ selections: Selection[_TDatum] | list[Selection[_TDatum]] | None = None,
71
+ ) -> None:
72
+ self._dataset = dataset
73
+ self._size_limit = len(dataset)
74
+ self._selection = list(range(self._size_limit))
75
+ self._selections = self._sort_selections(selections)
76
+ self.__dict__.update(dataset.__dict__)
77
+
78
+ # Ensure metadata is populated correctly as DatasetMetadata TypedDict
79
+ _metadata = getattr(dataset, "metadata", {})
80
+ if "id" not in _metadata:
81
+ _metadata["id"] = dataset.__class__.__name__
82
+ self._metadata = DatasetMetadata(**_metadata)
83
+
84
+ if self._selections:
85
+ self._apply_selections()
86
+
87
+ @property
88
+ def metadata(self) -> DatasetMetadata:
89
+ return self._metadata
90
+
91
+ def __str__(self) -> str:
92
+ nt = "\n "
93
+ title = f"{self.__class__.__name__} Dataset"
94
+ sep = "-" * len(title)
95
+ selections = f"Selections: [{', '.join([str(s) for s in self._sort_selections(self._selections)])}]"
96
+ return f"{title}\n{sep}{nt}{selections}\n\n{self._dataset}"
97
+
98
+ def _sort_selections(self, selections: Selection[_TDatum] | Sequence[Selection[_TDatum]] | None) -> list[Selection]:
99
+ if not selections:
100
+ return []
101
+
102
+ selections = [selections] if isinstance(selections, Selection) else selections
103
+ grouped: dict[int, list[Selection]] = {}
104
+ for selection in selections:
105
+ grouped.setdefault(selection.stage, []).append(selection)
106
+ selection_list = [selection for category in sorted(grouped) for selection in grouped[category]]
107
+ return selection_list
108
+
109
+ def _apply_selections(self) -> None:
110
+ for selection in self._selections:
111
+ selection(self)
112
+ self._selection = self._selection[: self._size_limit]
113
+
114
+ def __getattr__(self, name: str, /) -> Any:
115
+ selfattr = getattr(self._dataset, name, None)
116
+ return selfattr if selfattr is not None else getattr(self._dataset, name)
117
+
118
+ def __getitem__(self, index: int) -> _TDatum:
119
+ return self._dataset[self._selection[index]]
120
+
121
+ def __iter__(self) -> Iterator[_TDatum]:
122
+ for i in range(len(self)):
123
+ yield self[i]
124
+
125
+ def __len__(self) -> int:
126
+ return len(self._selection)
@@ -3,8 +3,7 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  import warnings
6
- from dataclasses import dataclass
7
- from typing import Any, Iterator, NamedTuple, Protocol
6
+ from typing import Any, Iterator, Protocol
8
7
 
9
8
  import numpy as np
10
9
  from numpy.typing import NDArray
@@ -13,31 +12,8 @@ from sklearn.metrics import silhouette_score
13
12
  from sklearn.model_selection import GroupKFold, KFold, StratifiedGroupKFold, StratifiedKFold
14
13
  from sklearn.utils.multiclass import type_of_target
15
14
 
16
- from dataeval.output import Output, set_metadata
17
-
18
-
19
- class TrainValSplit(NamedTuple):
20
- """Tuple containing train and validation indices"""
21
-
22
- train: NDArray[np.intp]
23
- val: NDArray[np.intp]
24
-
25
-
26
- @dataclass(frozen=True)
27
- class SplitDatasetOutput(Output):
28
- """
29
- Output class containing test indices and a list of TrainValSplits.
30
-
31
- Attributes
32
- ----------
33
- test: NDArray[np.intp]
34
- Indices for the test set
35
- folds: list[TrainValSplit]
36
- List where each index contains the indices for the train and validation splits
37
- """
38
-
39
- test: NDArray[np.intp]
40
- folds: list[TrainValSplit]
15
+ from dataeval.outputs._base import set_metadata
16
+ from dataeval.outputs._utils import SplitDatasetOutput, TrainValSplit
41
17
 
42
18
 
43
19
  class KFoldSplitter(Protocol):
@@ -274,8 +250,7 @@ def get_group_ids(metadata: dict[str, Any], group_names: list[str], num_samples:
274
250
  for name, feature in features2group.items():
275
251
  if len(feature) != num_samples:
276
252
  raise ValueError(
277
- f"Feature length does not match number of labels. "
278
- f"Got {len(feature)} features and {num_samples} samples"
253
+ f"Feature length does not match number of labels. Got {len(feature)} features and {num_samples} samples"
279
254
  )
280
255
 
281
256
  if type_of_target(feature) == "continuous":
@@ -505,23 +480,22 @@ def split_dataset(
505
480
  if is_groupable(possible_groups, group_partitions):
506
481
  groups = possible_groups
507
482
 
508
- test_indices: NDArray[np.intp]
509
483
  index = np.arange(label_length)
510
484
 
511
- tv_indices, test_indices = (
485
+ tvs = (
512
486
  single_split(index=index, labels=labels, split_frac=test_frac, groups=groups, stratified=stratify)
513
487
  if test_frac
514
- else (index, np.array([], dtype=np.intp))
488
+ else TrainValSplit(index, np.array([], dtype=np.intp))
515
489
  )
516
490
 
517
- tv_labels = labels[tv_indices]
518
- tv_groups = groups[tv_indices] if groups is not None else None
491
+ tv_labels = labels[tvs.train]
492
+ tv_groups = groups[tvs.train] if groups is not None else None
519
493
 
520
494
  if num_folds == 1:
521
- tv_splits = [single_split(tv_indices, tv_labels, val_frac, tv_groups, stratify)]
495
+ tv_splits = [single_split(tvs.train, tv_labels, val_frac, tv_groups, stratify)]
522
496
  else:
523
- tv_splits = make_splits(tv_indices, tv_labels, num_folds, tv_groups, stratify)
497
+ tv_splits = make_splits(tvs.train, tv_labels, num_folds, tv_groups, stratify)
524
498
 
525
- folds: list[TrainValSplit] = [TrainValSplit(tv_indices[split.train], tv_indices[split.val]) for split in tv_splits]
499
+ folds: list[TrainValSplit] = [TrainValSplit(tvs.train[split.train], tvs.train[split.val]) for split in tv_splits]
526
500
 
527
- return SplitDatasetOutput(test_indices, folds)
501
+ return SplitDatasetOutput(tvs.val, folds)
@@ -0,0 +1,85 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Iterator
4
+
5
+ __all__ = []
6
+
7
+ from dataclasses import dataclass
8
+
9
+ import numpy as np
10
+ from numpy.typing import NDArray
11
+
12
+
13
+ def _len(arr: NDArray, dim: int) -> int:
14
+ return 0 if len(arr) == 0 else len(np.atleast_1d(arr) if dim == 1 else np.atleast_2d(arr))
15
+
16
+
17
+ @dataclass(frozen=True)
18
+ class Targets:
19
+ """
20
+ Dataclass defining targets for image classification or object detection.
21
+
22
+ Attributes
23
+ ----------
24
+ labels : NDArray[np.intp]
25
+ Labels (N,) for N images or objects
26
+ scores : NDArray[np.float32]
27
+ Probability scores (N,M) for N images of M classes or confidence score (N,) of objects
28
+ bboxes : NDArray[np.float32] | None
29
+ Bounding boxes (N,4) for N objects in (x0,y0,x1,y1) format
30
+ source : NDArray[np.intp] | None
31
+ Source image index (N,) for N objects
32
+ """
33
+
34
+ labels: NDArray[np.intp]
35
+ scores: NDArray[np.float32]
36
+ bboxes: NDArray[np.float32] | None
37
+ source: NDArray[np.intp] | None
38
+
39
+ def __post_init__(self) -> None:
40
+ if (self.bboxes is None) != (self.source is None):
41
+ raise ValueError("Either both bboxes and source must be provided or neither.")
42
+
43
+ labels = _len(self.labels, 1)
44
+ scores = _len(self.scores, 2) if self.bboxes is None else _len(self.scores, 1)
45
+ bboxes = labels if self.bboxes is None else _len(self.bboxes, 2)
46
+ source = labels if self.source is None else _len(self.source, 1)
47
+
48
+ if labels != scores or labels != bboxes or labels != source:
49
+ raise ValueError(
50
+ "Labels, scores, bboxes and source must be the same length (if provided).\n"
51
+ + f" labels: {self.labels.shape}\n"
52
+ + f" scores: {self.scores.shape}\n"
53
+ + f" bboxes: {None if self.bboxes is None else self.bboxes.shape}\n"
54
+ + f" source: {None if self.source is None else self.source.shape}\n"
55
+ )
56
+
57
+ if self.bboxes is not None and len(self.bboxes) > 0 and self.bboxes.shape[-1] != 4:
58
+ raise ValueError("Bounding boxes must be in (x0,y0,x1,y1) format.")
59
+
60
+ def __len__(self) -> int:
61
+ if self.source is None:
62
+ return len(self.labels)
63
+ else:
64
+ return len(np.unique(self.source))
65
+
66
+ def __getitem__(self, idx: int, /) -> Targets:
67
+ if self.source is None or self.bboxes is None:
68
+ return Targets(
69
+ np.atleast_1d(self.labels[idx]),
70
+ np.atleast_2d(self.scores[idx]),
71
+ None,
72
+ None,
73
+ )
74
+ else:
75
+ mask = np.where(self.source == idx, True, False)
76
+ return Targets(
77
+ np.atleast_1d(self.labels[mask]),
78
+ np.atleast_1d(self.scores[mask]),
79
+ np.atleast_2d(self.bboxes[mask]),
80
+ np.atleast_1d(self.source[mask]),
81
+ )
82
+
83
+ def __iter__(self) -> Iterator[Targets]:
84
+ for i in range(len(self.labels)) if self.source is None else np.unique(self.source):
85
+ yield self[i]