dataeval 0.76.1__py3-none-any.whl → 0.81.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. dataeval/__init__.py +3 -3
  2. dataeval/{output.py → _output.py} +14 -0
  3. dataeval/config.py +77 -0
  4. dataeval/detectors/__init__.py +1 -1
  5. dataeval/detectors/drift/__init__.py +6 -6
  6. dataeval/detectors/drift/{base.py → _base.py} +41 -30
  7. dataeval/detectors/drift/{cvm.py → _cvm.py} +21 -28
  8. dataeval/detectors/drift/{ks.py → _ks.py} +20 -26
  9. dataeval/detectors/drift/{mmd.py → _mmd.py} +33 -19
  10. dataeval/detectors/drift/{torch.py → _torch.py} +2 -1
  11. dataeval/detectors/drift/{uncertainty.py → _uncertainty.py} +23 -7
  12. dataeval/detectors/drift/updates.py +1 -1
  13. dataeval/detectors/linters/__init__.py +0 -3
  14. dataeval/detectors/linters/duplicates.py +17 -8
  15. dataeval/detectors/linters/outliers.py +23 -14
  16. dataeval/detectors/ood/ae.py +29 -8
  17. dataeval/detectors/ood/base.py +5 -4
  18. dataeval/detectors/ood/metadata_ks_compare.py +1 -1
  19. dataeval/detectors/ood/mixin.py +20 -5
  20. dataeval/detectors/ood/output.py +1 -1
  21. dataeval/detectors/ood/vae.py +73 -0
  22. dataeval/metadata/__init__.py +5 -0
  23. dataeval/metadata/_ood.py +238 -0
  24. dataeval/metrics/__init__.py +1 -1
  25. dataeval/metrics/bias/__init__.py +5 -4
  26. dataeval/metrics/bias/{balance.py → _balance.py} +67 -17
  27. dataeval/metrics/bias/{coverage.py → _coverage.py} +41 -35
  28. dataeval/metrics/bias/{diversity.py → _diversity.py} +17 -12
  29. dataeval/metrics/bias/{parity.py → _parity.py} +89 -61
  30. dataeval/metrics/estimators/__init__.py +14 -4
  31. dataeval/metrics/estimators/{ber.py → _ber.py} +42 -11
  32. dataeval/metrics/estimators/_clusterer.py +104 -0
  33. dataeval/metrics/estimators/{divergence.py → _divergence.py} +18 -13
  34. dataeval/metrics/estimators/{uap.py → _uap.py} +4 -4
  35. dataeval/metrics/stats/__init__.py +7 -7
  36. dataeval/metrics/stats/{base.py → _base.py} +52 -16
  37. dataeval/metrics/stats/{boxratiostats.py → _boxratiostats.py} +6 -9
  38. dataeval/metrics/stats/{datasetstats.py → _datasetstats.py} +10 -14
  39. dataeval/metrics/stats/{dimensionstats.py → _dimensionstats.py} +6 -5
  40. dataeval/metrics/stats/{hashstats.py → _hashstats.py} +6 -6
  41. dataeval/metrics/stats/{labelstats.py → _labelstats.py} +4 -4
  42. dataeval/metrics/stats/{pixelstats.py → _pixelstats.py} +5 -4
  43. dataeval/metrics/stats/{visualstats.py → _visualstats.py} +9 -8
  44. dataeval/typing.py +54 -0
  45. dataeval/utils/__init__.py +2 -2
  46. dataeval/utils/_array.py +169 -0
  47. dataeval/utils/_bin.py +199 -0
  48. dataeval/utils/_clusterer.py +144 -0
  49. dataeval/utils/_fast_mst.py +189 -0
  50. dataeval/utils/{image.py → _image.py} +6 -4
  51. dataeval/utils/_method.py +18 -0
  52. dataeval/utils/{shared.py → _mst.py} +3 -65
  53. dataeval/utils/{plot.py → _plot.py} +4 -4
  54. dataeval/utils/data/__init__.py +22 -0
  55. dataeval/utils/data/_embeddings.py +105 -0
  56. dataeval/utils/data/_images.py +65 -0
  57. dataeval/utils/data/_metadata.py +352 -0
  58. dataeval/utils/data/_selection.py +119 -0
  59. dataeval/utils/{dataset/split.py → data/_split.py} +13 -14
  60. dataeval/utils/data/_targets.py +73 -0
  61. dataeval/utils/data/_types.py +58 -0
  62. dataeval/utils/data/collate.py +103 -0
  63. dataeval/utils/data/datasets/__init__.py +17 -0
  64. dataeval/utils/data/datasets/_base.py +254 -0
  65. dataeval/utils/data/datasets/_cifar10.py +134 -0
  66. dataeval/utils/data/datasets/_fileio.py +168 -0
  67. dataeval/utils/data/datasets/_milco.py +153 -0
  68. dataeval/utils/data/datasets/_mixin.py +56 -0
  69. dataeval/utils/data/datasets/_mnist.py +183 -0
  70. dataeval/utils/data/datasets/_ships.py +123 -0
  71. dataeval/utils/data/datasets/_voc.py +352 -0
  72. dataeval/utils/data/selections/__init__.py +15 -0
  73. dataeval/utils/data/selections/_classfilter.py +60 -0
  74. dataeval/utils/data/selections/_indices.py +26 -0
  75. dataeval/utils/data/selections/_limit.py +26 -0
  76. dataeval/utils/data/selections/_reverse.py +18 -0
  77. dataeval/utils/data/selections/_shuffle.py +29 -0
  78. dataeval/utils/metadata.py +51 -376
  79. dataeval/utils/torch/{gmm.py → _gmm.py} +4 -2
  80. dataeval/utils/torch/{internal.py → _internal.py} +21 -51
  81. dataeval/utils/torch/models.py +43 -2
  82. dataeval/workflows/sufficiency.py +10 -9
  83. {dataeval-0.76.1.dist-info → dataeval-0.81.0.dist-info}/METADATA +4 -1
  84. dataeval-0.81.0.dist-info/RECORD +94 -0
  85. dataeval/detectors/linters/clusterer.py +0 -512
  86. dataeval/detectors/linters/merged_stats.py +0 -49
  87. dataeval/detectors/ood/metadata_least_likely.py +0 -119
  88. dataeval/interop.py +0 -69
  89. dataeval/utils/dataset/__init__.py +0 -7
  90. dataeval/utils/dataset/datasets.py +0 -412
  91. dataeval/utils/dataset/read.py +0 -63
  92. dataeval-0.76.1.dist-info/RECORD +0 -67
  93. /dataeval/{log.py → _log.py} +0 -0
  94. /dataeval/utils/torch/{blocks.py → _blocks.py} +0 -0
  95. {dataeval-0.76.1.dist-info → dataeval-0.81.0.dist-info}/LICENSE.txt +0 -0
  96. {dataeval-0.76.1.dist-info → dataeval-0.81.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,352 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ from pathlib import Path
6
+ from typing import Any, Literal, Sequence, TypeVar
7
+
8
+ import torch
9
+ from defusedxml.ElementTree import parse
10
+ from numpy.typing import NDArray
11
+
12
+ from dataeval.utils.data._types import ObjectDetectionTarget, SegmentationTarget, Transform
13
+ from dataeval.utils.data.datasets._base import (
14
+ BaseDataset,
15
+ BaseODDataset,
16
+ BaseSegDataset,
17
+ DataLocation,
18
+ )
19
+ from dataeval.utils.data.datasets._mixin import BaseDatasetNumpyMixin, BaseDatasetTorchMixin
20
+
21
+ _TArray = TypeVar("_TArray")
22
+ _TTarget = TypeVar("_TTarget")
23
+
24
+ VOCClassStringMap = Literal[
25
+ "aeroplane",
26
+ "bicycle",
27
+ "bird",
28
+ "boat",
29
+ "bottle",
30
+ "bus",
31
+ "car",
32
+ "cat",
33
+ "chair",
34
+ "cow",
35
+ "diningtable",
36
+ "dog",
37
+ "horse",
38
+ "motorbike",
39
+ "person",
40
+ "pottedplant",
41
+ "sheep",
42
+ "sofa",
43
+ "train",
44
+ "tvmonitor",
45
+ ]
46
+ TVOCClassMap = TypeVar("TVOCClassMap", VOCClassStringMap, int, list[VOCClassStringMap], list[int])
47
+
48
+
49
+ class BaseVOCDataset(BaseDataset[_TArray, _TTarget, list[str]]):
50
+ _resources = [
51
+ DataLocation(
52
+ url="http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar",
53
+ filename="VOCtrainval_11-May-2012.tar",
54
+ md5=True,
55
+ checksum="6cd6e144f989b92b3379bac3b3de84fd",
56
+ ),
57
+ DataLocation(
58
+ url="http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar",
59
+ filename="VOCtrainval_25-May-2011.tar",
60
+ md5=True,
61
+ checksum="6c3384ef61512963050cb5d687e5bf1e",
62
+ ),
63
+ DataLocation(
64
+ url="http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar",
65
+ filename="VOCtrainval_03-May-2010.tar",
66
+ md5=True,
67
+ checksum="da459979d0c395079b5c75ee67908abb",
68
+ ),
69
+ DataLocation(
70
+ url="http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar",
71
+ filename="VOCtrainval_11-May-2009.tar",
72
+ md5=True,
73
+ checksum="da459979d0c395079b5c75ee67908abb",
74
+ ),
75
+ DataLocation(
76
+ url="http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar",
77
+ filename="VOCtrainval_14-Jul-2008.tar",
78
+ md5=True,
79
+ checksum="2629fa636546599198acfcfbfcf1904a",
80
+ ),
81
+ DataLocation(
82
+ url="http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar",
83
+ filename="VOCtrainval_06-Nov-2007.tar",
84
+ md5=True,
85
+ checksum="c52e279531787c972589f7e41ab4ae64",
86
+ ),
87
+ DataLocation(
88
+ url="http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar",
89
+ filename="VOCtest_06-Nov-2007.tar",
90
+ md5=True,
91
+ checksum="b6e924de25625d8de591ea690078ad9f",
92
+ ),
93
+ ]
94
+
95
+ index2label: dict[int, str] = {
96
+ 0: "aeroplane",
97
+ 1: "bicycle",
98
+ 2: "bird",
99
+ 3: "boat",
100
+ 4: "bottle",
101
+ 5: "bus",
102
+ 6: "car",
103
+ 7: "cat",
104
+ 8: "chair",
105
+ 9: "cow",
106
+ 10: "diningtable",
107
+ 11: "dog",
108
+ 12: "horse",
109
+ 13: "motorbike",
110
+ 14: "person",
111
+ 15: "pottedplant",
112
+ 16: "sheep",
113
+ 17: "sofa",
114
+ 18: "train",
115
+ 19: "tvmonitor",
116
+ }
117
+
118
+ def __init__(
119
+ self,
120
+ root: str | Path,
121
+ year: Literal["2007", "2008", "2009", "2010", "2011", "2012"] = "2012",
122
+ image_set: Literal["train", "val", "test", "base"] = "train",
123
+ download: bool = False,
124
+ transforms: Transform[_TArray] | Sequence[Transform[_TArray]] | None = None,
125
+ verbose: bool = False,
126
+ ) -> None:
127
+ self.year = year
128
+ self._resource_index = self._get_year_image_set_index(year, image_set)
129
+ super().__init__(
130
+ root,
131
+ download,
132
+ image_set,
133
+ transforms,
134
+ verbose,
135
+ )
136
+
137
+ def _get_dataset_dir(self) -> Path:
138
+ """Function to reassign the dataset directory for common use with the VOC detection and segmentation classes"""
139
+ if self._root.stem == f"VOC{self.year}":
140
+ dataset_dir: Path = self._root
141
+ else:
142
+ dataset_dir: Path = self._root / f"VOC{self.year}"
143
+ if not dataset_dir.exists():
144
+ dataset_dir.mkdir(parents=True, exist_ok=True)
145
+ return dataset_dir
146
+
147
+ def _get_year_image_set_index(self, year, image_set) -> int:
148
+ """Function to ensure that the correct resource file is accessed"""
149
+ if year == "2007" and image_set == "test":
150
+ return -1
151
+ elif year != "2007" and image_set == "test":
152
+ raise ValueError(
153
+ f"The only test set available is for the year 2007, not {year}. "
154
+ "Either select the year 2007 or use a different image_set."
155
+ )
156
+ else:
157
+ return 2012 - int(year)
158
+
159
+ def _get_image_sets(self) -> dict[str, list[str]]:
160
+ """Function to create the list of images in each image set"""
161
+ image_folder = self.path / "JPEGImages"
162
+ image_set_list = ["train", "val", "trainval"] if self.image_set != "test" else ["test"]
163
+ image_sets = {}
164
+ for image_set in image_set_list:
165
+ text_file = self.path / "ImageSets" / "Main" / (image_set + ".txt")
166
+ selected_images: list[str] = []
167
+ with open(text_file) as f:
168
+ for line in f.readlines():
169
+ out = line.strip()
170
+ selected_images.append(str(image_folder / (out + ".jpg")))
171
+
172
+ name = "base" if image_set == "trainval" else image_set
173
+ image_sets[name] = selected_images
174
+ return image_sets
175
+
176
+ def _load_data_inner(self) -> tuple[list[str], list[str], dict[str, Any]]:
177
+ """Function to load in the file paths for the data, annotations and segmentation masks"""
178
+ file_meta = {"year": [], "image_id": [], "mask_path": []}
179
+ ann_folder = self.path / "Annotations"
180
+ seg_folder = self.path / "SegmentationClass"
181
+
182
+ # Load in the image sets
183
+ image_sets = self._get_image_sets()
184
+
185
+ # Get the data, annotations and metadata
186
+ annotations = []
187
+ data = image_sets[self.image_set]
188
+ for entry in data:
189
+ file_name = Path(entry).name
190
+ file_stem = Path(entry).stem
191
+ # Remove file extension and split by "_"
192
+ parts = file_stem.split("_")
193
+ file_meta["year"].append(parts[0])
194
+ file_meta["image_id"].append(parts[1])
195
+ file_meta["mask_path"].append(str(seg_folder / file_name))
196
+ annotations.append(str(ann_folder / file_stem) + ".xml")
197
+
198
+ return data, annotations, file_meta
199
+
200
+ def _read_annotations(self, annotation: str) -> tuple[list[list[float]], list[int], dict[str, Any]]:
201
+ boxes: list[list[float]] = []
202
+ label_str = []
203
+ root = parse(annotation).getroot()
204
+ num_objects = len(root.findall("object"))
205
+ additional_meta: dict[str, Any] = {
206
+ "folder": [root.findtext("folder", default="") for _ in range(num_objects)],
207
+ "filename": [root.findtext("filename", default="") for _ in range(num_objects)],
208
+ "database": [root.findtext("source/database", default="") for _ in range(num_objects)],
209
+ "annotation_source": [root.findtext("source/annotation", default="") for _ in range(num_objects)],
210
+ "image_source": [root.findtext("source/image", default="") for _ in range(num_objects)],
211
+ "image_width": [int(root.findtext("size/width", default="-1")) for _ in range(num_objects)],
212
+ "image_height": [int(root.findtext("size/height", default="-1")) for _ in range(num_objects)],
213
+ "image_depth": [int(root.findtext("size/depth", default="-1")) for _ in range(num_objects)],
214
+ "segmented": [int(root.findtext("segmented", default="-1")) for _ in range(num_objects)],
215
+ "pose": [],
216
+ "truncated": [],
217
+ "difficult": [],
218
+ }
219
+ for obj in root.findall("object"):
220
+ label_str.append(obj.findtext("name", default=""))
221
+ additional_meta["pose"].append(obj.findtext("pose", default=""))
222
+ additional_meta["truncated"].append(int(obj.findtext("truncated", default="-1")))
223
+ additional_meta["difficult"].append(int(obj.findtext("difficult", default="-1")))
224
+ boxes.append(
225
+ [
226
+ float(obj.findtext("bndbox/xmin", default="0")),
227
+ float(obj.findtext("bndbox/ymin", default="0")),
228
+ float(obj.findtext("bndbox/xmax", default="0")),
229
+ float(obj.findtext("bndbox/ymax", default="0")),
230
+ ]
231
+ )
232
+ labels = [self.label2index[label] for label in label_str]
233
+ return boxes, labels, additional_meta
234
+
235
+
236
+ class VOCDetection(
237
+ BaseVOCDataset[NDArray[Any], ObjectDetectionTarget[NDArray[Any]]],
238
+ BaseODDataset[NDArray[Any]],
239
+ BaseDatasetNumpyMixin,
240
+ ):
241
+ """
242
+ `Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Detection Dataset.
243
+
244
+ Parameters
245
+ ----------
246
+ root : str or pathlib.Path
247
+ Root directory of dataset where the ``vocdataset`` folder exists.
248
+ download : bool, default False
249
+ If True, downloads the dataset from the internet and puts it in root directory.
250
+ Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
251
+ image_set : "train", "val", "test", or "base", default "train"
252
+ If "test", then dataset year must be "2007".
253
+ If "base", then the combined dataset of "train" and "val" is returned.
254
+ year : "2007", "2008", "2009", "2010", "2011" or "2012", default "2012"
255
+ The dataset year.
256
+ transforms : Transform | Sequence[Transform] | None, default None
257
+ Transform(s) to apply to the data.
258
+ verbose : bool, default False
259
+ If True, outputs print statements.
260
+
261
+ Attributes
262
+ ----------
263
+ index2label : dict
264
+ Dictionary which translates from class integers to the associated class strings.
265
+ label2index : dict
266
+ Dictionary which translates from class strings to the associated class integers.
267
+ path : Path
268
+ Location of the folder containing the data.
269
+ metadata : dict
270
+ Dictionary containing Dataset metadata, such as `id` which returns the dataset class name.
271
+ """
272
+
273
+
274
+ class VOCDetectionTorch(
275
+ BaseVOCDataset[torch.Tensor, ObjectDetectionTarget[torch.Tensor]],
276
+ BaseODDataset[torch.Tensor],
277
+ BaseDatasetTorchMixin,
278
+ ):
279
+ """
280
+ `Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Detection Dataset.
281
+
282
+ Parameters
283
+ ----------
284
+ root : str or pathlib.Path
285
+ Root directory of dataset where the ``vocdataset`` folder exists.
286
+ download : bool, default False
287
+ If True, downloads the dataset from the internet and puts it in root directory.
288
+ Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
289
+ image_set : "train", "val", "test", or "base", default "train"
290
+ If "test", then dataset year must be "2007".
291
+ If "base", then the combined dataset of "train" and "val" is returned.
292
+ year : "2007", "2008", "2009", "2010", "2011" or "2012", default "2012"
293
+ The dataset year.
294
+ transforms : Transform | Sequence[Transform] | None, default None
295
+ Transform(s) to apply to the data.
296
+ verbose : bool, default False
297
+ If True, outputs print statements.
298
+
299
+ Attributes
300
+ ----------
301
+ index2label : dict
302
+ Dictionary which translates from class integers to the associated class strings.
303
+ label2index : dict
304
+ Dictionary which translates from class strings to the associated class integers.
305
+ path : Path
306
+ Location of the folder containing the data.
307
+ metadata : dict
308
+ Dictionary containing Dataset metadata, such as `id` which returns the dataset class name.
309
+ """
310
+
311
+
312
+ class VOCSegmentation(
313
+ BaseVOCDataset[NDArray[Any], SegmentationTarget[NDArray[Any]]],
314
+ BaseSegDataset[NDArray[Any]],
315
+ BaseDatasetNumpyMixin,
316
+ ):
317
+ """
318
+ `Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
319
+
320
+ Parameters
321
+ ----------
322
+ root : str or pathlib.Path
323
+ Root directory of dataset where the ``vocdataset`` folder exists.
324
+ download : bool, default False
325
+ If True, downloads the dataset from the internet and puts it in root directory.
326
+ Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
327
+ image_set : "train", "val", "test", or "base", default "train"
328
+ If "test", then dataset year must be "2007".
329
+ If "base", then the combined dataset of "train" and "val" is returned.
330
+ year : "2007", "2008", "2009", "2010", "2011" or "2012", default "2012"
331
+ The dataset year.
332
+ transforms : Transform | Sequence[Transform] | None, default None
333
+ Transform(s) to apply to the data.
334
+ verbose : bool, default False
335
+ If True, outputs print statements.
336
+
337
+ Attributes
338
+ ----------
339
+ index2label : dict
340
+ Dictionary which translates from class integers to the associated class strings.
341
+ label2index : dict
342
+ Dictionary which translates from class strings to the associated class integers.
343
+ path : Path
344
+ Location of the folder containing the data.
345
+ metadata : dict
346
+ Dictionary containing Dataset metadata, such as `id` which returns the dataset class name.
347
+ """
348
+
349
+ def _load_data(self) -> tuple[list[str], list[str], dict[str, list[Any]]]:
350
+ filepaths, targets, datum_metadata = super()._load_data()
351
+ self._masks = datum_metadata.pop("mask_path")
352
+ return filepaths, targets, datum_metadata
@@ -0,0 +1,15 @@
1
+ """Provides selection classes for selecting subsets of Computer Vision datasets."""
2
+
3
+ __all__ = [
4
+ "ClassFilter",
5
+ "Indices",
6
+ "Limit",
7
+ "Reverse",
8
+ "Shuffle",
9
+ ]
10
+
11
+ from dataeval.utils.data.selections._classfilter import ClassFilter
12
+ from dataeval.utils.data.selections._indices import Indices
13
+ from dataeval.utils.data.selections._limit import Limit
14
+ from dataeval.utils.data.selections._reverse import Reverse
15
+ from dataeval.utils.data.selections._shuffle import Shuffle
@@ -0,0 +1,60 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ from typing import Sequence, TypeVar
6
+
7
+ import numpy as np
8
+
9
+ from dataeval.typing import Array
10
+ from dataeval.utils._array import as_numpy
11
+ from dataeval.utils.data._selection import Select, Selection, SelectionStage
12
+
13
+ _TData = TypeVar("_TData")
14
+ _TTarget = TypeVar("_TTarget", bound=Array)
15
+
16
+
17
+ class ClassFilter(Selection[_TData, _TTarget]):
18
+ """
19
+ Filter and balance the dataset by class.
20
+
21
+ Parameters
22
+ ----------
23
+ classes : Sequence[int] or None, default None
24
+ The classes to filter by. If None, all classes are included.
25
+ balance : bool, default False
26
+ Whether to balance the classes.
27
+
28
+ Note
29
+ ----
30
+ If `balance` is True, the total number of instances of each class will
31
+ be equalized. This may result in a lower total number of instances.
32
+ """
33
+
34
+ stage = SelectionStage.FILTER
35
+
36
+ def __init__(self, classes: Sequence[int] | None = None, balance: bool = False) -> None:
37
+ self.classes = classes
38
+ self.balance = balance
39
+
40
+ def __call__(self, dataset: Select[_TData, _TTarget]) -> None:
41
+ if self.classes is None and not self.balance:
42
+ return
43
+
44
+ per_class_limit = dataset._size_limit // len(self.classes) if self.classes and self.balance else 0
45
+ class_indices: dict[int, list[int]] = {} if self.classes is None else {k: [] for k in self.classes}
46
+ for i, idx in enumerate(dataset._selection):
47
+ target = dataset._dataset[idx][1]
48
+ if isinstance(target, Array):
49
+ label = int(np.argmax(as_numpy(target)))
50
+ else:
51
+ # ObjectDetectionTarget and SegmentationTarget not supported yet
52
+ raise TypeError("ClassFilter only supports classification targets as an array of confidence scores.")
53
+ if not self.classes or label in self.classes:
54
+ class_indices.setdefault(label, []).append(i)
55
+ if per_class_limit and all(len(indices) >= per_class_limit for indices in class_indices.values()):
56
+ break
57
+
58
+ per_class_limit = min(len(c) for c in class_indices.values()) if self.balance else dataset._size_limit
59
+ subselection = sorted([i for v in class_indices.values() for i in v[:per_class_limit]])
60
+ dataset._selection = [dataset._selection[i] for i in subselection]
@@ -0,0 +1,26 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ from typing import Any, Sequence
6
+
7
+ from dataeval.utils.data._selection import Select, Selection, SelectionStage
8
+
9
+
10
+ class Indices(Selection[Any, Any]):
11
+ """
12
+ Selects specific indices from the dataset.
13
+
14
+ Parameters
15
+ ----------
16
+ indices : Sequence[int]
17
+ The indices to select from the dataset.
18
+ """
19
+
20
+ stage = SelectionStage.FILTER
21
+
22
+ def __init__(self, indices: Sequence[int]) -> None:
23
+ self.indices = indices
24
+
25
+ def __call__(self, dataset: Select[Any, Any]) -> None:
26
+ dataset._selection = [index for index in self.indices if index in dataset._selection]
@@ -0,0 +1,26 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ from typing import Any
6
+
7
+ from dataeval.utils.data._selection import Select, Selection, SelectionStage
8
+
9
+
10
+ class Limit(Selection[Any, Any]):
11
+ """
12
+ Limit the size of the dataset.
13
+
14
+ Parameters
15
+ ----------
16
+ size : int
17
+ The maximum size of the dataset.
18
+ """
19
+
20
+ stage = SelectionStage.STATE
21
+
22
+ def __init__(self, size: int) -> None:
23
+ self.size = size
24
+
25
+ def __call__(self, dataset: Select[Any, Any]) -> None:
26
+ dataset._size_limit = self.size
@@ -0,0 +1,18 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ from typing import Any
6
+
7
+ from dataeval.utils.data._selection import Select, Selection, SelectionStage
8
+
9
+
10
+ class Reverse(Selection[Any, Any]):
11
+ """
12
+ Reverse the selection order of the dataset.
13
+ """
14
+
15
+ stage = SelectionStage.ORDER
16
+
17
+ def __call__(self, dataset: Select[Any, Any]) -> None:
18
+ dataset._selection.reverse()
@@ -0,0 +1,29 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ from typing import Any
6
+
7
+ import numpy as np
8
+
9
+ from dataeval.utils.data._selection import Select, Selection, SelectionStage
10
+
11
+
12
+ class Shuffle(Selection[Any, Any]):
13
+ """
14
+ Shuffle the dataset using a seed.
15
+
16
+ Parameters
17
+ ----------
18
+ seed
19
+ Seed for the random number generator.
20
+ """
21
+
22
+ stage = SelectionStage.ORDER
23
+
24
+ def __init__(self, seed: int):
25
+ self.seed = seed
26
+
27
+ def __call__(self, dataset: Select[Any, Any]) -> None:
28
+ rng = np.random.default_rng(self.seed)
29
+ rng.shuffle(dataset._selection)