dataeval 0.82.0__py3-none-any.whl → 0.83.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +7 -2
- dataeval/config.py +78 -11
- dataeval/detectors/drift/_mmd.py +9 -9
- dataeval/detectors/drift/_torch.py +7 -7
- dataeval/detectors/drift/_uncertainty.py +4 -4
- dataeval/detectors/linters/duplicates.py +3 -3
- dataeval/detectors/linters/outliers.py +3 -3
- dataeval/detectors/ood/ae.py +5 -4
- dataeval/detectors/ood/base.py +2 -2
- dataeval/detectors/ood/mixin.py +1 -1
- dataeval/detectors/ood/vae.py +2 -1
- dataeval/metadata/__init__.py +2 -2
- dataeval/metadata/_distance.py +11 -44
- dataeval/metadata/_ood.py +152 -33
- dataeval/metrics/bias/_balance.py +9 -5
- dataeval/metrics/bias/_diversity.py +3 -0
- dataeval/metrics/bias/_parity.py +2 -0
- dataeval/metrics/estimators/_ber.py +2 -1
- dataeval/metrics/stats/_base.py +20 -21
- dataeval/metrics/stats/_boxratiostats.py +1 -1
- dataeval/metrics/stats/_dimensionstats.py +2 -2
- dataeval/metrics/stats/_hashstats.py +2 -2
- dataeval/metrics/stats/_imagestats.py +8 -8
- dataeval/metrics/stats/_pixelstats.py +2 -2
- dataeval/metrics/stats/_visualstats.py +2 -2
- dataeval/outputs/__init__.py +5 -0
- dataeval/outputs/_base.py +50 -21
- dataeval/outputs/_bias.py +1 -1
- dataeval/outputs/_linters.py +4 -2
- dataeval/outputs/_metadata.py +61 -0
- dataeval/outputs/_stats.py +12 -6
- dataeval/typing.py +40 -9
- dataeval/utils/_mst.py +1 -2
- dataeval/utils/data/_embeddings.py +23 -19
- dataeval/utils/data/_metadata.py +16 -7
- dataeval/utils/data/_selection.py +22 -15
- dataeval/utils/data/_split.py +3 -2
- dataeval/utils/data/datasets/_base.py +4 -2
- dataeval/utils/data/datasets/_cifar10.py +17 -9
- dataeval/utils/data/datasets/_milco.py +18 -12
- dataeval/utils/data/datasets/_mnist.py +24 -8
- dataeval/utils/data/datasets/_ships.py +18 -8
- dataeval/utils/data/datasets/_types.py +1 -5
- dataeval/utils/data/datasets/_voc.py +47 -24
- dataeval/utils/data/selections/__init__.py +2 -0
- dataeval/utils/data/selections/_classfilter.py +5 -3
- dataeval/utils/data/selections/_prioritize.py +296 -0
- dataeval/utils/data/selections/_shuffle.py +13 -4
- dataeval/utils/torch/_gmm.py +3 -2
- dataeval/utils/torch/_internal.py +5 -5
- dataeval/utils/torch/trainer.py +8 -8
- {dataeval-0.82.0.dist-info → dataeval-0.83.0.dist-info}/METADATA +4 -4
- dataeval-0.83.0.dist-info/RECORD +105 -0
- dataeval/detectors/ood/metadata_ood_mi.py +0 -93
- dataeval-0.82.0.dist-info/RECORD +0 -104
- {dataeval-0.82.0.dist-info → dataeval-0.83.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.82.0.dist-info → dataeval-0.83.0.dist-info}/WHEEL +0 -0
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
5
|
from pathlib import Path
|
6
|
-
from typing import Any, Literal, Sequence, TypeVar
|
6
|
+
from typing import TYPE_CHECKING, Any, Literal, Sequence, TypeVar
|
7
7
|
|
8
8
|
import numpy as np
|
9
9
|
from numpy.typing import NDArray
|
@@ -11,7 +11,9 @@ from PIL import Image
|
|
11
11
|
|
12
12
|
from dataeval.utils.data.datasets._base import BaseICDataset, DataLocation
|
13
13
|
from dataeval.utils.data.datasets._mixin import BaseDatasetNumpyMixin
|
14
|
-
|
14
|
+
|
15
|
+
if TYPE_CHECKING:
|
16
|
+
from dataeval.typing import Transform
|
15
17
|
|
16
18
|
CIFARClassStringMap = Literal["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
|
17
19
|
TCIFARClassMap = TypeVar("TCIFARClassMap", CIFARClassStringMap, int, list[CIFARClassStringMap], list[int])
|
@@ -30,21 +32,27 @@ class CIFAR10(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
|
|
30
32
|
Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
|
31
33
|
image_set : "train", "test" or "base", default "train"
|
32
34
|
If "base", returns all of the data to allow the user to create their own splits.
|
33
|
-
transforms : Transform
|
35
|
+
transforms : Transform, Sequence[Transform] or None, default None
|
34
36
|
Transform(s) to apply to the data.
|
35
37
|
verbose : bool, default False
|
36
38
|
If True, outputs print statements.
|
37
39
|
|
38
40
|
Attributes
|
39
41
|
----------
|
40
|
-
|
42
|
+
path : pathlib.Path
|
43
|
+
Location of the folder containing the data.
|
44
|
+
image_set : "train", "test" or "base"
|
45
|
+
The selected image set from the dataset.
|
46
|
+
index2label : dict[int, str]
|
41
47
|
Dictionary which translates from class integers to the associated class strings.
|
42
|
-
label2index : dict
|
48
|
+
label2index : dict[str, int]
|
43
49
|
Dictionary which translates from class strings to the associated class integers.
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
50
|
+
metadata : DatasetMetadata
|
51
|
+
Typed dictionary containing dataset metadata, such as `id` which returns the dataset class name.
|
52
|
+
transforms : Sequence[Transform]
|
53
|
+
The transforms to be applied to the data.
|
54
|
+
size : int
|
55
|
+
The size of the dataset.
|
48
56
|
"""
|
49
57
|
|
50
58
|
_resources = [
|
@@ -1,23 +1,23 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from dataeval.utils.data.datasets._mixin import BaseDatasetNumpyMixin
|
4
|
-
|
5
3
|
__all__ = []
|
6
4
|
|
7
5
|
from pathlib import Path
|
8
|
-
from typing import Any, Sequence
|
6
|
+
from typing import TYPE_CHECKING, Any, Sequence
|
9
7
|
|
10
8
|
from numpy.typing import NDArray
|
11
9
|
|
12
10
|
from dataeval.utils.data.datasets._base import BaseODDataset, DataLocation
|
13
|
-
from dataeval.utils.data.datasets.
|
11
|
+
from dataeval.utils.data.datasets._mixin import BaseDatasetNumpyMixin
|
12
|
+
|
13
|
+
if TYPE_CHECKING:
|
14
|
+
from dataeval.typing import Transform
|
14
15
|
|
15
16
|
|
16
17
|
class MILCO(BaseODDataset[NDArray[Any]], BaseDatasetNumpyMixin):
|
17
18
|
"""
|
18
19
|
A side-scan sonar dataset focused on mine (object) detection.
|
19
20
|
|
20
|
-
|
21
21
|
The dataset comes from the paper
|
22
22
|
`Side-scan sonar imaging data of underwater vehicles for mine detection <https://doi.org/10.1016/j.dib.2024.110132>`_
|
23
23
|
by N.P. Santos et. al. (2024).
|
@@ -43,21 +43,27 @@ class MILCO(BaseODDataset[NDArray[Any]], BaseDatasetNumpyMixin):
|
|
43
43
|
download : bool, default False
|
44
44
|
If True, downloads the dataset from the internet and puts it in root directory.
|
45
45
|
Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
|
46
|
-
transforms : Transform
|
46
|
+
transforms : Transform, Sequence[Transform] or None, default None
|
47
47
|
Transform(s) to apply to the data.
|
48
48
|
verbose : bool, default False
|
49
49
|
If True, outputs print statements.
|
50
50
|
|
51
51
|
Attributes
|
52
52
|
----------
|
53
|
-
|
53
|
+
path : pathlib.Path
|
54
|
+
Location of the folder containing the data.
|
55
|
+
image_set : "base"
|
56
|
+
The base image set is the only available image set for the MILCO dataset.
|
57
|
+
index2label : dict[int, str]
|
54
58
|
Dictionary which translates from class integers to the associated class strings.
|
55
|
-
label2index : dict
|
59
|
+
label2index : dict[str, int]
|
56
60
|
Dictionary which translates from class strings to the associated class integers.
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
+
metadata : DatasetMetadata
|
62
|
+
Typed dictionary containing dataset metadata, such as `id` which returns the dataset class name.
|
63
|
+
transforms : Sequence[Transform]
|
64
|
+
The transforms to be applied to the data.
|
65
|
+
size : int
|
66
|
+
The size of the dataset.
|
61
67
|
"""
|
62
68
|
|
63
69
|
_resources = [
|
@@ -3,14 +3,16 @@ from __future__ import annotations
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
5
|
from pathlib import Path
|
6
|
-
from typing import Any, Literal, Sequence, TypeVar
|
6
|
+
from typing import TYPE_CHECKING, Any, Literal, Sequence, TypeVar
|
7
7
|
|
8
8
|
import numpy as np
|
9
9
|
from numpy.typing import NDArray
|
10
10
|
|
11
11
|
from dataeval.utils.data.datasets._base import BaseICDataset, DataLocation
|
12
12
|
from dataeval.utils.data.datasets._mixin import BaseDatasetNumpyMixin
|
13
|
-
|
13
|
+
|
14
|
+
if TYPE_CHECKING:
|
15
|
+
from dataeval.typing import Transform
|
14
16
|
|
15
17
|
MNISTClassStringMap = Literal["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
|
16
18
|
TMNISTClassMap = TypeVar("TMNISTClassMap", MNISTClassStringMap, int, list[MNISTClassStringMap], list[int])
|
@@ -52,19 +54,33 @@ class MNIST(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
|
|
52
54
|
Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
|
53
55
|
image_set : "train", "test" or "base", default "train"
|
54
56
|
If "base", returns all of the data to allow the user to create their own splits.
|
57
|
+
corruption : "identity", "shot_noise", "impulse_noise", "glass_blur", "motion_blur", \
|
58
|
+
"shear", "scale", "rotate", "brightness", "translate", "stripe", "fog", "spatter", \
|
59
|
+
"dotted_line", "zigzag", "canny_edges" or None, default None
|
60
|
+
Corruption to apply to the data.
|
61
|
+
transforms : Transform, Sequence[Transform] or None, default None
|
62
|
+
Transform(s) to apply to the data.
|
55
63
|
verbose : bool, default False
|
56
64
|
If True, outputs print statements.
|
57
65
|
|
58
66
|
Attributes
|
59
67
|
----------
|
60
|
-
|
68
|
+
path : pathlib.Path
|
69
|
+
Location of the folder containing the data.
|
70
|
+
image_set : "train", "test" or "base"
|
71
|
+
The selected image set from the dataset.
|
72
|
+
index2label : dict[int, str]
|
61
73
|
Dictionary which translates from class integers to the associated class strings.
|
62
|
-
label2index : dict
|
74
|
+
label2index : dict[str, int]
|
63
75
|
Dictionary which translates from class strings to the associated class integers.
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
76
|
+
metadata : DatasetMetadata
|
77
|
+
Typed dictionary containing dataset metadata, such as `id` which returns the dataset class name.
|
78
|
+
corruption : str or None
|
79
|
+
Corruption applied to the data.
|
80
|
+
transforms : Sequence[Transform]
|
81
|
+
The transforms to be applied to the data.
|
82
|
+
size : int
|
83
|
+
The size of the dataset.
|
68
84
|
"""
|
69
85
|
|
70
86
|
_resources = [
|
@@ -3,14 +3,16 @@ from __future__ import annotations
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
5
|
from pathlib import Path
|
6
|
-
from typing import Any, Sequence
|
6
|
+
from typing import TYPE_CHECKING, Any, Sequence
|
7
7
|
|
8
8
|
import numpy as np
|
9
9
|
from numpy.typing import NDArray
|
10
10
|
|
11
11
|
from dataeval.utils.data.datasets._base import BaseICDataset, DataLocation
|
12
12
|
from dataeval.utils.data.datasets._mixin import BaseDatasetNumpyMixin
|
13
|
-
|
13
|
+
|
14
|
+
if TYPE_CHECKING:
|
15
|
+
from dataeval.typing import Transform
|
14
16
|
|
15
17
|
|
16
18
|
class Ships(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
|
@@ -32,19 +34,27 @@ class Ships(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
|
|
32
34
|
download : bool, default False
|
33
35
|
If True, downloads the dataset from the internet and puts it in root directory.
|
34
36
|
Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
|
37
|
+
transforms : Transform, Sequence[Transform] or None, default None
|
38
|
+
Transform(s) to apply to the data.
|
35
39
|
verbose : bool, default False
|
36
40
|
If True, outputs print statements.
|
37
41
|
|
38
42
|
Attributes
|
39
43
|
----------
|
40
|
-
|
44
|
+
path : pathlib.Path
|
45
|
+
Location of the folder containing the data.
|
46
|
+
image_set : "base"
|
47
|
+
The base image set is the only available image set for the Ships dataset.
|
48
|
+
index2label : dict[int, str]
|
41
49
|
Dictionary which translates from class integers to the associated class strings.
|
42
|
-
label2index : dict
|
50
|
+
label2index : dict[str, int]
|
43
51
|
Dictionary which translates from class strings to the associated class integers.
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
52
|
+
metadata : DatasetMetadata
|
53
|
+
Typed dictionary containing dataset metadata, such as `id` which returns the dataset class name.
|
54
|
+
transforms : Sequence[Transform]
|
55
|
+
The transforms to be applied to the data.
|
56
|
+
size : int
|
57
|
+
The size of the dataset.
|
48
58
|
"""
|
49
59
|
|
50
60
|
_resources = [
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
5
|
from dataclasses import dataclass
|
6
|
-
from typing import Any, Generic,
|
6
|
+
from typing import Any, Generic, TypedDict, TypeVar
|
7
7
|
|
8
8
|
from torch.utils.data import Dataset
|
9
9
|
from typing_extensions import NotRequired, Required
|
@@ -46,7 +46,3 @@ class SegmentationTarget(Generic[_TArray]):
|
|
46
46
|
|
47
47
|
|
48
48
|
class SegmentationDataset(AnnotatedDataset[tuple[_TArray, SegmentationTarget[_TArray], dict[str, Any]]]): ...
|
49
|
-
|
50
|
-
|
51
|
-
class Transform(Generic[_TArray], Protocol):
|
52
|
-
def __call__(self, data: _TArray, /) -> _TArray: ...
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
5
|
from pathlib import Path
|
6
|
-
from typing import Any, Literal, Sequence, TypeVar
|
6
|
+
from typing import TYPE_CHECKING, Any, Literal, Sequence, TypeVar
|
7
7
|
|
8
8
|
import torch
|
9
9
|
from defusedxml.ElementTree import parse
|
@@ -16,7 +16,10 @@ from dataeval.utils.data.datasets._base import (
|
|
16
16
|
DataLocation,
|
17
17
|
)
|
18
18
|
from dataeval.utils.data.datasets._mixin import BaseDatasetNumpyMixin, BaseDatasetTorchMixin
|
19
|
-
from dataeval.utils.data.datasets._types import ObjectDetectionTarget, SegmentationTarget
|
19
|
+
from dataeval.utils.data.datasets._types import ObjectDetectionTarget, SegmentationTarget
|
20
|
+
|
21
|
+
if TYPE_CHECKING:
|
22
|
+
from dataeval.typing import Transform
|
20
23
|
|
21
24
|
_TArray = TypeVar("_TArray")
|
22
25
|
_TTarget = TypeVar("_TTarget")
|
@@ -201,6 +204,8 @@ class BaseVOCDataset(BaseDataset[_TArray, _TTarget, list[str]]):
|
|
201
204
|
boxes: list[list[float]] = []
|
202
205
|
label_str = []
|
203
206
|
root = parse(annotation).getroot()
|
207
|
+
if root is None:
|
208
|
+
raise ValueError(f"Unable to parse {annotation}")
|
204
209
|
num_objects = len(root.findall("object"))
|
205
210
|
additional_meta: dict[str, Any] = {
|
206
211
|
"folder": [root.findtext("folder", default="") for _ in range(num_objects)],
|
@@ -253,21 +258,27 @@ class VOCDetection(
|
|
253
258
|
If "base", then the combined dataset of "train" and "val" is returned.
|
254
259
|
year : "2007", "2008", "2009", "2010", "2011" or "2012", default "2012"
|
255
260
|
The dataset year.
|
256
|
-
transforms : Transform
|
261
|
+
transforms : Transform, Sequence[Transform] or None, default None
|
257
262
|
Transform(s) to apply to the data.
|
258
263
|
verbose : bool, default False
|
259
264
|
If True, outputs print statements.
|
260
265
|
|
261
266
|
Attributes
|
262
267
|
----------
|
263
|
-
|
268
|
+
path : pathlib.Path
|
269
|
+
Location of the folder containing the data.
|
270
|
+
image_set : "train", "val", "test" or "base"
|
271
|
+
The selected image set from the dataset.
|
272
|
+
index2label : dict[int, str]
|
264
273
|
Dictionary which translates from class integers to the associated class strings.
|
265
|
-
label2index : dict
|
274
|
+
label2index : dict[str, int]
|
266
275
|
Dictionary which translates from class strings to the associated class integers.
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
276
|
+
metadata : DatasetMetadata
|
277
|
+
Typed dictionary containing dataset metadata, such as `id` which returns the dataset class name.
|
278
|
+
transforms : Sequence[Transform]
|
279
|
+
The transforms to be applied to the data.
|
280
|
+
size : int
|
281
|
+
The size of the dataset.
|
271
282
|
"""
|
272
283
|
|
273
284
|
|
@@ -277,7 +288,7 @@ class VOCDetectionTorch(
|
|
277
288
|
BaseDatasetTorchMixin,
|
278
289
|
):
|
279
290
|
"""
|
280
|
-
`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Detection Dataset.
|
291
|
+
`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Detection Dataset as PyTorch tensors.
|
281
292
|
|
282
293
|
Parameters
|
283
294
|
----------
|
@@ -291,21 +302,27 @@ class VOCDetectionTorch(
|
|
291
302
|
If "base", then the combined dataset of "train" and "val" is returned.
|
292
303
|
year : "2007", "2008", "2009", "2010", "2011" or "2012", default "2012"
|
293
304
|
The dataset year.
|
294
|
-
transforms : Transform
|
305
|
+
transforms : Transform, Sequence[Transform] or None, default None
|
295
306
|
Transform(s) to apply to the data.
|
296
307
|
verbose : bool, default False
|
297
308
|
If True, outputs print statements.
|
298
309
|
|
299
310
|
Attributes
|
300
311
|
----------
|
301
|
-
|
312
|
+
path : pathlib.Path
|
313
|
+
Location of the folder containing the data.
|
314
|
+
image_set : "train", "val", "test" or "base"
|
315
|
+
The selected image set from the dataset.
|
316
|
+
index2label : dict[int, str]
|
302
317
|
Dictionary which translates from class integers to the associated class strings.
|
303
|
-
label2index : dict
|
318
|
+
label2index : dict[str, int]
|
304
319
|
Dictionary which translates from class strings to the associated class integers.
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
320
|
+
metadata : DatasetMetadata
|
321
|
+
Typed dictionary containing dataset metadata, such as `id` which returns the dataset class name.
|
322
|
+
transforms : Sequence[Transform]
|
323
|
+
The transforms to be applied to the data.
|
324
|
+
size : int
|
325
|
+
The size of the dataset.
|
309
326
|
"""
|
310
327
|
|
311
328
|
|
@@ -329,21 +346,27 @@ class VOCSegmentation(
|
|
329
346
|
If "base", then the combined dataset of "train" and "val" is returned.
|
330
347
|
year : "2007", "2008", "2009", "2010", "2011" or "2012", default "2012"
|
331
348
|
The dataset year.
|
332
|
-
transforms : Transform
|
349
|
+
transforms : Transform, Sequence[Transform] or None, default None
|
333
350
|
Transform(s) to apply to the data.
|
334
351
|
verbose : bool, default False
|
335
352
|
If True, outputs print statements.
|
336
353
|
|
337
354
|
Attributes
|
338
355
|
----------
|
339
|
-
|
356
|
+
path : pathlib.Path
|
357
|
+
Location of the folder containing the data.
|
358
|
+
image_set : "train", "val", "test" or "base"
|
359
|
+
The selected image set from the dataset.
|
360
|
+
index2label : dict[int, str]
|
340
361
|
Dictionary which translates from class integers to the associated class strings.
|
341
|
-
label2index : dict
|
362
|
+
label2index : dict[str, int]
|
342
363
|
Dictionary which translates from class strings to the associated class integers.
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
364
|
+
metadata : DatasetMetadata
|
365
|
+
Typed dictionary containing dataset metadata, such as `id` which returns the dataset class name.
|
366
|
+
transforms : Sequence[Transform]
|
367
|
+
The transforms to be applied to the data.
|
368
|
+
size : int
|
369
|
+
The size of the dataset.
|
347
370
|
"""
|
348
371
|
|
349
372
|
def _load_data(self) -> tuple[list[str], list[str], dict[str, list[Any]]]:
|
@@ -4,6 +4,7 @@ __all__ = [
|
|
4
4
|
"ClassFilter",
|
5
5
|
"Indices",
|
6
6
|
"Limit",
|
7
|
+
"Prioritize",
|
7
8
|
"Reverse",
|
8
9
|
"Shuffle",
|
9
10
|
]
|
@@ -11,5 +12,6 @@ __all__ = [
|
|
11
12
|
from dataeval.utils.data.selections._classfilter import ClassFilter
|
12
13
|
from dataeval.utils.data.selections._indices import Indices
|
13
14
|
from dataeval.utils.data.selections._limit import Limit
|
15
|
+
from dataeval.utils.data.selections._prioritize import Prioritize
|
14
16
|
from dataeval.utils.data.selections._reverse import Reverse
|
15
17
|
from dataeval.utils.data.selections._shuffle import Shuffle
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
-
from typing import Sequence
|
5
|
+
from typing import Sequence, TypeVar
|
6
6
|
|
7
7
|
import numpy as np
|
8
8
|
|
@@ -10,8 +10,10 @@ from dataeval.typing import Array, ImageClassificationDatum
|
|
10
10
|
from dataeval.utils._array import as_numpy
|
11
11
|
from dataeval.utils.data._selection import Select, Selection, SelectionStage
|
12
12
|
|
13
|
+
TImageClassificationDatum = TypeVar("TImageClassificationDatum", bound=ImageClassificationDatum)
|
13
14
|
|
14
|
-
|
15
|
+
|
16
|
+
class ClassFilter(Selection[TImageClassificationDatum]):
|
15
17
|
"""
|
16
18
|
Filter and balance the dataset by class.
|
17
19
|
|
@@ -34,7 +36,7 @@ class ClassFilter(Selection[ImageClassificationDatum]):
|
|
34
36
|
self.classes = classes
|
35
37
|
self.balance = balance
|
36
38
|
|
37
|
-
def __call__(self, dataset: Select[
|
39
|
+
def __call__(self, dataset: Select[TImageClassificationDatum]) -> None:
|
38
40
|
if self.classes is None and not self.balance:
|
39
41
|
return
|
40
42
|
|