dataeval 0.84.1__py3-none-any.whl → 0.86.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +1 -1
- dataeval/data/__init__.py +19 -0
- dataeval/{utils/data → data}/_embeddings.py +137 -17
- dataeval/{utils/data → data}/_metadata.py +20 -8
- dataeval/{utils/data → data}/_selection.py +22 -9
- dataeval/{utils/data → data}/_split.py +1 -1
- dataeval/data/selections/__init__.py +19 -0
- dataeval/{utils/data → data}/selections/_classbalance.py +1 -2
- dataeval/data/selections/_classfilter.py +110 -0
- dataeval/{utils/data → data}/selections/_indices.py +1 -1
- dataeval/{utils/data → data}/selections/_limit.py +1 -1
- dataeval/{utils/data → data}/selections/_prioritize.py +2 -2
- dataeval/{utils/data → data}/selections/_reverse.py +1 -1
- dataeval/{utils/data → data}/selections/_shuffle.py +1 -1
- dataeval/detectors/drift/__init__.py +4 -1
- dataeval/detectors/drift/_base.py +1 -1
- dataeval/detectors/drift/_cvm.py +2 -2
- dataeval/detectors/drift/_ks.py +2 -2
- dataeval/detectors/drift/_mmd.py +2 -2
- dataeval/detectors/drift/_mvdc.py +92 -0
- dataeval/detectors/drift/_nml/__init__.py +6 -0
- dataeval/detectors/drift/_nml/_base.py +68 -0
- dataeval/detectors/drift/_nml/_chunk.py +404 -0
- dataeval/detectors/drift/_nml/_domainclassifier.py +192 -0
- dataeval/detectors/drift/_nml/_result.py +98 -0
- dataeval/detectors/drift/_nml/_thresholds.py +280 -0
- dataeval/detectors/linters/duplicates.py +1 -1
- dataeval/detectors/linters/outliers.py +1 -1
- dataeval/metadata/_distance.py +1 -1
- dataeval/metadata/_ood.py +4 -4
- dataeval/metrics/bias/_balance.py +1 -1
- dataeval/metrics/bias/_diversity.py +1 -1
- dataeval/metrics/bias/_parity.py +1 -1
- dataeval/metrics/stats/_labelstats.py +2 -2
- dataeval/outputs/__init__.py +2 -1
- dataeval/outputs/_bias.py +2 -4
- dataeval/outputs/_drift.py +68 -0
- dataeval/outputs/_linters.py +1 -6
- dataeval/outputs/_stats.py +1 -6
- dataeval/typing.py +31 -0
- dataeval/utils/__init__.py +2 -2
- dataeval/utils/data/__init__.py +5 -20
- dataeval/utils/data/collate.py +2 -0
- dataeval/utils/datasets/__init__.py +17 -0
- dataeval/utils/{data/datasets → datasets}/_base.py +3 -3
- dataeval/utils/{data/datasets → datasets}/_cifar10.py +2 -2
- dataeval/utils/{data/datasets → datasets}/_milco.py +2 -2
- dataeval/utils/{data/datasets → datasets}/_mnist.py +2 -2
- dataeval/utils/{data/datasets → datasets}/_ships.py +2 -2
- dataeval/utils/{data/datasets → datasets}/_voc.py +3 -3
- {dataeval-0.84.1.dist-info → dataeval-0.86.0.dist-info}/METADATA +3 -2
- dataeval-0.86.0.dist-info/RECORD +114 -0
- dataeval/utils/data/datasets/__init__.py +0 -17
- dataeval/utils/data/selections/__init__.py +0 -19
- dataeval/utils/data/selections/_classfilter.py +0 -44
- dataeval-0.84.1.dist-info/RECORD +0 -106
- /dataeval/{utils/data → data}/_images.py +0 -0
- /dataeval/{utils/data → data}/_targets.py +0 -0
- /dataeval/utils/{metadata.py → data/metadata.py} +0 -0
- /dataeval/utils/{data/datasets → datasets}/_fileio.py +0 -0
- /dataeval/utils/{data/datasets → datasets}/_mixin.py +0 -0
- /dataeval/utils/{data/datasets → datasets}/_types.py +0 -0
- {dataeval-0.84.1.dist-info → dataeval-0.86.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.84.1.dist-info → dataeval-0.86.0.dist-info}/WHEEL +0 -0
dataeval/metrics/bias/_parity.py
CHANGED
@@ -10,11 +10,11 @@ from numpy.typing import NDArray
|
|
10
10
|
from scipy.stats import chisquare
|
11
11
|
from scipy.stats.contingency import chi2_contingency, crosstab
|
12
12
|
|
13
|
+
from dataeval.data import Metadata
|
13
14
|
from dataeval.outputs import LabelParityOutput, ParityOutput
|
14
15
|
from dataeval.outputs._base import set_metadata
|
15
16
|
from dataeval.typing import ArrayLike
|
16
17
|
from dataeval.utils._array import as_numpy
|
17
|
-
from dataeval.utils.data import Metadata
|
18
18
|
|
19
19
|
|
20
20
|
def normalize_expected_dist(expected_dist: NDArray[Any], observed_dist: NDArray[Any]) -> NDArray[Any]:
|
@@ -5,10 +5,10 @@ __all__ = []
|
|
5
5
|
from collections import Counter, defaultdict
|
6
6
|
from typing import Any, Mapping, TypeVar
|
7
7
|
|
8
|
+
from dataeval.data._metadata import Metadata
|
8
9
|
from dataeval.outputs import LabelStatsOutput
|
9
10
|
from dataeval.outputs._base import set_metadata
|
10
11
|
from dataeval.typing import AnnotatedDataset
|
11
|
-
from dataeval.utils.data._metadata import Metadata
|
12
12
|
|
13
13
|
TValue = TypeVar("TValue")
|
14
14
|
|
@@ -38,7 +38,7 @@ def labelstats(dataset: Metadata | AnnotatedDataset[Any]) -> LabelStatsOutput:
|
|
38
38
|
--------
|
39
39
|
Calculate basic :term:`statistics<Statistics>` on labels for a dataset.
|
40
40
|
|
41
|
-
>>> from dataeval.
|
41
|
+
>>> from dataeval.data import Metadata
|
42
42
|
>>> stats = labelstats(Metadata(dataset))
|
43
43
|
>>> print(stats.to_table())
|
44
44
|
Class Count: 5
|
dataeval/outputs/__init__.py
CHANGED
@@ -5,7 +5,7 @@ as well as runtime metadata for reproducibility and logging.
|
|
5
5
|
|
6
6
|
from ._base import ExecutionMetadata
|
7
7
|
from ._bias import BalanceOutput, CompletenessOutput, CoverageOutput, DiversityOutput, LabelParityOutput, ParityOutput
|
8
|
-
from ._drift import DriftMMDOutput, DriftOutput
|
8
|
+
from ._drift import DriftMMDOutput, DriftMVDCOutput, DriftOutput
|
9
9
|
from ._estimators import BEROutput, ClustererOutput, DivergenceOutput, UAPOutput
|
10
10
|
from ._linters import DuplicatesOutput, OutliersOutput
|
11
11
|
from ._metadata import MetadataDistanceOutput, MetadataDistanceValues, MostDeviatedFactorsOutput, OODPredictorOutput
|
@@ -34,6 +34,7 @@ __all__ = [
|
|
34
34
|
"DivergenceOutput",
|
35
35
|
"DiversityOutput",
|
36
36
|
"DriftMMDOutput",
|
37
|
+
"DriftMVDCOutput",
|
37
38
|
"DriftOutput",
|
38
39
|
"DuplicatesOutput",
|
39
40
|
"ExecutionMetadata",
|
dataeval/outputs/_bias.py
CHANGED
@@ -7,17 +7,17 @@ from dataclasses import asdict, dataclass
|
|
7
7
|
from typing import Any, Literal, TypeVar, overload
|
8
8
|
|
9
9
|
import numpy as np
|
10
|
+
import pandas as pd
|
10
11
|
from numpy.typing import NDArray
|
11
12
|
|
12
13
|
with contextlib.suppress(ImportError):
|
13
|
-
import pandas as pd
|
14
14
|
from matplotlib.figure import Figure
|
15
15
|
|
16
|
+
from dataeval.data._images import Images
|
16
17
|
from dataeval.outputs._base import Output
|
17
18
|
from dataeval.typing import ArrayLike, Dataset
|
18
19
|
from dataeval.utils._array import as_numpy, channels_first_to_last
|
19
20
|
from dataeval.utils._plot import heatmap
|
20
|
-
from dataeval.utils.data._images import Images
|
21
21
|
|
22
22
|
TData = TypeVar("TData", np.float64, NDArray[np.float64])
|
23
23
|
|
@@ -38,8 +38,6 @@ class ToDataFrameMixin:
|
|
38
38
|
-----
|
39
39
|
This method requires `pandas <https://pandas.pydata.org/>`_ to be installed.
|
40
40
|
"""
|
41
|
-
import pandas as pd
|
42
|
-
|
43
41
|
return pd.DataFrame(
|
44
42
|
index=self.factor_names, # type: ignore - list[str] is documented as acceptable index type
|
45
43
|
data={
|
dataeval/outputs/_drift.py
CHANGED
@@ -2,11 +2,17 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
+
import contextlib
|
5
6
|
from dataclasses import dataclass
|
6
7
|
|
7
8
|
import numpy as np
|
9
|
+
import pandas as pd
|
8
10
|
from numpy.typing import NDArray
|
9
11
|
|
12
|
+
with contextlib.suppress(ImportError):
|
13
|
+
from matplotlib.figure import Figure
|
14
|
+
|
15
|
+
from dataeval.detectors.drift._nml._result import Metric, PerMetricResult
|
10
16
|
from dataeval.outputs._base import Output
|
11
17
|
|
12
18
|
|
@@ -81,3 +87,65 @@ class DriftOutput(DriftBaseOutput):
|
|
81
87
|
feature_threshold: float
|
82
88
|
p_vals: NDArray[np.float32]
|
83
89
|
distances: NDArray[np.float32]
|
90
|
+
|
91
|
+
|
92
|
+
class DriftMVDCOutput(PerMetricResult):
|
93
|
+
"""Class wrapping the results of the classifier for drift detection and providing plotting functionality."""
|
94
|
+
|
95
|
+
def __init__(self, results_data: pd.DataFrame) -> None:
|
96
|
+
"""Initialize a DomainClassifierCalculator results object.
|
97
|
+
|
98
|
+
Parameters
|
99
|
+
----------
|
100
|
+
results_data : pd.DataFrame
|
101
|
+
Results data returned by a DomainClassifierCalculator.
|
102
|
+
"""
|
103
|
+
metric = Metric(display_name="Domain Classifier", column_name="domain_classifier_auroc")
|
104
|
+
super().__init__(results_data, [metric])
|
105
|
+
|
106
|
+
def plot(self, showme: bool = True) -> Figure:
|
107
|
+
"""
|
108
|
+
Render the roc_auc metric over the train/test data in relation to the threshold.
|
109
|
+
|
110
|
+
Parameters
|
111
|
+
----------
|
112
|
+
showme : bool, default True
|
113
|
+
Option to display the figure.
|
114
|
+
|
115
|
+
Returns
|
116
|
+
-------
|
117
|
+
matplotlib.figure.Figure
|
118
|
+
|
119
|
+
"""
|
120
|
+
import matplotlib.pyplot as plt
|
121
|
+
|
122
|
+
fig, ax = plt.subplots(dpi=300)
|
123
|
+
resdf = self.to_df()
|
124
|
+
xticks = np.arange(resdf.shape[0])
|
125
|
+
trndf = resdf[resdf["chunk"]["period"] == "reference"]
|
126
|
+
tstdf = resdf[resdf["chunk"]["period"] == "analysis"]
|
127
|
+
# Get local indices for drift markers
|
128
|
+
driftx = np.where(resdf["domain_classifier_auroc"]["alert"].values) # type: ignore | dataframe
|
129
|
+
if np.size(driftx) > 2:
|
130
|
+
ax.plot(resdf.index, resdf["domain_classifier_auroc"]["upper_threshold"], "r--", label="thr_up")
|
131
|
+
ax.plot(resdf.index, resdf["domain_classifier_auroc"]["lower_threshold"], "r--", label="thr_low")
|
132
|
+
ax.plot(trndf.index, trndf["domain_classifier_auroc"]["value"], "b", label="train")
|
133
|
+
ax.plot(tstdf.index, tstdf["domain_classifier_auroc"]["value"], "g", label="test")
|
134
|
+
ax.plot(
|
135
|
+
resdf.index.values[driftx], # type: ignore | dataframe
|
136
|
+
resdf["domain_classifier_auroc"]["value"].values[driftx], # type: ignore | dataframe
|
137
|
+
"dm",
|
138
|
+
markersize=3,
|
139
|
+
label="drift",
|
140
|
+
)
|
141
|
+
ax.set_xticks(xticks)
|
142
|
+
ax.tick_params(axis="x", labelsize=6)
|
143
|
+
ax.tick_params(axis="y", labelsize=6)
|
144
|
+
ax.legend(loc="lower left", fontsize=6)
|
145
|
+
ax.set_title("Domain Classifier, Drift Detection", fontsize=8)
|
146
|
+
ax.set_ylabel("ROC AUC", fontsize=7)
|
147
|
+
ax.set_xlabel("Chunk Index", fontsize=7)
|
148
|
+
ax.set_ylim((0.0, 1.1))
|
149
|
+
if showme:
|
150
|
+
plt.show()
|
151
|
+
return fig
|
dataeval/outputs/_linters.py
CHANGED
@@ -2,15 +2,12 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
-
import contextlib
|
6
5
|
from dataclasses import dataclass
|
7
6
|
from typing import Generic, TypeVar, Union
|
8
7
|
|
8
|
+
import pandas as pd
|
9
9
|
from typing_extensions import TypeAlias
|
10
10
|
|
11
|
-
with contextlib.suppress(ImportError):
|
12
|
-
import pandas as pd
|
13
|
-
|
14
11
|
from dataeval.outputs._base import Output
|
15
12
|
from dataeval.outputs._stats import DimensionStatsOutput, LabelStatsOutput, PixelStatsOutput, VisualStatsOutput
|
16
13
|
|
@@ -168,8 +165,6 @@ class OutliersOutput(Output, Generic[TIndexIssueMap]):
|
|
168
165
|
-----
|
169
166
|
This method requires `pandas <https://pandas.pydata.org/>`_ to be installed.
|
170
167
|
"""
|
171
|
-
import pandas as pd
|
172
|
-
|
173
168
|
if isinstance(self.issues, dict):
|
174
169
|
_, classwise = _reorganize_by_class_and_metric(self.issues, labelstats)
|
175
170
|
data = _create_pandas_dataframe(classwise)
|
dataeval/outputs/_stats.py
CHANGED
@@ -2,17 +2,14 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
-
import contextlib
|
6
5
|
from dataclasses import dataclass
|
7
6
|
from typing import Any, Iterable, NamedTuple, Optional, Union
|
8
7
|
|
9
8
|
import numpy as np
|
9
|
+
import pandas as pd
|
10
10
|
from numpy.typing import NDArray
|
11
11
|
from typing_extensions import TypeAlias
|
12
12
|
|
13
|
-
with contextlib.suppress(ImportError):
|
14
|
-
import pandas as pd
|
15
|
-
|
16
13
|
from dataeval.outputs._base import Output
|
17
14
|
from dataeval.utils._plot import channel_histogram_plot, histogram_plot
|
18
15
|
|
@@ -281,8 +278,6 @@ class LabelStatsOutput(Output):
|
|
281
278
|
-------
|
282
279
|
pd.DataFrame
|
283
280
|
"""
|
284
|
-
import pandas as pd
|
285
|
-
|
286
281
|
total_count = []
|
287
282
|
image_count = []
|
288
283
|
for cls in range(len(self.class_names)):
|
dataeval/typing.py
CHANGED
@@ -98,6 +98,22 @@ class DatasetMetadata(TypedDict, total=False):
|
|
98
98
|
index2label: NotRequired[ReadOnly[dict[int, str]]]
|
99
99
|
|
100
100
|
|
101
|
+
class ModelMetadata(TypedDict, total=False):
|
102
|
+
"""
|
103
|
+
Model metadata required for all `AnnotatedModel` classes.
|
104
|
+
|
105
|
+
Attributes
|
106
|
+
----------
|
107
|
+
id : Required[str]
|
108
|
+
A unique identifier for the model
|
109
|
+
index2label : NotRequired[dict[int, str]]
|
110
|
+
A lookup table converting label value to class name
|
111
|
+
"""
|
112
|
+
|
113
|
+
id: Required[ReadOnly[str]]
|
114
|
+
index2label: NotRequired[ReadOnly[dict[int, str]]]
|
115
|
+
|
116
|
+
|
101
117
|
@runtime_checkable
|
102
118
|
class Dataset(Generic[_T_co], Protocol):
|
103
119
|
"""
|
@@ -238,6 +254,21 @@ SegmentationDataset: TypeAlias = AnnotatedDataset[SegmentationDatum]
|
|
238
254
|
Type alias for an :class:`AnnotatedDataset` of :class:`SegmentationDatum` elements.
|
239
255
|
"""
|
240
256
|
|
257
|
+
# ========== MODEL ==========
|
258
|
+
|
259
|
+
|
260
|
+
@runtime_checkable
|
261
|
+
class AnnotatedModel(Protocol):
|
262
|
+
"""
|
263
|
+
Protocol for an annotated model.
|
264
|
+
"""
|
265
|
+
|
266
|
+
@property
|
267
|
+
def metadata(self) -> ModelMetadata: ...
|
268
|
+
|
269
|
+
|
270
|
+
# ========== TRANSFORM ==========
|
271
|
+
|
241
272
|
|
242
273
|
@runtime_checkable
|
243
274
|
class Transform(Generic[_T], Protocol):
|
dataeval/utils/__init__.py
CHANGED
@@ -4,6 +4,6 @@ in setting up data and architectures that are guaranteed to work with applicable
|
|
4
4
|
DataEval metrics.
|
5
5
|
"""
|
6
6
|
|
7
|
-
__all__ = ["data", "
|
7
|
+
__all__ = ["data", "datasets", "torch"]
|
8
8
|
|
9
|
-
from . import data,
|
9
|
+
from . import data, datasets, torch
|
dataeval/utils/data/__init__.py
CHANGED
@@ -1,26 +1,11 @@
|
|
1
|
-
"""Provides
|
1
|
+
"""Provides access to common Computer Vision datasets."""
|
2
|
+
|
3
|
+
from dataeval.utils.data import collate, metadata
|
4
|
+
from dataeval.utils.data._dataset import to_image_classification_dataset, to_object_detection_dataset
|
2
5
|
|
3
6
|
__all__ = [
|
4
7
|
"collate",
|
5
|
-
"
|
6
|
-
"Embeddings",
|
7
|
-
"Images",
|
8
|
-
"Metadata",
|
9
|
-
"Select",
|
10
|
-
"SplitDatasetOutput",
|
11
|
-
"Targets",
|
12
|
-
"split_dataset",
|
8
|
+
"metadata",
|
13
9
|
"to_image_classification_dataset",
|
14
10
|
"to_object_detection_dataset",
|
15
11
|
]
|
16
|
-
|
17
|
-
from dataeval.outputs._utils import SplitDatasetOutput
|
18
|
-
from dataeval.utils.data._dataset import to_image_classification_dataset, to_object_detection_dataset
|
19
|
-
from dataeval.utils.data._embeddings import Embeddings
|
20
|
-
from dataeval.utils.data._images import Images
|
21
|
-
from dataeval.utils.data._metadata import Metadata
|
22
|
-
from dataeval.utils.data._selection import Select
|
23
|
-
from dataeval.utils.data._split import split_dataset
|
24
|
-
from dataeval.utils.data._targets import Targets
|
25
|
-
|
26
|
-
from . import collate, datasets
|
dataeval/utils/data/collate.py
CHANGED
@@ -4,6 +4,8 @@ Collate functions used with a PyTorch DataLoader to load data from MAITE complia
|
|
4
4
|
|
5
5
|
from __future__ import annotations
|
6
6
|
|
7
|
+
__all__ = ["list_collate_fn", "numpy_collate_fn", "torch_collate_fn"]
|
8
|
+
|
7
9
|
from typing import Any, Iterable, Sequence, TypeVar
|
8
10
|
|
9
11
|
import numpy as np
|
@@ -0,0 +1,17 @@
|
|
1
|
+
"""Provides access to common Computer Vision datasets."""
|
2
|
+
|
3
|
+
from dataeval.utils.datasets._cifar10 import CIFAR10
|
4
|
+
from dataeval.utils.datasets._milco import MILCO
|
5
|
+
from dataeval.utils.datasets._mnist import MNIST
|
6
|
+
from dataeval.utils.datasets._ships import Ships
|
7
|
+
from dataeval.utils.datasets._voc import VOCDetection, VOCDetectionTorch, VOCSegmentation
|
8
|
+
|
9
|
+
__all__ = [
|
10
|
+
"MNIST",
|
11
|
+
"Ships",
|
12
|
+
"CIFAR10",
|
13
|
+
"MILCO",
|
14
|
+
"VOCDetection",
|
15
|
+
"VOCDetectionTorch",
|
16
|
+
"VOCSegmentation",
|
17
|
+
]
|
@@ -6,9 +6,9 @@ from abc import abstractmethod
|
|
6
6
|
from pathlib import Path
|
7
7
|
from typing import TYPE_CHECKING, Any, Generic, Iterator, Literal, NamedTuple, Sequence, TypeVar
|
8
8
|
|
9
|
-
from dataeval.utils.
|
10
|
-
from dataeval.utils.
|
11
|
-
from dataeval.utils.
|
9
|
+
from dataeval.utils.datasets._fileio import _ensure_exists
|
10
|
+
from dataeval.utils.datasets._mixin import BaseDatasetMixin
|
11
|
+
from dataeval.utils.datasets._types import (
|
12
12
|
AnnotatedDataset,
|
13
13
|
DatasetMetadata,
|
14
14
|
ImageClassificationDataset,
|
@@ -9,8 +9,8 @@ import numpy as np
|
|
9
9
|
from numpy.typing import NDArray
|
10
10
|
from PIL import Image
|
11
11
|
|
12
|
-
from dataeval.utils.
|
13
|
-
from dataeval.utils.
|
12
|
+
from dataeval.utils.datasets._base import BaseICDataset, DataLocation
|
13
|
+
from dataeval.utils.datasets._mixin import BaseDatasetNumpyMixin
|
14
14
|
|
15
15
|
if TYPE_CHECKING:
|
16
16
|
from dataeval.typing import Transform
|
@@ -7,8 +7,8 @@ from typing import TYPE_CHECKING, Any, Literal, Sequence
|
|
7
7
|
|
8
8
|
from numpy.typing import NDArray
|
9
9
|
|
10
|
-
from dataeval.utils.
|
11
|
-
from dataeval.utils.
|
10
|
+
from dataeval.utils.datasets._base import BaseODDataset, DataLocation
|
11
|
+
from dataeval.utils.datasets._mixin import BaseDatasetNumpyMixin
|
12
12
|
|
13
13
|
if TYPE_CHECKING:
|
14
14
|
from dataeval.typing import Transform
|
@@ -8,8 +8,8 @@ from typing import TYPE_CHECKING, Any, Literal, Sequence, TypeVar
|
|
8
8
|
import numpy as np
|
9
9
|
from numpy.typing import NDArray
|
10
10
|
|
11
|
-
from dataeval.utils.
|
12
|
-
from dataeval.utils.
|
11
|
+
from dataeval.utils.datasets._base import BaseICDataset, DataLocation
|
12
|
+
from dataeval.utils.datasets._mixin import BaseDatasetNumpyMixin
|
13
13
|
|
14
14
|
if TYPE_CHECKING:
|
15
15
|
from dataeval.typing import Transform
|
@@ -8,8 +8,8 @@ from typing import TYPE_CHECKING, Any, Sequence
|
|
8
8
|
import numpy as np
|
9
9
|
from numpy.typing import NDArray
|
10
10
|
|
11
|
-
from dataeval.utils.
|
12
|
-
from dataeval.utils.
|
11
|
+
from dataeval.utils.datasets._base import BaseICDataset, DataLocation
|
12
|
+
from dataeval.utils.datasets._mixin import BaseDatasetNumpyMixin
|
13
13
|
|
14
14
|
if TYPE_CHECKING:
|
15
15
|
from dataeval.typing import Transform
|
@@ -9,7 +9,7 @@ import torch
|
|
9
9
|
from defusedxml.ElementTree import parse
|
10
10
|
from numpy.typing import NDArray
|
11
11
|
|
12
|
-
from dataeval.utils.
|
12
|
+
from dataeval.utils.datasets._base import (
|
13
13
|
BaseDataset,
|
14
14
|
BaseODDataset,
|
15
15
|
BaseSegDataset,
|
@@ -17,8 +17,8 @@ from dataeval.utils.data.datasets._base import (
|
|
17
17
|
_TArray,
|
18
18
|
_TTarget,
|
19
19
|
)
|
20
|
-
from dataeval.utils.
|
21
|
-
from dataeval.utils.
|
20
|
+
from dataeval.utils.datasets._mixin import BaseDatasetNumpyMixin, BaseDatasetTorchMixin
|
21
|
+
from dataeval.utils.datasets._types import ObjectDetectionTarget, SegmentationTarget
|
22
22
|
|
23
23
|
if TYPE_CHECKING:
|
24
24
|
from dataeval.typing import Transform
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: dataeval
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.86.0
|
4
4
|
Summary: DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks
|
5
5
|
Home-page: https://dataeval.ai/
|
6
6
|
License: MIT
|
@@ -23,10 +23,11 @@ Classifier: Topic :: Scientific/Engineering
|
|
23
23
|
Provides-Extra: all
|
24
24
|
Requires-Dist: defusedxml (>=0.7.1)
|
25
25
|
Requires-Dist: fast_hdbscan (==0.2.0)
|
26
|
+
Requires-Dist: lightgbm (>=4)
|
26
27
|
Requires-Dist: matplotlib (>=3.7.1) ; extra == "all"
|
27
28
|
Requires-Dist: numba (>=0.59.1)
|
28
29
|
Requires-Dist: numpy (>=1.24.2)
|
29
|
-
Requires-Dist: pandas (>=2.0)
|
30
|
+
Requires-Dist: pandas (>=2.0)
|
30
31
|
Requires-Dist: pillow (>=10.3.0)
|
31
32
|
Requires-Dist: requests
|
32
33
|
Requires-Dist: scikit-learn (>=1.5.0)
|
@@ -0,0 +1,114 @@
|
|
1
|
+
dataeval/__init__.py,sha256=GdieNQ3woQUTyIFklJx7AgEeiBCz9gXzo-UVt6YFHPo,1636
|
2
|
+
dataeval/_log.py,sha256=Mn5bRWO0cgtAYd5VGYSFiPgu57ta3zoktrtHAZ1m3dU,357
|
3
|
+
dataeval/config.py,sha256=lD1YDH8HosFeRU5rQEYRBcmXMZy-csWaMlJTRZGd9iU,3582
|
4
|
+
dataeval/data/__init__.py,sha256=qNnRRiVP_sLthkkHpUrMgI_r8dQK-cC-xoGrrjQeRKc,544
|
5
|
+
dataeval/data/_embeddings.py,sha256=6Medqj_JCQt1iwZwWGSs1OeX-bHB8bg5BJqADY1N2s8,12883
|
6
|
+
dataeval/data/_images.py,sha256=WF9XJRka8ohUdyI2IKBMAy3JoJhOm1iC-8tbYl8woRM,2642
|
7
|
+
dataeval/data/_metadata.py,sha256=mK-WbrFkMo3v8f66uHT4B6-Fsc1odh0CcMTuz2aXSZc,14968
|
8
|
+
dataeval/data/_selection.py,sha256=rYCM4KTqLSOYOzyjKCQKH2KQgJhxNnB2g3pY4JbOEYc,4503
|
9
|
+
dataeval/data/_split.py,sha256=6Jtm_i__CcPtNE3eSeBdPxc7gn7Cp-GM7g9wJWFlVus,16761
|
10
|
+
dataeval/data/_targets.py,sha256=ws5d9wRiDkIuOV7GSAKNxzgSm6AWTgb0BFroQK5nAmM,3057
|
11
|
+
dataeval/data/selections/__init__.py,sha256=2m8ZB53wXzqLcqmc6p5atO6graB6ZyiRSNJFxf11X_g,613
|
12
|
+
dataeval/data/selections/_classbalance.py,sha256=7v8ApoL3X8eCZ6fGDNTehE_bZ1loaP3TlhsJLaICVWg,1458
|
13
|
+
dataeval/data/selections/_classfilter.py,sha256=VSNl_BSPRHQOBU6GYQwPZhl7j2jYESVJSSdyqWiG_vA,4394
|
14
|
+
dataeval/data/selections/_indices.py,sha256=RFsR9z10aM3N0gJSfKrukFpi-LkiQGXoOwXhmOQ5cpg,630
|
15
|
+
dataeval/data/selections/_limit.py,sha256=JG4GmEiNKt3sk4PbOUbBnGGzNlyz72H-kQrt8COMm4Y,512
|
16
|
+
dataeval/data/selections/_prioritize.py,sha256=yw51ZQk6FPvyC38M4_pS_Se2Dq0LDFcdDhfbsELzTZc,11306
|
17
|
+
dataeval/data/selections/_reverse.py,sha256=b67kNC43A5KpQOic5gifjo9HpJ7FMh4LFCrfovPiJ-M,368
|
18
|
+
dataeval/data/selections/_shuffle.py,sha256=gVz_2T4rlucq8Ytqz5jvmmZdTrZDaIv43jJbq97tLjQ,1173
|
19
|
+
dataeval/detectors/__init__.py,sha256=3Sg-XWlwr75zEEH3hZKA4nWMtGvaRlnfzTWvZG_Ak6U,189
|
20
|
+
dataeval/detectors/drift/__init__.py,sha256=Jqv98oOVeC2tvHlNGxQ8RJ6De2q4SyS5lTpaYlb4ocM,756
|
21
|
+
dataeval/detectors/drift/_base.py,sha256=amGqzUAe8fU5qwM5lq1p8PCuhjGh9MHkdW1zeBF1LEE,7574
|
22
|
+
dataeval/detectors/drift/_cvm.py,sha256=cS33zWJmFY1fft1XcANcP2jSD5ou7TxvIU2AldhTynM,3004
|
23
|
+
dataeval/detectors/drift/_ks.py,sha256=uMc5-NA-lSV1IODrY8uJe87ll3uRJT_oXLJFXy95M1w,3186
|
24
|
+
dataeval/detectors/drift/_mmd.py,sha256=wHUy_vUafCikrZ_WX8qQXpxFwzw07-5zVutloR6hl1k,11589
|
25
|
+
dataeval/detectors/drift/_mvdc.py,sha256=ABxGut6KzxF_oM-Hs87WARCR0692dhPVdZNoGGwJaa4,3058
|
26
|
+
dataeval/detectors/drift/_nml/__init__.py,sha256=MNyKyZlfTjr5uQql2uBBfRkUdsuduie_WJdn09GYmqg,137
|
27
|
+
dataeval/detectors/drift/_nml/_base.py,sha256=g8RmOnsBVN8vV1S9B9JaQQLudcbyKERwy4OuDjGIxb8,2632
|
28
|
+
dataeval/detectors/drift/_nml/_chunk.py,sha256=QxohvSycm_cjldmK-ll-APfIsopPgeATHV-9aejyIKE,13826
|
29
|
+
dataeval/detectors/drift/_nml/_domainclassifier.py,sha256=ccb1tgJ_K7gMYtg1Wdy2gPIpYIhconHQVu3xW5v0hjs,7743
|
30
|
+
dataeval/detectors/drift/_nml/_result.py,sha256=mnWnP1CwzrDChJygcsuFhkKR5g3yAQS520oo-l9PcZU,3273
|
31
|
+
dataeval/detectors/drift/_nml/_thresholds.py,sha256=jnhfd0qR99TKF0PyUVcbtE7cj9lic0QxwrWq_fwoAHM,12687
|
32
|
+
dataeval/detectors/drift/_uncertainty.py,sha256=BHlykJ-r7TGLJxdPfoazXnoAJ1qVDzbk5HjAMdsnHz8,5847
|
33
|
+
dataeval/detectors/drift/updates.py,sha256=L1PnrPlIE1x6ujCc5mCwjcAZwadVTn-Zjb6MnTDvzJQ,2251
|
34
|
+
dataeval/detectors/linters/__init__.py,sha256=xn2zPwUcmsuf-Jd9uw6AVI11C9z1b1Y9fYtuFnXenZ0,404
|
35
|
+
dataeval/detectors/linters/duplicates.py,sha256=X5WSEvI_BHkLoXjkaHK6wTnSkx4IjpO_exMRjSlhc70,4963
|
36
|
+
dataeval/detectors/linters/outliers.py,sha256=D8A-Fov5iUrlU9xMX5Ht33FqUY8Lk5ulC6BlHbUoLwU,9048
|
37
|
+
dataeval/detectors/ood/__init__.py,sha256=juCYBDs7CQEAtMhnEpPqF6uTrOIH9kTBSuQ_GRw6a8o,283
|
38
|
+
dataeval/detectors/ood/ae.py,sha256=fTrUfFxv6xUqzKpwMC8rW3JrizA16M_bgzqLuBKMrS0,2944
|
39
|
+
dataeval/detectors/ood/base.py,sha256=9b-Ljznf0lB1SXF4F_Aj3eJ4Y3ijGEDPMjucUsWOGJM,3051
|
40
|
+
dataeval/detectors/ood/mixin.py,sha256=0_o-1HPvgf3-Lf1MSOIfjj5UB8LTLEBGYtJJfyCCzwc,5431
|
41
|
+
dataeval/detectors/ood/vae.py,sha256=Fcq0-WbLhzYCgYOAJPBklHm7yuXmFJuEpBkhgwM5kiA,2291
|
42
|
+
dataeval/metadata/__init__.py,sha256=XDDmJbOZBNM6pL0r6Nbu6oMRoyAh22IDkPYGndNlkZU,316
|
43
|
+
dataeval/metadata/_distance.py,sha256=T1Umju_QwBiLmn1iUbxZagzBS2VnHaDIdp6j-NpaZuk,4076
|
44
|
+
dataeval/metadata/_ood.py,sha256=lnKtKModArnUrAhH_XswEtUAhUkh1U_oNsLt1UmNP44,12748
|
45
|
+
dataeval/metadata/_utils.py,sha256=r8qBJT83RblobD5W5zyTVi6vYi51Dwkqswizdbzss-M,1169
|
46
|
+
dataeval/metrics/__init__.py,sha256=8VC8q3HuJN3o_WN51Ae2_wXznl3RMXIvA5GYVcy7vr8,225
|
47
|
+
dataeval/metrics/bias/__init__.py,sha256=329S1_3WnWqeU4-qVcbe0fMy4lDrj9uKslWHIQf93yg,839
|
48
|
+
dataeval/metrics/bias/_balance.py,sha256=l1hTVkVwD85bP20MTthA-I5BkvbytylQkJu3Q6iTuPA,6152
|
49
|
+
dataeval/metrics/bias/_completeness.py,sha256=BysXU2Jpw33n5dl3acJFEqF3mFGiJLsfG4n5Q2fkTaY,4608
|
50
|
+
dataeval/metrics/bias/_coverage.py,sha256=PeUoOiaghUEdn6Ov8z2-am7-fnBVIPcFbJK7Ty5JObA,3647
|
51
|
+
dataeval/metrics/bias/_diversity.py,sha256=B_qWVDMZfh818U0qVm8yidquB0H0XvW8N75OWVWXy2g,5814
|
52
|
+
dataeval/metrics/bias/_parity.py,sha256=ea1D-eJh6cJxQ11XD6VbDXBKecE0jJJwptGD7LQJmBw,11529
|
53
|
+
dataeval/metrics/estimators/__init__.py,sha256=Pnds8uIyAovt2fKqZjiHCIP_kVoBWlVllekYuK5UmmU,568
|
54
|
+
dataeval/metrics/estimators/_ber.py,sha256=C30E5LiGGTAfo31zWFYDptDg0R7CTJGJ-a60YgzSkYY,5382
|
55
|
+
dataeval/metrics/estimators/_clusterer.py,sha256=1HrpihGTJ63IkNSOy4Ibw633Gllkm1RxKmoKT5MOgt0,1434
|
56
|
+
dataeval/metrics/estimators/_divergence.py,sha256=QDWl1lyAYoO9D3Ho7qOHSk6ud8Gi2MGuXEsYwO1HxvA,4043
|
57
|
+
dataeval/metrics/estimators/_uap.py,sha256=BULEBbJ9BQ1IcTeZf0x7iI60QHAWCccBOM97FIu9VXA,1928
|
58
|
+
dataeval/metrics/stats/__init__.py,sha256=6tA_9nbbM5ObJ6cds8Y1VBtTQiTOxrpGQSFLu_lWGGA,1098
|
59
|
+
dataeval/metrics/stats/_base.py,sha256=YIfOVGd7E19B4dpAnzDYRQkaikvRRyJIpznJNfVtPdw,10750
|
60
|
+
dataeval/metrics/stats/_boxratiostats.py,sha256=8Kd2FTZ5PLNYZfdAjU_R385gb0Z16JY0L9H_d5ZhgQs,6341
|
61
|
+
dataeval/metrics/stats/_dimensionstats.py,sha256=73mFP-Myxne0peFliwvTntc0kk4cpq0krzMvSLDSIMM,2702
|
62
|
+
dataeval/metrics/stats/_hashstats.py,sha256=gp9X_pnTT3mPH9YNrWLdn2LQPK_epJ3dQRoyOCwmKlg,4758
|
63
|
+
dataeval/metrics/stats/_imagestats.py,sha256=gUPNgN5Zwzdr7WnSwbve1NXNsyxd5dy3cSnlR_7guCg,3007
|
64
|
+
dataeval/metrics/stats/_labelstats.py,sha256=lz8I6eSd8tFkmQqy5cOG8hn9yxs0mP-Ic9ratFHiuoU,2813
|
65
|
+
dataeval/metrics/stats/_pixelstats.py,sha256=SfergRbjNJE4h0xqe-0c8RnKtZmEkZ9MwExdipLSGvg,3247
|
66
|
+
dataeval/metrics/stats/_visualstats.py,sha256=cq4AbF2B50Ihbzb86FphcnKQ1TSwNnP3PsnbpiPQZWw,3698
|
67
|
+
dataeval/outputs/__init__.py,sha256=geHB5M3QOiFFaQGV4ZwDTTKpqZPvPePbqG7lzaPhaXQ,1741
|
68
|
+
dataeval/outputs/_base.py,sha256=aZFbgybnZSQ3ws7QYRLTbDFqUfBFRVtIwX2LZfeGFUA,5703
|
69
|
+
dataeval/outputs/_bias.py,sha256=_4qgboPstvEFBjTPZOVAOOaXb_BMARLiHY_ElA5wD8E,12368
|
70
|
+
dataeval/outputs/_drift.py,sha256=kS6gGfaf0XOivf1D8go2fzF5yxl0EHlWFlkwv-4LMNI,4770
|
71
|
+
dataeval/outputs/_estimators.py,sha256=a2oAIxxEDZ9WLGfMWH8KD-BVUS_SnULRPR-iI9hFPoQ,3047
|
72
|
+
dataeval/outputs/_linters.py,sha256=PqLa2wIAkwC-NCb5dhDN29PtTiCUk2TLDFpsMO7Awrc,6325
|
73
|
+
dataeval/outputs/_metadata.py,sha256=ffZgpX8KWURPHXpOWjbvJ2KRqWQkS2nWuIjKUzoHhMI,1710
|
74
|
+
dataeval/outputs/_ood.py,sha256=suLKVXULGtXH0rq9eXHI1d3d2jhGmItJtz4QiQd47A4,1718
|
75
|
+
dataeval/outputs/_stats.py,sha256=ACUzwsalDl-bV8llaBArZQ1tLj07RFvzmv-IXViAvSA,13089
|
76
|
+
dataeval/outputs/_utils.py,sha256=HHlGC7sk416m_3Bgn075Qdblz_aPup_UOafJpB0RuXY,893
|
77
|
+
dataeval/outputs/_workflows.py,sha256=MkRD6ubI4NCBXb9v3kjXy64cUGs3G-JKkBdOpRD9XVE,10750
|
78
|
+
dataeval/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
79
|
+
dataeval/typing.py,sha256=GDMuef-oFFukNtsiKFmsExHdNvYR_j-tQcsCwZ9reow,7198
|
80
|
+
dataeval/utils/__init__.py,sha256=hRvyUK7b3d6JBEV5u47rFcOHEcmDYqAvZQw_T5pDAWw,264
|
81
|
+
dataeval/utils/_array.py,sha256=KqAdXEMjcXYvdWdYEEoEbigwQJ4S9VYxQS3sRFeY5XY,5929
|
82
|
+
dataeval/utils/_bin.py,sha256=nylthmsC3vzLHLhlUMACvZs--h7xvAh9Pt75InaQJW8,7322
|
83
|
+
dataeval/utils/_clusterer.py,sha256=fw5x-2QN0TIbiodDKHZxRgxKHINedpPcOklzce0Rbjg,5436
|
84
|
+
dataeval/utils/_fast_mst.py,sha256=4_7ykVihCL5jWtxcGnrecIsDQo65kUml9SZ1JxgBZYY,7172
|
85
|
+
dataeval/utils/_image.py,sha256=capzF_X5H0jy0PmTP3Hf52GFgLqrnfU6gS4tiwck9jo,1939
|
86
|
+
dataeval/utils/_method.py,sha256=9B9JQbgqWJBRhQJb7glajUtWaQzUTIUuvrZ9_bisxsM,394
|
87
|
+
dataeval/utils/_mst.py,sha256=f0vXytTUjlOS6AyL7c6PkXmaHuuGUK-vMLpq-5xMgxk,2183
|
88
|
+
dataeval/utils/_plot.py,sha256=mTRQNbJsA42QMiOwZbJaH8sNYgP996QFDEGVVE9HSgY,7076
|
89
|
+
dataeval/utils/data/__init__.py,sha256=xGzrjrOxOP2DP1tU84AWMKPnSxFvSjM81CTlDg4rNM8,331
|
90
|
+
dataeval/utils/data/_dataset.py,sha256=MHY582yRm4FxQkkLWUhKZBb7ZyvWypM6ldUG89vd3uE,7936
|
91
|
+
dataeval/utils/data/collate.py,sha256=5egEEKhNNCGeNLChO1p6dZ4Wg6x51VEaMNHz7hEZUxI,3936
|
92
|
+
dataeval/utils/data/metadata.py,sha256=1XeGYj_e97-nJ_IrWEHPhWICmouYU5qbXWbp7uhZrIE,14171
|
93
|
+
dataeval/utils/datasets/__init__.py,sha256=Jfe7XI_9U5S4wuI_2QCoeuWNOxz4j0nAQvxc5wG5mWY,486
|
94
|
+
dataeval/utils/datasets/_base.py,sha256=TpmgPzF3EShCLAF5S4Zf9lFN78q17bTZF6AUE1qKdlk,8857
|
95
|
+
dataeval/utils/datasets/_cifar10.py,sha256=oSX5JEzbBM4zGC9kC7-hVTOglms3rYaUuYiA00_DUJ4,5439
|
96
|
+
dataeval/utils/datasets/_fileio.py,sha256=SixIk5nIlIwJdX9zjNXS10vHA3hL8aaYbqHsDg1xSpY,6447
|
97
|
+
dataeval/utils/datasets/_milco.py,sha256=BF2XvyzuOop1mg5pFZcRfYmZcezlbpZWHyd_TtEHFF4,7573
|
98
|
+
dataeval/utils/datasets/_mixin.py,sha256=FJgZP_cpJkgAHA3j3ai_j3Wt7aFSEjIMVmt9NpvVXzg,1757
|
99
|
+
dataeval/utils/datasets/_mnist.py,sha256=4WOkQTORYMs6KEeyyJgChTnH03797y4ezgaZtYqplh4,8102
|
100
|
+
dataeval/utils/datasets/_ships.py,sha256=RMdX2KlnXJYOTzBb6euA5TAqxs-S8b56pAGiyQhNMuo,4870
|
101
|
+
dataeval/utils/datasets/_types.py,sha256=iSKyHXRlGuomXs0FHK6md8lXLQrQQ4fxgVOwr4o81bo,1089
|
102
|
+
dataeval/utils/datasets/_voc.py,sha256=kif6ms_romK6VElP4pf2SK4cJ5dEHDOkxSaSaeP3c5k,15565
|
103
|
+
dataeval/utils/torch/__init__.py,sha256=dn5mjCrFp0b1aL_UEURhONU0Ag0cmXoTOBSGagpkTiA,325
|
104
|
+
dataeval/utils/torch/_blocks.py,sha256=HVhBTMMD5NA4qheMUgyol1KWiKZDIuc8k5j4RcMKmhk,1466
|
105
|
+
dataeval/utils/torch/_gmm.py,sha256=XM68GNEP97EjaB1U49-ZXRb81d0CEFnPS910alrcB3g,3740
|
106
|
+
dataeval/utils/torch/_internal.py,sha256=vHy-DzPhmvE8h3wmWc3aciBJ8nDGzQ1z1jTZgGjmDyM,4154
|
107
|
+
dataeval/utils/torch/models.py,sha256=hmroEs6C6jQ5tAoZa71RFeIvXLxfXrTJSFH_jG2LGQU,9749
|
108
|
+
dataeval/utils/torch/trainer.py,sha256=iUotX4OdirH8-ZtjdpU8gbJavkYW9YY9qpA2mAlFy1Y,5520
|
109
|
+
dataeval/workflows/__init__.py,sha256=ou8y0KO-d6W5lgmcyLjKlf-J_ckP3vilW7wHkgiDlZ4,255
|
110
|
+
dataeval/workflows/sufficiency.py,sha256=mjKmfRrAjShLUFIARv5o8yT5fnFvDsS5Qu6ujIPUgQg,8497
|
111
|
+
dataeval-0.86.0.dist-info/LICENSE.txt,sha256=uAooygKWvX6NbU9Ran9oG2msttoG8aeTeHSTe5JeCnY,1061
|
112
|
+
dataeval-0.86.0.dist-info/METADATA,sha256=viF0VCgv5_1SzwfTVCTNdbw1q5k1D3hgJhB7PoZ1tCM,5321
|
113
|
+
dataeval-0.86.0.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
|
114
|
+
dataeval-0.86.0.dist-info/RECORD,,
|
@@ -1,17 +0,0 @@
|
|
1
|
-
"""Provides access to common Computer Vision datasets."""
|
2
|
-
|
3
|
-
from dataeval.utils.data.datasets._cifar10 import CIFAR10
|
4
|
-
from dataeval.utils.data.datasets._milco import MILCO
|
5
|
-
from dataeval.utils.data.datasets._mnist import MNIST
|
6
|
-
from dataeval.utils.data.datasets._ships import Ships
|
7
|
-
from dataeval.utils.data.datasets._voc import VOCDetection, VOCDetectionTorch, VOCSegmentation
|
8
|
-
|
9
|
-
__all__ = [
|
10
|
-
"MNIST",
|
11
|
-
"Ships",
|
12
|
-
"CIFAR10",
|
13
|
-
"MILCO",
|
14
|
-
"VOCDetection",
|
15
|
-
"VOCDetectionTorch",
|
16
|
-
"VOCSegmentation",
|
17
|
-
]
|
@@ -1,19 +0,0 @@
|
|
1
|
-
"""Provides selection classes for selecting subsets of Computer Vision datasets."""
|
2
|
-
|
3
|
-
__all__ = [
|
4
|
-
"ClassBalance",
|
5
|
-
"ClassFilter",
|
6
|
-
"Indices",
|
7
|
-
"Limit",
|
8
|
-
"Prioritize",
|
9
|
-
"Reverse",
|
10
|
-
"Shuffle",
|
11
|
-
]
|
12
|
-
|
13
|
-
from dataeval.utils.data.selections._classbalance import ClassBalance
|
14
|
-
from dataeval.utils.data.selections._classfilter import ClassFilter
|
15
|
-
from dataeval.utils.data.selections._indices import Indices
|
16
|
-
from dataeval.utils.data.selections._limit import Limit
|
17
|
-
from dataeval.utils.data.selections._prioritize import Prioritize
|
18
|
-
from dataeval.utils.data.selections._reverse import Reverse
|
19
|
-
from dataeval.utils.data.selections._shuffle import Shuffle
|
@@ -1,44 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
__all__ = []
|
4
|
-
|
5
|
-
from typing import Sequence
|
6
|
-
|
7
|
-
import numpy as np
|
8
|
-
|
9
|
-
from dataeval.typing import Array, ImageClassificationDatum
|
10
|
-
from dataeval.utils._array import as_numpy
|
11
|
-
from dataeval.utils.data._selection import Select, Selection, SelectionStage
|
12
|
-
|
13
|
-
|
14
|
-
class ClassFilter(Selection[ImageClassificationDatum]):
|
15
|
-
"""
|
16
|
-
Filter the dataset by class.
|
17
|
-
|
18
|
-
Parameters
|
19
|
-
----------
|
20
|
-
classes : Sequence[int]
|
21
|
-
The classes to filter by.
|
22
|
-
"""
|
23
|
-
|
24
|
-
stage = SelectionStage.FILTER
|
25
|
-
|
26
|
-
def __init__(self, classes: Sequence[int]) -> None:
|
27
|
-
self.classes = classes
|
28
|
-
|
29
|
-
def __call__(self, dataset: Select[ImageClassificationDatum]) -> None:
|
30
|
-
if not self.classes:
|
31
|
-
return
|
32
|
-
|
33
|
-
selection = []
|
34
|
-
for idx in dataset._selection:
|
35
|
-
target = dataset._dataset[idx][1]
|
36
|
-
if isinstance(target, Array):
|
37
|
-
label = int(np.argmax(as_numpy(target)))
|
38
|
-
else:
|
39
|
-
# ObjectDetectionTarget and SegmentationTarget not supported yet
|
40
|
-
raise TypeError("ClassFilter only supports classification targets as an array of confidence scores.")
|
41
|
-
if label in self.classes:
|
42
|
-
selection.append(idx)
|
43
|
-
|
44
|
-
dataset._selection = selection
|