maite-datasets 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- maite_datasets-0.0.1/.gitignore +13 -0
- maite_datasets-0.0.1/LICENSE +21 -0
- maite_datasets-0.0.1/PKG-INFO +91 -0
- maite_datasets-0.0.1/README.md +65 -0
- maite_datasets-0.0.1/pyproject.toml +115 -0
- maite_datasets-0.0.1/src/maite_datasets/__init__.py +1 -0
- maite_datasets-0.0.1/src/maite_datasets/_base.py +254 -0
- maite_datasets-0.0.1/src/maite_datasets/_fileio.py +174 -0
- maite_datasets-0.0.1/src/maite_datasets/_mixin/__init__.py +0 -0
- maite_datasets-0.0.1/src/maite_datasets/_mixin/_numpy.py +28 -0
- maite_datasets-0.0.1/src/maite_datasets/_mixin/_torch.py +28 -0
- maite_datasets-0.0.1/src/maite_datasets/_protocols.py +224 -0
- maite_datasets-0.0.1/src/maite_datasets/_types.py +54 -0
- maite_datasets-0.0.1/src/maite_datasets/image_classification/__init__.py +11 -0
- maite_datasets-0.0.1/src/maite_datasets/image_classification/_cifar10.py +233 -0
- maite_datasets-0.0.1/src/maite_datasets/image_classification/_mnist.py +215 -0
- maite_datasets-0.0.1/src/maite_datasets/image_classification/_ships.py +150 -0
- maite_datasets-0.0.1/src/maite_datasets/object_detection/__init__.py +20 -0
- maite_datasets-0.0.1/src/maite_datasets/object_detection/_antiuav.py +200 -0
- maite_datasets-0.0.1/src/maite_datasets/object_detection/_milco.py +207 -0
- maite_datasets-0.0.1/src/maite_datasets/object_detection/_seadrone.py +551 -0
- maite_datasets-0.0.1/src/maite_datasets/object_detection/_voc.py +510 -0
- maite_datasets-0.0.1/src/maite_datasets/object_detection/_voc_torch.py +65 -0
- maite_datasets-0.0.1/src/maite_datasets/py.typed +0 -0
@@ -0,0 +1,21 @@
|
|
1
|
+
MIT License
|
2
|
+
|
3
|
+
Copyright (c) 2025 ARiA
|
4
|
+
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
7
|
+
in the Software without restriction, including without limitation the rights
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
10
|
+
furnished to do so, subject to the following conditions:
|
11
|
+
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
13
|
+
copies or substantial portions of the Software.
|
14
|
+
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21
|
+
SOFTWARE.
|
@@ -0,0 +1,91 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: maite-datasets
|
3
|
+
Version: 0.0.1
|
4
|
+
Summary: A collection of Image Classification and Object Detection task datasets conforming to the MAITE protocol.
|
5
|
+
Author-email: Andrew Weng <andrew.weng@ariacoustics.com>, Ryan Wood <ryan.wood@ariacoustics.com>, Shaun Jullens <shaun.jullens@ariacoustics.com>
|
6
|
+
License-Expression: MIT
|
7
|
+
License-File: LICENSE
|
8
|
+
Classifier: Development Status :: 4 - Beta
|
9
|
+
Classifier: Framework :: Pytest
|
10
|
+
Classifier: License :: OSI Approved :: MIT License
|
11
|
+
Classifier: Operating System :: OS Independent
|
12
|
+
Classifier: Programming Language :: Python :: 3 :: Only
|
13
|
+
Classifier: Programming Language :: Python :: 3.9
|
14
|
+
Classifier: Programming Language :: Python :: 3.10
|
15
|
+
Classifier: Programming Language :: Python :: 3.11
|
16
|
+
Classifier: Programming Language :: Python :: 3.12
|
17
|
+
Requires-Python: >=3.9
|
18
|
+
Requires-Dist: defusedxml>=0.7.1
|
19
|
+
Requires-Dist: numpy>=1.24.2
|
20
|
+
Requires-Dist: pillow>=10.3.0
|
21
|
+
Requires-Dist: requests>=2.32.3
|
22
|
+
Requires-Dist: typing-extensions>=4.12
|
23
|
+
Provides-Extra: tqdm
|
24
|
+
Requires-Dist: tqdm>=4.66; extra == 'tqdm'
|
25
|
+
Description-Content-Type: text/markdown
|
26
|
+
|
27
|
+
# MAITE Datasets
|
28
|
+
|
29
|
+
MAITE Datasets are a collection of public datasets wrapped in a [MAITE](https://mit-ll-ai-technology.github.io/maite/) compliant format.
|
30
|
+
|
31
|
+
## Installation
|
32
|
+
|
33
|
+
To install and use `maite-datasets` you can use pip:
|
34
|
+
|
35
|
+
```bash
|
36
|
+
pip install maite-datasets
|
37
|
+
```
|
38
|
+
|
39
|
+
For status bar indicators when downloading, you can include the extra `tqdm` when installing:
|
40
|
+
|
41
|
+
```bash
|
42
|
+
pip install maite-datasets[tqdm]
|
43
|
+
```
|
44
|
+
|
45
|
+
## Available Datasets
|
46
|
+
|
47
|
+
| Task | Dataset | Description |
|
48
|
+
|----------------|------------------|---------------------------------------------------------------------|
|
49
|
+
| Classification | CIFAR10 | [CIFAR10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset. |
|
50
|
+
| Classification | MNIST | A dataset of hand-written digits. |
|
51
|
+
| Classification | Ships | A dataset that focuses on identifying ships from satellite images. |
|
52
|
+
| Detection | AntiUAVDetection | A UAV detection dataset in natural images with varying backgrounds. |
|
53
|
+
| Detection | MILCO | A side-scan sonar dataset focused on mine-like object detection. |
|
54
|
+
| Detection | Seadrone | A UAV dataset focused on open water object detection. |
|
55
|
+
| Detection | VOCDetection | [Pascal VOC](http://host.robots.ox.ac.uk/pascal/VOC/) dataset. |
|
56
|
+
|
57
|
+
## Usage
|
58
|
+
|
59
|
+
Here is an example of how to import MNIST for usage with your workflow.
|
60
|
+
|
61
|
+
```python
|
62
|
+
>>> from maite_datasets.image_classification import MNIST
|
63
|
+
|
64
|
+
>>> mnist = MNIST(root="data", download=True)
|
65
|
+
>>> print(mnist)
|
66
|
+
MNIST Dataset
|
67
|
+
-------------
|
68
|
+
Corruption: None
|
69
|
+
Transforms: []
|
70
|
+
Image_set: train
|
71
|
+
Metadata: {'id': 'MNIST_train', 'index2label': {0: 'zero', 1: 'one', 2: 'two', 3: 'three', 4: 'four', 5: 'five', 6: 'six', 7: 'seven', 8: 'eight', 9: 'nine'}, 'split': 'train'}
|
72
|
+
Path: /home/user/maite-datasets/data/mnist
|
73
|
+
Size: 60000
|
74
|
+
|
75
|
+
>>> print("tuple("+", ".join([str(type(t)) for t in mnist[0]])+")")
|
76
|
+
tuple(<class 'numpy.ndarray'>, <class 'numpy.ndarray'>, <class 'dict'>)
|
77
|
+
```
|
78
|
+
|
79
|
+
## Additional Information
|
80
|
+
|
81
|
+
For more information on the MAITE protocol, check out their [documentation](https://mit-ll-ai-technology.github.io/maite/).
|
82
|
+
|
83
|
+
## Acknowledgement
|
84
|
+
|
85
|
+
### CDAO Funding Acknowledgement
|
86
|
+
|
87
|
+
This material is based upon work supported by the Chief Digital and Artificial
|
88
|
+
Intelligence Office under Contract No. W519TC-23-9-2033. The views and
|
89
|
+
conclusions contained herein are those of the author(s) and should not be
|
90
|
+
interpreted as necessarily representing the official policies or endorsements,
|
91
|
+
either expressed or implied, of the U.S. Government.
|
@@ -0,0 +1,65 @@
|
|
1
|
+
# MAITE Datasets
|
2
|
+
|
3
|
+
MAITE Datasets are a collection of public datasets wrapped in a [MAITE](https://mit-ll-ai-technology.github.io/maite/) compliant format.
|
4
|
+
|
5
|
+
## Installation
|
6
|
+
|
7
|
+
To install and use `maite-datasets` you can use pip:
|
8
|
+
|
9
|
+
```bash
|
10
|
+
pip install maite-datasets
|
11
|
+
```
|
12
|
+
|
13
|
+
For status bar indicators when downloading, you can include the extra `tqdm` when installing:
|
14
|
+
|
15
|
+
```bash
|
16
|
+
pip install maite-datasets[tqdm]
|
17
|
+
```
|
18
|
+
|
19
|
+
## Available Datasets
|
20
|
+
|
21
|
+
| Task | Dataset | Description |
|
22
|
+
|----------------|------------------|---------------------------------------------------------------------|
|
23
|
+
| Classification | CIFAR10 | [CIFAR10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset. |
|
24
|
+
| Classification | MNIST | A dataset of hand-written digits. |
|
25
|
+
| Classification | Ships | A dataset that focuses on identifying ships from satellite images. |
|
26
|
+
| Detection | AntiUAVDetection | A UAV detection dataset in natural images with varying backgrounds. |
|
27
|
+
| Detection | MILCO | A side-scan sonar dataset focused on mine-like object detection. |
|
28
|
+
| Detection | Seadrone | A UAV dataset focused on open water object detection. |
|
29
|
+
| Detection | VOCDetection | [Pascal VOC](http://host.robots.ox.ac.uk/pascal/VOC/) dataset. |
|
30
|
+
|
31
|
+
## Usage
|
32
|
+
|
33
|
+
Here is an example of how to import MNIST for usage with your workflow.
|
34
|
+
|
35
|
+
```python
|
36
|
+
>>> from maite_datasets.image_classification import MNIST
|
37
|
+
|
38
|
+
>>> mnist = MNIST(root="data", download=True)
|
39
|
+
>>> print(mnist)
|
40
|
+
MNIST Dataset
|
41
|
+
-------------
|
42
|
+
Corruption: None
|
43
|
+
Transforms: []
|
44
|
+
Image_set: train
|
45
|
+
Metadata: {'id': 'MNIST_train', 'index2label': {0: 'zero', 1: 'one', 2: 'two', 3: 'three', 4: 'four', 5: 'five', 6: 'six', 7: 'seven', 8: 'eight', 9: 'nine'}, 'split': 'train'}
|
46
|
+
Path: /home/user/maite-datasets/data/mnist
|
47
|
+
Size: 60000
|
48
|
+
|
49
|
+
>>> print("tuple("+", ".join([str(type(t)) for t in mnist[0]])+")")
|
50
|
+
tuple(<class 'numpy.ndarray'>, <class 'numpy.ndarray'>, <class 'dict'>)
|
51
|
+
```
|
52
|
+
|
53
|
+
## Additional Information
|
54
|
+
|
55
|
+
For more information on the MAITE protocol, check out their [documentation](https://mit-ll-ai-technology.github.io/maite/).
|
56
|
+
|
57
|
+
## Acknowledgement
|
58
|
+
|
59
|
+
### CDAO Funding Acknowledgement
|
60
|
+
|
61
|
+
This material is based upon work supported by the Chief Digital and Artificial
|
62
|
+
Intelligence Office under Contract No. W519TC-23-9-2033. The views and
|
63
|
+
conclusions contained herein are those of the author(s) and should not be
|
64
|
+
interpreted as necessarily representing the official policies or endorsements,
|
65
|
+
either expressed or implied, of the U.S. Government.
|
@@ -0,0 +1,115 @@
|
|
1
|
+
[project]
|
2
|
+
name = "maite-datasets"
|
3
|
+
description = "A collection of Image Classification and Object Detection task datasets conforming to the MAITE protocol."
|
4
|
+
readme = "README.md"
|
5
|
+
requires-python = ">=3.9"
|
6
|
+
dynamic = ["version"]
|
7
|
+
dependencies = [
|
8
|
+
"defusedxml>=0.7.1",
|
9
|
+
"numpy>=1.24.2",
|
10
|
+
"pillow>=10.3.0",
|
11
|
+
"requests>=2.32.3",
|
12
|
+
"typing-extensions>=4.12",
|
13
|
+
]
|
14
|
+
license = "MIT"
|
15
|
+
authors = [
|
16
|
+
{ name = "Andrew Weng", email = "andrew.weng@ariacoustics.com" },
|
17
|
+
{ name = "Ryan Wood", email = "ryan.wood@ariacoustics.com" },
|
18
|
+
{ name = "Shaun Jullens", email = "shaun.jullens@ariacoustics.com" },
|
19
|
+
]
|
20
|
+
classifiers = [
|
21
|
+
"Development Status :: 4 - Beta",
|
22
|
+
"Framework :: Pytest",
|
23
|
+
"Operating System :: OS Independent",
|
24
|
+
"License :: OSI Approved :: MIT License",
|
25
|
+
"Programming Language :: Python :: 3 :: Only",
|
26
|
+
"Programming Language :: Python :: 3.9",
|
27
|
+
"Programming Language :: Python :: 3.10",
|
28
|
+
"Programming Language :: Python :: 3.11",
|
29
|
+
"Programming Language :: Python :: 3.12",
|
30
|
+
]
|
31
|
+
|
32
|
+
[project.optional-dependencies]
|
33
|
+
tqdm = [
|
34
|
+
"tqdm>=4.66",
|
35
|
+
]
|
36
|
+
|
37
|
+
[dependency-groups]
|
38
|
+
base = [
|
39
|
+
"nox[uv]>=2025.5.1",
|
40
|
+
"torch>=2.2.0",
|
41
|
+
"uv>=0.7.8",
|
42
|
+
]
|
43
|
+
lint = [
|
44
|
+
"ruff>=0.11",
|
45
|
+
"codespell[toml]>=2.3",
|
46
|
+
]
|
47
|
+
test = [
|
48
|
+
{ include-group = "base" },
|
49
|
+
"pytest>=8.3",
|
50
|
+
"pytest-cov>=6.1",
|
51
|
+
"coverage[toml]>=7.6",
|
52
|
+
]
|
53
|
+
type = [
|
54
|
+
{ include-group = "base" },
|
55
|
+
"pyright[nodejs]>=1.1.400",
|
56
|
+
]
|
57
|
+
dev = [
|
58
|
+
{ include-group = "base" },
|
59
|
+
{ include-group = "lint" },
|
60
|
+
{ include-group = "test" },
|
61
|
+
{ include-group = "type" },
|
62
|
+
]
|
63
|
+
|
64
|
+
[tool.uv.sources]
|
65
|
+
torch = [
|
66
|
+
{ index = "pytorch-cpu" },
|
67
|
+
]
|
68
|
+
|
69
|
+
[[tool.uv.index]]
|
70
|
+
name = "pytorch-cpu"
|
71
|
+
url = "https://download.pytorch.org/whl/cpu"
|
72
|
+
explicit = true
|
73
|
+
|
74
|
+
[tool.pytest.ini_options]
|
75
|
+
testpaths = ["tests"]
|
76
|
+
addopts = [
|
77
|
+
"--pythonwarnings=ignore::DeprecationWarning",
|
78
|
+
]
|
79
|
+
markers = [
|
80
|
+
"required: marks tests for required features",
|
81
|
+
"optional: marks tests for optional features",
|
82
|
+
"year: marks tests that need a specified dataset year",
|
83
|
+
]
|
84
|
+
|
85
|
+
[tool.coverage.run]
|
86
|
+
source = ["src/maite_datasets"]
|
87
|
+
branch = true
|
88
|
+
|
89
|
+
[tool.coverage.report]
|
90
|
+
exclude_also = [
|
91
|
+
"raise NotImplementedError",
|
92
|
+
": \\.\\.\\.",
|
93
|
+
]
|
94
|
+
include = ["*/src/maite_datasets/*"]
|
95
|
+
fail_under = 90
|
96
|
+
|
97
|
+
[tool.codespell]
|
98
|
+
skip = './*env*,./output,uv.lock'
|
99
|
+
|
100
|
+
[tool.hatch.build.targets.wheel]
|
101
|
+
packages = ["src/maite_datasets"]
|
102
|
+
|
103
|
+
[tool.hatch.build.targets.sdist]
|
104
|
+
include = [
|
105
|
+
"/src",
|
106
|
+
"/LICENSE",
|
107
|
+
"/README.md",
|
108
|
+
]
|
109
|
+
|
110
|
+
[tool.hatch.version]
|
111
|
+
source = "vcs"
|
112
|
+
|
113
|
+
[build-system]
|
114
|
+
requires = ["hatchling", "hatch-vcs"]
|
115
|
+
build-backend = "hatchling.build"
|
@@ -0,0 +1 @@
|
|
1
|
+
"""Module for MAITE compliant Computer Vision datasets."""
|
@@ -0,0 +1,254 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
from abc import abstractmethod
|
6
|
+
from pathlib import Path
|
7
|
+
from typing import Any, Generic, Iterator, Literal, NamedTuple, Sequence, TypeVar, cast
|
8
|
+
|
9
|
+
import numpy as np
|
10
|
+
|
11
|
+
from maite_datasets._fileio import _ensure_exists
|
12
|
+
from maite_datasets._protocols import Array, Transform
|
13
|
+
from maite_datasets._types import (
|
14
|
+
AnnotatedDataset,
|
15
|
+
DatasetMetadata,
|
16
|
+
DatumMetadata,
|
17
|
+
ImageClassificationDataset,
|
18
|
+
ObjectDetectionDataset,
|
19
|
+
ObjectDetectionTarget,
|
20
|
+
)
|
21
|
+
|
22
|
+
_TArray = TypeVar("_TArray", bound=Array)
|
23
|
+
_TTarget = TypeVar("_TTarget")
|
24
|
+
_TRawTarget = TypeVar(
|
25
|
+
"_TRawTarget",
|
26
|
+
Sequence[int],
|
27
|
+
Sequence[str],
|
28
|
+
Sequence[tuple[list[int], list[list[float]]]],
|
29
|
+
)
|
30
|
+
_TAnnotation = TypeVar("_TAnnotation", int, str, tuple[list[int], list[list[float]]])
|
31
|
+
|
32
|
+
|
33
|
+
def _to_datum_metadata(index: int, metadata: dict[str, Any]) -> DatumMetadata:
|
34
|
+
_id = metadata.pop("id", index)
|
35
|
+
return DatumMetadata(id=_id, **metadata)
|
36
|
+
|
37
|
+
|
38
|
+
class DataLocation(NamedTuple):
|
39
|
+
url: str
|
40
|
+
filename: str
|
41
|
+
md5: bool
|
42
|
+
checksum: str
|
43
|
+
|
44
|
+
|
45
|
+
class BaseDatasetMixin(Generic[_TArray]):
|
46
|
+
index2label: dict[int, str]
|
47
|
+
|
48
|
+
def _as_array(self, raw: list[Any]) -> _TArray: ...
|
49
|
+
def _one_hot_encode(self, value: int | list[int]) -> _TArray: ...
|
50
|
+
def _read_file(self, path: str) -> _TArray: ...
|
51
|
+
|
52
|
+
|
53
|
+
class BaseDataset(
|
54
|
+
AnnotatedDataset[tuple[_TArray, _TTarget, DatumMetadata]],
|
55
|
+
Generic[_TArray, _TTarget, _TRawTarget, _TAnnotation],
|
56
|
+
):
|
57
|
+
"""
|
58
|
+
Base class for internet downloaded datasets.
|
59
|
+
"""
|
60
|
+
|
61
|
+
# Each subclass should override the attributes below.
|
62
|
+
# Each resource tuple must contain:
|
63
|
+
# 'url': str, the URL to download from
|
64
|
+
# 'filename': str, the name of the file once downloaded
|
65
|
+
# 'md5': boolean, True if it's the checksum value is md5
|
66
|
+
# 'checksum': str, the associated checksum for the downloaded file
|
67
|
+
_resources: list[DataLocation]
|
68
|
+
_resource_index: int = 0
|
69
|
+
index2label: dict[int, str]
|
70
|
+
|
71
|
+
def __init__(
|
72
|
+
self,
|
73
|
+
root: str | Path,
|
74
|
+
image_set: Literal["train", "val", "test", "operational", "base"] = "train",
|
75
|
+
transforms: Transform[_TArray] | Sequence[Transform[_TArray]] | None = None,
|
76
|
+
download: bool = False,
|
77
|
+
verbose: bool = False,
|
78
|
+
) -> None:
|
79
|
+
self._root: Path = (
|
80
|
+
root.absolute() if isinstance(root, Path) else Path(root).absolute()
|
81
|
+
)
|
82
|
+
transforms = transforms if transforms is not None else []
|
83
|
+
self.transforms: Sequence[Transform[_TArray]] = (
|
84
|
+
transforms if isinstance(transforms, Sequence) else [transforms]
|
85
|
+
)
|
86
|
+
self.image_set = image_set
|
87
|
+
self._verbose = verbose
|
88
|
+
|
89
|
+
# Internal Attributes
|
90
|
+
self._download = download
|
91
|
+
self._filepaths: list[str]
|
92
|
+
self._targets: _TRawTarget
|
93
|
+
self._datum_metadata: dict[str, list[Any]]
|
94
|
+
self._resource: DataLocation = self._resources[self._resource_index]
|
95
|
+
self._label2index = {v: k for k, v in self.index2label.items()}
|
96
|
+
|
97
|
+
self.metadata: DatasetMetadata = DatasetMetadata(
|
98
|
+
id=self._unique_id(),
|
99
|
+
index2label=self.index2label,
|
100
|
+
split=self.image_set,
|
101
|
+
)
|
102
|
+
|
103
|
+
# Load the data
|
104
|
+
self.path: Path = self._get_dataset_dir()
|
105
|
+
self._filepaths, self._targets, self._datum_metadata = self._load_data()
|
106
|
+
self.size: int = len(self._filepaths)
|
107
|
+
|
108
|
+
def __str__(self) -> str:
|
109
|
+
nt = "\n "
|
110
|
+
title = f"{self.__class__.__name__} Dataset"
|
111
|
+
sep = "-" * len(title)
|
112
|
+
attrs = [
|
113
|
+
f"{k.capitalize()}: {v}"
|
114
|
+
for k, v in self.__dict__.items()
|
115
|
+
if not k.startswith("_")
|
116
|
+
]
|
117
|
+
return f"{title}\n{sep}{nt}{nt.join(attrs)}"
|
118
|
+
|
119
|
+
@property
|
120
|
+
def label2index(self) -> dict[str, int]:
|
121
|
+
return self._label2index
|
122
|
+
|
123
|
+
def __iter__(self) -> Iterator[tuple[_TArray, _TTarget, DatumMetadata]]:
|
124
|
+
for i in range(len(self)):
|
125
|
+
yield self[i]
|
126
|
+
|
127
|
+
def _get_dataset_dir(self) -> Path:
|
128
|
+
# Create a designated folder for this dataset (named after the class)
|
129
|
+
if self._root.stem.lower() == self.__class__.__name__.lower():
|
130
|
+
dataset_dir: Path = self._root
|
131
|
+
else:
|
132
|
+
dataset_dir: Path = self._root / self.__class__.__name__.lower()
|
133
|
+
if not dataset_dir.exists():
|
134
|
+
dataset_dir.mkdir(parents=True, exist_ok=True)
|
135
|
+
return dataset_dir
|
136
|
+
|
137
|
+
def _unique_id(self) -> str:
|
138
|
+
return f"{self.__class__.__name__}_{self.image_set}"
|
139
|
+
|
140
|
+
def _load_data(self) -> tuple[list[str], _TRawTarget, dict[str, Any]]:
|
141
|
+
"""
|
142
|
+
Function to determine if data can be accessed or if it needs to be downloaded and/or extracted.
|
143
|
+
"""
|
144
|
+
if self._verbose:
|
145
|
+
print(f"Determining if {self._resource.filename} needs to be downloaded.")
|
146
|
+
|
147
|
+
try:
|
148
|
+
result = self._load_data_inner()
|
149
|
+
if self._verbose:
|
150
|
+
print("No download needed, loaded data successfully.")
|
151
|
+
except FileNotFoundError:
|
152
|
+
_ensure_exists(
|
153
|
+
*self._resource, self.path, self._root, self._download, self._verbose
|
154
|
+
)
|
155
|
+
result = self._load_data_inner()
|
156
|
+
return result
|
157
|
+
|
158
|
+
@abstractmethod
|
159
|
+
def _load_data_inner(self) -> tuple[list[str], _TRawTarget, dict[str, Any]]: ...
|
160
|
+
|
161
|
+
def _transform(self, image: _TArray) -> _TArray:
|
162
|
+
"""Function to transform the image prior to returning based on parameters passed in."""
|
163
|
+
for transform in self.transforms:
|
164
|
+
image = transform(image)
|
165
|
+
return image
|
166
|
+
|
167
|
+
def __len__(self) -> int:
|
168
|
+
return self.size
|
169
|
+
|
170
|
+
|
171
|
+
class BaseICDataset(
|
172
|
+
BaseDataset[_TArray, _TArray, list[int], int],
|
173
|
+
BaseDatasetMixin[_TArray],
|
174
|
+
ImageClassificationDataset[_TArray],
|
175
|
+
):
|
176
|
+
"""
|
177
|
+
Base class for image classification datasets.
|
178
|
+
"""
|
179
|
+
|
180
|
+
def __getitem__(self, index: int) -> tuple[_TArray, _TArray, DatumMetadata]:
|
181
|
+
"""
|
182
|
+
Args
|
183
|
+
----
|
184
|
+
index : int
|
185
|
+
Value of the desired data point
|
186
|
+
|
187
|
+
Returns
|
188
|
+
-------
|
189
|
+
tuple[TArray, TArray, DatumMetadata]
|
190
|
+
Image, target, datum_metadata - where target is one-hot encoding of class.
|
191
|
+
"""
|
192
|
+
# Get the associated label and score
|
193
|
+
label = self._targets[index]
|
194
|
+
score = self._one_hot_encode(label)
|
195
|
+
# Get the image
|
196
|
+
img = self._read_file(self._filepaths[index])
|
197
|
+
img = self._transform(img)
|
198
|
+
|
199
|
+
img_metadata = {key: val[index] for key, val in self._datum_metadata.items()}
|
200
|
+
|
201
|
+
return img, score, _to_datum_metadata(index, img_metadata)
|
202
|
+
|
203
|
+
|
204
|
+
class BaseODDataset(
|
205
|
+
BaseDataset[_TArray, ObjectDetectionTarget[_TArray], _TRawTarget, _TAnnotation],
|
206
|
+
BaseDatasetMixin[_TArray],
|
207
|
+
ObjectDetectionDataset[_TArray],
|
208
|
+
):
|
209
|
+
"""
|
210
|
+
Base class for object detection datasets.
|
211
|
+
"""
|
212
|
+
|
213
|
+
_bboxes_per_size: bool = False
|
214
|
+
|
215
|
+
def __getitem__(
|
216
|
+
self, index: int
|
217
|
+
) -> tuple[_TArray, ObjectDetectionTarget[_TArray], DatumMetadata]:
|
218
|
+
"""
|
219
|
+
Args
|
220
|
+
----
|
221
|
+
index : int
|
222
|
+
Value of the desired data point
|
223
|
+
|
224
|
+
Returns
|
225
|
+
-------
|
226
|
+
tuple[TArray, ObjectDetectionTarget[TArray], DatumMetadata]
|
227
|
+
Image, target, datum_metadata - target.boxes returns boxes in x0, y0, x1, y1 format
|
228
|
+
"""
|
229
|
+
# Grab the bounding boxes and labels from the annotations
|
230
|
+
annotation = cast(_TAnnotation, self._targets[index])
|
231
|
+
boxes, labels, additional_metadata = self._read_annotations(annotation)
|
232
|
+
# Get the image
|
233
|
+
img = self._read_file(self._filepaths[index])
|
234
|
+
img_size = img.shape
|
235
|
+
img = self._transform(img)
|
236
|
+
# Adjust labels if necessary
|
237
|
+
if self._bboxes_per_size and boxes:
|
238
|
+
boxes = boxes * np.array(
|
239
|
+
[[img_size[1], img_size[2], img_size[1], img_size[2]]]
|
240
|
+
)
|
241
|
+
# Create the Object Detection Target
|
242
|
+
target = ObjectDetectionTarget(
|
243
|
+
self._as_array(boxes), self._as_array(labels), self._one_hot_encode(labels)
|
244
|
+
)
|
245
|
+
|
246
|
+
img_metadata = {key: val[index] for key, val in self._datum_metadata.items()}
|
247
|
+
img_metadata = img_metadata | additional_metadata
|
248
|
+
|
249
|
+
return img, target, _to_datum_metadata(index, img_metadata)
|
250
|
+
|
251
|
+
@abstractmethod
|
252
|
+
def _read_annotations(
|
253
|
+
self, annotation: _TAnnotation
|
254
|
+
) -> tuple[list[list[float]], list[int], dict[str, Any]]: ...
|