dataeval 0.86.9__py3-none-any.whl → 0.87.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. dataeval/__init__.py +1 -1
  2. dataeval/_version.py +2 -2
  3. dataeval/config.py +4 -19
  4. dataeval/data/_metadata.py +56 -27
  5. dataeval/data/_split.py +1 -1
  6. dataeval/data/selections/_classbalance.py +4 -3
  7. dataeval/data/selections/_classfilter.py +5 -5
  8. dataeval/data/selections/_indices.py +2 -2
  9. dataeval/data/selections/_prioritize.py +249 -29
  10. dataeval/data/selections/_reverse.py +1 -1
  11. dataeval/data/selections/_shuffle.py +2 -2
  12. dataeval/detectors/ood/__init__.py +2 -1
  13. dataeval/detectors/ood/base.py +38 -1
  14. dataeval/detectors/ood/knn.py +95 -0
  15. dataeval/metrics/bias/_balance.py +28 -21
  16. dataeval/metrics/bias/_diversity.py +4 -4
  17. dataeval/metrics/bias/_parity.py +2 -2
  18. dataeval/metrics/stats/_hashstats.py +19 -2
  19. dataeval/outputs/_workflows.py +20 -7
  20. dataeval/typing.py +14 -2
  21. dataeval/utils/__init__.py +2 -2
  22. dataeval/utils/_bin.py +7 -6
  23. dataeval/utils/data/__init__.py +2 -0
  24. dataeval/utils/data/_dataset.py +13 -6
  25. dataeval/utils/data/_validate.py +169 -0
  26. {dataeval-0.86.9.dist-info → dataeval-0.87.0.dist-info}/METADATA +5 -17
  27. {dataeval-0.86.9.dist-info → dataeval-0.87.0.dist-info}/RECORD +29 -39
  28. dataeval/utils/datasets/__init__.py +0 -21
  29. dataeval/utils/datasets/_antiuav.py +0 -189
  30. dataeval/utils/datasets/_base.py +0 -266
  31. dataeval/utils/datasets/_cifar10.py +0 -201
  32. dataeval/utils/datasets/_fileio.py +0 -142
  33. dataeval/utils/datasets/_milco.py +0 -197
  34. dataeval/utils/datasets/_mixin.py +0 -54
  35. dataeval/utils/datasets/_mnist.py +0 -202
  36. dataeval/utils/datasets/_seadrone.py +0 -512
  37. dataeval/utils/datasets/_ships.py +0 -144
  38. dataeval/utils/datasets/_types.py +0 -48
  39. dataeval/utils/datasets/_voc.py +0 -583
  40. {dataeval-0.86.9.dist-info → dataeval-0.87.0.dist-info}/WHEEL +0 -0
  41. /dataeval-0.86.9.dist-info/licenses/LICENSE.txt → /dataeval-0.87.0.dist-info/licenses/LICENSE +0 -0
@@ -1,201 +0,0 @@
1
- from __future__ import annotations
2
-
3
- __all__ = []
4
-
5
- from pathlib import Path
6
- from typing import TYPE_CHECKING, Any, Literal, Sequence, TypeVar
7
-
8
- import numpy as np
9
- from numpy.typing import NDArray
10
-
11
- from dataeval.utils.datasets._base import BaseICDataset, DataLocation
12
- from dataeval.utils.datasets._mixin import BaseDatasetNumpyMixin
13
-
14
- if TYPE_CHECKING:
15
- from dataeval.typing import Transform
16
-
17
- CIFARClassStringMap = Literal["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
18
- TCIFARClassMap = TypeVar("TCIFARClassMap", CIFARClassStringMap, int, list[CIFARClassStringMap], list[int])
19
-
20
-
21
- class CIFAR10(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
22
- """
23
- `CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset as NumPy arrays.
24
-
25
- Parameters
26
- ----------
27
- root : str or pathlib.Path
28
- Root directory where the data should be downloaded to or the ``cifar10`` folder of the already downloaded data.
29
- image_set : "train", "test" or "base", default "train"
30
- If "base", returns all of the data to allow the user to create their own splits.
31
- transforms : Transform, Sequence[Transform] or None, default None
32
- Transform(s) to apply to the data.
33
- download : bool, default False
34
- If True, downloads the dataset from the internet and puts it in root directory.
35
- Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
36
- verbose : bool, default False
37
- If True, outputs print statements.
38
-
39
- Attributes
40
- ----------
41
- path : pathlib.Path
42
- Location of the folder containing the data.
43
- image_set : "train", "test" or "base"
44
- The selected image set from the dataset.
45
- transforms : Sequence[Transform]
46
- The transforms to be applied to the data.
47
- size : int
48
- The size of the dataset.
49
- index2label : dict[int, str]
50
- Dictionary which translates from class integers to the associated class strings.
51
- label2index : dict[str, int]
52
- Dictionary which translates from class strings to the associated class integers.
53
- metadata : DatasetMetadata
54
- Typed dictionary containing dataset metadata, such as `id` which returns the dataset class name.
55
- """
56
-
57
- _resources = [
58
- DataLocation(
59
- url="https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz",
60
- filename="cifar-10-binary.tar.gz",
61
- md5=True,
62
- checksum="c32a1d4ab5d03f1284b67883e8d87530",
63
- ),
64
- ]
65
-
66
- index2label: dict[int, str] = {
67
- 0: "airplane",
68
- 1: "automobile",
69
- 2: "bird",
70
- 3: "cat",
71
- 4: "deer",
72
- 5: "dog",
73
- 6: "frog",
74
- 7: "horse",
75
- 8: "ship",
76
- 9: "truck",
77
- }
78
-
79
- def __init__(
80
- self,
81
- root: str | Path,
82
- image_set: Literal["train", "test", "base"] = "train",
83
- transforms: Transform[NDArray[Any]] | Sequence[Transform[NDArray[Any]]] | None = None,
84
- download: bool = False,
85
- verbose: bool = False,
86
- ) -> None:
87
- super().__init__(
88
- root,
89
- image_set,
90
- transforms,
91
- download,
92
- verbose,
93
- )
94
-
95
- def _load_bin_data(self, data_folder: list[Path]) -> tuple[list[str], list[int], dict[str, Any]]:
96
- batch_nums = np.zeros(60000, dtype=np.uint8)
97
- all_labels = np.zeros(60000, dtype=np.uint8)
98
- all_images = np.zeros((60000, 3, 32, 32), dtype=np.uint8)
99
- # Process each batch file, skipping .meta and .html files
100
- for batch_file in data_folder:
101
- # Get batch parameters
102
- batch_type = "test" if "test" in batch_file.stem else "train"
103
- batch_num = 5 if batch_type == "test" else int(batch_file.stem.split("_")[-1]) - 1
104
-
105
- # Load data
106
- batch_images, batch_labels = self._unpack_batch_files(batch_file)
107
-
108
- # Stack data
109
- num_images = batch_images.shape[0]
110
- batch_start = batch_num * num_images
111
- all_images[batch_start : batch_start + num_images] = batch_images
112
- all_labels[batch_start : batch_start + num_images] = batch_labels
113
- batch_nums[batch_start : batch_start + num_images] = batch_num
114
-
115
- # Save data
116
- self._loaded_data = all_images
117
- np.savez(self.path / "cifar10", images=self._loaded_data, labels=all_labels, batches=batch_nums)
118
-
119
- # Select data
120
- image_list = np.arange(all_labels.shape[0]).astype(str)
121
- if self.image_set == "train":
122
- return (
123
- image_list[np.nonzero(batch_nums != 5)[0]].tolist(),
124
- all_labels[batch_nums != 5].tolist(),
125
- {"batch_num": batch_nums[batch_nums != 5].tolist()},
126
- )
127
- if self.image_set == "test":
128
- return (
129
- image_list[np.nonzero(batch_nums == 5)[0]].tolist(),
130
- all_labels[batch_nums == 5].tolist(),
131
- {"batch_num": batch_nums[batch_nums == 5].tolist()},
132
- )
133
- return image_list.tolist(), all_labels.tolist(), {"batch_num": batch_nums.tolist()}
134
-
135
- def _load_data_inner(self) -> tuple[list[str], list[int], dict[str, Any]]:
136
- """Function to load in the file paths for the data and labels and retrieve metadata"""
137
- data_file = self.path / "cifar10.npz"
138
- if not data_file.exists():
139
- data_folder = sorted((self.path / "cifar-10-batches-bin").glob("*.bin"))
140
- if not data_folder:
141
- raise FileNotFoundError
142
- return self._load_bin_data(data_folder)
143
-
144
- # Load data
145
- data = np.load(data_file)
146
- self._loaded_data = data["images"]
147
- all_labels = data["labels"]
148
- batch_nums = data["batches"]
149
-
150
- # Select data
151
- image_list = np.arange(all_labels.shape[0]).astype(str)
152
- if self.image_set == "train":
153
- return (
154
- image_list[np.nonzero(batch_nums != 5)[0]].tolist(),
155
- all_labels[batch_nums != 5].tolist(),
156
- {"batch_num": batch_nums[batch_nums != 5].tolist()},
157
- )
158
- if self.image_set == "test":
159
- return (
160
- image_list[np.nonzero(batch_nums == 5)[0]].tolist(),
161
- all_labels[batch_nums == 5].tolist(),
162
- {"batch_num": batch_nums[batch_nums == 5].tolist()},
163
- )
164
- return image_list.tolist(), all_labels.tolist(), {"batch_num": batch_nums.tolist()}
165
-
166
- def _unpack_batch_files(self, file_path: Path) -> tuple[NDArray[np.uint8], NDArray[np.uint8]]:
167
- # Load pickle data with latin1 encoding
168
- with file_path.open("rb") as f:
169
- buffer = np.frombuffer(f.read(), dtype=np.uint8)
170
- # Each entry is 1 byte for label + 3072 bytes for image (3*32*32)
171
- entry_size = 1 + 3072
172
- num_entries = buffer.size // entry_size
173
- # Extract labels (first byte of each entry)
174
- labels = buffer[::entry_size]
175
-
176
- # Extract image data and reshape to (N, 3, 32, 32)
177
- images = np.zeros((num_entries, 3, 32, 32), dtype=np.uint8)
178
- for i in range(num_entries):
179
- # Skip the label byte and get image data for this entry
180
- start_idx = i * entry_size + 1 # +1 to skip label
181
- img_flat = buffer[start_idx : start_idx + 3072]
182
-
183
- # The CIFAR format stores channels in blocks (all R, then all G, then all B)
184
- # Each channel block is 1024 bytes (32x32)
185
- red_channel = img_flat[0:1024].reshape(32, 32)
186
- green_channel = img_flat[1024:2048].reshape(32, 32)
187
- blue_channel = img_flat[2048:3072].reshape(32, 32)
188
-
189
- # Stack the channels in the proper C×H×W format
190
- images[i, 0] = red_channel # Red channel
191
- images[i, 1] = green_channel # Green channel
192
- images[i, 2] = blue_channel # Blue channel
193
- return images, labels
194
-
195
- def _read_file(self, path: str) -> NDArray[Any]:
196
- """
197
- Function to grab the correct image from the loaded data.
198
- Overwrite of the base `_read_file` because data is an all or nothing load.
199
- """
200
- index = int(path)
201
- return self._loaded_data[index]
@@ -1,142 +0,0 @@
1
- from __future__ import annotations
2
-
3
- __all__ = []
4
-
5
- import hashlib
6
- import tarfile
7
- import zipfile
8
- from pathlib import Path
9
-
10
- import requests
11
- from tqdm.auto import tqdm
12
-
13
- ARCHIVE_ENDINGS = [".zip", ".tar", ".tgz"]
14
- COMPRESS_ENDINGS = [".gz", ".bz2"]
15
-
16
-
17
- def _print(text: str, verbose: bool) -> None:
18
- if verbose:
19
- print(text)
20
-
21
-
22
- def _validate_file(fpath: Path | str, file_md5: str, md5: bool = False, chunk_size: int = 65535) -> bool:
23
- hasher = hashlib.md5(usedforsecurity=False) if md5 else hashlib.sha256()
24
- with open(fpath, "rb") as fpath_file:
25
- while chunk := fpath_file.read(chunk_size):
26
- hasher.update(chunk)
27
- return hasher.hexdigest() == file_md5
28
-
29
-
30
- def _download_dataset(url: str, file_path: Path, timeout: int = 60, verbose: bool = False) -> None:
31
- """Download a single resource from its URL to the `data_folder`."""
32
- error_msg = "URL fetch failure on {}: {} -- {}"
33
- try:
34
- response = requests.get(url, stream=True, timeout=timeout)
35
- response.raise_for_status()
36
- except requests.exceptions.HTTPError as e:
37
- raise RuntimeError(f"{error_msg.format(url, e.response.status_code, e.response.reason)}") from e
38
- except requests.exceptions.RequestException as e:
39
- raise ValueError(f"{error_msg.format(url, 'Unknown error', str(e))}") from e
40
-
41
- total_size = int(response.headers.get("content-length", 0))
42
- block_size = 8192 # 8 KB
43
- progress_bar = tqdm(total=total_size, unit="iB", unit_scale=True, disable=not verbose)
44
-
45
- with open(file_path, "wb") as f:
46
- for chunk in response.iter_content(block_size):
47
- f.write(chunk)
48
- progress_bar.update(len(chunk))
49
- progress_bar.close()
50
-
51
-
52
- def _extract_zip_archive(file_path: Path, extract_to: Path) -> None:
53
- """Extracts the zip file to the given directory."""
54
- try:
55
- with zipfile.ZipFile(file_path, "r") as zip_ref:
56
- zip_ref.extractall(extract_to) # noqa: S202
57
- file_path.unlink()
58
- except zipfile.BadZipFile:
59
- raise FileNotFoundError(f"{file_path.name} is not a valid zip file, skipping extraction.")
60
-
61
-
62
- def _extract_tar_archive(file_path: Path, extract_to: Path) -> None:
63
- """Extracts a tar file (or compressed tar) to the specified directory."""
64
- try:
65
- with tarfile.open(file_path, "r:*") as tar_ref:
66
- tar_ref.extractall(extract_to) # noqa: S202
67
- file_path.unlink()
68
- except tarfile.TarError:
69
- raise FileNotFoundError(f"{file_path.name} is not a valid tar file, skipping extraction.")
70
-
71
-
72
- def _extract_archive(
73
- file_ext: str, file_path: Path, directory: Path, compression: bool = False, verbose: bool = False
74
- ) -> None:
75
- """
76
- Single function to extract and then flatten if necessary.
77
- Recursively extracts nested zip files as well.
78
- Extracts and flattens all folders to the base directory.
79
- """
80
- if file_ext != ".zip" or compression:
81
- _extract_tar_archive(file_path, directory)
82
- else:
83
- _extract_zip_archive(file_path, directory)
84
- # Look for nested zip files in the extraction directory and extract them recursively.
85
- # Does NOT extract in place - extracts everything to directory
86
- for child in directory.iterdir():
87
- if child.suffix == ".zip":
88
- _print(f"Extracting nested zip: {child} to {directory}", verbose)
89
- _extract_zip_archive(child, directory)
90
-
91
-
92
- def _ensure_exists(
93
- url: str,
94
- filename: str,
95
- md5: bool,
96
- checksum: str,
97
- directory: Path,
98
- root: Path,
99
- download: bool = True,
100
- verbose: bool = False,
101
- ) -> None:
102
- """
103
- For each resource, download it if it doesn't exist in the dataset_dir.
104
- If the resource is a zip file, extract it (including recursively extracting nested zips).
105
- """
106
- file_path = directory / str(filename)
107
- alternate_path = root / str(filename)
108
- _, file_ext = file_path.stem, file_path.suffix
109
- compression = False
110
- if file_ext in COMPRESS_ENDINGS:
111
- file_ext = file_path.suffixes[0]
112
- compression = True
113
-
114
- check_path = alternate_path if alternate_path.exists() and not file_path.exists() else file_path
115
-
116
- # Download file if it doesn't exist.
117
- if not check_path.exists() and download:
118
- _print(f"Downloading {filename} from {url}", verbose)
119
- _download_dataset(url, check_path, verbose=verbose)
120
-
121
- if not _validate_file(check_path, checksum, md5):
122
- raise Exception("File checksum mismatch. Remove current file and retry download.")
123
-
124
- # If the file is a zip, tar or tgz extract it into the designated folder.
125
- if file_ext in ARCHIVE_ENDINGS:
126
- _print(f"Extracting {filename}...", verbose)
127
- _extract_archive(file_ext, check_path, directory, compression, verbose)
128
-
129
- elif not check_path.exists() and not download:
130
- raise FileNotFoundError(
131
- "Data could not be loaded with the provided root directory, "
132
- f"the file path to the file {filename} does not exist, "
133
- "and the download parameter is set to False."
134
- )
135
- else:
136
- if not _validate_file(check_path, checksum, md5):
137
- raise Exception("File checksum mismatch. Remove current file and retry download.")
138
- _print(f"{filename} already exists, skipping download.", verbose)
139
-
140
- if file_ext in ARCHIVE_ENDINGS:
141
- _print(f"Extracting {filename}...", verbose)
142
- _extract_archive(file_ext, check_path, directory, compression, verbose)
@@ -1,197 +0,0 @@
1
- from __future__ import annotations
2
-
3
- __all__ = []
4
-
5
- from pathlib import Path
6
- from typing import TYPE_CHECKING, Any, Literal, Sequence
7
-
8
- from numpy.typing import NDArray
9
-
10
- from dataeval.utils.datasets._base import BaseODDataset, DataLocation
11
- from dataeval.utils.datasets._mixin import BaseDatasetNumpyMixin
12
-
13
- if TYPE_CHECKING:
14
- from dataeval.typing import Transform
15
-
16
-
17
- class MILCO(BaseODDataset[NDArray[Any], list[str], str], BaseDatasetNumpyMixin):
18
- """
19
- A side-scan sonar dataset focused on mine-like object detection.
20
-
21
- The dataset comes from the paper
22
- `Side-scan sonar imaging data of underwater vehicles for mine detection <https://doi.org/10.1016/j.dib.2024.110132>`_
23
- by N.P. Santos et. al. (2024).
24
-
25
- The full dataset contains 1170 side-scan sonar images collected using a 900-1800 kHz Marine Sonic
26
- dual frequency side-scan sonar of a Teledyne Marine Gavia Autonomous Underwater Vehicle.
27
- All the images were carefully analyzed and annotated, including the image coordinates of the
28
- Bounding Box (BB) of the detected objects divided into NOn-Mine-like BOttom Objects (NOMBO)
29
- and MIne-Like COntacts (MILCO) classes.
30
-
31
- This dataset is consists of 345 images from 2010, 120 images from 2015, 93 images from 2017, 564 images from 2018,
32
- and 48 images from 2021). In these 1170 images, there are 432 MILCO objects, and 235 NOMBO objects.
33
- The class “0” corresponds to a MILCO object and the class “1” corresponds to a NOMBO object.
34
- The raw BB coordinates provided in the downloaded text files are (x, y, w, h),
35
- given as percentages of the image (x_BB = x/img_width, y_BB = y/img_height, etc.).
36
- The images come in 2 sizes, 416 x 416 or 1024 x 1024.
37
-
38
- Parameters
39
- ----------
40
- root : str or pathlib.Path
41
- Root directory where the data should be downloaded to or the ``milco`` folder of the already downloaded data.
42
- image_set: "train", "operational", or "base", default "train"
43
- If "train", then the images from 2015, 2017 and 2021 are selected,
44
- resulting in 315 MILCO objects and 177 NOMBO objects.
45
- If "operational", then the images from 2010 and 2018 are selected,
46
- resulting in 117 MILCO objects and 58 NOMBO objects.
47
- If "base", then the full dataset is selected.
48
- transforms : Transform, Sequence[Transform] or None, default None
49
- Transform(s) to apply to the data.
50
- download : bool, default False
51
- If True, downloads the dataset from the internet and puts it in root directory.
52
- Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
53
- verbose : bool, default False
54
- If True, outputs print statements.
55
-
56
- Attributes
57
- ----------
58
- path : pathlib.Path
59
- Location of the folder containing the data.
60
- image_set : "train", "operational" or "base"
61
- The selected image set from the dataset.
62
- index2label : dict[int, str]
63
- Dictionary which translates from class integers to the associated class strings.
64
- label2index : dict[str, int]
65
- Dictionary which translates from class strings to the associated class integers.
66
- metadata : DatasetMetadata
67
- Typed dictionary containing dataset metadata, such as `id` which returns the dataset class name.
68
- transforms : Sequence[Transform]
69
- The transforms to be applied to the data.
70
- size : int
71
- The size of the dataset.
72
-
73
- Note
74
- ----
75
- Data License: `CC BY 4.0 <https://creativecommons.org/licenses/by/4.0/>`_
76
- """
77
-
78
- _resources = [
79
- DataLocation(
80
- url="https://figshare.com/ndownloader/files/43169002",
81
- filename="2015.zip",
82
- md5=True,
83
- checksum="93dfbb4fb7987734152c372496b4884c",
84
- ),
85
- DataLocation(
86
- url="https://figshare.com/ndownloader/files/43169005",
87
- filename="2017.zip",
88
- md5=True,
89
- checksum="9c2de230a2bbf654921416bea6fc0f42",
90
- ),
91
- DataLocation(
92
- url="https://figshare.com/ndownloader/files/43168999",
93
- filename="2021.zip",
94
- md5=True,
95
- checksum="b84749b21fa95a4a4c7de3741db78bc7",
96
- ),
97
- DataLocation(
98
- url="https://figshare.com/ndownloader/files/43169008",
99
- filename="2010.zip",
100
- md5=True,
101
- checksum="43347a0cc383c0d3dbe0d24ae56f328d",
102
- ),
103
- DataLocation(
104
- url="https://figshare.com/ndownloader/files/43169011",
105
- filename="2018.zip",
106
- md5=True,
107
- checksum="25d091044a10c78674fedad655023e3b",
108
- ),
109
- ]
110
-
111
- index2label: dict[int, str] = {
112
- 0: "MILCO",
113
- 1: "NOMBO",
114
- }
115
-
116
- def __init__(
117
- self,
118
- root: str | Path,
119
- image_set: Literal["train", "operational", "base"] = "train",
120
- transforms: Transform[NDArray[Any]] | Sequence[Transform[NDArray[Any]]] | None = None,
121
- download: bool = False,
122
- verbose: bool = False,
123
- ) -> None:
124
- super().__init__(
125
- root,
126
- image_set,
127
- transforms,
128
- download,
129
- verbose,
130
- )
131
- self._bboxes_per_size = True
132
-
133
- def _load_data(self) -> tuple[list[str], list[str], dict[str, list[Any]]]:
134
- filepaths: list[str] = []
135
- targets: list[str] = []
136
- datum_metadata: dict[str, list[Any]] = {}
137
- metadata_list: list[dict[str, Any]] = []
138
- image_sets: dict[str, list[int]] = {
139
- "base": list(range(len(self._resources))),
140
- "train": list(range(3)),
141
- "operational": list(range(3, len(self._resources))),
142
- }
143
-
144
- # Load the data
145
- resource_indices = image_sets[self.image_set]
146
- for idx in resource_indices:
147
- self._resource = self._resources[idx]
148
- filepath, target, metadata = super()._load_data()
149
- filepaths.extend(filepath)
150
- targets.extend(target)
151
- metadata_list.append(metadata)
152
-
153
- # Adjust datum metadata to correct format
154
- for data_dict in metadata_list:
155
- for key, val in data_dict.items():
156
- if key not in datum_metadata:
157
- datum_metadata[str(key)] = []
158
- datum_metadata[str(key)].extend(val)
159
-
160
- return filepaths, targets, datum_metadata
161
-
162
- def _load_data_inner(self) -> tuple[list[str], list[str], dict[str, Any]]:
163
- file_data = {"year": [], "image_id": [], "data_path": [], "label_path": []}
164
- data_folder = sorted((self.path / self._resource.filename[:-4]).glob("*.jpg"))
165
- if not data_folder:
166
- raise FileNotFoundError
167
-
168
- for entry in data_folder:
169
- # Remove file extension and split by "_"
170
- parts = entry.stem.split("_")
171
- file_data["image_id"].append(parts[0])
172
- file_data["year"].append(parts[1])
173
- file_data["data_path"].append(str(entry))
174
- file_data["label_path"].append(str(entry.parent / entry.stem) + ".txt")
175
- data = file_data.pop("data_path")
176
- annotations = file_data.pop("label_path")
177
-
178
- return data, annotations, file_data
179
-
180
- def _read_annotations(self, annotation: str) -> tuple[list[list[float]], list[int], dict[str, Any]]:
181
- """Function for extracting the info out of the text files"""
182
- labels: list[int] = []
183
- boxes: list[list[float]] = []
184
- with open(annotation) as f:
185
- for line in f.readlines():
186
- out = line.strip().split()
187
- labels.append(int(out[0]))
188
-
189
- xcenter, ycenter, width, height = [float(out[1]), float(out[2]), float(out[3]), float(out[4])]
190
-
191
- x0 = xcenter - width / 2
192
- x1 = x0 + width
193
- y0 = ycenter - height / 2
194
- y1 = y0 + height
195
- boxes.append([x0, y0, x1, y1])
196
-
197
- return boxes, labels, {}
@@ -1,54 +0,0 @@
1
- from __future__ import annotations
2
-
3
- __all__ = []
4
-
5
- from typing import Any, Generic, TypeVar
6
-
7
- import numpy as np
8
- import torch
9
- from numpy.typing import NDArray
10
- from PIL import Image
11
-
12
- _TArray = TypeVar("_TArray")
13
-
14
-
15
- class BaseDatasetMixin(Generic[_TArray]):
16
- index2label: dict[int, str]
17
-
18
- def _as_array(self, raw: list[Any]) -> _TArray: ...
19
- def _one_hot_encode(self, value: int | list[int]) -> _TArray: ...
20
- def _read_file(self, path: str) -> _TArray: ...
21
-
22
-
23
- class BaseDatasetNumpyMixin(BaseDatasetMixin[NDArray[Any]]):
24
- def _as_array(self, raw: list[Any]) -> NDArray[Any]:
25
- return np.asarray(raw)
26
-
27
- def _one_hot_encode(self, value: int | list[int]) -> NDArray[Any]:
28
- if isinstance(value, int):
29
- encoded = np.zeros(len(self.index2label))
30
- encoded[value] = 1
31
- else:
32
- encoded = np.zeros((len(value), len(self.index2label)))
33
- encoded[np.arange(len(value)), value] = 1
34
- return encoded
35
-
36
- def _read_file(self, path: str) -> NDArray[Any]:
37
- return np.array(Image.open(path)).transpose(2, 0, 1)
38
-
39
-
40
- class BaseDatasetTorchMixin(BaseDatasetMixin[torch.Tensor]):
41
- def _as_array(self, raw: list[Any]) -> torch.Tensor:
42
- return torch.as_tensor(raw)
43
-
44
- def _one_hot_encode(self, value: int | list[int]) -> torch.Tensor:
45
- if isinstance(value, int):
46
- encoded = torch.zeros(len(self.index2label))
47
- encoded[value] = 1
48
- else:
49
- encoded = torch.zeros((len(value), len(self.index2label)))
50
- encoded[torch.arange(len(value)), value] = 1
51
- return encoded
52
-
53
- def _read_file(self, path: str) -> torch.Tensor:
54
- return torch.as_tensor(np.array(Image.open(path)).transpose(2, 0, 1))