dataeval 0.84.0__py3-none-any.whl → 0.84.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. dataeval/__init__.py +1 -1
  2. dataeval/detectors/drift/__init__.py +2 -2
  3. dataeval/detectors/drift/_base.py +55 -203
  4. dataeval/detectors/drift/_cvm.py +19 -30
  5. dataeval/detectors/drift/_ks.py +18 -30
  6. dataeval/detectors/drift/_mmd.py +189 -53
  7. dataeval/detectors/drift/_uncertainty.py +52 -56
  8. dataeval/detectors/drift/updates.py +13 -12
  9. dataeval/detectors/linters/duplicates.py +5 -3
  10. dataeval/detectors/linters/outliers.py +2 -2
  11. dataeval/detectors/ood/ae.py +1 -1
  12. dataeval/metrics/stats/_base.py +7 -7
  13. dataeval/metrics/stats/_dimensionstats.py +2 -2
  14. dataeval/metrics/stats/_hashstats.py +2 -2
  15. dataeval/metrics/stats/_imagestats.py +4 -4
  16. dataeval/metrics/stats/_pixelstats.py +2 -2
  17. dataeval/metrics/stats/_visualstats.py +2 -2
  18. dataeval/typing.py +22 -19
  19. dataeval/utils/_array.py +18 -7
  20. dataeval/utils/data/_dataset.py +6 -4
  21. dataeval/utils/data/_embeddings.py +46 -7
  22. dataeval/utils/data/_images.py +2 -2
  23. dataeval/utils/data/_metadata.py +5 -4
  24. dataeval/utils/data/datasets/_base.py +7 -4
  25. dataeval/utils/data/datasets/_cifar10.py +9 -9
  26. dataeval/utils/data/datasets/_milco.py +42 -14
  27. dataeval/utils/data/datasets/_mnist.py +9 -5
  28. dataeval/utils/data/datasets/_ships.py +8 -4
  29. dataeval/utils/data/datasets/_voc.py +40 -19
  30. dataeval/utils/data/selections/__init__.py +2 -0
  31. dataeval/utils/data/selections/_classbalance.py +38 -0
  32. dataeval/utils/data/selections/_classfilter.py +14 -29
  33. dataeval/utils/data/selections/_prioritize.py +1 -1
  34. dataeval/utils/data/selections/_shuffle.py +2 -2
  35. dataeval/utils/torch/_internal.py +12 -35
  36. {dataeval-0.84.0.dist-info → dataeval-0.84.1.dist-info}/METADATA +2 -3
  37. {dataeval-0.84.0.dist-info → dataeval-0.84.1.dist-info}/RECORD +39 -39
  38. dataeval/detectors/drift/_torch.py +0 -222
  39. {dataeval-0.84.0.dist-info → dataeval-0.84.1.dist-info}/LICENSE.txt +0 -0
  40. {dataeval-0.84.0.dist-info → dataeval-0.84.1.dist-info}/WHEEL +0 -0
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  from pathlib import Path
6
- from typing import TYPE_CHECKING, Any, Sequence
6
+ from typing import TYPE_CHECKING, Any, Literal, Sequence
7
7
 
8
8
  from numpy.typing import NDArray
9
9
 
@@ -16,21 +16,20 @@ if TYPE_CHECKING:
16
16
 
17
17
  class MILCO(BaseODDataset[NDArray[Any]], BaseDatasetNumpyMixin):
18
18
  """
19
- A side-scan sonar dataset focused on mine (object) detection.
19
+ A side-scan sonar dataset focused on mine-like object detection.
20
20
 
21
21
  The dataset comes from the paper
22
22
  `Side-scan sonar imaging data of underwater vehicles for mine detection <https://doi.org/10.1016/j.dib.2024.110132>`_
23
23
  by N.P. Santos et. al. (2024).
24
24
 
25
- This class only accesses a portion of the above dataset due to size constraints.
26
25
  The full dataset contains 1170 side-scan sonar images collected using a 900-1800 kHz Marine Sonic
27
26
  dual frequency side-scan sonar of a Teledyne Marine Gavia Autonomous Underwater Vehicle.
28
27
  All the images were carefully analyzed and annotated, including the image coordinates of the
29
28
  Bounding Box (BB) of the detected objects divided into NOn-Mine-like BOttom Objects (NOMBO)
30
29
  and MIne-Like COntacts (MILCO) classes.
31
30
 
32
- This dataset is consists of 261 images (120 images from 2015, 93 images from 2017, and 48 images from 2021).
33
- In these 261 images, there are 315 MILCO objects, and 175 NOMBO objects.
31
+ This dataset is consists of 345 images from 2010, 120 images from 2015, 93 images from 2017, 564 images from 2018,
32
+ and 48 images from 2021). In these 1170 images, there are 432 MILCO objects, and 235 NOMBO objects.
34
33
  The class “0” corresponds to a MILCO object and the class “1” corresponds to a NOMBO object.
35
34
  The raw BB coordinates provided in the downloaded text files are (x, y, w, h),
36
35
  given as percentages of the image (x_BB = x/img_width, y_BB = y/img_height, etc.).
@@ -40,11 +39,17 @@ class MILCO(BaseODDataset[NDArray[Any]], BaseDatasetNumpyMixin):
40
39
  ----------
41
40
  root : str or pathlib.Path
42
41
  Root directory of dataset where the ``milco`` folder exists.
42
+ image_set: "train", "operational", or "base", default "train"
43
+ If "train", then the images from 2015, 2017 and 2021 are selected,
44
+ resulting in 315 MILCO objects and 177 NOMBO objects.
45
+ If "operational", then the images from 2010 and 2018 are selected,
46
+ resulting in 117 MILCO objects and 58 NOMBO objects.
47
+ If "base", then the full dataset is selected.
48
+ transforms : Transform, Sequence[Transform] or None, default None
49
+ Transform(s) to apply to the data.
43
50
  download : bool, default False
44
51
  If True, downloads the dataset from the internet and puts it in root directory.
45
52
  Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
46
- transforms : Transform, Sequence[Transform] or None, default None
47
- Transform(s) to apply to the data.
48
53
  verbose : bool, default False
49
54
  If True, outputs print statements.
50
55
 
@@ -52,8 +57,8 @@ class MILCO(BaseODDataset[NDArray[Any]], BaseDatasetNumpyMixin):
52
57
  ----------
53
58
  path : pathlib.Path
54
59
  Location of the folder containing the data.
55
- image_set : "base"
56
- The base image set is the only available image set for the MILCO dataset.
60
+ image_set : "train", "operational" or "base"
61
+ The selected image set from the dataset.
57
62
  index2label : dict[int, str]
58
63
  Dictionary which translates from class integers to the associated class strings.
59
64
  label2index : dict[str, int]
@@ -64,6 +69,10 @@ class MILCO(BaseODDataset[NDArray[Any]], BaseDatasetNumpyMixin):
64
69
  The transforms to be applied to the data.
65
70
  size : int
66
71
  The size of the dataset.
72
+
73
+ Note
74
+ ----
75
+ Data License: `CC BY 4.0 <https://creativecommons.org/licenses/by/4.0/>`_
67
76
  """
68
77
 
69
78
  _resources = [
@@ -85,6 +94,18 @@ class MILCO(BaseODDataset[NDArray[Any]], BaseDatasetNumpyMixin):
85
94
  md5=True,
86
95
  checksum="b84749b21fa95a4a4c7de3741db78bc7",
87
96
  ),
97
+ DataLocation(
98
+ url="https://figshare.com/ndownloader/files/43169008",
99
+ filename="2010.zip",
100
+ md5=True,
101
+ checksum="43347a0cc383c0d3dbe0d24ae56f328d",
102
+ ),
103
+ DataLocation(
104
+ url="https://figshare.com/ndownloader/files/43169011",
105
+ filename="2018.zip",
106
+ md5=True,
107
+ checksum="25d091044a10c78674fedad655023e3b",
108
+ ),
88
109
  ]
89
110
 
90
111
  index2label: dict[int, str] = {
@@ -95,15 +116,16 @@ class MILCO(BaseODDataset[NDArray[Any]], BaseDatasetNumpyMixin):
95
116
  def __init__(
96
117
  self,
97
118
  root: str | Path,
98
- download: bool = False,
119
+ image_set: Literal["train", "operational", "base"] = "train",
99
120
  transforms: Transform[NDArray[Any]] | Sequence[Transform[NDArray[Any]]] | None = None,
121
+ download: bool = False,
100
122
  verbose: bool = False,
101
123
  ) -> None:
102
124
  super().__init__(
103
125
  root,
104
- download,
105
- "base",
126
+ image_set,
106
127
  transforms,
128
+ download,
107
129
  verbose,
108
130
  )
109
131
 
@@ -112,10 +134,16 @@ class MILCO(BaseODDataset[NDArray[Any]], BaseDatasetNumpyMixin):
112
134
  targets: list[str] = []
113
135
  datum_metadata: dict[str, list[Any]] = {}
114
136
  metadata_list: list[dict[str, Any]] = []
137
+ image_sets: dict[str, list[int]] = {
138
+ "base": list(range(len(self._resources))),
139
+ "train": list(range(3)),
140
+ "operational": list(range(3, len(self._resources))),
141
+ }
115
142
 
116
143
  # Load the data
117
- for resource in self._resources:
118
- self._resource = resource
144
+ resource_indices = image_sets[self.image_set]
145
+ for idx in resource_indices:
146
+ self._resource = self._resources[idx]
119
147
  filepath, target, metadata = super()._load_data()
120
148
  filepaths.extend(filepath)
121
149
  targets.extend(target)
@@ -49,9 +49,6 @@ class MNIST(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
49
49
  ----------
50
50
  root : str or pathlib.Path
51
51
  Root directory of dataset where the ``mnist`` folder exists.
52
- download : bool, default False
53
- If True, downloads the dataset from the internet and puts it in root directory.
54
- Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
55
52
  image_set : "train", "test" or "base", default "train"
56
53
  If "base", returns all of the data to allow the user to create their own splits.
57
54
  corruption : "identity", "shot_noise", "impulse_noise", "glass_blur", "motion_blur", \
@@ -60,6 +57,9 @@ class MNIST(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
60
57
  Corruption to apply to the data.
61
58
  transforms : Transform, Sequence[Transform] or None, default None
62
59
  Transform(s) to apply to the data.
60
+ download : bool, default False
61
+ If True, downloads the dataset from the internet and puts it in root directory.
62
+ Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
63
63
  verbose : bool, default False
64
64
  If True, outputs print statements.
65
65
 
@@ -81,6 +81,10 @@ class MNIST(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
81
81
  The transforms to be applied to the data.
82
82
  size : int
83
83
  The size of the dataset.
84
+
85
+ Note
86
+ ----
87
+ Data License: `CC BY 4.0 <https://creativecommons.org/licenses/by/4.0/>`_ for corruption dataset
84
88
  """
85
89
 
86
90
  _resources = [
@@ -114,10 +118,10 @@ class MNIST(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
114
118
  def __init__(
115
119
  self,
116
120
  root: str | Path,
117
- download: bool = False,
118
121
  image_set: Literal["train", "test", "base"] = "train",
119
122
  corruption: CorruptionStringMap | None = None,
120
123
  transforms: Transform[NDArray[Any]] | Sequence[Transform[NDArray[Any]]] | None = None,
124
+ download: bool = False,
121
125
  verbose: bool = False,
122
126
  ) -> None:
123
127
  self.corruption = corruption
@@ -127,9 +131,9 @@ class MNIST(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
127
131
 
128
132
  super().__init__(
129
133
  root,
130
- download,
131
134
  image_set,
132
135
  transforms,
136
+ download,
133
137
  verbose,
134
138
  )
135
139
 
@@ -31,11 +31,11 @@ class Ships(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
31
31
  ----------
32
32
  root : str or pathlib.Path
33
33
  Root directory of dataset where the ``shipdataset`` folder exists.
34
+ transforms : Transform, Sequence[Transform] or None, default None
35
+ Transform(s) to apply to the data.
34
36
  download : bool, default False
35
37
  If True, downloads the dataset from the internet and puts it in root directory.
36
38
  Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
37
- transforms : Transform, Sequence[Transform] or None, default None
38
- Transform(s) to apply to the data.
39
39
  verbose : bool, default False
40
40
  If True, outputs print statements.
41
41
 
@@ -55,6 +55,10 @@ class Ships(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
55
55
  The transforms to be applied to the data.
56
56
  size : int
57
57
  The size of the dataset.
58
+
59
+ Note
60
+ ----
61
+ Data License: `CC BY-SA 4.0 <https://creativecommons.org/licenses/by-sa/4.0/>`_
58
62
  """
59
63
 
60
64
  _resources = [
@@ -74,15 +78,15 @@ class Ships(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
74
78
  def __init__(
75
79
  self,
76
80
  root: str | Path,
77
- download: bool = False,
78
81
  transforms: Transform[NDArray[Any]] | Sequence[Transform[NDArray[Any]]] | None = None,
82
+ download: bool = False,
79
83
  verbose: bool = False,
80
84
  ) -> None:
81
85
  super().__init__(
82
86
  root,
83
- download,
84
87
  "base",
85
88
  transforms,
89
+ download,
86
90
  verbose,
87
91
  )
88
92
  self._scenes: list[str] = self._load_scenes()
@@ -14,6 +14,8 @@ from dataeval.utils.data.datasets._base import (
14
14
  BaseODDataset,
15
15
  BaseSegDataset,
16
16
  DataLocation,
17
+ _TArray,
18
+ _TTarget,
17
19
  )
18
20
  from dataeval.utils.data.datasets._mixin import BaseDatasetNumpyMixin, BaseDatasetTorchMixin
19
21
  from dataeval.utils.data.datasets._types import ObjectDetectionTarget, SegmentationTarget
@@ -21,9 +23,6 @@ from dataeval.utils.data.datasets._types import ObjectDetectionTarget, Segmentat
21
23
  if TYPE_CHECKING:
22
24
  from dataeval.typing import Transform
23
25
 
24
- _TArray = TypeVar("_TArray")
25
- _TTarget = TypeVar("_TTarget")
26
-
27
26
  VOCClassStringMap = Literal[
28
27
  "aeroplane",
29
28
  "bicycle",
@@ -121,19 +120,19 @@ class BaseVOCDataset(BaseDataset[_TArray, _TTarget, list[str]]):
121
120
  def __init__(
122
121
  self,
123
122
  root: str | Path,
124
- year: Literal["2007", "2008", "2009", "2010", "2011", "2012"] = "2012",
125
123
  image_set: Literal["train", "val", "test", "base"] = "train",
126
- download: bool = False,
124
+ year: Literal["2007", "2008", "2009", "2010", "2011", "2012"] = "2012",
127
125
  transforms: Transform[_TArray] | Sequence[Transform[_TArray]] | None = None,
126
+ download: bool = False,
128
127
  verbose: bool = False,
129
128
  ) -> None:
130
129
  self.year = year
131
130
  self._resource_index = self._get_year_image_set_index(year, image_set)
132
131
  super().__init__(
133
132
  root,
134
- download,
135
133
  image_set,
136
134
  transforms,
135
+ download,
137
136
  verbose,
138
137
  )
139
138
 
@@ -191,10 +190,14 @@ class BaseVOCDataset(BaseDataset[_TArray, _TTarget, list[str]]):
191
190
  for entry in data:
192
191
  file_name = Path(entry).name
193
192
  file_stem = Path(entry).stem
194
- # Remove file extension and split by "_"
195
- parts = file_stem.split("_")
196
- file_meta["year"].append(parts[0])
197
- file_meta["image_id"].append(parts[1])
193
+ if self.year != "2007":
194
+ # Remove file extension and split by "_"
195
+ parts = file_stem.split("_")
196
+ file_meta["year"].append(parts[0])
197
+ file_meta["image_id"].append(parts[1])
198
+ else:
199
+ file_meta["year"].append(self.year)
200
+ file_meta["image_id"].append(file_stem)
198
201
  file_meta["mask_path"].append(str(seg_folder / file_name))
199
202
  annotations.append(str(ann_folder / file_stem) + ".xml")
200
203
 
@@ -250,9 +253,6 @@ class VOCDetection(
250
253
  ----------
251
254
  root : str or pathlib.Path
252
255
  Root directory of dataset where the ``vocdataset`` folder exists.
253
- download : bool, default False
254
- If True, downloads the dataset from the internet and puts it in root directory.
255
- Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
256
256
  image_set : "train", "val", "test", or "base", default "train"
257
257
  If "test", then dataset year must be "2007".
258
258
  If "base", then the combined dataset of "train" and "val" is returned.
@@ -260,6 +260,9 @@ class VOCDetection(
260
260
  The dataset year.
261
261
  transforms : Transform, Sequence[Transform] or None, default None
262
262
  Transform(s) to apply to the data.
263
+ download : bool, default False
264
+ If True, downloads the dataset from the internet and puts it in root directory.
265
+ Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
263
266
  verbose : bool, default False
264
267
  If True, outputs print statements.
265
268
 
@@ -267,6 +270,8 @@ class VOCDetection(
267
270
  ----------
268
271
  path : pathlib.Path
269
272
  Location of the folder containing the data.
273
+ year : "2007", "2008", "2009", "2010", "2011" or "2012"
274
+ The selected dataset year.
270
275
  image_set : "train", "val", "test" or "base"
271
276
  The selected image set from the dataset.
272
277
  index2label : dict[int, str]
@@ -279,6 +284,10 @@ class VOCDetection(
279
284
  The transforms to be applied to the data.
280
285
  size : int
281
286
  The size of the dataset.
287
+
288
+ Note
289
+ ----
290
+ Data License: `Flickr Terms of Use <http://www.flickr.com/terms.gne?legacy=1>`_
282
291
  """
283
292
 
284
293
 
@@ -294,9 +303,6 @@ class VOCDetectionTorch(
294
303
  ----------
295
304
  root : str or pathlib.Path
296
305
  Root directory of dataset where the ``vocdataset`` folder exists.
297
- download : bool, default False
298
- If True, downloads the dataset from the internet and puts it in root directory.
299
- Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
300
306
  image_set : "train", "val", "test", or "base", default "train"
301
307
  If "test", then dataset year must be "2007".
302
308
  If "base", then the combined dataset of "train" and "val" is returned.
@@ -304,6 +310,9 @@ class VOCDetectionTorch(
304
310
  The dataset year.
305
311
  transforms : Transform, Sequence[Transform] or None, default None
306
312
  Transform(s) to apply to the data.
313
+ download : bool, default False
314
+ If True, downloads the dataset from the internet and puts it in root directory.
315
+ Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
307
316
  verbose : bool, default False
308
317
  If True, outputs print statements.
309
318
 
@@ -311,6 +320,8 @@ class VOCDetectionTorch(
311
320
  ----------
312
321
  path : pathlib.Path
313
322
  Location of the folder containing the data.
323
+ year : "2007", "2008", "2009", "2010", "2011" or "2012"
324
+ The selected dataset year.
314
325
  image_set : "train", "val", "test" or "base"
315
326
  The selected image set from the dataset.
316
327
  index2label : dict[int, str]
@@ -323,6 +334,10 @@ class VOCDetectionTorch(
323
334
  The transforms to be applied to the data.
324
335
  size : int
325
336
  The size of the dataset.
337
+
338
+ Note
339
+ ----
340
+ Data License: `Flickr Terms of Use <http://www.flickr.com/terms.gne?legacy=1>`_
326
341
  """
327
342
 
328
343
 
@@ -338,9 +353,6 @@ class VOCSegmentation(
338
353
  ----------
339
354
  root : str or pathlib.Path
340
355
  Root directory of dataset where the ``vocdataset`` folder exists.
341
- download : bool, default False
342
- If True, downloads the dataset from the internet and puts it in root directory.
343
- Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
344
356
  image_set : "train", "val", "test", or "base", default "train"
345
357
  If "test", then dataset year must be "2007".
346
358
  If "base", then the combined dataset of "train" and "val" is returned.
@@ -348,6 +360,9 @@ class VOCSegmentation(
348
360
  The dataset year.
349
361
  transforms : Transform, Sequence[Transform] or None, default None
350
362
  Transform(s) to apply to the data.
363
+ download : bool, default False
364
+ If True, downloads the dataset from the internet and puts it in root directory.
365
+ Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
351
366
  verbose : bool, default False
352
367
  If True, outputs print statements.
353
368
 
@@ -355,6 +370,8 @@ class VOCSegmentation(
355
370
  ----------
356
371
  path : pathlib.Path
357
372
  Location of the folder containing the data.
373
+ year : "2007", "2008", "2009", "2010", "2011" or "2012"
374
+ The selected dataset year.
358
375
  image_set : "train", "val", "test" or "base"
359
376
  The selected image set from the dataset.
360
377
  index2label : dict[int, str]
@@ -367,6 +384,10 @@ class VOCSegmentation(
367
384
  The transforms to be applied to the data.
368
385
  size : int
369
386
  The size of the dataset.
387
+
388
+ Note
389
+ ----
390
+ Data License: `Flickr Terms of Use <http://www.flickr.com/terms.gne?legacy=1>`_
370
391
  """
371
392
 
372
393
  def _load_data(self) -> tuple[list[str], list[str], dict[str, list[Any]]]:
@@ -1,6 +1,7 @@
1
1
  """Provides selection classes for selecting subsets of Computer Vision datasets."""
2
2
 
3
3
  __all__ = [
4
+ "ClassBalance",
4
5
  "ClassFilter",
5
6
  "Indices",
6
7
  "Limit",
@@ -9,6 +10,7 @@ __all__ = [
9
10
  "Shuffle",
10
11
  ]
11
12
 
13
+ from dataeval.utils.data.selections._classbalance import ClassBalance
12
14
  from dataeval.utils.data.selections._classfilter import ClassFilter
13
15
  from dataeval.utils.data.selections._indices import Indices
14
16
  from dataeval.utils.data.selections._limit import Limit
@@ -0,0 +1,38 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+
6
+ import numpy as np
7
+
8
+ from dataeval.typing import Array, ImageClassificationDatum
9
+ from dataeval.utils._array import as_numpy
10
+ from dataeval.utils.data._selection import Select, Selection, SelectionStage
11
+
12
+
13
+ class ClassBalance(Selection[ImageClassificationDatum]):
14
+ """
15
+ Balance the dataset by class.
16
+
17
+ Note
18
+ ----
19
+ The total number of instances of each class will be equalized which may result
20
+ in a lower total number of instances than specified by the selection limit.
21
+ """
22
+
23
+ stage = SelectionStage.FILTER
24
+
25
+ def __call__(self, dataset: Select[ImageClassificationDatum]) -> None:
26
+ class_indices: dict[int, list[int]] = {}
27
+ for i, idx in enumerate(dataset._selection):
28
+ target = dataset._dataset[idx][1]
29
+ if isinstance(target, Array):
30
+ label = int(np.argmax(as_numpy(target)))
31
+ else:
32
+ # ObjectDetectionTarget and SegmentationTarget not supported yet
33
+ raise TypeError("ClassFilter only supports classification targets as an array of confidence scores.")
34
+ class_indices.setdefault(label, []).append(i)
35
+
36
+ per_class_limit = min(min(len(c) for c in class_indices.values()), dataset._size_limit // len(class_indices))
37
+ subselection = sorted([i for v in class_indices.values() for i in v[:per_class_limit]])
38
+ dataset._selection = [dataset._selection[i] for i in subselection]
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- from typing import Sequence, TypeVar
5
+ from typing import Sequence
6
6
 
7
7
  import numpy as np
8
8
 
@@ -10,50 +10,35 @@ from dataeval.typing import Array, ImageClassificationDatum
10
10
  from dataeval.utils._array import as_numpy
11
11
  from dataeval.utils.data._selection import Select, Selection, SelectionStage
12
12
 
13
- TImageClassificationDatum = TypeVar("TImageClassificationDatum", bound=ImageClassificationDatum)
14
13
 
15
-
16
- class ClassFilter(Selection[TImageClassificationDatum]):
14
+ class ClassFilter(Selection[ImageClassificationDatum]):
17
15
  """
18
- Filter and balance the dataset by class.
16
+ Filter the dataset by class.
19
17
 
20
18
  Parameters
21
19
  ----------
22
- classes : Sequence[int] or None, default None
23
- The classes to filter by. If None, all classes are included.
24
- balance : bool, default False
25
- Whether to balance the classes.
26
-
27
- Note
28
- ----
29
- If `balance` is True, the total number of instances of each class will
30
- be equalized. This may result in a lower total number of instances.
20
+ classes : Sequence[int]
21
+ The classes to filter by.
31
22
  """
32
23
 
33
24
  stage = SelectionStage.FILTER
34
25
 
35
- def __init__(self, classes: Sequence[int] | None = None, balance: bool = False) -> None:
26
+ def __init__(self, classes: Sequence[int]) -> None:
36
27
  self.classes = classes
37
- self.balance = balance
38
28
 
39
- def __call__(self, dataset: Select[TImageClassificationDatum]) -> None:
40
- if self.classes is None and not self.balance:
29
+ def __call__(self, dataset: Select[ImageClassificationDatum]) -> None:
30
+ if not self.classes:
41
31
  return
42
32
 
43
- per_class_limit = dataset._size_limit // len(self.classes) if self.classes and self.balance else 0
44
- class_indices: dict[int, list[int]] = {} if self.classes is None else {k: [] for k in self.classes}
45
- for i, idx in enumerate(dataset._selection):
33
+ selection = []
34
+ for idx in dataset._selection:
46
35
  target = dataset._dataset[idx][1]
47
36
  if isinstance(target, Array):
48
37
  label = int(np.argmax(as_numpy(target)))
49
38
  else:
50
39
  # ObjectDetectionTarget and SegmentationTarget not supported yet
51
40
  raise TypeError("ClassFilter only supports classification targets as an array of confidence scores.")
52
- if not self.classes or label in self.classes:
53
- class_indices.setdefault(label, []).append(i)
54
- if per_class_limit and all(len(indices) >= per_class_limit for indices in class_indices.values()):
55
- break
56
-
57
- per_class_limit = min(len(c) for c in class_indices.values()) if self.balance else dataset._size_limit
58
- subselection = sorted([i for v in class_indices.values() for i in v[:per_class_limit]])
59
- dataset._selection = [dataset._selection[i] for i in subselection]
41
+ if label in self.classes:
42
+ selection.append(idx)
43
+
44
+ dataset._selection = selection
@@ -272,7 +272,7 @@ class Prioritize(Selection[Any]):
272
272
  return _KMeansComplexitySorter(samples, self._c)
273
273
 
274
274
  def _to_normalized_ndarray(self, embeddings: Embeddings, selection: list[int] | None = None) -> NDArray[Any]:
275
- emb: NDArray[Any] = embeddings.to_tensor(selection).cpu().numpy()
275
+ emb: NDArray[Any] = embeddings.to_numpy(selection)
276
276
  emb /= max(np.max(np.linalg.norm(emb, axis=1)), EPSILON)
277
277
  return emb
278
278
 
@@ -8,7 +8,7 @@ import numpy as np
8
8
  from numpy.random import BitGenerator, Generator, SeedSequence
9
9
  from numpy.typing import NDArray
10
10
 
11
- from dataeval.typing import Array, ArrayLike
11
+ from dataeval.typing import Array
12
12
  from dataeval.utils._array import as_numpy
13
13
  from dataeval.utils.data._selection import Select, Selection, SelectionStage
14
14
 
@@ -30,7 +30,7 @@ class Shuffle(Selection[Any]):
30
30
  seed: int | NDArray[Any] | SeedSequence | BitGenerator | Generator | None
31
31
  stage = SelectionStage.ORDER
32
32
 
33
- def __init__(self, seed: int | ArrayLike | SeedSequence | BitGenerator | Generator | None = None):
33
+ def __init__(self, seed: int | Sequence[int] | Array | SeedSequence | BitGenerator | Generator | None = None):
34
34
  self.seed = as_numpy(seed) if isinstance(seed, (Sequence, Array)) else seed
35
35
 
36
36
  def __call__(self, dataset: Select[Any]) -> None:
@@ -2,7 +2,6 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- from functools import partial
6
5
  from typing import Any, Callable
7
6
 
8
7
  import numpy as np
@@ -12,16 +11,16 @@ from torch.utils.data import DataLoader, TensorDataset
12
11
  from tqdm import tqdm
13
12
 
14
13
  from dataeval.config import DeviceLike, get_device
14
+ from dataeval.typing import Array
15
15
 
16
16
 
17
17
  def predict_batch(
18
- x: NDArray[Any] | torch.Tensor,
19
- model: Callable | torch.nn.Module | torch.nn.Sequential,
18
+ x: Array,
19
+ model: torch.nn.Module,
20
20
  device: DeviceLike | None = None,
21
21
  batch_size: int = int(1e10),
22
22
  preprocess_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
23
- dtype: type[np.generic] | torch.dtype = np.float32,
24
- ) -> NDArray[Any] | torch.Tensor | tuple[Any, ...]:
23
+ ) -> torch.Tensor:
25
24
  """
26
25
  Make batch predictions on a model.
27
26
 
@@ -29,7 +28,7 @@ def predict_batch(
29
28
  ----------
30
29
  x : np.ndarray | torch.Tensor
31
30
  Batch of instances.
32
- model : Callable | nn.Module | nn.Sequential
31
+ model : nn.Module
33
32
  PyTorch model.
34
33
  device : DeviceLike or None, default None
35
34
  The hardware device to use if specified, otherwise uses the DataEval
@@ -38,21 +37,18 @@ def predict_batch(
38
37
  Batch size used during prediction.
39
38
  preprocess_fn : Callable | None, default None
40
39
  Optional preprocessing function for each batch.
41
- dtype : np.dtype | torch.dtype, default np.float32
42
- Model output type, either a :term:`NumPy` or torch dtype, e.g. np.float32 or torch.float32.
43
40
 
44
41
  Returns
45
42
  -------
46
- NDArray | torch.Tensor | tuple
47
- Numpy array, torch tensor or tuples of those with model outputs.
43
+ torch.Tensor
44
+ PyTorch tensor with model outputs.
48
45
  """
49
46
  device = get_device(device)
50
- if isinstance(x, np.ndarray):
51
- x = torch.tensor(x, device=device)
47
+ if isinstance(model, torch.nn.Module):
48
+ model = model.to(device).eval()
49
+ x = torch.tensor(x, device=device)
52
50
  n = len(x)
53
51
  n_minibatch = int(np.ceil(n / batch_size))
54
- return_np = not isinstance(dtype, torch.dtype)
55
- preds_tuple = None
56
52
  preds_array = []
57
53
  with torch.no_grad():
58
54
  for i in range(n_minibatch):
@@ -60,28 +56,9 @@ def predict_batch(
60
56
  x_batch = x[istart:istop]
61
57
  if isinstance(preprocess_fn, Callable):
62
58
  x_batch = preprocess_fn(x_batch)
59
+ preds_array.append(model(x_batch.to(dtype=torch.float32)).cpu())
63
60
 
64
- preds_tmp = model(x_batch.to(dtype=torch.float32))
65
- if isinstance(preds_tmp, (list, tuple)):
66
- if preds_tuple is None: # init tuple with lists to store predictions
67
- preds_tuple = tuple([] for _ in range(len(preds_tmp)))
68
- for j, p in enumerate(preds_tmp):
69
- p = p.cpu() if isinstance(p, torch.Tensor) else p
70
- preds_tuple[j].append(p if not return_np or isinstance(p, np.ndarray) else p.numpy())
71
- elif isinstance(preds_tmp, (np.ndarray, torch.Tensor)):
72
- preds_tmp = preds_tmp.cpu() if isinstance(preds_tmp, torch.Tensor) else preds_tmp
73
- preds_array.append(
74
- preds_tmp if not return_np or isinstance(preds_tmp, np.ndarray) else preds_tmp.numpy()
75
- )
76
- else:
77
- raise TypeError(
78
- f"Model output type {type(preds_tmp)} not supported. The model \
79
- output type needs to be one of list, tuple, NDArray or \
80
- torch.Tensor."
81
- )
82
- concat = partial(np.concatenate, axis=0) if return_np else partial(torch.cat, dim=0)
83
- out = tuple(concat(p) for p in preds_tuple) if preds_tuple is not None else concat(preds_array)
84
- return out
61
+ return torch.cat(preds_array, dim=0)
85
62
 
86
63
 
87
64
  def trainer(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: dataeval
3
- Version: 0.84.0
3
+ Version: 0.84.1
4
4
  Summary: DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks
5
5
  Home-page: https://dataeval.ai/
6
6
  License: MIT
@@ -82,8 +82,7 @@ using MAITE-compliant datasets and models.
82
82
 
83
83
  **Python versions:** 3.9 - 3.12
84
84
 
85
- **Supported packages**: *NumPy*, *Pandas*, *Sci-kit learn*, *MAITE*, *NRTK*,
86
- *Gradient*
85
+ **Supported packages**: *NumPy*, *Pandas*, *Sci-kit learn*, *MAITE*, *NRTK*
87
86
 
88
87
  Choose your preferred method of installation below or follow our
89
88
  [installation guide](https://dataeval.readthedocs.io/en/v0.74.2/installation.html).