dataeval 0.86.7__py3-none-any.whl → 0.86.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +10 -3
- dataeval/_version.py +21 -0
- dataeval/config.py +7 -1
- dataeval/detectors/drift/_mvdc.py +2 -9
- dataeval/detectors/drift/_nml/_chunk.py +2 -2
- dataeval/detectors/ood/ae.py +1 -1
- dataeval/detectors/ood/base.py +3 -3
- dataeval/metrics/bias/_completeness.py +3 -3
- dataeval/metrics/bias/_coverage.py +2 -2
- dataeval/metrics/bias/_parity.py +1 -1
- dataeval/metrics/estimators/_ber.py +2 -2
- dataeval/metrics/estimators/_divergence.py +2 -2
- dataeval/outputs/_estimators.py +6 -6
- dataeval/utils/_array.py +20 -9
- dataeval/utils/_clusterer.py +7 -7
- dataeval/utils/datasets/__init__.py +2 -0
- dataeval/utils/datasets/_antiuav.py +1 -1
- dataeval/utils/datasets/_base.py +12 -8
- dataeval/utils/datasets/_fileio.py +3 -3
- dataeval/utils/datasets/_milco.py +1 -1
- dataeval/utils/datasets/_seadrone.py +512 -0
- dataeval/utils/datasets/_voc.py +3 -3
- dataeval/utils/torch/_internal.py +3 -3
- dataeval/utils/torch/trainer.py +1 -1
- dataeval/workflows/sufficiency.py +53 -10
- {dataeval-0.86.7.dist-info → dataeval-0.86.9.dist-info}/METADATA +67 -47
- {dataeval-0.86.7.dist-info → dataeval-0.86.9.dist-info}/RECORD +33 -31
- {dataeval-0.86.7.dist-info → dataeval-0.86.9.dist-info}/WHEEL +1 -1
- {dataeval-0.86.7.dist-info → dataeval-0.86.9.dist-info/licenses}/LICENSE.txt +0 -0
dataeval/__init__.py
CHANGED
@@ -7,12 +7,19 @@ shifts that impact performance of deployed models.
|
|
7
7
|
|
8
8
|
from __future__ import annotations
|
9
9
|
|
10
|
-
|
11
|
-
|
10
|
+
try:
|
11
|
+
from ._version import __version__
|
12
|
+
except ImportError:
|
13
|
+
__version__ = "unknown"
|
14
|
+
|
15
|
+
# Strongly type for pyright
|
16
|
+
__version__ = str(__version__)
|
17
|
+
|
18
|
+
__all__ = ["__version__", "config", "detectors", "log", "metrics", "typing", "utils", "workflows"]
|
12
19
|
|
13
20
|
import logging
|
14
21
|
|
15
|
-
from
|
22
|
+
from . import config, detectors, metrics, typing, utils, workflows
|
16
23
|
|
17
24
|
logging.getLogger(__name__).addHandler(logging.NullHandler())
|
18
25
|
|
dataeval/_version.py
ADDED
@@ -0,0 +1,21 @@
|
|
1
|
+
# file generated by setuptools-scm
|
2
|
+
# don't change, don't track in version control
|
3
|
+
|
4
|
+
__all__ = ["__version__", "__version_tuple__", "version", "version_tuple"]
|
5
|
+
|
6
|
+
TYPE_CHECKING = False
|
7
|
+
if TYPE_CHECKING:
|
8
|
+
from typing import Tuple
|
9
|
+
from typing import Union
|
10
|
+
|
11
|
+
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
12
|
+
else:
|
13
|
+
VERSION_TUPLE = object
|
14
|
+
|
15
|
+
version: str
|
16
|
+
__version__: str
|
17
|
+
__version_tuple__: VERSION_TUPLE
|
18
|
+
version_tuple: VERSION_TUPLE
|
19
|
+
|
20
|
+
__version__ = version = '0.86.9'
|
21
|
+
__version_tuple__ = version_tuple = (0, 86, 9)
|
dataeval/config.py
CHANGED
@@ -77,7 +77,13 @@ def get_device(override: DeviceLike | None = None) -> torch.device:
|
|
77
77
|
"""
|
78
78
|
if override is None:
|
79
79
|
global _device
|
80
|
-
return
|
80
|
+
return (
|
81
|
+
torch.get_default_device()
|
82
|
+
if hasattr(torch, "get_default_device")
|
83
|
+
else torch.device("cpu")
|
84
|
+
if _device is None
|
85
|
+
else _device
|
86
|
+
)
|
81
87
|
return _todevice(override)
|
82
88
|
|
83
89
|
|
@@ -1,16 +1,9 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from typing import TYPE_CHECKING
|
4
|
-
|
5
3
|
import numpy as np
|
6
4
|
import pandas as pd
|
7
5
|
from numpy.typing import ArrayLike
|
8
6
|
|
9
|
-
if TYPE_CHECKING:
|
10
|
-
from typing import Self
|
11
|
-
else:
|
12
|
-
from typing_extensions import Self
|
13
|
-
|
14
7
|
from dataeval.detectors.drift._nml._chunk import CountBasedChunker, SizeBasedChunker
|
15
8
|
from dataeval.detectors.drift._nml._domainclassifier import DomainClassifierCalculator
|
16
9
|
from dataeval.detectors.drift._nml._thresholds import ConstantThreshold
|
@@ -52,7 +45,7 @@ class DriftMVDC:
|
|
52
45
|
threshold=ConstantThreshold(lower=self.threshold[0], upper=self.threshold[1]),
|
53
46
|
)
|
54
47
|
|
55
|
-
def fit(self, x_ref: ArrayLike) ->
|
48
|
+
def fit(self, x_ref: ArrayLike) -> DriftMVDC:
|
56
49
|
"""
|
57
50
|
Fit the domain classifier on the training dataframe
|
58
51
|
|
@@ -63,7 +56,7 @@ class DriftMVDC:
|
|
63
56
|
|
64
57
|
Returns
|
65
58
|
-------
|
66
|
-
|
59
|
+
DriftMVDC
|
67
60
|
|
68
61
|
"""
|
69
62
|
# for 1D input, assume that is 1 sample: dim[1,n_features]
|
@@ -46,10 +46,10 @@ class Chunk(ABC):
|
|
46
46
|
return self.data.shape[0]
|
47
47
|
|
48
48
|
@abstractmethod
|
49
|
-
def __add__(self, other:
|
49
|
+
def __add__(self, other: Any) -> Any: ...
|
50
50
|
|
51
51
|
@abstractmethod
|
52
|
-
def __lt__(self, other:
|
52
|
+
def __lt__(self, other: Any) -> bool: ...
|
53
53
|
|
54
54
|
@abstractmethod
|
55
55
|
def dict(self) -> dict[str, Any]: ...
|
dataeval/detectors/ood/ae.py
CHANGED
@@ -65,7 +65,7 @@ class OOD_AE(OODBase):
|
|
65
65
|
self,
|
66
66
|
x_ref: ArrayLike,
|
67
67
|
threshold_perc: float,
|
68
|
-
loss_fn: Callable[..., torch.
|
68
|
+
loss_fn: Callable[..., torch.Tensor] | None = None,
|
69
69
|
optimizer: torch.optim.Optimizer | None = None,
|
70
70
|
epochs: int = 20,
|
71
71
|
batch_size: int = 64,
|
dataeval/detectors/ood/base.py
CHANGED
@@ -22,7 +22,7 @@ from dataeval.utils.torch._gmm import GaussianMixtureModelParams, gmm_params
|
|
22
22
|
from dataeval.utils.torch._internal import trainer
|
23
23
|
|
24
24
|
|
25
|
-
class OODBase(OODBaseMixin[torch.nn.Module], OODFitMixin[Callable[..., torch.
|
25
|
+
class OODBase(OODBaseMixin[torch.nn.Module], OODFitMixin[Callable[..., torch.Tensor], torch.optim.Optimizer]):
|
26
26
|
def __init__(self, model: torch.nn.Module, device: DeviceLike | None = None) -> None:
|
27
27
|
self.device: torch.device = get_device(device)
|
28
28
|
super().__init__(model)
|
@@ -31,7 +31,7 @@ class OODBase(OODBaseMixin[torch.nn.Module], OODFitMixin[Callable[..., torch.nn.
|
|
31
31
|
self,
|
32
32
|
x_ref: ArrayLike,
|
33
33
|
threshold_perc: float,
|
34
|
-
loss_fn: Callable[..., torch.
|
34
|
+
loss_fn: Callable[..., torch.Tensor] | None,
|
35
35
|
optimizer: torch.optim.Optimizer | None,
|
36
36
|
epochs: int,
|
37
37
|
batch_size: int,
|
@@ -82,7 +82,7 @@ class OODBaseGMM(OODBase, OODGMMMixin[GaussianMixtureModelParams]):
|
|
82
82
|
self,
|
83
83
|
x_ref: ArrayLike,
|
84
84
|
threshold_perc: float,
|
85
|
-
loss_fn: Callable[..., torch.
|
85
|
+
loss_fn: Callable[..., torch.Tensor] | None,
|
86
86
|
optimizer: torch.optim.Optimizer | None,
|
87
87
|
epochs: int,
|
88
88
|
batch_size: int,
|
@@ -9,11 +9,11 @@ import numpy as np
|
|
9
9
|
|
10
10
|
from dataeval.config import EPSILON
|
11
11
|
from dataeval.outputs import CompletenessOutput
|
12
|
-
from dataeval.typing import
|
12
|
+
from dataeval.typing import Array
|
13
13
|
from dataeval.utils._array import ensure_embeddings
|
14
14
|
|
15
15
|
|
16
|
-
def completeness(embeddings:
|
16
|
+
def completeness(embeddings: Array, quantiles: int) -> CompletenessOutput:
|
17
17
|
"""
|
18
18
|
Calculate the fraction of boxes in a grid defined by quantiles that
|
19
19
|
contain at least one data point.
|
@@ -21,7 +21,7 @@ def completeness(embeddings: ArrayLike, quantiles: int) -> CompletenessOutput:
|
|
21
21
|
|
22
22
|
Parameters
|
23
23
|
----------
|
24
|
-
embeddings :
|
24
|
+
embeddings : Array
|
25
25
|
Embedded dataset (or other low-dimensional data) (nxp)
|
26
26
|
quantiles : int
|
27
27
|
number of quantile values to use for partitioning each dimension
|
@@ -10,13 +10,13 @@ from scipy.spatial.distance import pdist, squareform
|
|
10
10
|
|
11
11
|
from dataeval.outputs import CoverageOutput
|
12
12
|
from dataeval.outputs._base import set_metadata
|
13
|
-
from dataeval.typing import
|
13
|
+
from dataeval.typing import Array
|
14
14
|
from dataeval.utils._array import ensure_embeddings, flatten
|
15
15
|
|
16
16
|
|
17
17
|
@set_metadata
|
18
18
|
def coverage(
|
19
|
-
embeddings:
|
19
|
+
embeddings: Array,
|
20
20
|
radius_type: Literal["adaptive", "naive"] = "adaptive",
|
21
21
|
num_observations: int = 20,
|
22
22
|
percent: float = 0.01,
|
dataeval/metrics/bias/_parity.py
CHANGED
@@ -271,7 +271,7 @@ def parity(metadata: Metadata) -> ParityOutput:
|
|
271
271
|
# because scipy.stats.chi2_contingency fails when there are rows containing only zeros.
|
272
272
|
contingency_matrix = contingency_matrix[np.any(contingency_matrix, axis=1)]
|
273
273
|
|
274
|
-
chi_scores[i], p_values[i] = chi2_contingency(contingency_matrix)[:2]
|
274
|
+
chi_scores[i], p_values[i] = chi2_contingency(contingency_matrix)[:2] # type: ignore
|
275
275
|
|
276
276
|
if insufficient_data:
|
277
277
|
warnings.warn(
|
@@ -22,7 +22,7 @@ from scipy.stats import mode
|
|
22
22
|
from dataeval.config import EPSILON
|
23
23
|
from dataeval.outputs import BEROutput
|
24
24
|
from dataeval.outputs._base import set_metadata
|
25
|
-
from dataeval.typing import
|
25
|
+
from dataeval.typing import Array
|
26
26
|
from dataeval.utils._array import as_numpy, ensure_embeddings
|
27
27
|
from dataeval.utils._method import get_method
|
28
28
|
from dataeval.utils._mst import compute_neighbors, minimum_spanning_tree
|
@@ -105,7 +105,7 @@ _BER_FN_MAP = {"KNN": ber_knn, "MST": ber_mst}
|
|
105
105
|
|
106
106
|
|
107
107
|
@set_metadata
|
108
|
-
def ber(embeddings:
|
108
|
+
def ber(embeddings: Array, labels: Array, k: int = 1, method: Literal["KNN", "MST"] = "KNN") -> BEROutput:
|
109
109
|
"""
|
110
110
|
An estimator for Multi-class :term:`Bayes error rate<Bayes Error Rate (BER)>` \
|
111
111
|
using FR or KNN test statistic basis.
|
@@ -14,7 +14,7 @@ from numpy.typing import NDArray
|
|
14
14
|
|
15
15
|
from dataeval.outputs import DivergenceOutput
|
16
16
|
from dataeval.outputs._base import set_metadata
|
17
|
-
from dataeval.typing import
|
17
|
+
from dataeval.typing import Array
|
18
18
|
from dataeval.utils._array import ensure_embeddings
|
19
19
|
from dataeval.utils._method import get_method
|
20
20
|
from dataeval.utils._mst import compute_neighbors, minimum_spanning_tree
|
@@ -65,7 +65,7 @@ _DIVERGENCE_FN_MAP = {"FNN": divergence_fnn, "MST": divergence_mst}
|
|
65
65
|
|
66
66
|
|
67
67
|
@set_metadata
|
68
|
-
def divergence(emb_a:
|
68
|
+
def divergence(emb_a: Array, emb_b: Array, method: Literal["FNN", "MST"] = "FNN") -> DivergenceOutput:
|
69
69
|
"""
|
70
70
|
Calculates the :term:`divergence` and any errors between the datasets.
|
71
71
|
|
dataeval/outputs/_estimators.py
CHANGED
@@ -47,11 +47,11 @@ class ClustererOutput(Output):
|
|
47
47
|
The strength of the data point belonging to the assigned cluster
|
48
48
|
"""
|
49
49
|
|
50
|
-
clusters: NDArray[np.
|
51
|
-
mst: NDArray[np.
|
52
|
-
linkage_tree: NDArray[np.
|
53
|
-
condensed_tree: NDArray[np.
|
54
|
-
membership_strengths: NDArray[np.
|
50
|
+
clusters: NDArray[np.intp]
|
51
|
+
mst: NDArray[np.float32]
|
52
|
+
linkage_tree: NDArray[np.float32]
|
53
|
+
condensed_tree: NDArray[np.float32]
|
54
|
+
membership_strengths: NDArray[np.float32]
|
55
55
|
|
56
56
|
def find_outliers(self) -> NDArray[np.int_]:
|
57
57
|
"""
|
@@ -77,7 +77,7 @@ class ClustererOutput(Output):
|
|
77
77
|
# Delay load numba compiled functions
|
78
78
|
from dataeval.utils._clusterer import compare_links_to_cluster_std, sorted_union_find
|
79
79
|
|
80
|
-
exact_indices, near_indices = compare_links_to_cluster_std(self.mst, self.clusters)
|
80
|
+
exact_indices, near_indices = compare_links_to_cluster_std(self.mst, self.clusters) # type: ignore
|
81
81
|
exact_dupes = sorted_union_find(exact_indices)
|
82
82
|
near_dupes = sorted_union_find(near_indices)
|
83
83
|
|
dataeval/utils/_array.py
CHANGED
@@ -19,7 +19,7 @@ _logger = logging.getLogger(__name__)
|
|
19
19
|
|
20
20
|
_MODULE_CACHE = {}
|
21
21
|
|
22
|
-
T = TypeVar("T",
|
22
|
+
T = TypeVar("T", Array, np.ndarray, torch.Tensor)
|
23
23
|
_np_dtype = TypeVar("_np_dtype", bound=np.generic)
|
24
24
|
|
25
25
|
|
@@ -73,6 +73,19 @@ def to_numpy_iter(iterable: Iterable[ArrayLike]) -> Iterator[NDArray[Any]]:
|
|
73
73
|
yield to_numpy(array)
|
74
74
|
|
75
75
|
|
76
|
+
@overload
|
77
|
+
def rescale_array(array: NDArray[_np_dtype]) -> NDArray[_np_dtype]: ...
|
78
|
+
@overload
|
79
|
+
def rescale_array(array: torch.Tensor) -> torch.Tensor: ...
|
80
|
+
def rescale_array(array: Array | NDArray[_np_dtype] | torch.Tensor) -> Array | NDArray[_np_dtype] | torch.Tensor:
|
81
|
+
"""Rescale an array to the range [0, 1]"""
|
82
|
+
if isinstance(array, (np.ndarray, torch.Tensor)):
|
83
|
+
arr_min = array.min()
|
84
|
+
arr_max = array.max()
|
85
|
+
return (array - arr_min) / (arr_max - arr_min)
|
86
|
+
raise TypeError(f"Unsupported type: {type(array)}")
|
87
|
+
|
88
|
+
|
76
89
|
@overload
|
77
90
|
def ensure_embeddings(
|
78
91
|
embeddings: T,
|
@@ -137,14 +150,12 @@ def ensure_embeddings(
|
|
137
150
|
if arr.ndim != 2:
|
138
151
|
raise ValueError(f"Expected a 2D array, but got a {arr.ndim}D array.")
|
139
152
|
|
140
|
-
if unit_interval:
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
else:
|
147
|
-
raise ValueError("Embeddings must be unit interval [0, 1].")
|
153
|
+
if unit_interval and (arr.min() < 0 or arr.max() > 1):
|
154
|
+
if unit_interval == "force":
|
155
|
+
warnings.warn("Embeddings are not unit interval [0, 1]. Forcing to unit interval.")
|
156
|
+
arr = rescale_array(arr)
|
157
|
+
else:
|
158
|
+
raise ValueError("Embeddings must be unit interval [0, 1].")
|
148
159
|
|
149
160
|
if dtype is None:
|
150
161
|
return embeddings
|
dataeval/utils/_clusterer.py
CHANGED
@@ -69,12 +69,12 @@ def compare_links_to_cluster_std(
|
|
69
69
|
@dataclass
|
70
70
|
class ClusterData:
|
71
71
|
clusters: NDArray[np.intp]
|
72
|
-
mst: NDArray[np.
|
73
|
-
linkage_tree: NDArray[np.
|
72
|
+
mst: NDArray[np.float32]
|
73
|
+
linkage_tree: NDArray[np.float32]
|
74
74
|
condensed_tree: CondensedTree
|
75
|
-
membership_strengths: NDArray[np.
|
75
|
+
membership_strengths: NDArray[np.float32]
|
76
76
|
k_neighbors: NDArray[np.int32]
|
77
|
-
k_distances: NDArray[np.
|
77
|
+
k_distances: NDArray[np.float32]
|
78
78
|
|
79
79
|
|
80
80
|
def cluster(data: ArrayLike) -> ClusterData:
|
@@ -95,9 +95,9 @@ def cluster(data: ArrayLike) -> ClusterData:
|
|
95
95
|
|
96
96
|
max_neighbors = min(25, num_samples - 1)
|
97
97
|
kneighbors, kdistances = calculate_neighbor_distances(x, max_neighbors)
|
98
|
-
unsorted_mst: NDArray[np.
|
99
|
-
mst: NDArray[np.
|
100
|
-
linkage_tree: NDArray[np.
|
98
|
+
unsorted_mst: NDArray[np.float32] = minimum_spanning_tree(x, kneighbors, kdistances)
|
99
|
+
mst: NDArray[np.float32] = unsorted_mst[np.argsort(unsorted_mst.T[2])]
|
100
|
+
linkage_tree: NDArray[np.float32] = mst_to_linkage_tree(mst).astype(np.float32)
|
101
101
|
condensed_tree: CondensedTree = condense_tree(linkage_tree, min_cluster_size, None)
|
102
102
|
|
103
103
|
cluster_tree = cluster_tree_from_condensed_tree(condensed_tree)
|
@@ -4,6 +4,7 @@ from dataeval.utils.datasets._antiuav import AntiUAVDetection
|
|
4
4
|
from dataeval.utils.datasets._cifar10 import CIFAR10
|
5
5
|
from dataeval.utils.datasets._milco import MILCO
|
6
6
|
from dataeval.utils.datasets._mnist import MNIST
|
7
|
+
from dataeval.utils.datasets._seadrone import SeaDrone
|
7
8
|
from dataeval.utils.datasets._ships import Ships
|
8
9
|
from dataeval.utils.datasets._voc import VOCDetection, VOCDetectionTorch, VOCSegmentation
|
9
10
|
|
@@ -13,6 +14,7 @@ __all__ = [
|
|
13
14
|
"CIFAR10",
|
14
15
|
"AntiUAVDetection",
|
15
16
|
"MILCO",
|
17
|
+
"SeaDrone",
|
16
18
|
"VOCDetection",
|
17
19
|
"VOCDetectionTorch",
|
18
20
|
"VOCSegmentation",
|
@@ -15,7 +15,7 @@ if TYPE_CHECKING:
|
|
15
15
|
from dataeval.typing import Transform
|
16
16
|
|
17
17
|
|
18
|
-
class AntiUAVDetection(BaseODDataset[NDArray[Any]], BaseDatasetNumpyMixin):
|
18
|
+
class AntiUAVDetection(BaseODDataset[NDArray[Any], list[str], str], BaseDatasetNumpyMixin):
|
19
19
|
"""
|
20
20
|
A UAV detection dataset focused on detecting UAVs in natural images against large variation in backgrounds.
|
21
21
|
|
dataeval/utils/datasets/_base.py
CHANGED
@@ -4,7 +4,7 @@ __all__ = []
|
|
4
4
|
|
5
5
|
from abc import abstractmethod
|
6
6
|
from pathlib import Path
|
7
|
-
from typing import TYPE_CHECKING, Any, Generic, Iterator, Literal, NamedTuple, Sequence, TypeVar
|
7
|
+
from typing import TYPE_CHECKING, Any, Generic, Iterator, Literal, NamedTuple, Sequence, TypeVar, cast
|
8
8
|
|
9
9
|
import numpy as np
|
10
10
|
|
@@ -28,7 +28,8 @@ else:
|
|
28
28
|
_TArray = TypeVar("_TArray")
|
29
29
|
|
30
30
|
_TTarget = TypeVar("_TTarget")
|
31
|
-
_TRawTarget = TypeVar("_TRawTarget", list[int], list[
|
31
|
+
_TRawTarget = TypeVar("_TRawTarget", Sequence[int], Sequence[str], Sequence[tuple[list[int], list[list[float]]]])
|
32
|
+
_TAnnotation = TypeVar("_TAnnotation", int, str, tuple[list[int], list[list[float]]])
|
32
33
|
|
33
34
|
|
34
35
|
class DataLocation(NamedTuple):
|
@@ -38,7 +39,9 @@ class DataLocation(NamedTuple):
|
|
38
39
|
checksum: str
|
39
40
|
|
40
41
|
|
41
|
-
class BaseDataset(
|
42
|
+
class BaseDataset(
|
43
|
+
AnnotatedDataset[tuple[_TArray, _TTarget, dict[str, Any]]], Generic[_TArray, _TTarget, _TRawTarget, _TAnnotation]
|
44
|
+
):
|
42
45
|
"""
|
43
46
|
Base class for internet downloaded datasets.
|
44
47
|
"""
|
@@ -144,7 +147,7 @@ class BaseDataset(AnnotatedDataset[tuple[_TArray, _TTarget, dict[str, Any]]], Ge
|
|
144
147
|
|
145
148
|
|
146
149
|
class BaseICDataset(
|
147
|
-
BaseDataset[_TArray, _TArray, list[int]],
|
150
|
+
BaseDataset[_TArray, _TArray, list[int], int],
|
148
151
|
BaseDatasetMixin[_TArray],
|
149
152
|
ImageClassificationDataset[_TArray],
|
150
153
|
):
|
@@ -177,7 +180,7 @@ class BaseICDataset(
|
|
177
180
|
|
178
181
|
|
179
182
|
class BaseODDataset(
|
180
|
-
BaseDataset[_TArray, ObjectDetectionTarget[_TArray],
|
183
|
+
BaseDataset[_TArray, ObjectDetectionTarget[_TArray], _TRawTarget, _TAnnotation],
|
181
184
|
BaseDatasetMixin[_TArray],
|
182
185
|
ObjectDetectionDataset[_TArray],
|
183
186
|
):
|
@@ -200,7 +203,8 @@ class BaseODDataset(
|
|
200
203
|
Image, target, datum_metadata - target.boxes returns boxes in x0, y0, x1, y1 format
|
201
204
|
"""
|
202
205
|
# Grab the bounding boxes and labels from the annotations
|
203
|
-
|
206
|
+
annotation = cast(_TAnnotation, self._targets[index])
|
207
|
+
boxes, labels, additional_metadata = self._read_annotations(annotation)
|
204
208
|
# Get the image
|
205
209
|
img = self._read_file(self._filepaths[index])
|
206
210
|
img_size = img.shape
|
@@ -217,11 +221,11 @@ class BaseODDataset(
|
|
217
221
|
return img, target, img_metadata
|
218
222
|
|
219
223
|
@abstractmethod
|
220
|
-
def _read_annotations(self, annotation:
|
224
|
+
def _read_annotations(self, annotation: _TAnnotation) -> tuple[list[list[float]], list[int], dict[str, Any]]: ...
|
221
225
|
|
222
226
|
|
223
227
|
class BaseSegDataset(
|
224
|
-
BaseDataset[_TArray, SegmentationTarget[_TArray], list[str]],
|
228
|
+
BaseDataset[_TArray, SegmentationTarget[_TArray], list[str], str],
|
225
229
|
BaseDatasetMixin[_TArray],
|
226
230
|
SegmentationDataset[_TArray],
|
227
231
|
):
|
@@ -128,9 +128,9 @@ def _ensure_exists(
|
|
128
128
|
|
129
129
|
elif not check_path.exists() and not download:
|
130
130
|
raise FileNotFoundError(
|
131
|
-
"Data could not be loaded with the provided root directory, "
|
132
|
-
f"the file path to the file {filename} does not exist, "
|
133
|
-
"and the download parameter is set to False."
|
131
|
+
"Data could not be loaded with the provided root directory, "
|
132
|
+
f"the file path to the file {filename} does not exist, "
|
133
|
+
"and the download parameter is set to False."
|
134
134
|
)
|
135
135
|
else:
|
136
136
|
if not _validate_file(check_path, checksum, md5):
|
@@ -14,7 +14,7 @@ if TYPE_CHECKING:
|
|
14
14
|
from dataeval.typing import Transform
|
15
15
|
|
16
16
|
|
17
|
-
class MILCO(BaseODDataset[NDArray[Any]], BaseDatasetNumpyMixin):
|
17
|
+
class MILCO(BaseODDataset[NDArray[Any], list[str], str], BaseDatasetNumpyMixin):
|
18
18
|
"""
|
19
19
|
A side-scan sonar dataset focused on mine-like object detection.
|
20
20
|
|