dataeval 0.86.0__py3-none-any.whl → 0.86.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. dataeval/__init__.py +1 -1
  2. dataeval/_log.py +1 -1
  3. dataeval/config.py +21 -4
  4. dataeval/data/_embeddings.py +2 -2
  5. dataeval/data/_images.py +2 -3
  6. dataeval/data/_metadata.py +188 -178
  7. dataeval/data/_selection.py +1 -2
  8. dataeval/data/_split.py +4 -5
  9. dataeval/data/_targets.py +17 -13
  10. dataeval/data/selections/_classfilter.py +2 -5
  11. dataeval/data/selections/_prioritize.py +6 -9
  12. dataeval/data/selections/_shuffle.py +3 -1
  13. dataeval/detectors/drift/_base.py +4 -5
  14. dataeval/detectors/drift/_mmd.py +3 -6
  15. dataeval/detectors/drift/_nml/_base.py +4 -2
  16. dataeval/detectors/drift/_nml/_chunk.py +11 -19
  17. dataeval/detectors/drift/_nml/_domainclassifier.py +8 -19
  18. dataeval/detectors/drift/_nml/_result.py +8 -9
  19. dataeval/detectors/drift/_nml/_thresholds.py +66 -77
  20. dataeval/detectors/linters/outliers.py +7 -7
  21. dataeval/metadata/_distance.py +10 -7
  22. dataeval/metadata/_ood.py +11 -103
  23. dataeval/metrics/bias/_balance.py +23 -33
  24. dataeval/metrics/bias/_diversity.py +16 -14
  25. dataeval/metrics/bias/_parity.py +18 -18
  26. dataeval/metrics/estimators/_divergence.py +2 -4
  27. dataeval/metrics/stats/_base.py +103 -42
  28. dataeval/metrics/stats/_boxratiostats.py +21 -19
  29. dataeval/metrics/stats/_dimensionstats.py +14 -10
  30. dataeval/metrics/stats/_hashstats.py +1 -1
  31. dataeval/metrics/stats/_pixelstats.py +6 -6
  32. dataeval/metrics/stats/_visualstats.py +3 -3
  33. dataeval/outputs/_base.py +22 -7
  34. dataeval/outputs/_bias.py +24 -70
  35. dataeval/outputs/_drift.py +1 -9
  36. dataeval/outputs/_linters.py +11 -11
  37. dataeval/outputs/_stats.py +82 -23
  38. dataeval/outputs/_workflows.py +2 -2
  39. dataeval/utils/_array.py +6 -9
  40. dataeval/utils/_bin.py +1 -2
  41. dataeval/utils/_clusterer.py +7 -4
  42. dataeval/utils/_fast_mst.py +27 -13
  43. dataeval/utils/_image.py +65 -11
  44. dataeval/utils/_mst.py +1 -3
  45. dataeval/utils/_plot.py +15 -10
  46. dataeval/utils/data/_dataset.py +54 -28
  47. dataeval/utils/data/metadata.py +104 -82
  48. dataeval/utils/datasets/__init__.py +2 -0
  49. dataeval/utils/datasets/_antiuav.py +189 -0
  50. dataeval/utils/datasets/_base.py +11 -8
  51. dataeval/utils/datasets/_cifar10.py +104 -45
  52. dataeval/utils/datasets/_fileio.py +21 -47
  53. dataeval/utils/datasets/_milco.py +22 -12
  54. dataeval/utils/datasets/_mixin.py +2 -4
  55. dataeval/utils/datasets/_mnist.py +3 -4
  56. dataeval/utils/datasets/_ships.py +14 -7
  57. dataeval/utils/datasets/_voc.py +229 -42
  58. dataeval/utils/torch/models.py +5 -10
  59. dataeval/utils/torch/trainer.py +3 -3
  60. dataeval/workflows/sufficiency.py +2 -2
  61. {dataeval-0.86.0.dist-info → dataeval-0.86.2.dist-info}/METADATA +2 -1
  62. dataeval-0.86.2.dist-info/RECORD +114 -0
  63. dataeval/detectors/ood/vae.py +0 -74
  64. dataeval-0.86.0.dist-info/RECORD +0 -114
  65. {dataeval-0.86.0.dist-info → dataeval-0.86.2.dist-info}/LICENSE.txt +0 -0
  66. {dataeval-0.86.0.dist-info → dataeval-0.86.2.dist-info}/WHEEL +0 -0
@@ -0,0 +1,189 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ from pathlib import Path
6
+ from typing import TYPE_CHECKING, Any, Literal, Sequence
7
+
8
+ from defusedxml.ElementTree import parse
9
+ from numpy.typing import NDArray
10
+
11
+ from dataeval.utils.datasets._base import BaseODDataset, DataLocation
12
+ from dataeval.utils.datasets._mixin import BaseDatasetNumpyMixin
13
+
14
+ if TYPE_CHECKING:
15
+ from dataeval.typing import Transform
16
+
17
+
18
+ class AntiUAVDetection(BaseODDataset[NDArray[Any]], BaseDatasetNumpyMixin):
19
+ """
20
+ A UAV detection dataset focused on detecting UAVs in natural images against large variation in backgrounds.
21
+
22
+ The dataset comes from the paper
23
+ `Vision-based Anti-UAV Detection and Tracking <https://ieeexplore.ieee.org/document/9785379>`_
24
+ by Jie Zhao et. al. (2022).
25
+
26
+ The dataset is approximately 1.3 GB and can be found `here <https://github.com/wangdongdut/DUT-Anti-UAV>`_.
27
+ Images are collected against a variety of different backgrounds with a variety in the number and type of UAV.
28
+ Ground truth labels are provided for the train, validation and test set.
29
+ There are 35 different types of drones along with a variety in lighting conditions and weather conditions.
30
+
31
+ There are 10,000 images: 5200 images in the training set, 2200 images in the validation set,
32
+ and 2600 images in the test set.
33
+ The dataset only has a single UAV class with the focus being on identifying object location in the image.
34
+ Ground-truth bounding boxes are provided in (x0, y0, x1, y1) format.
35
+ The images come in a variety of sizes from 3744 x 5616 to 160 x 240.
36
+
37
+ Parameters
38
+ ----------
39
+ root : str or pathlib.Path
40
+ Root directory where the data should be downloaded to or
41
+ the ``antiuavdetection`` folder of the already downloaded data.
42
+ image_set: "train", "val", "test", or "base", default "train"
43
+ If "base", then the full dataset is selected (train, val and test).
44
+ transforms : Transform, Sequence[Transform] or None, default None
45
+ Transform(s) to apply to the data.
46
+ download : bool, default False
47
+ If True, downloads the dataset from the internet and puts it in root directory.
48
+ Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
49
+ verbose : bool, default False
50
+ If True, outputs print statements.
51
+
52
+ Attributes
53
+ ----------
54
+ path : pathlib.Path
55
+ Location of the folder containing the data.
56
+ image_set : "train", "val", "test", or "base"
57
+ The selected image set from the dataset.
58
+ index2label : dict[int, str]
59
+ Dictionary which translates from class integers to the associated class strings.
60
+ label2index : dict[str, int]
61
+ Dictionary which translates from class strings to the associated class integers.
62
+ metadata : DatasetMetadata
63
+ Typed dictionary containing dataset metadata, such as `id` which returns the dataset class name.
64
+ transforms : Sequence[Transform]
65
+ The transforms to be applied to the data.
66
+ size : int
67
+ The size of the dataset.
68
+
69
+ Note
70
+ ----
71
+ Data License: `Apache 2.0 <https://www.apache.org/licenses/LICENSE-2.0.txt>`_
72
+ """
73
+
74
+ # Need to run the sha256 on the files and then store that
75
+ _resources = [
76
+ DataLocation(
77
+ url="https://drive.usercontent.google.com/download?id=1RVsSGPUKTdmoyoPTBTWwroyulLek1eTj&export=download&authuser=0&confirm=t&uuid=6bca4f94-a242-4bc2-9663-fb03cd94ef2c&at=APcmpox0--NroQ_3bqeTFaJxP7Pw%3A1746552902927",
78
+ filename="train.zip",
79
+ md5=False,
80
+ checksum="14f927290556df60e23cedfa80dffc10dc21e4a3b6843e150cfc49644376eece",
81
+ ),
82
+ DataLocation(
83
+ url="https://drive.usercontent.google.com/download?id=1333uEQfGuqTKslRkkeLSCxylh6AQ0X6n&export=download&authuser=0&confirm=t&uuid=c2ad2f01-aca8-4a85-96bb-b8ef6e40feea&at=APcmpozY-8bhk3nZSFaYbE8rq1Fi%3A1746551543297",
84
+ filename="val.zip",
85
+ md5=False,
86
+ checksum="238be0ceb3e7c5be6711ee3247e49df2750d52f91f54f5366c68bebac112ebf8",
87
+ ),
88
+ DataLocation(
89
+ url="https://drive.usercontent.google.com/download?id=1L1zeW1EMDLlXHClSDcCjl3rs_A6sVai0&export=download&authuser=0&confirm=t&uuid=5a1d7650-d8cd-4461-8354-7daf7292f06c&at=APcmpozLQC1CuP-n5_UX2JnP53Zo%3A1746551676177",
90
+ filename="test.zip",
91
+ md5=False,
92
+ checksum="a671989a01cff98c684aeb084e59b86f4152c50499d86152eb970a9fc7fb1cbe",
93
+ ),
94
+ ]
95
+
96
+ index2label: dict[int, str] = {
97
+ 0: "unknown",
98
+ 1: "UAV",
99
+ }
100
+
101
+ def __init__(
102
+ self,
103
+ root: str | Path,
104
+ image_set: Literal["train", "val", "test", "base"] = "train",
105
+ transforms: Transform[NDArray[Any]] | Sequence[Transform[NDArray[Any]]] | None = None,
106
+ download: bool = False,
107
+ verbose: bool = False,
108
+ ) -> None:
109
+ super().__init__(
110
+ root,
111
+ image_set,
112
+ transforms,
113
+ download,
114
+ verbose,
115
+ )
116
+
117
+ def _load_data(self) -> tuple[list[str], list[str], dict[str, list[Any]]]:
118
+ filepaths: list[str] = []
119
+ targets: list[str] = []
120
+ datum_metadata: dict[str, list[Any]] = {}
121
+
122
+ # If base, load all resources
123
+ if self.image_set == "base":
124
+ metadata_list: list[dict[str, Any]] = []
125
+
126
+ for resource in self._resources:
127
+ self._resource = resource
128
+ resource_filepaths, resource_targets, resource_metadata = super()._load_data()
129
+ filepaths.extend(resource_filepaths)
130
+ targets.extend(resource_targets)
131
+ metadata_list.append(resource_metadata)
132
+
133
+ # Combine metadata
134
+ for data_dict in metadata_list:
135
+ for key, val in data_dict.items():
136
+ str_key = str(key) # Ensure key is string
137
+ if str_key not in datum_metadata:
138
+ datum_metadata[str_key] = []
139
+ datum_metadata[str_key].extend(val)
140
+
141
+ else:
142
+ # Grab only the desired data
143
+ for resource in self._resources:
144
+ if self.image_set in resource.filename:
145
+ self._resource = resource
146
+ resource_filepaths, resource_targets, resource_metadata = super()._load_data()
147
+ filepaths.extend(resource_filepaths)
148
+ targets.extend(resource_targets)
149
+ datum_metadata.update(resource_metadata)
150
+
151
+ return filepaths, targets, datum_metadata
152
+
153
+ def _load_data_inner(self) -> tuple[list[str], list[str], dict[str, Any]]:
154
+ resource_name = self._resource.filename[:-4]
155
+ base_dir = self.path / resource_name
156
+ data_folder = sorted((base_dir / "img").glob("*.jpg"))
157
+ if not data_folder:
158
+ raise FileNotFoundError
159
+
160
+ file_data = {"image_id": [f"{resource_name}_{entry.name}" for entry in data_folder]}
161
+ data = [str(entry) for entry in data_folder]
162
+ annotations = sorted(str(entry) for entry in (base_dir / "xml").glob("*.xml"))
163
+
164
+ return data, annotations, file_data
165
+
166
+ def _read_annotations(self, annotation: str) -> tuple[list[list[float]], list[int], dict[str, Any]]:
167
+ """Function for extracting the info for the label and boxes"""
168
+ boxes: list[list[float]] = []
169
+ labels = []
170
+ root = parse(annotation).getroot()
171
+ if root is None:
172
+ raise ValueError(f"Unable to parse {annotation}")
173
+ additional_meta: dict[str, Any] = {
174
+ "image_width": int(root.findtext("size/width", default="-1")),
175
+ "image_height": int(root.findtext("size/height", default="-1")),
176
+ "image_depth": int(root.findtext("size/depth", default="-1")),
177
+ }
178
+ for obj in root.findall("object"):
179
+ labels.append(1 if obj.findtext("name", default="") == "UAV" else 0)
180
+ boxes.append(
181
+ [
182
+ float(obj.findtext("bndbox/xmin", default="0")),
183
+ float(obj.findtext("bndbox/ymin", default="0")),
184
+ float(obj.findtext("bndbox/xmax", default="0")),
185
+ float(obj.findtext("bndbox/ymax", default="0")),
186
+ ]
187
+ )
188
+
189
+ return boxes, labels, additional_meta
@@ -6,6 +6,8 @@ from abc import abstractmethod
6
6
  from pathlib import Path
7
7
  from typing import TYPE_CHECKING, Any, Generic, Iterator, Literal, NamedTuple, Sequence, TypeVar
8
8
 
9
+ import numpy as np
10
+
9
11
  from dataeval.utils.datasets._fileio import _ensure_exists
10
12
  from dataeval.utils.datasets._mixin import BaseDatasetMixin
11
13
  from dataeval.utils.datasets._types import (
@@ -101,11 +103,7 @@ class BaseDataset(AnnotatedDataset[tuple[_TArray, _TTarget, dict[str, Any]]], Ge
101
103
 
102
104
  def _get_dataset_dir(self) -> Path:
103
105
  # Create a designated folder for this dataset (named after the class)
104
- if self._root.stem in [
105
- self.__class__.__name__.lower(),
106
- self.__class__.__name__.upper(),
107
- self.__class__.__name__,
108
- ]:
106
+ if self._root.stem.lower() == self.__class__.__name__.lower():
109
107
  dataset_dir: Path = self._root
110
108
  else:
111
109
  dataset_dir: Path = self._root / self.__class__.__name__.lower()
@@ -114,8 +112,7 @@ class BaseDataset(AnnotatedDataset[tuple[_TArray, _TTarget, dict[str, Any]]], Ge
114
112
  return dataset_dir
115
113
 
116
114
  def _unique_id(self) -> str:
117
- unique_id = f"{self.__class__.__name__}_{self.image_set}"
118
- return unique_id
115
+ return f"{self.__class__.__name__}_{self.image_set}"
119
116
 
120
117
  def _load_data(self) -> tuple[list[str], _TRawTarget, dict[str, Any]]:
121
118
  """
@@ -188,6 +185,8 @@ class BaseODDataset(
188
185
  Base class for object detection datasets.
189
186
  """
190
187
 
188
+ _bboxes_per_size: bool = False
189
+
191
190
  def __getitem__(self, index: int) -> tuple[_TArray, ObjectDetectionTarget[_TArray], dict[str, Any]]:
192
191
  """
193
192
  Args
@@ -204,8 +203,12 @@ class BaseODDataset(
204
203
  boxes, labels, additional_metadata = self._read_annotations(self._targets[index])
205
204
  # Get the image
206
205
  img = self._read_file(self._filepaths[index])
206
+ img_size = img.shape
207
207
  img = self._transform(img)
208
-
208
+ # Adjust labels if necessary
209
+ if self._bboxes_per_size and boxes:
210
+ boxes = boxes * np.array([[img_size[1], img_size[2], img_size[1], img_size[2]]])
211
+ # Create the Object Detection Target
209
212
  target = ObjectDetectionTarget(self._as_array(boxes), self._as_array(labels), self._one_hot_encode(labels))
210
213
 
211
214
  img_metadata = {key: val[index] for key, val in self._datum_metadata.items()}
@@ -7,7 +7,6 @@ from typing import TYPE_CHECKING, Any, Literal, Sequence, TypeVar
7
7
 
8
8
  import numpy as np
9
9
  from numpy.typing import NDArray
10
- from PIL import Image
11
10
 
12
11
  from dataeval.utils.datasets._base import BaseICDataset, DataLocation
13
12
  from dataeval.utils.datasets._mixin import BaseDatasetNumpyMixin
@@ -26,7 +25,7 @@ class CIFAR10(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
26
25
  Parameters
27
26
  ----------
28
27
  root : str or pathlib.Path
29
- Root directory of dataset where the ``mnist`` folder exists.
28
+ Root directory where the data should be downloaded to or the ``cifar10`` folder of the already downloaded data.
30
29
  image_set : "train", "test" or "base", default "train"
31
30
  If "base", returns all of the data to allow the user to create their own splits.
32
31
  transforms : Transform, Sequence[Transform] or None, default None
@@ -93,50 +92,110 @@ class CIFAR10(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
93
92
  verbose,
94
93
  )
95
94
 
95
+ def _load_bin_data(self, data_folder: list[Path]) -> tuple[list[str], list[int], dict[str, Any]]:
96
+ batch_nums = np.zeros(60000, dtype=np.uint8)
97
+ all_labels = np.zeros(60000, dtype=np.uint8)
98
+ all_images = np.zeros((60000, 3, 32, 32), dtype=np.uint8)
99
+ # Process each batch file, skipping .meta and .html files
100
+ for batch_file in data_folder:
101
+ # Get batch parameters
102
+ batch_type = "test" if "test" in batch_file.stem else "train"
103
+ batch_num = 5 if batch_type == "test" else int(batch_file.stem.split("_")[-1]) - 1
104
+
105
+ # Load data
106
+ batch_images, batch_labels = self._unpack_batch_files(batch_file)
107
+
108
+ # Stack data
109
+ num_images = batch_images.shape[0]
110
+ batch_start = batch_num * num_images
111
+ all_images[batch_start : batch_start + num_images] = batch_images
112
+ all_labels[batch_start : batch_start + num_images] = batch_labels
113
+ batch_nums[batch_start : batch_start + num_images] = batch_num
114
+
115
+ # Save data
116
+ self._loaded_data = all_images
117
+ np.savez(self.path / "cifar10", images=self._loaded_data, labels=all_labels, batches=batch_nums)
118
+
119
+ # Select data
120
+ image_list = np.arange(all_labels.shape[0]).astype(str)
121
+ if self.image_set == "train":
122
+ return (
123
+ image_list[np.nonzero(batch_nums != 5)[0]].tolist(),
124
+ all_labels[batch_nums != 5].tolist(),
125
+ {"batch_num": batch_nums[batch_nums != 5].tolist()},
126
+ )
127
+ if self.image_set == "test":
128
+ return (
129
+ image_list[np.nonzero(batch_nums == 5)[0]].tolist(),
130
+ all_labels[batch_nums == 5].tolist(),
131
+ {"batch_num": batch_nums[batch_nums == 5].tolist()},
132
+ )
133
+ return image_list.tolist(), all_labels.tolist(), {"batch_num": batch_nums.tolist()}
134
+
96
135
  def _load_data_inner(self) -> tuple[list[str], list[int], dict[str, Any]]:
97
136
  """Function to load in the file paths for the data and labels and retrieve metadata"""
98
- file_meta = {"batch_num": []}
99
- raw_data = []
100
- labels = []
101
- data_folder = self.path / "cifar-10-batches-bin"
102
- save_folder = self.path / "images"
103
- image_sets: dict[str, list[str]] = {"base": [], "train": [], "test": []}
104
-
105
- # Process each batch file, skipping .meta and .html files
106
- for entry in data_folder.iterdir():
107
- if entry.suffix == ".bin":
108
- batch_data, batch_labels = self._unpack_batch_files(entry)
109
- raw_data.append(batch_data)
110
- group = "train" if "test" not in entry.stem else "test"
111
- name_split = entry.stem.split("_")
112
- batch_num = int(name_split[-1]) - 1 if group == "train" else 5
113
- file_names = [
114
- str(save_folder / f"{i + 10000 * batch_num:05d}_{self.index2label[label]}.png")
115
- for i, label in enumerate(batch_labels)
116
- ]
117
- image_sets["base"].extend(file_names)
118
- image_sets[group].extend(file_names)
119
-
120
- if self.image_set in (group, "base"):
121
- labels.extend(batch_labels)
122
- file_meta["batch_num"].extend([batch_num] * len(labels))
123
-
124
- # Stack and reshape images
125
- images = np.vstack(raw_data).reshape(-1, 3, 32, 32)
126
-
127
- # Save the raw data into images if not already there
128
- if not save_folder.exists():
129
- save_folder.mkdir(exist_ok=True)
130
- for i, file in enumerate(image_sets["base"]):
131
- Image.fromarray(images[i].transpose(1, 2, 0).astype(np.uint8)).save(file)
132
-
133
- return image_sets[self.image_set], labels, file_meta
134
-
135
- def _unpack_batch_files(self, file_path: Path) -> tuple[NDArray[Any], list[int]]:
137
+ data_file = self.path / "cifar10.npz"
138
+ if not data_file.exists():
139
+ data_folder = sorted((self.path / "cifar-10-batches-bin").glob("*.bin"))
140
+ if not data_folder:
141
+ raise FileNotFoundError
142
+ return self._load_bin_data(data_folder)
143
+
144
+ # Load data
145
+ data = np.load(data_file)
146
+ self._loaded_data = data["images"]
147
+ all_labels = data["labels"]
148
+ batch_nums = data["batches"]
149
+
150
+ # Select data
151
+ image_list = np.arange(all_labels.shape[0]).astype(str)
152
+ if self.image_set == "train":
153
+ return (
154
+ image_list[np.nonzero(batch_nums != 5)[0]].tolist(),
155
+ all_labels[batch_nums != 5].tolist(),
156
+ {"batch_num": batch_nums[batch_nums != 5].tolist()},
157
+ )
158
+ if self.image_set == "test":
159
+ return (
160
+ image_list[np.nonzero(batch_nums == 5)[0]].tolist(),
161
+ all_labels[batch_nums == 5].tolist(),
162
+ {"batch_num": batch_nums[batch_nums == 5].tolist()},
163
+ )
164
+ return image_list.tolist(), all_labels.tolist(), {"batch_num": batch_nums.tolist()}
165
+
166
+ def _unpack_batch_files(self, file_path: Path) -> tuple[NDArray[np.uint8], NDArray[np.uint8]]:
136
167
  # Load pickle data with latin1 encoding
137
168
  with file_path.open("rb") as f:
138
- buffer = np.frombuffer(f.read(), "B")
139
- labels = buffer[::3073]
140
- pixels = np.delete(buffer, np.arange(0, buffer.size, 3073))
141
- images = pixels.reshape(-1, 3072)
142
- return images, labels.tolist()
169
+ buffer = np.frombuffer(f.read(), dtype=np.uint8)
170
+ # Each entry is 1 byte for label + 3072 bytes for image (3*32*32)
171
+ entry_size = 1 + 3072
172
+ num_entries = buffer.size // entry_size
173
+ # Extract labels (first byte of each entry)
174
+ labels = buffer[::entry_size]
175
+
176
+ # Extract image data and reshape to (N, 3, 32, 32)
177
+ images = np.zeros((num_entries, 3, 32, 32), dtype=np.uint8)
178
+ for i in range(num_entries):
179
+ # Skip the label byte and get image data for this entry
180
+ start_idx = i * entry_size + 1 # +1 to skip label
181
+ img_flat = buffer[start_idx : start_idx + 3072]
182
+
183
+ # The CIFAR format stores channels in blocks (all R, then all G, then all B)
184
+ # Each channel block is 1024 bytes (32x32)
185
+ red_channel = img_flat[0:1024].reshape(32, 32)
186
+ green_channel = img_flat[1024:2048].reshape(32, 32)
187
+ blue_channel = img_flat[2048:3072].reshape(32, 32)
188
+
189
+ # Stack the channels in the proper C×H×W format
190
+ images[i, 0] = red_channel # Red channel
191
+ images[i, 1] = green_channel # Green channel
192
+ images[i, 2] = blue_channel # Blue channel
193
+ return images, labels
194
+
195
+ def _read_file(self, path: str) -> NDArray[Any]:
196
+ """
197
+ Function to grab the correct image from the loaded data.
198
+ Overwrite of the base `_read_file` because data is an all or nothing load.
199
+ """
200
+ index = int(path)
201
+ return self._loaded_data[index]
@@ -3,7 +3,6 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  import hashlib
6
- import shutil
7
6
  import tarfile
8
7
  import zipfile
9
8
  from pathlib import Path
@@ -15,7 +14,12 @@ ARCHIVE_ENDINGS = [".zip", ".tar", ".tgz"]
15
14
  COMPRESS_ENDINGS = [".gz", ".bz2"]
16
15
 
17
16
 
18
- def _validate_file(fpath, file_md5, md5: bool = False, chunk_size=65535) -> bool:
17
+ def _print(text: str, verbose: bool) -> None:
18
+ if verbose:
19
+ print(text)
20
+
21
+
22
+ def _validate_file(fpath: Path | str, file_md5: str, md5: bool = False, chunk_size: int = 65535) -> bool:
19
23
  hasher = hashlib.md5(usedforsecurity=False) if md5 else hashlib.sha256()
20
24
  with open(fpath, "rb") as fpath_file:
21
25
  while chunk := fpath_file.read(chunk_size):
@@ -23,7 +27,7 @@ def _validate_file(fpath, file_md5, md5: bool = False, chunk_size=65535) -> bool
23
27
  return hasher.hexdigest() == file_md5
24
28
 
25
29
 
26
- def _download_dataset(url: str, file_path: Path, timeout: int = 60) -> None:
30
+ def _download_dataset(url: str, file_path: Path, timeout: int = 60, verbose: bool = False) -> None:
27
31
  """Download a single resource from its URL to the `data_folder`."""
28
32
  error_msg = "URL fetch failure on {}: {} -- {}"
29
33
  try:
@@ -36,7 +40,7 @@ def _download_dataset(url: str, file_path: Path, timeout: int = 60) -> None:
36
40
 
37
41
  total_size = int(response.headers.get("content-length", 0))
38
42
  block_size = 8192 # 8 KB
39
- progress_bar = tqdm(total=total_size, unit="iB", unit_scale=True)
43
+ progress_bar = tqdm(total=total_size, unit="iB", unit_scale=True, disable=not verbose)
40
44
 
41
45
  with open(file_path, "wb") as f:
42
46
  for chunk in response.iter_content(block_size):
@@ -49,7 +53,7 @@ def _extract_zip_archive(file_path: Path, extract_to: Path) -> None:
49
53
  """Extracts the zip file to the given directory."""
50
54
  try:
51
55
  with zipfile.ZipFile(file_path, "r") as zip_ref:
52
- zip_ref.extractall(extract_to)
56
+ zip_ref.extractall(extract_to) # noqa: S202
53
57
  file_path.unlink()
54
58
  except zipfile.BadZipFile:
55
59
  raise FileNotFoundError(f"{file_path.name} is not a valid zip file, skipping extraction.")
@@ -59,36 +63,15 @@ def _extract_tar_archive(file_path: Path, extract_to: Path) -> None:
59
63
  """Extracts a tar file (or compressed tar) to the specified directory."""
60
64
  try:
61
65
  with tarfile.open(file_path, "r:*") as tar_ref:
62
- tar_ref.extractall(extract_to)
66
+ tar_ref.extractall(extract_to) # noqa: S202
63
67
  file_path.unlink()
64
68
  except tarfile.TarError:
65
69
  raise FileNotFoundError(f"{file_path.name} is not a valid tar file, skipping extraction.")
66
70
 
67
71
 
68
- def _flatten_extraction(base_directory: Path, verbose: bool = False) -> None:
69
- """
70
- If the extracted folder contains only directories (and no files),
71
- move all its subfolders to the dataset_dir and remove the now-empty folder.
72
- """
73
- for child in base_directory.iterdir():
74
- if child.is_dir():
75
- inner_list = list(child.iterdir())
76
- if all(subchild.is_dir() for subchild in inner_list):
77
- for subchild in child.iterdir():
78
- if verbose:
79
- print(f"Moving {subchild.stem} to {base_directory}")
80
- shutil.move(subchild, base_directory)
81
-
82
- if verbose:
83
- print(f"Removing empty folder {child.stem}")
84
- child.rmdir()
85
-
86
- # Checking for additional placeholder folders
87
- if len(inner_list) == 1:
88
- _flatten_extraction(base_directory, verbose)
89
-
90
-
91
- def _archive_extraction(file_ext, file_path, directory, compression: bool = False, verbose: bool = False):
72
+ def _extract_archive(
73
+ file_ext: str, file_path: Path, directory: Path, compression: bool = False, verbose: bool = False
74
+ ) -> None:
92
75
  """
93
76
  Single function to extract and then flatten if necessary.
94
77
  Recursively extracts nested zip files as well.
@@ -102,14 +85,9 @@ def _archive_extraction(file_ext, file_path, directory, compression: bool = Fals
102
85
  # Does NOT extract in place - extracts everything to directory
103
86
  for child in directory.iterdir():
104
87
  if child.suffix == ".zip":
105
- if verbose:
106
- print(f"Extracting nested zip: {child} to {directory}")
88
+ _print(f"Extracting nested zip: {child} to {directory}", verbose)
107
89
  _extract_zip_archive(child, directory)
108
90
 
109
- # Determine if there are nested folders and remove them
110
- # Helps ensure there that data is at most one folder below main directory
111
- _flatten_extraction(directory, verbose)
112
-
113
91
 
114
92
  def _ensure_exists(
115
93
  url: str,
@@ -137,18 +115,16 @@ def _ensure_exists(
137
115
 
138
116
  # Download file if it doesn't exist.
139
117
  if not check_path.exists() and download:
140
- if verbose:
141
- print(f"Downloading {filename} from {url}")
142
- _download_dataset(url, check_path)
118
+ _print(f"Downloading {filename} from {url}", verbose)
119
+ _download_dataset(url, check_path, verbose=verbose)
143
120
 
144
121
  if not _validate_file(check_path, checksum, md5):
145
122
  raise Exception("File checksum mismatch. Remove current file and retry download.")
146
123
 
147
124
  # If the file is a zip, tar or tgz extract it into the designated folder.
148
125
  if file_ext in ARCHIVE_ENDINGS:
149
- if verbose:
150
- print(f"Extracting {filename}...")
151
- _archive_extraction(file_ext, check_path, directory, compression, verbose)
126
+ _print(f"Extracting {filename}...", verbose)
127
+ _extract_archive(file_ext, check_path, directory, compression, verbose)
152
128
 
153
129
  elif not check_path.exists() and not download:
154
130
  raise FileNotFoundError(
@@ -159,10 +135,8 @@ def _ensure_exists(
159
135
  else:
160
136
  if not _validate_file(check_path, checksum, md5):
161
137
  raise Exception("File checksum mismatch. Remove current file and retry download.")
162
- if verbose:
163
- print(f"{filename} already exists, skipping download.")
138
+ _print(f"{filename} already exists, skipping download.", verbose)
164
139
 
165
140
  if file_ext in ARCHIVE_ENDINGS:
166
- if verbose:
167
- print(f"Extracting {filename}...")
168
- _archive_extraction(file_ext, check_path, directory, compression, verbose)
141
+ _print(f"Extracting {filename}...", verbose)
142
+ _extract_archive(file_ext, check_path, directory, compression, verbose)
@@ -38,7 +38,7 @@ class MILCO(BaseODDataset[NDArray[Any]], BaseDatasetNumpyMixin):
38
38
  Parameters
39
39
  ----------
40
40
  root : str or pathlib.Path
41
- Root directory of dataset where the ``milco`` folder exists.
41
+ Root directory where the data should be downloaded to or the ``milco`` folder of the already downloaded data.
42
42
  image_set: "train", "operational", or "base", default "train"
43
43
  If "train", then the images from 2015, 2017 and 2021 are selected,
44
44
  resulting in 315 MILCO objects and 177 NOMBO objects.
@@ -128,6 +128,7 @@ class MILCO(BaseODDataset[NDArray[Any]], BaseDatasetNumpyMixin):
128
128
  download,
129
129
  verbose,
130
130
  )
131
+ self._bboxes_per_size = True
131
132
 
132
133
  def _load_data(self) -> tuple[list[str], list[str], dict[str, list[Any]]]:
133
134
  filepaths: list[str] = []
@@ -160,15 +161,17 @@ class MILCO(BaseODDataset[NDArray[Any]], BaseDatasetNumpyMixin):
160
161
 
161
162
  def _load_data_inner(self) -> tuple[list[str], list[str], dict[str, Any]]:
162
163
  file_data = {"year": [], "image_id": [], "data_path": [], "label_path": []}
163
- data_folder = self.path / self._resource.filename[:-4]
164
- for entry in data_folder.iterdir():
165
- if entry.is_file() and entry.suffix == ".jpg":
166
- # Remove file extension and split by "_"
167
- parts = entry.stem.split("_")
168
- file_data["image_id"].append(parts[0])
169
- file_data["year"].append(parts[1])
170
- file_data["data_path"].append(str(entry))
171
- file_data["label_path"].append(str(entry.parent / entry.stem) + ".txt")
164
+ data_folder = sorted((self.path / self._resource.filename[:-4]).glob("*.jpg"))
165
+ if not data_folder:
166
+ raise FileNotFoundError
167
+
168
+ for entry in data_folder:
169
+ # Remove file extension and split by "_"
170
+ parts = entry.stem.split("_")
171
+ file_data["image_id"].append(parts[0])
172
+ file_data["year"].append(parts[1])
173
+ file_data["data_path"].append(str(entry))
174
+ file_data["label_path"].append(str(entry.parent / entry.stem) + ".txt")
172
175
  data = file_data.pop("data_path")
173
176
  annotations = file_data.pop("label_path")
174
177
 
@@ -180,8 +183,15 @@ class MILCO(BaseODDataset[NDArray[Any]], BaseDatasetNumpyMixin):
180
183
  boxes: list[list[float]] = []
181
184
  with open(annotation) as f:
182
185
  for line in f.readlines():
183
- out = line.strip().split(" ")
186
+ out = line.strip().split()
184
187
  labels.append(int(out[0]))
185
- boxes.append([float(out[1]), float(out[2]), float(out[3]), float(out[4])])
188
+
189
+ xcenter, ycenter, width, height = [float(out[1]), float(out[2]), float(out[3]), float(out[4])]
190
+
191
+ x0 = xcenter - width / 2
192
+ x1 = x0 + width
193
+ y0 = ycenter - height / 2
194
+ y1 = y0 + height
195
+ boxes.append([x0, y0, x1, y1])
186
196
 
187
197
  return boxes, labels, {}
@@ -34,8 +34,7 @@ class BaseDatasetNumpyMixin(BaseDatasetMixin[NDArray[Any]]):
34
34
  return encoded
35
35
 
36
36
  def _read_file(self, path: str) -> NDArray[Any]:
37
- x = np.array(Image.open(path)).transpose(2, 0, 1)
38
- return x
37
+ return np.array(Image.open(path)).transpose(2, 0, 1)
39
38
 
40
39
 
41
40
  class BaseDatasetTorchMixin(BaseDatasetMixin[torch.Tensor]):
@@ -52,5 +51,4 @@ class BaseDatasetTorchMixin(BaseDatasetMixin[torch.Tensor]):
52
51
  return encoded
53
52
 
54
53
  def _read_file(self, path: str) -> torch.Tensor:
55
- x = torch.as_tensor(np.array(Image.open(path)).transpose(2, 0, 1))
56
- return x
54
+ return torch.as_tensor(np.array(Image.open(path)).transpose(2, 0, 1))
@@ -48,7 +48,7 @@ class MNIST(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
48
48
  Parameters
49
49
  ----------
50
50
  root : str or pathlib.Path
51
- Root directory of dataset where the ``mnist`` folder exists.
51
+ Root directory where the data should be downloaded to or the ``minst`` folder of the already downloaded data.
52
52
  image_set : "train", "test" or "base", default "train"
53
53
  If "base", returns all of the data to allow the user to create their own splits.
54
54
  corruption : "identity", "shot_noise", "impulse_noise", "glass_blur", "motion_blur", \
@@ -154,7 +154,7 @@ class MNIST(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
154
154
  def _load_corruption(self) -> tuple[NDArray[Any], NDArray[np.uintp]]:
155
155
  """Function to load in the file paths for the data and labels for the different corrupt data formats"""
156
156
  corruption = self.corruption if self.corruption is not None else "identity"
157
- base_path = self.path / corruption
157
+ base_path = self.path / "mnist_c" / corruption
158
158
  if self.image_set == "base":
159
159
  raw_data = []
160
160
  raw_labels = []
@@ -191,8 +191,7 @@ class MNIST(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
191
191
 
192
192
  def _grab_corruption_data(self, path: Path) -> NDArray[Any]:
193
193
  """Function to load in the data numpy array for the previously chosen corrupt format"""
194
- x = np.load(path, allow_pickle=False)
195
- return x
194
+ return np.load(path, allow_pickle=False)
196
195
 
197
196
  def _read_file(self, path: str) -> NDArray[Any]:
198
197
  """