maite-datasets 0.0.6__tar.gz → 0.0.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {maite_datasets-0.0.6 → maite_datasets-0.0.7}/PKG-INFO +40 -3
- {maite_datasets-0.0.6 → maite_datasets-0.0.7}/README.md +37 -0
- {maite_datasets-0.0.6 → maite_datasets-0.0.7}/pyproject.toml +11 -4
- maite_datasets-0.0.7/src/maite_datasets/adapters/__init__.py +3 -0
- maite_datasets-0.0.7/src/maite_datasets/adapters/_huggingface.py +391 -0
- maite_datasets-0.0.7/src/maite_datasets/protocols.py +94 -0
- {maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/wrappers/_torch.py +5 -7
- maite_datasets-0.0.6/src/maite_datasets/protocols.py +0 -23
- {maite_datasets-0.0.6 → maite_datasets-0.0.7}/.gitignore +0 -0
- {maite_datasets-0.0.6 → maite_datasets-0.0.7}/LICENSE +0 -0
- {maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/__init__.py +0 -0
- {maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/_base.py +0 -0
- {maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/_builder.py +0 -0
- {maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/_collate.py +0 -0
- {maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/_fileio.py +0 -0
- {maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/_reader.py +0 -0
- {maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/_validate.py +0 -0
- {maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/image_classification/__init__.py +0 -0
- {maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/image_classification/_cifar10.py +0 -0
- {maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/image_classification/_mnist.py +0 -0
- {maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/image_classification/_ships.py +0 -0
- {maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/object_detection/__init__.py +0 -0
- {maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/object_detection/_antiuav.py +0 -0
- {maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/object_detection/_coco.py +0 -0
- {maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/object_detection/_milco.py +0 -0
- {maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/object_detection/_seadrone.py +0 -0
- {maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/object_detection/_voc.py +0 -0
- {maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/object_detection/_yolo.py +0 -0
- {maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/py.typed +0 -0
- {maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/wrappers/__init__.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: maite-datasets
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.7
|
4
4
|
Summary: A collection of Image Classification and Object Detection task datasets conforming to the MAITE protocol.
|
5
5
|
Author-email: Andrew Weng <andrew.weng@ariacoustics.com>, Ryan Wood <ryan.wood@ariacoustics.com>, Shaun Jullens <shaun.jullens@ariacoustics.com>
|
6
6
|
License-Expression: MIT
|
@@ -10,11 +10,11 @@ Classifier: Framework :: Pytest
|
|
10
10
|
Classifier: License :: OSI Approved :: MIT License
|
11
11
|
Classifier: Operating System :: OS Independent
|
12
12
|
Classifier: Programming Language :: Python :: 3 :: Only
|
13
|
-
Classifier: Programming Language :: Python :: 3.9
|
14
13
|
Classifier: Programming Language :: Python :: 3.10
|
15
14
|
Classifier: Programming Language :: Python :: 3.11
|
16
15
|
Classifier: Programming Language :: Python :: 3.12
|
17
|
-
|
16
|
+
Classifier: Programming Language :: Python :: 3.13
|
17
|
+
Requires-Python: >=3.10
|
18
18
|
Requires-Dist: defusedxml>=0.7.1
|
19
19
|
Requires-Dist: maite<0.9,>=0.7
|
20
20
|
Requires-Dist: numpy>=1.24.2
|
@@ -81,6 +81,8 @@ tuple(<class 'numpy.ndarray'>, <class 'numpy.ndarray'>, <class 'dict'>)
|
|
81
81
|
|
82
82
|
Wrappers provide a way to convert datasets to allow usage of tools within specific backend frameworks.
|
83
83
|
|
84
|
+
### Torchvision
|
85
|
+
|
84
86
|
`TorchvisionWrapper` is a convenience class that wraps any of the datasets and provides the capability to apply
|
85
87
|
`torchvision` transforms to the dataset.
|
86
88
|
|
@@ -129,6 +131,41 @@ type=Image, shape=torch.Size([3, 224, 224])
|
|
129
131
|
tensor([16.4062, 47.4688, 28.4375, 54.0312], dtype=torch.float64)
|
130
132
|
```
|
131
133
|
|
134
|
+
## Dataset Adapters
|
135
|
+
|
136
|
+
Adapters provide a way to read in datasets from other popular formats.
|
137
|
+
|
138
|
+
### Huggingface
|
139
|
+
|
140
|
+
Hugging face datasets can be adapted into MAITE compliant format using the `from_huggingface` adapter.
|
141
|
+
|
142
|
+
```python
|
143
|
+
>>> from datasets import load_dataset
|
144
|
+
>>> from maite_datasets.adapters import from_huggingface
|
145
|
+
|
146
|
+
>>> cppe5 = load_dataset("cppe-5")
|
147
|
+
>>> m_cppe5 = from_huggingface(cppe5["train"])
|
148
|
+
>>> print(m_cppe5)
|
149
|
+
HFObjectDetection Dataset
|
150
|
+
-------------------------
|
151
|
+
Source: Dataset({
|
152
|
+
features: ['image_id', 'image', 'width', 'height', 'objects'],
|
153
|
+
num_rows: 1000
|
154
|
+
})
|
155
|
+
Metadata: {'id': 'cppe-5', 'index2label': {0: 'Coverall', 1: 'Face_Shield', 2: 'Gloves', 3: 'Goggles', 4: 'Mask'}, 'description': '', 'citation': '', 'homepage': '', 'license': '', 'features': {'image_id': Value('int64'), 'image': Image(mode=None, decode=True), 'width': Value('int32'), 'height': Value('int32'), 'objects': {'id': List(Value('int64')), 'area': List(Value('int64')), 'bbox': List(List(Value('float32'), length=4)), 'category': List(ClassLabel(names=['Coverall', 'Face_Shield', 'Gloves', 'Goggles', 'Mask']))}}, 'post_processed': None, 'supervised_keys': None, 'builder_name': 'parquet', 'dataset_name': 'cppe-5', 'config_name': 'default', 'version': 0.0.0, 'splits': {'train': SplitInfo(name='train', num_bytes=240478590, num_examples=1000, shard_lengths=None, dataset_name='cppe-5'), 'test': SplitInfo(name='test', num_bytes=4172706, num_examples=29, shard_lengths=None, dataset_name='cppe-5')}, 'download_checksums': {'hf://datasets/cppe-5@66f6a5efd474e35bd7cb94bf15dea27d4c6ad3f8/data/train-00000-of-00001.parquet': {'num_bytes': 237015519, 'checksum': None}, 'hf://datasets/cppe-5@66f6a5efd474e35bd7cb94bf15dea27d4c6ad3f8/data/test-00000-of-00001.parquet': {'num_bytes': 4137134, 'checksum': None}}, 'download_size': 241152653, 'post_processing_size': None, 'dataset_size': 244651296, 'size_in_bytes': 485803949}
|
156
|
+
|
157
|
+
>>> image = m_cppe5[0][0]
|
158
|
+
>>> print(f"type={image.__class__.__name__}, shape={image.shape}")
|
159
|
+
type=ndarray, shape=(3, 663, 943)
|
160
|
+
|
161
|
+
>>> target = m_cppe5[0][1]
|
162
|
+
>>> print(f"box={target.boxes[0]}, label={target.labels[0]}")
|
163
|
+
box=[302.0, 109.0, 73.0, 52.0], label=4
|
164
|
+
|
165
|
+
>>> print(m_cppe5[0][2])
|
166
|
+
{'id': [114, 115, 116, 117], 'image_id': 15, 'width': 943, 'height': 663, 'area': [3796, 1596, 152768, 81002]}
|
167
|
+
```
|
168
|
+
|
132
169
|
## Additional Information
|
133
170
|
|
134
171
|
For more information on the MAITE protocol, check out their [documentation](https://mit-ll-ai-technology.github.io/maite/).
|
@@ -54,6 +54,8 @@ tuple(<class 'numpy.ndarray'>, <class 'numpy.ndarray'>, <class 'dict'>)
|
|
54
54
|
|
55
55
|
Wrappers provide a way to convert datasets to allow usage of tools within specific backend frameworks.
|
56
56
|
|
57
|
+
### Torchvision
|
58
|
+
|
57
59
|
`TorchvisionWrapper` is a convenience class that wraps any of the datasets and provides the capability to apply
|
58
60
|
`torchvision` transforms to the dataset.
|
59
61
|
|
@@ -102,6 +104,41 @@ type=Image, shape=torch.Size([3, 224, 224])
|
|
102
104
|
tensor([16.4062, 47.4688, 28.4375, 54.0312], dtype=torch.float64)
|
103
105
|
```
|
104
106
|
|
107
|
+
## Dataset Adapters
|
108
|
+
|
109
|
+
Adapters provide a way to read in datasets from other popular formats.
|
110
|
+
|
111
|
+
### Huggingface
|
112
|
+
|
113
|
+
Hugging face datasets can be adapted into MAITE compliant format using the `from_huggingface` adapter.
|
114
|
+
|
115
|
+
```python
|
116
|
+
>>> from datasets import load_dataset
|
117
|
+
>>> from maite_datasets.adapters import from_huggingface
|
118
|
+
|
119
|
+
>>> cppe5 = load_dataset("cppe-5")
|
120
|
+
>>> m_cppe5 = from_huggingface(cppe5["train"])
|
121
|
+
>>> print(m_cppe5)
|
122
|
+
HFObjectDetection Dataset
|
123
|
+
-------------------------
|
124
|
+
Source: Dataset({
|
125
|
+
features: ['image_id', 'image', 'width', 'height', 'objects'],
|
126
|
+
num_rows: 1000
|
127
|
+
})
|
128
|
+
Metadata: {'id': 'cppe-5', 'index2label': {0: 'Coverall', 1: 'Face_Shield', 2: 'Gloves', 3: 'Goggles', 4: 'Mask'}, 'description': '', 'citation': '', 'homepage': '', 'license': '', 'features': {'image_id': Value('int64'), 'image': Image(mode=None, decode=True), 'width': Value('int32'), 'height': Value('int32'), 'objects': {'id': List(Value('int64')), 'area': List(Value('int64')), 'bbox': List(List(Value('float32'), length=4)), 'category': List(ClassLabel(names=['Coverall', 'Face_Shield', 'Gloves', 'Goggles', 'Mask']))}}, 'post_processed': None, 'supervised_keys': None, 'builder_name': 'parquet', 'dataset_name': 'cppe-5', 'config_name': 'default', 'version': 0.0.0, 'splits': {'train': SplitInfo(name='train', num_bytes=240478590, num_examples=1000, shard_lengths=None, dataset_name='cppe-5'), 'test': SplitInfo(name='test', num_bytes=4172706, num_examples=29, shard_lengths=None, dataset_name='cppe-5')}, 'download_checksums': {'hf://datasets/cppe-5@66f6a5efd474e35bd7cb94bf15dea27d4c6ad3f8/data/train-00000-of-00001.parquet': {'num_bytes': 237015519, 'checksum': None}, 'hf://datasets/cppe-5@66f6a5efd474e35bd7cb94bf15dea27d4c6ad3f8/data/test-00000-of-00001.parquet': {'num_bytes': 4137134, 'checksum': None}}, 'download_size': 241152653, 'post_processing_size': None, 'dataset_size': 244651296, 'size_in_bytes': 485803949}
|
129
|
+
|
130
|
+
>>> image = m_cppe5[0][0]
|
131
|
+
>>> print(f"type={image.__class__.__name__}, shape={image.shape}")
|
132
|
+
type=ndarray, shape=(3, 663, 943)
|
133
|
+
|
134
|
+
>>> target = m_cppe5[0][1]
|
135
|
+
>>> print(f"box={target.boxes[0]}, label={target.labels[0]}")
|
136
|
+
box=[302.0, 109.0, 73.0, 52.0], label=4
|
137
|
+
|
138
|
+
>>> print(m_cppe5[0][2])
|
139
|
+
{'id': [114, 115, 116, 117], 'image_id': 15, 'width': 943, 'height': 663, 'area': [3796, 1596, 152768, 81002]}
|
140
|
+
```
|
141
|
+
|
105
142
|
## Additional Information
|
106
143
|
|
107
144
|
For more information on the MAITE protocol, check out their [documentation](https://mit-ll-ai-technology.github.io/maite/).
|
@@ -2,7 +2,7 @@
|
|
2
2
|
name = "maite-datasets"
|
3
3
|
description = "A collection of Image Classification and Object Detection task datasets conforming to the MAITE protocol."
|
4
4
|
readme = "README.md"
|
5
|
-
requires-python = ">=3.
|
5
|
+
requires-python = ">=3.10"
|
6
6
|
dynamic = ["version"]
|
7
7
|
dependencies = [
|
8
8
|
"defusedxml>=0.7.1",
|
@@ -24,10 +24,10 @@ classifiers = [
|
|
24
24
|
"Operating System :: OS Independent",
|
25
25
|
"License :: OSI Approved :: MIT License",
|
26
26
|
"Programming Language :: Python :: 3 :: Only",
|
27
|
-
"Programming Language :: Python :: 3.9",
|
28
27
|
"Programming Language :: Python :: 3.10",
|
29
28
|
"Programming Language :: Python :: 3.11",
|
30
29
|
"Programming Language :: Python :: 3.12",
|
30
|
+
"Programming Language :: Python :: 3.13",
|
31
31
|
]
|
32
32
|
|
33
33
|
[project.optional-dependencies]
|
@@ -37,28 +37,35 @@ tqdm = [
|
|
37
37
|
|
38
38
|
[dependency-groups]
|
39
39
|
base = [
|
40
|
-
"nox
|
40
|
+
"nox>=2025.5.1",
|
41
|
+
"nox-uv>=0.6.2",
|
42
|
+
"uv>=0.8.0",
|
43
|
+
]
|
44
|
+
more = [
|
41
45
|
"torch>=2.2.0",
|
42
46
|
"torchvision>=0.17.0",
|
43
47
|
"tqdm>=4.66",
|
44
|
-
"uv>=0.8.0",
|
45
48
|
]
|
46
49
|
lint = [
|
50
|
+
{ include-group = "base" },
|
47
51
|
"ruff>=0.11",
|
48
52
|
"codespell[toml]>=2.3",
|
49
53
|
]
|
50
54
|
test = [
|
51
55
|
{ include-group = "base" },
|
56
|
+
{ include-group = "more" },
|
52
57
|
"pytest>=8.3",
|
53
58
|
"pytest-cov>=6.1",
|
54
59
|
"coverage[toml]>=7.6",
|
55
60
|
]
|
56
61
|
type = [
|
57
62
|
{ include-group = "base" },
|
63
|
+
{ include-group = "more" },
|
58
64
|
"pyright[nodejs]>=1.1.400",
|
59
65
|
]
|
60
66
|
dev = [
|
61
67
|
{ include-group = "base" },
|
68
|
+
{ include-group = "more" },
|
62
69
|
{ include-group = "lint" },
|
63
70
|
{ include-group = "test" },
|
64
71
|
{ include-group = "type" },
|
@@ -0,0 +1,391 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from collections.abc import Mapping
|
4
|
+
from dataclasses import dataclass
|
5
|
+
from functools import lru_cache
|
6
|
+
from typing import Any, Literal, TypeAlias, overload
|
7
|
+
|
8
|
+
import maite.protocols.image_classification as ic
|
9
|
+
import maite.protocols.object_detection as od
|
10
|
+
import numpy as np
|
11
|
+
from maite.protocols import DatasetMetadata, DatumMetadata
|
12
|
+
|
13
|
+
from maite_datasets._base import BaseDataset, NumpyArray, ObjectDetectionTarget
|
14
|
+
from maite_datasets.protocols import HFArray, HFClassLabel, HFDataset, HFImage, HFList, HFValue
|
15
|
+
from maite_datasets.wrappers._torch import TTarget
|
16
|
+
|
17
|
+
# Constants for image processing
|
18
|
+
MAX_VALID_CHANNELS = 10
|
19
|
+
|
20
|
+
FeatureDict: TypeAlias = Mapping[str, Any]
|
21
|
+
|
22
|
+
|
23
|
+
@dataclass
|
24
|
+
class HFDatasetInfo:
|
25
|
+
image_key: str
|
26
|
+
|
27
|
+
|
28
|
+
@dataclass
|
29
|
+
class HFImageClassificationDatasetInfo(HFDatasetInfo):
|
30
|
+
label_key: str
|
31
|
+
|
32
|
+
|
33
|
+
@dataclass
|
34
|
+
class HFObjectDetectionDatasetInfo(HFDatasetInfo):
|
35
|
+
objects_key: str
|
36
|
+
bbox_key: str
|
37
|
+
label_key: str
|
38
|
+
|
39
|
+
|
40
|
+
class HFBaseDataset(BaseDataset[NumpyArray, TTarget]):
|
41
|
+
"""Base wrapper for Hugging Face datasets, handling common logic."""
|
42
|
+
|
43
|
+
def __init__(self, hf_dataset: HFDataset, image_key: str, known_keys: set[str]) -> None:
|
44
|
+
self.source = hf_dataset
|
45
|
+
self._image_key = image_key
|
46
|
+
|
47
|
+
# Add dataset metadata
|
48
|
+
dataset_info_dict = hf_dataset.info.__dict__
|
49
|
+
if "id" in dataset_info_dict:
|
50
|
+
dataset_info_dict["datasetinfo_id"] = dataset_info_dict.pop("id")
|
51
|
+
self._metadata_id = dataset_info_dict["dataset_name"]
|
52
|
+
self._metadata_dict = dataset_info_dict
|
53
|
+
|
54
|
+
# Pre-validate features and cache metadata keys
|
55
|
+
self._validate_features(hf_dataset.features)
|
56
|
+
self._scalar_meta_keys = self._extract_scalar_meta_keys(hf_dataset.features, known_keys)
|
57
|
+
|
58
|
+
# Cache for image conversions
|
59
|
+
self._image_cache: dict[int, np.ndarray] = {}
|
60
|
+
|
61
|
+
def _validate_features(self, features: FeatureDict) -> None:
|
62
|
+
"""Pre-validate all features during initialization."""
|
63
|
+
if self._image_key not in features:
|
64
|
+
raise ValueError(f"Image key '{self._image_key}' not found in dataset features.")
|
65
|
+
|
66
|
+
if not isinstance(features[self._image_key], (HFImage, HFArray)):
|
67
|
+
raise TypeError(f"Image feature '{self._image_key}' must be HFImage or HFArray.")
|
68
|
+
|
69
|
+
def _extract_scalar_meta_keys(self, features: FeatureDict, known_keys: set[str]) -> list[str]:
|
70
|
+
"""Extract scalar metadata keys during initialization."""
|
71
|
+
return [key for key, feature in features.items() if key not in known_keys and isinstance(feature, HFValue)]
|
72
|
+
|
73
|
+
def __len__(self) -> int:
|
74
|
+
return len(self.source)
|
75
|
+
|
76
|
+
def _get_base_metadata(self, index: int) -> DatumMetadata:
|
77
|
+
"""Extract base metadata for a datum."""
|
78
|
+
item = self.source[index]
|
79
|
+
datum_metadata: DatumMetadata = {"id": index}
|
80
|
+
for key in self._scalar_meta_keys:
|
81
|
+
datum_metadata[key] = item[key]
|
82
|
+
return datum_metadata
|
83
|
+
|
84
|
+
@lru_cache(maxsize=64) # Cache image conversions
|
85
|
+
def _get_image(self, index: int) -> np.ndarray:
|
86
|
+
"""Get and process image with caching and optimized conversions."""
|
87
|
+
# Convert to numpy array only once
|
88
|
+
raw_image = self.source[index][self._image_key]
|
89
|
+
image = np.asarray(raw_image)
|
90
|
+
|
91
|
+
# Handle different image formats efficiently
|
92
|
+
if image.ndim == 2:
|
93
|
+
# Grayscale: HW -> CHW
|
94
|
+
image = image[np.newaxis, ...] # More efficient than expand_dims
|
95
|
+
elif image.ndim == 3:
|
96
|
+
# Check if we need to transpose from HWC to CHW
|
97
|
+
if image.shape[-1] < image.shape[-3] and image.shape[-1] <= MAX_VALID_CHANNELS:
|
98
|
+
# HWC -> CHW using optimized transpose
|
99
|
+
image = np.transpose(image, (2, 0, 1))
|
100
|
+
elif image.shape[0] > MAX_VALID_CHANNELS:
|
101
|
+
raise ValueError(
|
102
|
+
f"Image at index {index} has invalid channel configuration. "
|
103
|
+
f"Expected channels to be less than {MAX_VALID_CHANNELS}, got shape {image.shape}"
|
104
|
+
)
|
105
|
+
else:
|
106
|
+
raise ValueError(
|
107
|
+
f"Image at index {index} has unsupported dimensionality. "
|
108
|
+
f"Expected 2D or 3D, got {image.ndim}D with shape {image.shape}"
|
109
|
+
)
|
110
|
+
|
111
|
+
if image.ndim != 3:
|
112
|
+
raise ValueError(f"Image processing failed for index {index}. Final shape: {image.shape}")
|
113
|
+
|
114
|
+
return image
|
115
|
+
|
116
|
+
|
117
|
+
class HFImageClassificationDataset(HFBaseDataset[NumpyArray], ic.Dataset):
|
118
|
+
"""Wraps a Hugging Face dataset to comply with the ImageClassificationDataset protocol."""
|
119
|
+
|
120
|
+
def __init__(self, hf_dataset: HFDataset, image_key: str, label_key: str) -> None:
|
121
|
+
super().__init__(hf_dataset, image_key, known_keys={image_key, label_key})
|
122
|
+
self._label_key = label_key
|
123
|
+
|
124
|
+
# Pre-validate label feature
|
125
|
+
label_feature = hf_dataset.features[self._label_key]
|
126
|
+
if not isinstance(label_feature, HFClassLabel):
|
127
|
+
raise TypeError(
|
128
|
+
f"Label feature '{self._label_key}' must be a datasets.ClassLabel, got {type(label_feature).__name__}."
|
129
|
+
)
|
130
|
+
|
131
|
+
self._num_classes: int = label_feature.num_classes
|
132
|
+
|
133
|
+
# Pre-compute one-hot identity matrix for efficient encoding
|
134
|
+
self._one_hot_matrix = np.eye(self._num_classes, dtype=np.float32)
|
135
|
+
|
136
|
+
# Enhanced metadata with validation
|
137
|
+
self.metadata: DatasetMetadata = DatasetMetadata(
|
138
|
+
id=self._metadata_id, index2label=dict(enumerate(label_feature.names), **self._metadata_dict)
|
139
|
+
)
|
140
|
+
|
141
|
+
def __getitem__(self, index: int) -> tuple[NumpyArray, NumpyArray, DatumMetadata]:
|
142
|
+
if not 0 <= index < len(self.source):
|
143
|
+
raise IndexError(f"Index {index} out of range for dataset of size {len(self.source)}")
|
144
|
+
|
145
|
+
# Process image
|
146
|
+
image = self._get_image(index)
|
147
|
+
label_int = self.source[index][self._label_key]
|
148
|
+
|
149
|
+
# Process target
|
150
|
+
if not 0 <= label_int < self._num_classes:
|
151
|
+
raise ValueError(f"Label {label_int} at index {index} is out of range [0, {self._num_classes})")
|
152
|
+
one_hot_label = self._one_hot_matrix[label_int]
|
153
|
+
|
154
|
+
# Process metadata
|
155
|
+
datum_metadata = self._get_base_metadata(index)
|
156
|
+
|
157
|
+
return image, one_hot_label, datum_metadata
|
158
|
+
|
159
|
+
|
160
|
+
class HFObjectDetectionDataset(HFBaseDataset[ObjectDetectionTarget], od.Dataset):
|
161
|
+
"""Wraps a Hugging Face dataset to comply with the ObjectDetectionDataset protocol."""
|
162
|
+
|
163
|
+
def __init__(self, hf_dataset: HFDataset, image_key: str, objects_key: str, bbox_key: str, label_key: str) -> None:
|
164
|
+
super().__init__(hf_dataset, image_key, known_keys={image_key, objects_key})
|
165
|
+
self._objects_key = objects_key
|
166
|
+
self._bbox_key = bbox_key
|
167
|
+
self._label_key = label_key
|
168
|
+
|
169
|
+
# Pre-validate and extract object features
|
170
|
+
self._object_meta_keys = self._validate_and_extract_object_features(hf_dataset.features)
|
171
|
+
|
172
|
+
# Validate and extract label information
|
173
|
+
label_feature = self._extract_label_feature(hf_dataset.features)
|
174
|
+
self.metadata: DatasetMetadata = DatasetMetadata(
|
175
|
+
id=self._metadata_id, index2label=dict(enumerate(label_feature.names)), **self._metadata_dict
|
176
|
+
)
|
177
|
+
|
178
|
+
def _validate_and_extract_object_features(self, features: FeatureDict) -> list[str]:
|
179
|
+
"""Validate objects feature and extract metadata keys."""
|
180
|
+
objects_feature = features[self._objects_key]
|
181
|
+
|
182
|
+
# Determine the structure and get inner features
|
183
|
+
if isinstance(objects_feature, HFList): # list(dict) case
|
184
|
+
if not isinstance(objects_feature.feature, dict):
|
185
|
+
raise TypeError(f"Objects feature '{self._objects_key}' with list type must contain dict features.")
|
186
|
+
inner_feature_dict = objects_feature.feature
|
187
|
+
elif isinstance(objects_feature, dict): # dict(list) case
|
188
|
+
inner_feature_dict = objects_feature
|
189
|
+
else:
|
190
|
+
raise TypeError(
|
191
|
+
f"Objects feature '{self._objects_key}' must be a list or dict, got {type(objects_feature).__name__}."
|
192
|
+
)
|
193
|
+
|
194
|
+
# Validate required keys exist
|
195
|
+
required_keys = {self._bbox_key, self._label_key}
|
196
|
+
missing_keys = required_keys - set(inner_feature_dict.keys())
|
197
|
+
if missing_keys:
|
198
|
+
raise ValueError(f"Objects feature '{self._objects_key}' missing required keys: {missing_keys}")
|
199
|
+
|
200
|
+
# Extract object metadata keys
|
201
|
+
known_inner_keys = {self._bbox_key, self._label_key}
|
202
|
+
return [
|
203
|
+
key
|
204
|
+
for key, feature in inner_feature_dict.items()
|
205
|
+
if key not in known_inner_keys and isinstance(feature, (HFValue, HFList))
|
206
|
+
]
|
207
|
+
|
208
|
+
def _extract_label_feature(self, features: FeatureDict) -> HFClassLabel:
|
209
|
+
"""Extract and validate the label feature."""
|
210
|
+
objects_feature = features[self._objects_key]
|
211
|
+
|
212
|
+
inner_features = objects_feature.feature if isinstance(objects_feature, HFList) else objects_feature
|
213
|
+
label_feature_container = inner_features[self._label_key]
|
214
|
+
label_feature = (
|
215
|
+
label_feature_container.feature
|
216
|
+
if isinstance(label_feature_container.feature, HFClassLabel)
|
217
|
+
else label_feature_container
|
218
|
+
)
|
219
|
+
|
220
|
+
if not isinstance(label_feature, HFClassLabel):
|
221
|
+
raise TypeError(
|
222
|
+
f"Label '{self._label_key}' in '{self._objects_key}' must be a ClassLabel, "
|
223
|
+
f"got {type(label_feature).__name__}."
|
224
|
+
)
|
225
|
+
|
226
|
+
return label_feature
|
227
|
+
|
228
|
+
def __getitem__(self, index: int) -> tuple[NumpyArray, ObjectDetectionTarget, DatumMetadata]:
|
229
|
+
if not 0 <= index < len(self.source):
|
230
|
+
raise IndexError(f"Index {index} out of range for dataset of size {len(self.source)}")
|
231
|
+
|
232
|
+
# Process image
|
233
|
+
image = self._get_image(index)
|
234
|
+
objects = self.source[index][self._objects_key]
|
235
|
+
|
236
|
+
# Process target
|
237
|
+
boxes = objects[self._bbox_key]
|
238
|
+
labels = objects[self._label_key]
|
239
|
+
scores = np.zeros_like(labels, dtype=np.float32)
|
240
|
+
target = ObjectDetectionTarget(boxes, labels, scores)
|
241
|
+
|
242
|
+
# Process metadata
|
243
|
+
datum_metadata = self._get_base_metadata(index)
|
244
|
+
self._add_object_metadata(objects, datum_metadata)
|
245
|
+
|
246
|
+
return image, target, datum_metadata
|
247
|
+
|
248
|
+
def _add_object_metadata(self, objects: dict[str, Any], datum_metadata: DatumMetadata) -> None:
|
249
|
+
"""Efficiently add object metadata to datum metadata."""
|
250
|
+
if not objects[self._bbox_key]: # No objects
|
251
|
+
return
|
252
|
+
|
253
|
+
num_objects = len(objects[self._bbox_key])
|
254
|
+
|
255
|
+
for key in self._object_meta_keys:
|
256
|
+
value = objects[key]
|
257
|
+
if isinstance(value, list):
|
258
|
+
if len(value) == num_objects:
|
259
|
+
datum_metadata[key] = value
|
260
|
+
else:
|
261
|
+
raise ValueError(
|
262
|
+
f"Object metadata '{key}' length {len(value)} doesn't match number of objects {num_objects}"
|
263
|
+
)
|
264
|
+
else:
|
265
|
+
datum_metadata[key] = [value] * num_objects
|
266
|
+
|
267
|
+
|
268
|
+
def is_bbox(feature: Any) -> bool:
|
269
|
+
"""Check if feature represents bounding box data with proper type validation."""
|
270
|
+
if not isinstance(feature, HFList):
|
271
|
+
return False
|
272
|
+
|
273
|
+
# Handle nested list structure
|
274
|
+
bbox_candidate = feature.feature if isinstance(feature.feature, HFList) else feature
|
275
|
+
|
276
|
+
return (
|
277
|
+
isinstance(bbox_candidate, HFList)
|
278
|
+
and bbox_candidate.length == 4
|
279
|
+
and isinstance(bbox_candidate.feature, HFValue)
|
280
|
+
and any(dtype in bbox_candidate.feature.dtype for dtype in ["float", "int"])
|
281
|
+
)
|
282
|
+
|
283
|
+
|
284
|
+
def is_label(feature: Any) -> bool:
|
285
|
+
"""Check if feature represents label data with proper type validation."""
|
286
|
+
target_feature = feature.feature if isinstance(feature, HFList) else feature
|
287
|
+
return isinstance(target_feature, HFClassLabel)
|
288
|
+
|
289
|
+
|
290
|
+
def find_od_keys(feature: Any) -> tuple[str | None, str | None]:
|
291
|
+
"""Helper to find bbox and label keys for object detection with improved logic."""
|
292
|
+
if not ((isinstance(feature, HFList) and isinstance(feature.feature, dict)) or isinstance(feature, dict)):
|
293
|
+
return None, None
|
294
|
+
|
295
|
+
inner_features: FeatureDict = feature.feature if isinstance(feature, HFList) else feature
|
296
|
+
|
297
|
+
bbox_key = label_key = None
|
298
|
+
|
299
|
+
for inner_name, inner_feature in inner_features.items():
|
300
|
+
if bbox_key is None and is_bbox(inner_feature):
|
301
|
+
bbox_key = inner_name
|
302
|
+
if label_key is None and is_label(inner_feature):
|
303
|
+
label_key = inner_name
|
304
|
+
|
305
|
+
# Early exit if both found
|
306
|
+
if bbox_key and label_key:
|
307
|
+
break
|
308
|
+
|
309
|
+
return bbox_key, label_key
|
310
|
+
|
311
|
+
|
312
|
+
def get_dataset_info(dataset: HFDataset) -> HFDatasetInfo:
|
313
|
+
"""Extract dataset information with improved validation and error messages."""
|
314
|
+
features = dataset.features
|
315
|
+
image_key = label_key = objects_key = bbox_key = None
|
316
|
+
|
317
|
+
# More efficient feature detection
|
318
|
+
for name, feature in features.items():
|
319
|
+
if image_key is None and isinstance(feature, (HFImage, HFArray)):
|
320
|
+
image_key = name
|
321
|
+
elif label_key is None and isinstance(feature, HFClassLabel):
|
322
|
+
label_key = name
|
323
|
+
elif objects_key is None:
|
324
|
+
temp_bbox, temp_label = find_od_keys(feature)
|
325
|
+
if temp_bbox and temp_label:
|
326
|
+
objects_key, bbox_key, label_key = name, temp_bbox, temp_label
|
327
|
+
|
328
|
+
if not image_key:
|
329
|
+
available_features = list(features.keys())
|
330
|
+
raise ValueError(
|
331
|
+
f"No image key found in dataset. Available features: {available_features}. "
|
332
|
+
f"Expected HFImage or HFArray type."
|
333
|
+
)
|
334
|
+
|
335
|
+
# Return appropriate dataset info based on detected features
|
336
|
+
if objects_key and bbox_key and label_key:
|
337
|
+
return HFObjectDetectionDatasetInfo(image_key, objects_key, bbox_key, label_key)
|
338
|
+
if label_key:
|
339
|
+
return HFImageClassificationDatasetInfo(image_key, label_key)
|
340
|
+
return HFDatasetInfo(image_key)
|
341
|
+
|
342
|
+
|
343
|
+
@overload
|
344
|
+
def from_huggingface(dataset: HFDataset, task: Literal["image_classification"]) -> HFImageClassificationDataset: ...
|
345
|
+
|
346
|
+
|
347
|
+
@overload
|
348
|
+
def from_huggingface(dataset: HFDataset, task: Literal["object_detection"]) -> HFObjectDetectionDataset: ...
|
349
|
+
|
350
|
+
|
351
|
+
@overload
|
352
|
+
def from_huggingface(
|
353
|
+
dataset: HFDataset, task: Literal["auto"] = "auto"
|
354
|
+
) -> HFObjectDetectionDataset | HFImageClassificationDataset: ...
|
355
|
+
|
356
|
+
|
357
|
+
def from_huggingface(
|
358
|
+
dataset: HFDataset, task: Literal["image_classification", "object_detection", "auto"] = "auto"
|
359
|
+
) -> HFObjectDetectionDataset | HFImageClassificationDataset:
|
360
|
+
"""Create appropriate dataset wrapper with enhanced error handling."""
|
361
|
+
info = get_dataset_info(dataset)
|
362
|
+
|
363
|
+
if isinstance(info, HFImageClassificationDatasetInfo):
|
364
|
+
if task in ("image_classification", "auto"):
|
365
|
+
return HFImageClassificationDataset(dataset, info.image_key, info.label_key)
|
366
|
+
if task == "object_detection":
|
367
|
+
raise ValueError(
|
368
|
+
f"Task mismatch: requested 'object_detection' but dataset appears to be "
|
369
|
+
f"image classification. Detected features: image='{info.image_key}', "
|
370
|
+
f"label='{info.label_key}'"
|
371
|
+
)
|
372
|
+
|
373
|
+
elif isinstance(info, HFObjectDetectionDatasetInfo):
|
374
|
+
if task in ("object_detection", "auto"):
|
375
|
+
return HFObjectDetectionDataset(dataset, info.image_key, info.objects_key, info.bbox_key, info.label_key)
|
376
|
+
if task == "image_classification":
|
377
|
+
raise ValueError(
|
378
|
+
f"Task mismatch: requested 'image_classification' but dataset appears to be "
|
379
|
+
f"object detection. Detected features: image='{info.image_key}', "
|
380
|
+
f"objects='{info.objects_key}'"
|
381
|
+
)
|
382
|
+
|
383
|
+
# Enhanced error message for auto-detection failure
|
384
|
+
available_features = list(dataset.features.keys())
|
385
|
+
feature_types = {k: type(v).__name__ for k, v in dataset.features.items()}
|
386
|
+
|
387
|
+
raise ValueError(
|
388
|
+
f"Could not automatically determine task for requested type '{task}'. "
|
389
|
+
f"Detected info: {info}. Available features: {available_features}. "
|
390
|
+
f"Feature types: {feature_types}. Ensure dataset has proper image and label/objects features."
|
391
|
+
)
|
@@ -0,0 +1,94 @@
|
|
1
|
+
"""
|
2
|
+
Common type protocols used for interoperability.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from collections.abc import Iterable, Iterator, Mapping, Sequence
|
6
|
+
from typing import Any, Protocol, overload, runtime_checkable
|
7
|
+
|
8
|
+
|
9
|
+
@runtime_checkable
|
10
|
+
class Array(Protocol):
|
11
|
+
"""
|
12
|
+
Protocol for interoperable array objects.
|
13
|
+
|
14
|
+
Supports common array representations with popular libraries like
|
15
|
+
PyTorch, Tensorflow and JAX, as well as NumPy arrays.
|
16
|
+
"""
|
17
|
+
|
18
|
+
@property
|
19
|
+
def shape(self) -> tuple[int, ...]: ...
|
20
|
+
def __array__(self) -> Any: ...
|
21
|
+
def __getitem__(self, key: Any, /) -> Any: ...
|
22
|
+
def __iter__(self) -> Iterator[Any]: ...
|
23
|
+
def __len__(self) -> int: ...
|
24
|
+
|
25
|
+
|
26
|
+
@runtime_checkable
|
27
|
+
class HFDatasetInfo(Protocol):
|
28
|
+
@property
|
29
|
+
def dataset_name(self) -> str: ...
|
30
|
+
|
31
|
+
|
32
|
+
@runtime_checkable
|
33
|
+
class HFDataset(Protocol):
|
34
|
+
@property
|
35
|
+
def features(self) -> Mapping[str, Any]: ...
|
36
|
+
|
37
|
+
@property
|
38
|
+
def builder_name(self) -> str | None: ...
|
39
|
+
|
40
|
+
@property
|
41
|
+
def info(self) -> HFDatasetInfo: ...
|
42
|
+
|
43
|
+
@overload
|
44
|
+
def __getitem__(self, key: int | slice | Iterable[int]) -> dict[str, Any]: ...
|
45
|
+
@overload
|
46
|
+
def __getitem__(self, key: str) -> Sequence[int]: ...
|
47
|
+
def __getitem__(self, key: str | int | slice | Iterable[int]) -> dict[str, Any] | Sequence[int]: ...
|
48
|
+
|
49
|
+
def __len__(self) -> int: ...
|
50
|
+
|
51
|
+
|
52
|
+
@runtime_checkable
|
53
|
+
class HFFeature(Protocol):
|
54
|
+
@property
|
55
|
+
def _type(self) -> str: ...
|
56
|
+
|
57
|
+
|
58
|
+
@runtime_checkable
|
59
|
+
class HFClassLabel(HFFeature, Protocol):
|
60
|
+
@property
|
61
|
+
def names(self) -> list[str]: ...
|
62
|
+
|
63
|
+
@property
|
64
|
+
def num_classes(self) -> int: ...
|
65
|
+
|
66
|
+
|
67
|
+
@runtime_checkable
|
68
|
+
class HFImage(HFFeature, Protocol):
|
69
|
+
@property
|
70
|
+
def decode(self) -> bool: ...
|
71
|
+
|
72
|
+
|
73
|
+
@runtime_checkable
|
74
|
+
class HFArray(HFFeature, Protocol):
|
75
|
+
@property
|
76
|
+
def shape(self) -> tuple[int, ...]: ...
|
77
|
+
@property
|
78
|
+
def dtype(self) -> str: ...
|
79
|
+
|
80
|
+
|
81
|
+
@runtime_checkable
|
82
|
+
class HFList(HFFeature, Protocol):
|
83
|
+
@property
|
84
|
+
def feature(self) -> Any: ...
|
85
|
+
@property
|
86
|
+
def length(self) -> int: ...
|
87
|
+
|
88
|
+
|
89
|
+
@runtime_checkable
|
90
|
+
class HFValue(HFFeature, Protocol):
|
91
|
+
@property
|
92
|
+
def pa_type(self) -> Any: ... # pyarrow type ... not documented
|
93
|
+
@property
|
94
|
+
def dtype(self) -> str: ...
|
@@ -2,7 +2,6 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
from typing import Any, Callable, Generic, TypeAlias, TypeVar, cast, overload
|
4
4
|
|
5
|
-
import numpy as np
|
6
5
|
import torch
|
7
6
|
from maite.protocols import DatasetMetadata, DatumMetadata
|
8
7
|
from maite.protocols.object_detection import ObjectDetectionTarget as _ObjectDetectionTarget
|
@@ -70,23 +69,22 @@ class TorchvisionWrapper(Generic[TArray, TTarget]):
|
|
70
69
|
image, target, metadata = self._dataset[index]
|
71
70
|
|
72
71
|
# Convert image to torch tensor
|
73
|
-
torch_image = torch.
|
74
|
-
torch_image = Image(torch_image)
|
72
|
+
torch_image = Image(torch.tensor(image))
|
75
73
|
|
76
74
|
# Handle different target types
|
77
75
|
if isinstance(target, Array):
|
78
76
|
# Image classification case
|
79
|
-
torch_target = torch.
|
77
|
+
torch_target = torch.tensor(target, dtype=torch.float32)
|
80
78
|
torch_datum = self._transform((torch_image, torch_target, metadata))
|
81
79
|
return cast(TorchvisionImageClassificationDatum, torch_datum)
|
82
80
|
|
83
81
|
if isinstance(target, _ObjectDetectionTarget):
|
84
82
|
# Object detection case
|
85
83
|
torch_boxes = BoundingBoxes(
|
86
|
-
torch.
|
84
|
+
torch.tensor(target.boxes), format="XYXY", canvas_size=(torch_image.shape[-2], torch_image.shape[-1])
|
87
85
|
) # type: ignore
|
88
|
-
torch_labels = torch.
|
89
|
-
torch_scores = torch.
|
86
|
+
torch_labels = torch.tensor(target.labels, dtype=torch.int64)
|
87
|
+
torch_scores = torch.tensor(target.scores, dtype=torch.float32)
|
90
88
|
torch_target = ObjectDetectionTarget(torch_boxes, torch_labels, torch_scores)
|
91
89
|
torch_datum = self._transform((torch_image, torch_target, metadata))
|
92
90
|
return cast(TorchvisionObjectDetectionDatum, torch_datum)
|
@@ -1,23 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
Common type protocols used for interoperability.
|
3
|
-
"""
|
4
|
-
|
5
|
-
from collections.abc import Iterator
|
6
|
-
from typing import Any, Protocol, runtime_checkable
|
7
|
-
|
8
|
-
|
9
|
-
@runtime_checkable
|
10
|
-
class Array(Protocol):
|
11
|
-
"""
|
12
|
-
Protocol for interoperable array objects.
|
13
|
-
|
14
|
-
Supports common array representations with popular libraries like
|
15
|
-
PyTorch, Tensorflow and JAX, as well as NumPy arrays.
|
16
|
-
"""
|
17
|
-
|
18
|
-
@property
|
19
|
-
def shape(self) -> tuple[int, ...]: ...
|
20
|
-
def __array__(self) -> Any: ...
|
21
|
-
def __getitem__(self, key: Any, /) -> Any: ...
|
22
|
-
def __iter__(self) -> Iterator[Any]: ...
|
23
|
-
def __len__(self) -> int: ...
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/image_classification/__init__.py
RENAMED
File without changes
|
{maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/image_classification/_cifar10.py
RENAMED
File without changes
|
{maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/image_classification/_mnist.py
RENAMED
File without changes
|
{maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/image_classification/_ships.py
RENAMED
File without changes
|
{maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/object_detection/__init__.py
RENAMED
File without changes
|
{maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/object_detection/_antiuav.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
{maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/object_detection/_seadrone.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|