dataeval 0.76.1__py3-none-any.whl → 0.82.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +3 -3
- dataeval/config.py +77 -0
- dataeval/detectors/__init__.py +1 -1
- dataeval/detectors/drift/__init__.py +6 -6
- dataeval/detectors/drift/{base.py → _base.py} +40 -85
- dataeval/detectors/drift/{cvm.py → _cvm.py} +21 -28
- dataeval/detectors/drift/{ks.py → _ks.py} +20 -26
- dataeval/detectors/drift/{mmd.py → _mmd.py} +31 -43
- dataeval/detectors/drift/{torch.py → _torch.py} +2 -1
- dataeval/detectors/drift/{uncertainty.py → _uncertainty.py} +24 -7
- dataeval/detectors/drift/updates.py +20 -3
- dataeval/detectors/linters/__init__.py +3 -5
- dataeval/detectors/linters/duplicates.py +13 -36
- dataeval/detectors/linters/outliers.py +23 -148
- dataeval/detectors/ood/__init__.py +1 -1
- dataeval/detectors/ood/ae.py +30 -9
- dataeval/detectors/ood/base.py +5 -4
- dataeval/detectors/ood/mixin.py +21 -7
- dataeval/detectors/ood/vae.py +73 -0
- dataeval/metadata/__init__.py +6 -0
- dataeval/metadata/_distance.py +167 -0
- dataeval/metadata/_ood.py +217 -0
- dataeval/metadata/_utils.py +44 -0
- dataeval/metrics/__init__.py +1 -1
- dataeval/metrics/bias/__init__.py +6 -4
- dataeval/metrics/bias/{balance.py → _balance.py} +15 -101
- dataeval/metrics/bias/_coverage.py +98 -0
- dataeval/metrics/bias/{diversity.py → _diversity.py} +18 -111
- dataeval/metrics/bias/{parity.py → _parity.py} +39 -77
- dataeval/metrics/estimators/__init__.py +15 -4
- dataeval/metrics/estimators/{ber.py → _ber.py} +42 -29
- dataeval/metrics/estimators/_clusterer.py +44 -0
- dataeval/metrics/estimators/{divergence.py → _divergence.py} +18 -30
- dataeval/metrics/estimators/{uap.py → _uap.py} +4 -18
- dataeval/metrics/stats/__init__.py +16 -13
- dataeval/metrics/stats/{base.py → _base.py} +82 -133
- dataeval/metrics/stats/{boxratiostats.py → _boxratiostats.py} +15 -18
- dataeval/metrics/stats/_dimensionstats.py +75 -0
- dataeval/metrics/stats/{hashstats.py → _hashstats.py} +21 -37
- dataeval/metrics/stats/_imagestats.py +94 -0
- dataeval/metrics/stats/_labelstats.py +131 -0
- dataeval/metrics/stats/{pixelstats.py → _pixelstats.py} +19 -50
- dataeval/metrics/stats/{visualstats.py → _visualstats.py} +23 -54
- dataeval/outputs/__init__.py +53 -0
- dataeval/{output.py → outputs/_base.py} +55 -25
- dataeval/outputs/_bias.py +381 -0
- dataeval/outputs/_drift.py +83 -0
- dataeval/outputs/_estimators.py +114 -0
- dataeval/outputs/_linters.py +184 -0
- dataeval/{detectors/ood/output.py → outputs/_ood.py} +22 -22
- dataeval/outputs/_stats.py +387 -0
- dataeval/outputs/_utils.py +44 -0
- dataeval/outputs/_workflows.py +364 -0
- dataeval/typing.py +234 -0
- dataeval/utils/__init__.py +2 -2
- dataeval/utils/_array.py +169 -0
- dataeval/utils/_bin.py +199 -0
- dataeval/utils/_clusterer.py +144 -0
- dataeval/utils/_fast_mst.py +189 -0
- dataeval/utils/{image.py → _image.py} +6 -4
- dataeval/utils/_method.py +14 -0
- dataeval/utils/{shared.py → _mst.py} +3 -65
- dataeval/utils/{plot.py → _plot.py} +6 -6
- dataeval/utils/data/__init__.py +26 -0
- dataeval/utils/data/_dataset.py +217 -0
- dataeval/utils/data/_embeddings.py +104 -0
- dataeval/utils/data/_images.py +68 -0
- dataeval/utils/data/_metadata.py +360 -0
- dataeval/utils/data/_selection.py +126 -0
- dataeval/utils/{dataset/split.py → data/_split.py} +12 -38
- dataeval/utils/data/_targets.py +85 -0
- dataeval/utils/data/collate.py +103 -0
- dataeval/utils/data/datasets/__init__.py +17 -0
- dataeval/utils/data/datasets/_base.py +254 -0
- dataeval/utils/data/datasets/_cifar10.py +134 -0
- dataeval/utils/data/datasets/_fileio.py +168 -0
- dataeval/utils/data/datasets/_milco.py +153 -0
- dataeval/utils/data/datasets/_mixin.py +56 -0
- dataeval/utils/data/datasets/_mnist.py +183 -0
- dataeval/utils/data/datasets/_ships.py +123 -0
- dataeval/utils/data/datasets/_types.py +52 -0
- dataeval/utils/data/datasets/_voc.py +352 -0
- dataeval/utils/data/selections/__init__.py +15 -0
- dataeval/utils/data/selections/_classfilter.py +57 -0
- dataeval/utils/data/selections/_indices.py +26 -0
- dataeval/utils/data/selections/_limit.py +26 -0
- dataeval/utils/data/selections/_reverse.py +18 -0
- dataeval/utils/data/selections/_shuffle.py +29 -0
- dataeval/utils/metadata.py +51 -376
- dataeval/utils/torch/{gmm.py → _gmm.py} +4 -2
- dataeval/utils/torch/{internal.py → _internal.py} +21 -51
- dataeval/utils/torch/models.py +43 -2
- dataeval/workflows/__init__.py +2 -1
- dataeval/workflows/sufficiency.py +11 -346
- {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/METADATA +5 -2
- dataeval-0.82.0.dist-info/RECORD +104 -0
- dataeval/detectors/linters/clusterer.py +0 -512
- dataeval/detectors/linters/merged_stats.py +0 -49
- dataeval/detectors/ood/metadata_ks_compare.py +0 -129
- dataeval/detectors/ood/metadata_least_likely.py +0 -119
- dataeval/interop.py +0 -69
- dataeval/metrics/bias/coverage.py +0 -194
- dataeval/metrics/stats/datasetstats.py +0 -202
- dataeval/metrics/stats/dimensionstats.py +0 -115
- dataeval/metrics/stats/labelstats.py +0 -210
- dataeval/utils/dataset/__init__.py +0 -7
- dataeval/utils/dataset/datasets.py +0 -412
- dataeval/utils/dataset/read.py +0 -63
- dataeval-0.76.1.dist-info/RECORD +0 -67
- /dataeval/{log.py → _log.py} +0 -0
- /dataeval/utils/torch/{blocks.py → _blocks.py} +0 -0
- {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,153 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from dataeval.utils.data.datasets._mixin import BaseDatasetNumpyMixin
|
4
|
+
|
5
|
+
__all__ = []
|
6
|
+
|
7
|
+
from pathlib import Path
|
8
|
+
from typing import Any, Sequence
|
9
|
+
|
10
|
+
from numpy.typing import NDArray
|
11
|
+
|
12
|
+
from dataeval.utils.data.datasets._base import BaseODDataset, DataLocation
|
13
|
+
from dataeval.utils.data.datasets._types import Transform
|
14
|
+
|
15
|
+
|
16
|
+
class MILCO(BaseODDataset[NDArray[Any]], BaseDatasetNumpyMixin):
|
17
|
+
"""
|
18
|
+
A side-scan sonar dataset focused on mine (object) detection.
|
19
|
+
|
20
|
+
|
21
|
+
The dataset comes from the paper
|
22
|
+
`Side-scan sonar imaging data of underwater vehicles for mine detection <https://doi.org/10.1016/j.dib.2024.110132>`_
|
23
|
+
by N.P. Santos et. al. (2024).
|
24
|
+
|
25
|
+
This class only accesses a portion of the above dataset due to size constraints.
|
26
|
+
The full dataset contains 1170 side-scan sonar images collected using a 900-1800 kHz Marine Sonic
|
27
|
+
dual frequency side-scan sonar of a Teledyne Marine Gavia Autonomous Underwater Vehicle.
|
28
|
+
All the images were carefully analyzed and annotated, including the image coordinates of the
|
29
|
+
Bounding Box (BB) of the detected objects divided into NOn-Mine-like BOttom Objects (NOMBO)
|
30
|
+
and MIne-Like COntacts (MILCO) classes.
|
31
|
+
|
32
|
+
This dataset is consists of 261 images (120 images from 2015, 93 images from 2017, and 48 images from 2021).
|
33
|
+
In these 261 images, there are 315 MILCO objects, and 175 NOMBO objects.
|
34
|
+
The class “0” corresponds to a MILCO object and the class “1” corresponds to a NOMBO object.
|
35
|
+
The raw BB coordinates provided in the downloaded text files are (x, y, w, h),
|
36
|
+
given as percentages of the image (x_BB = x/img_width, y_BB = y/img_height, etc.).
|
37
|
+
The images come in 2 sizes, 416 x 416 or 1024 x 1024.
|
38
|
+
|
39
|
+
Parameters
|
40
|
+
----------
|
41
|
+
root : str or pathlib.Path
|
42
|
+
Root directory of dataset where the ``milco`` folder exists.
|
43
|
+
download : bool, default False
|
44
|
+
If True, downloads the dataset from the internet and puts it in root directory.
|
45
|
+
Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
|
46
|
+
transforms : Transform | Sequence[Transform] | None, default None
|
47
|
+
Transform(s) to apply to the data.
|
48
|
+
verbose : bool, default False
|
49
|
+
If True, outputs print statements.
|
50
|
+
|
51
|
+
Attributes
|
52
|
+
----------
|
53
|
+
index2label : dict
|
54
|
+
Dictionary which translates from class integers to the associated class strings.
|
55
|
+
label2index : dict
|
56
|
+
Dictionary which translates from class strings to the associated class integers.
|
57
|
+
path : Path
|
58
|
+
Location of the folder containing the data.
|
59
|
+
metadata : dict
|
60
|
+
Dictionary containing Dataset metadata, such as `id` which returns the dataset class name.
|
61
|
+
"""
|
62
|
+
|
63
|
+
_resources = [
|
64
|
+
DataLocation(
|
65
|
+
url="https://figshare.com/ndownloader/files/43169002",
|
66
|
+
filename="2015.zip",
|
67
|
+
md5=True,
|
68
|
+
checksum="93dfbb4fb7987734152c372496b4884c",
|
69
|
+
),
|
70
|
+
DataLocation(
|
71
|
+
url="https://figshare.com/ndownloader/files/43169005",
|
72
|
+
filename="2017.zip",
|
73
|
+
md5=True,
|
74
|
+
checksum="9c2de230a2bbf654921416bea6fc0f42",
|
75
|
+
),
|
76
|
+
DataLocation(
|
77
|
+
url="https://figshare.com/ndownloader/files/43168999",
|
78
|
+
filename="2021.zip",
|
79
|
+
md5=True,
|
80
|
+
checksum="b84749b21fa95a4a4c7de3741db78bc7",
|
81
|
+
),
|
82
|
+
]
|
83
|
+
|
84
|
+
index2label: dict[int, str] = {
|
85
|
+
0: "MILCO",
|
86
|
+
1: "NOMBO",
|
87
|
+
}
|
88
|
+
|
89
|
+
def __init__(
|
90
|
+
self,
|
91
|
+
root: str | Path,
|
92
|
+
download: bool = False,
|
93
|
+
transforms: Transform[NDArray[Any]] | Sequence[Transform[NDArray[Any]]] | None = None,
|
94
|
+
verbose: bool = False,
|
95
|
+
) -> None:
|
96
|
+
super().__init__(
|
97
|
+
root,
|
98
|
+
download,
|
99
|
+
"base",
|
100
|
+
transforms,
|
101
|
+
verbose,
|
102
|
+
)
|
103
|
+
|
104
|
+
def _load_data(self) -> tuple[list[str], list[str], dict[str, list[Any]]]:
|
105
|
+
filepaths: list[str] = []
|
106
|
+
targets: list[str] = []
|
107
|
+
datum_metadata: dict[str, list[Any]] = {}
|
108
|
+
metadata_list: list[dict[str, Any]] = []
|
109
|
+
|
110
|
+
# Load the data
|
111
|
+
for resource in self._resources:
|
112
|
+
self._resource = resource
|
113
|
+
filepath, target, metadata = super()._load_data()
|
114
|
+
filepaths.extend(filepath)
|
115
|
+
targets.extend(target)
|
116
|
+
metadata_list.append(metadata)
|
117
|
+
|
118
|
+
# Adjust datum metadata to correct format
|
119
|
+
for data_dict in metadata_list:
|
120
|
+
for key, val in data_dict.items():
|
121
|
+
if key not in datum_metadata:
|
122
|
+
datum_metadata[str(key)] = []
|
123
|
+
datum_metadata[str(key)].extend(val)
|
124
|
+
|
125
|
+
return filepaths, targets, datum_metadata
|
126
|
+
|
127
|
+
def _load_data_inner(self) -> tuple[list[str], list[str], dict[str, Any]]:
|
128
|
+
file_data = {"year": [], "image_id": [], "data_path": [], "label_path": []}
|
129
|
+
data_folder = self.path / self._resource.filename[:-4]
|
130
|
+
for entry in data_folder.iterdir():
|
131
|
+
if entry.is_file() and entry.suffix == ".jpg":
|
132
|
+
# Remove file extension and split by "_"
|
133
|
+
parts = entry.stem.split("_")
|
134
|
+
file_data["image_id"].append(parts[0])
|
135
|
+
file_data["year"].append(parts[1])
|
136
|
+
file_data["data_path"].append(str(entry))
|
137
|
+
file_data["label_path"].append(str(entry.parent / entry.stem) + ".txt")
|
138
|
+
data = file_data.pop("data_path")
|
139
|
+
annotations = file_data.pop("label_path")
|
140
|
+
|
141
|
+
return data, annotations, file_data
|
142
|
+
|
143
|
+
def _read_annotations(self, annotation: str) -> tuple[list[list[float]], list[int], dict[str, Any]]:
|
144
|
+
"""Function for extracting the info out of the text files"""
|
145
|
+
labels: list[int] = []
|
146
|
+
boxes: list[list[float]] = []
|
147
|
+
with open(annotation) as f:
|
148
|
+
for line in f.readlines():
|
149
|
+
out = line.strip().split(" ")
|
150
|
+
labels.append(int(out[0]))
|
151
|
+
boxes.append([float(out[1]), float(out[2]), float(out[3]), float(out[4])])
|
152
|
+
|
153
|
+
return boxes, labels, {}
|
@@ -0,0 +1,56 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
from typing import Any, Generic, TypeVar
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
import torch
|
9
|
+
from numpy.typing import NDArray
|
10
|
+
from PIL import Image
|
11
|
+
|
12
|
+
_TArray = TypeVar("_TArray")
|
13
|
+
|
14
|
+
|
15
|
+
class BaseDatasetMixin(Generic[_TArray]):
|
16
|
+
index2label: dict[int, str]
|
17
|
+
|
18
|
+
def _as_array(self, raw: list[Any]) -> _TArray: ...
|
19
|
+
def _one_hot_encode(self, value: int | list[int]) -> _TArray: ...
|
20
|
+
def _read_file(self, path: str) -> _TArray: ...
|
21
|
+
|
22
|
+
|
23
|
+
class BaseDatasetNumpyMixin(BaseDatasetMixin[NDArray[Any]]):
|
24
|
+
def _as_array(self, raw: list[Any]) -> NDArray[Any]:
|
25
|
+
return np.asarray(raw)
|
26
|
+
|
27
|
+
def _one_hot_encode(self, value: int | list[int]) -> NDArray[Any]:
|
28
|
+
if isinstance(value, int):
|
29
|
+
encoded = np.zeros(len(self.index2label))
|
30
|
+
encoded[value] = 1
|
31
|
+
else:
|
32
|
+
encoded = np.zeros((len(value), len(self.index2label)))
|
33
|
+
encoded[np.arange(len(value)), value] = 1
|
34
|
+
return encoded
|
35
|
+
|
36
|
+
def _read_file(self, path: str) -> NDArray[Any]:
|
37
|
+
x = np.array(Image.open(path)).transpose(2, 0, 1)
|
38
|
+
return x
|
39
|
+
|
40
|
+
|
41
|
+
class BaseDatasetTorchMixin(BaseDatasetMixin[torch.Tensor]):
|
42
|
+
def _as_array(self, raw: list[Any]) -> torch.Tensor:
|
43
|
+
return torch.as_tensor(raw)
|
44
|
+
|
45
|
+
def _one_hot_encode(self, value: int | list[int]) -> torch.Tensor:
|
46
|
+
if isinstance(value, int):
|
47
|
+
encoded = torch.zeros(len(self.index2label))
|
48
|
+
encoded[value] = 1
|
49
|
+
else:
|
50
|
+
encoded = torch.zeros((len(value), len(self.index2label)))
|
51
|
+
encoded[torch.arange(len(value)), value] = 1
|
52
|
+
return encoded
|
53
|
+
|
54
|
+
def _read_file(self, path: str) -> torch.Tensor:
|
55
|
+
x = torch.as_tensor(np.array(Image.open(path)).transpose(2, 0, 1))
|
56
|
+
return x
|
@@ -0,0 +1,183 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import Any, Literal, Sequence, TypeVar
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
from numpy.typing import NDArray
|
10
|
+
|
11
|
+
from dataeval.utils.data.datasets._base import BaseICDataset, DataLocation
|
12
|
+
from dataeval.utils.data.datasets._mixin import BaseDatasetNumpyMixin
|
13
|
+
from dataeval.utils.data.datasets._types import Transform
|
14
|
+
|
15
|
+
MNISTClassStringMap = Literal["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
|
16
|
+
TMNISTClassMap = TypeVar("TMNISTClassMap", MNISTClassStringMap, int, list[MNISTClassStringMap], list[int])
|
17
|
+
CorruptionStringMap = Literal[
|
18
|
+
"identity",
|
19
|
+
"shot_noise",
|
20
|
+
"impulse_noise",
|
21
|
+
"glass_blur",
|
22
|
+
"motion_blur",
|
23
|
+
"shear",
|
24
|
+
"scale",
|
25
|
+
"rotate",
|
26
|
+
"brightness",
|
27
|
+
"translate",
|
28
|
+
"stripe",
|
29
|
+
"fog",
|
30
|
+
"spatter",
|
31
|
+
"dotted_line",
|
32
|
+
"zigzag",
|
33
|
+
"canny_edges",
|
34
|
+
]
|
35
|
+
|
36
|
+
|
37
|
+
class MNIST(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
|
38
|
+
"""`MNIST <https://en.wikipedia.org/wiki/MNIST_database>`_ Dataset and `Corruptions <https://arxiv.org/abs/1906.02337>`_.
|
39
|
+
|
40
|
+
There are 15 different styles of corruptions. This class downloads differently depending on if you
|
41
|
+
need just the original dataset or if you need corruptions. If you need both a corrupt version and the
|
42
|
+
original version then choose `corruption="identity"` as this downloads all of the corrupt datasets and
|
43
|
+
provides the original as `identity`. If you just need the original, then using `corruption=None` will
|
44
|
+
download only the original dataset to save time and space.
|
45
|
+
|
46
|
+
Parameters
|
47
|
+
----------
|
48
|
+
root : str or pathlib.Path
|
49
|
+
Root directory of dataset where the ``mnist`` folder exists.
|
50
|
+
download : bool, default False
|
51
|
+
If True, downloads the dataset from the internet and puts it in root directory.
|
52
|
+
Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
|
53
|
+
image_set : "train", "test" or "base", default "train"
|
54
|
+
If "base", returns all of the data to allow the user to create their own splits.
|
55
|
+
verbose : bool, default False
|
56
|
+
If True, outputs print statements.
|
57
|
+
|
58
|
+
Attributes
|
59
|
+
----------
|
60
|
+
index2label : dict
|
61
|
+
Dictionary which translates from class integers to the associated class strings.
|
62
|
+
label2index : dict
|
63
|
+
Dictionary which translates from class strings to the associated class integers.
|
64
|
+
path : Path
|
65
|
+
Location of the folder containing the data.
|
66
|
+
metadata : dict
|
67
|
+
Dictionary containing Dataset metadata, such as `id` which returns the dataset class name.
|
68
|
+
"""
|
69
|
+
|
70
|
+
_resources = [
|
71
|
+
DataLocation(
|
72
|
+
url="https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz",
|
73
|
+
filename="mnist.npz",
|
74
|
+
md5=False,
|
75
|
+
checksum="731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1",
|
76
|
+
),
|
77
|
+
DataLocation(
|
78
|
+
url="https://zenodo.org/record/3239543/files/mnist_c.zip",
|
79
|
+
filename="mnist_c.zip",
|
80
|
+
md5=True,
|
81
|
+
checksum="4b34b33045869ee6d424616cd3a65da3",
|
82
|
+
),
|
83
|
+
]
|
84
|
+
|
85
|
+
index2label: dict[int, str] = {
|
86
|
+
0: "zero",
|
87
|
+
1: "one",
|
88
|
+
2: "two",
|
89
|
+
3: "three",
|
90
|
+
4: "four",
|
91
|
+
5: "five",
|
92
|
+
6: "six",
|
93
|
+
7: "seven",
|
94
|
+
8: "eight",
|
95
|
+
9: "nine",
|
96
|
+
}
|
97
|
+
|
98
|
+
def __init__(
|
99
|
+
self,
|
100
|
+
root: str | Path,
|
101
|
+
download: bool = False,
|
102
|
+
image_set: Literal["train", "test", "base"] = "train",
|
103
|
+
corruption: CorruptionStringMap | None = None,
|
104
|
+
transforms: Transform[NDArray[Any]] | Sequence[Transform[NDArray[Any]]] | None = None,
|
105
|
+
verbose: bool = False,
|
106
|
+
) -> None:
|
107
|
+
self.corruption = corruption
|
108
|
+
if self.corruption == "identity" and verbose:
|
109
|
+
print("Identity is not a corrupted dataset but the original MNIST dataset.")
|
110
|
+
self._resource_index = 0 if self.corruption is None else 1
|
111
|
+
|
112
|
+
super().__init__(
|
113
|
+
root,
|
114
|
+
download,
|
115
|
+
image_set,
|
116
|
+
transforms,
|
117
|
+
verbose,
|
118
|
+
)
|
119
|
+
|
120
|
+
def _load_data_inner(self) -> tuple[list[str], list[int], dict[str, Any]]:
|
121
|
+
"""Function to load in the file paths for the data and labels from the correct data format"""
|
122
|
+
if self.corruption is None:
|
123
|
+
try:
|
124
|
+
file_path = self.path / self._resource.filename
|
125
|
+
self._loaded_data, labels = self._grab_data(file_path)
|
126
|
+
except FileNotFoundError:
|
127
|
+
self._loaded_data, labels = self._load_corruption()
|
128
|
+
else:
|
129
|
+
self._loaded_data, labels = self._load_corruption()
|
130
|
+
|
131
|
+
index_strings = np.arange(self._loaded_data.shape[0]).astype(str).tolist()
|
132
|
+
return index_strings, labels.tolist(), {}
|
133
|
+
|
134
|
+
def _load_corruption(self) -> tuple[NDArray[Any], NDArray[np.uintp]]:
|
135
|
+
"""Function to load in the file paths for the data and labels for the different corrupt data formats"""
|
136
|
+
corruption = self.corruption if self.corruption is not None else "identity"
|
137
|
+
base_path = self.path / corruption
|
138
|
+
if self.image_set == "base":
|
139
|
+
raw_data = []
|
140
|
+
raw_labels = []
|
141
|
+
for group in ["train", "test"]:
|
142
|
+
file_path = base_path / f"{group}_images.npy"
|
143
|
+
raw_data.append(self._grab_corruption_data(file_path))
|
144
|
+
|
145
|
+
label_path = base_path / f"{group}_labels.npy"
|
146
|
+
raw_labels.append(self._grab_corruption_data(label_path))
|
147
|
+
|
148
|
+
data = np.concatenate(raw_data, axis=0).transpose(0, 3, 1, 2)
|
149
|
+
labels = np.concatenate(raw_labels).astype(np.uintp)
|
150
|
+
else:
|
151
|
+
file_path = base_path / f"{self.image_set}_images.npy"
|
152
|
+
data = self._grab_corruption_data(file_path)
|
153
|
+
data = data.astype(np.float64).transpose(0, 3, 1, 2)
|
154
|
+
|
155
|
+
label_path = base_path / f"{self.image_set}_labels.npy"
|
156
|
+
labels = self._grab_corruption_data(label_path)
|
157
|
+
labels = labels.astype(np.uintp)
|
158
|
+
|
159
|
+
return data, labels
|
160
|
+
|
161
|
+
def _grab_data(self, path: Path) -> tuple[NDArray[Any], NDArray[np.uintp]]:
|
162
|
+
"""Function to load in the data numpy array"""
|
163
|
+
with np.load(path, allow_pickle=True) as data_array:
|
164
|
+
if self.image_set == "base":
|
165
|
+
data = np.concatenate([data_array["x_train"], data_array["x_test"]], axis=0)
|
166
|
+
labels = np.concatenate([data_array["y_train"], data_array["y_test"]], axis=0).astype(np.uintp)
|
167
|
+
else:
|
168
|
+
data, labels = data_array[f"x_{self.image_set}"], data_array[f"y_{self.image_set}"].astype(np.uintp)
|
169
|
+
data = np.expand_dims(data, axis=1)
|
170
|
+
return data, labels
|
171
|
+
|
172
|
+
def _grab_corruption_data(self, path: Path) -> NDArray[Any]:
|
173
|
+
"""Function to load in the data numpy array for the previously chosen corrupt format"""
|
174
|
+
x = np.load(path, allow_pickle=False)
|
175
|
+
return x
|
176
|
+
|
177
|
+
def _read_file(self, path: str) -> NDArray[Any]:
|
178
|
+
"""
|
179
|
+
Function to grab the correct image from the loaded data.
|
180
|
+
Overwrite of the base `_read_file` because data is an all or nothing load.
|
181
|
+
"""
|
182
|
+
index = int(path)
|
183
|
+
return self._loaded_data[index]
|
@@ -0,0 +1,123 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import Any, Sequence
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
from numpy.typing import NDArray
|
10
|
+
|
11
|
+
from dataeval.utils.data.datasets._base import BaseICDataset, DataLocation
|
12
|
+
from dataeval.utils.data.datasets._mixin import BaseDatasetNumpyMixin
|
13
|
+
from dataeval.utils.data.datasets._types import Transform
|
14
|
+
|
15
|
+
|
16
|
+
class Ships(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
|
17
|
+
"""
|
18
|
+
A dataset that focuses on identifying ships from satellite images.
|
19
|
+
|
20
|
+
The dataset comes from kaggle,
|
21
|
+
`Ships in Satellite Imagery <https://www.kaggle.com/datasets/rhammell/ships-in-satellite-imagery>`_.
|
22
|
+
The images come from Planet satellite imagery when they gave
|
23
|
+
`open-access to a portion of their data <https://www.planet.com/pulse/open-california-rapideye-data/>`_.
|
24
|
+
|
25
|
+
There are 4000 80x80x3 (HWC) images of ships, sea, and land.
|
26
|
+
There are also 8 larger scene images similar to what would be operationally provided.
|
27
|
+
|
28
|
+
Parameters
|
29
|
+
----------
|
30
|
+
root : str or pathlib.Path
|
31
|
+
Root directory of dataset where the ``shipdataset`` folder exists.
|
32
|
+
download : bool, default False
|
33
|
+
If True, downloads the dataset from the internet and puts it in root directory.
|
34
|
+
Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
|
35
|
+
verbose : bool, default False
|
36
|
+
If True, outputs print statements.
|
37
|
+
|
38
|
+
Attributes
|
39
|
+
----------
|
40
|
+
index2label : dict
|
41
|
+
Dictionary which translates from class integers to the associated class strings.
|
42
|
+
label2index : dict
|
43
|
+
Dictionary which translates from class strings to the associated class integers.
|
44
|
+
path : Path
|
45
|
+
Location of the folder containing the data.
|
46
|
+
metadata : dict
|
47
|
+
Dictionary containing Dataset metadata, such as `id` which returns the dataset class name.
|
48
|
+
"""
|
49
|
+
|
50
|
+
_resources = [
|
51
|
+
DataLocation(
|
52
|
+
url="https://zenodo.org/record/3611230/files/ships-in-satellite-imagery.zip",
|
53
|
+
filename="ships-in-satellite-imagery.zip",
|
54
|
+
md5=True,
|
55
|
+
checksum="b2e8a41ed029592b373bd72ee4b89f32",
|
56
|
+
),
|
57
|
+
]
|
58
|
+
|
59
|
+
index2label: dict[int, str] = {
|
60
|
+
0: "no ship",
|
61
|
+
1: "ship",
|
62
|
+
}
|
63
|
+
|
64
|
+
def __init__(
|
65
|
+
self,
|
66
|
+
root: str | Path,
|
67
|
+
download: bool = False,
|
68
|
+
transforms: Transform[NDArray[Any]] | Sequence[Transform[NDArray[Any]]] | None = None,
|
69
|
+
verbose: bool = False,
|
70
|
+
) -> None:
|
71
|
+
super().__init__(
|
72
|
+
root,
|
73
|
+
download,
|
74
|
+
"base",
|
75
|
+
transforms,
|
76
|
+
verbose,
|
77
|
+
)
|
78
|
+
self._scenes: list[str] = self._load_scenes()
|
79
|
+
|
80
|
+
def _load_data_inner(self) -> tuple[list[str], list[int], dict[str, Any]]:
|
81
|
+
"""Function to load in the file paths for the data and labels"""
|
82
|
+
file_data = {"label": [], "scene_id": [], "longitude": [], "latitude": [], "path": []}
|
83
|
+
data_folder = self.path / "shipsnet"
|
84
|
+
for entry in data_folder.iterdir():
|
85
|
+
# Remove file extension and split by "_"
|
86
|
+
parts = entry.stem.split("__") # Removes ".png" and splits the string
|
87
|
+
file_data["label"].append(int(parts[0]))
|
88
|
+
file_data["scene_id"].append(parts[1])
|
89
|
+
lat_lon = parts[2].split("_")
|
90
|
+
file_data["longitude"].append(float(lat_lon[0]))
|
91
|
+
file_data["latitude"].append(float(lat_lon[1]))
|
92
|
+
file_data["path"].append(entry)
|
93
|
+
data = file_data.pop("path")
|
94
|
+
labels = file_data.pop("label")
|
95
|
+
return data, labels, file_data
|
96
|
+
|
97
|
+
def _load_scenes(self) -> list[str]:
|
98
|
+
"""Function to load in the file paths for the scene images"""
|
99
|
+
data_folder = self.path / "scenes"
|
100
|
+
scene = [str(entry) for entry in data_folder.iterdir()]
|
101
|
+
return scene
|
102
|
+
|
103
|
+
def get_scene(self, index: int) -> NDArray[np.uintp]:
|
104
|
+
"""
|
105
|
+
Get the desired satellite image (scene) by passing in the index of the desired file.
|
106
|
+
|
107
|
+
Args
|
108
|
+
----
|
109
|
+
index : int
|
110
|
+
Value of the desired data point
|
111
|
+
|
112
|
+
Returns
|
113
|
+
-------
|
114
|
+
NDArray[np.uintp]
|
115
|
+
Scene image
|
116
|
+
|
117
|
+
Note
|
118
|
+
----
|
119
|
+
The scene will be returned with the channel axis first.
|
120
|
+
"""
|
121
|
+
scene = self._read_file(self._scenes[index])
|
122
|
+
np.moveaxis(scene, -1, 0)
|
123
|
+
return scene
|
@@ -0,0 +1,52 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
from dataclasses import dataclass
|
6
|
+
from typing import Any, Generic, Protocol, TypedDict, TypeVar
|
7
|
+
|
8
|
+
from torch.utils.data import Dataset
|
9
|
+
from typing_extensions import NotRequired, Required
|
10
|
+
|
11
|
+
|
12
|
+
class DatasetMetadata(TypedDict):
|
13
|
+
id: Required[str]
|
14
|
+
index2label: NotRequired[dict[int, str]]
|
15
|
+
split: NotRequired[str]
|
16
|
+
|
17
|
+
|
18
|
+
_TDatum = TypeVar("_TDatum")
|
19
|
+
_TArray = TypeVar("_TArray")
|
20
|
+
|
21
|
+
|
22
|
+
class AnnotatedDataset(Dataset[_TDatum]):
|
23
|
+
metadata: DatasetMetadata
|
24
|
+
|
25
|
+
def __len__(self) -> int: ...
|
26
|
+
|
27
|
+
|
28
|
+
class ImageClassificationDataset(AnnotatedDataset[tuple[_TArray, _TArray, dict[str, Any]]]): ...
|
29
|
+
|
30
|
+
|
31
|
+
@dataclass
|
32
|
+
class ObjectDetectionTarget(Generic[_TArray]):
|
33
|
+
boxes: _TArray
|
34
|
+
labels: _TArray
|
35
|
+
scores: _TArray
|
36
|
+
|
37
|
+
|
38
|
+
class ObjectDetectionDataset(AnnotatedDataset[tuple[_TArray, ObjectDetectionTarget[_TArray], dict[str, Any]]]): ...
|
39
|
+
|
40
|
+
|
41
|
+
@dataclass
|
42
|
+
class SegmentationTarget(Generic[_TArray]):
|
43
|
+
mask: _TArray
|
44
|
+
labels: _TArray
|
45
|
+
scores: _TArray
|
46
|
+
|
47
|
+
|
48
|
+
class SegmentationDataset(AnnotatedDataset[tuple[_TArray, SegmentationTarget[_TArray], dict[str, Any]]]): ...
|
49
|
+
|
50
|
+
|
51
|
+
class Transform(Generic[_TArray], Protocol):
|
52
|
+
def __call__(self, data: _TArray, /) -> _TArray: ...
|