maite-datasets 0.0.6__tar.gz → 0.0.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. {maite_datasets-0.0.6 → maite_datasets-0.0.7}/PKG-INFO +40 -3
  2. {maite_datasets-0.0.6 → maite_datasets-0.0.7}/README.md +37 -0
  3. {maite_datasets-0.0.6 → maite_datasets-0.0.7}/pyproject.toml +11 -4
  4. maite_datasets-0.0.7/src/maite_datasets/adapters/__init__.py +3 -0
  5. maite_datasets-0.0.7/src/maite_datasets/adapters/_huggingface.py +391 -0
  6. maite_datasets-0.0.7/src/maite_datasets/protocols.py +94 -0
  7. {maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/wrappers/_torch.py +5 -7
  8. maite_datasets-0.0.6/src/maite_datasets/protocols.py +0 -23
  9. {maite_datasets-0.0.6 → maite_datasets-0.0.7}/.gitignore +0 -0
  10. {maite_datasets-0.0.6 → maite_datasets-0.0.7}/LICENSE +0 -0
  11. {maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/__init__.py +0 -0
  12. {maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/_base.py +0 -0
  13. {maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/_builder.py +0 -0
  14. {maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/_collate.py +0 -0
  15. {maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/_fileio.py +0 -0
  16. {maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/_reader.py +0 -0
  17. {maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/_validate.py +0 -0
  18. {maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/image_classification/__init__.py +0 -0
  19. {maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/image_classification/_cifar10.py +0 -0
  20. {maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/image_classification/_mnist.py +0 -0
  21. {maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/image_classification/_ships.py +0 -0
  22. {maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/object_detection/__init__.py +0 -0
  23. {maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/object_detection/_antiuav.py +0 -0
  24. {maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/object_detection/_coco.py +0 -0
  25. {maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/object_detection/_milco.py +0 -0
  26. {maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/object_detection/_seadrone.py +0 -0
  27. {maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/object_detection/_voc.py +0 -0
  28. {maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/object_detection/_yolo.py +0 -0
  29. {maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/py.typed +0 -0
  30. {maite_datasets-0.0.6 → maite_datasets-0.0.7}/src/maite_datasets/wrappers/__init__.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: maite-datasets
3
- Version: 0.0.6
3
+ Version: 0.0.7
4
4
  Summary: A collection of Image Classification and Object Detection task datasets conforming to the MAITE protocol.
5
5
  Author-email: Andrew Weng <andrew.weng@ariacoustics.com>, Ryan Wood <ryan.wood@ariacoustics.com>, Shaun Jullens <shaun.jullens@ariacoustics.com>
6
6
  License-Expression: MIT
@@ -10,11 +10,11 @@ Classifier: Framework :: Pytest
10
10
  Classifier: License :: OSI Approved :: MIT License
11
11
  Classifier: Operating System :: OS Independent
12
12
  Classifier: Programming Language :: Python :: 3 :: Only
13
- Classifier: Programming Language :: Python :: 3.9
14
13
  Classifier: Programming Language :: Python :: 3.10
15
14
  Classifier: Programming Language :: Python :: 3.11
16
15
  Classifier: Programming Language :: Python :: 3.12
17
- Requires-Python: >=3.9
16
+ Classifier: Programming Language :: Python :: 3.13
17
+ Requires-Python: >=3.10
18
18
  Requires-Dist: defusedxml>=0.7.1
19
19
  Requires-Dist: maite<0.9,>=0.7
20
20
  Requires-Dist: numpy>=1.24.2
@@ -81,6 +81,8 @@ tuple(<class 'numpy.ndarray'>, <class 'numpy.ndarray'>, <class 'dict'>)
81
81
 
82
82
  Wrappers provide a way to convert datasets to allow usage of tools within specific backend frameworks.
83
83
 
84
+ ### Torchvision
85
+
84
86
  `TorchvisionWrapper` is a convenience class that wraps any of the datasets and provides the capability to apply
85
87
  `torchvision` transforms to the dataset.
86
88
 
@@ -129,6 +131,41 @@ type=Image, shape=torch.Size([3, 224, 224])
129
131
  tensor([16.4062, 47.4688, 28.4375, 54.0312], dtype=torch.float64)
130
132
  ```
131
133
 
134
+ ## Dataset Adapters
135
+
136
+ Adapters provide a way to read in datasets from other popular formats.
137
+
138
+ ### Huggingface
139
+
140
+ Hugging face datasets can be adapted into MAITE compliant format using the `from_huggingface` adapter.
141
+
142
+ ```python
143
+ >>> from datasets import load_dataset
144
+ >>> from maite_datasets.adapters import from_huggingface
145
+
146
+ >>> cppe5 = load_dataset("cppe-5")
147
+ >>> m_cppe5 = from_huggingface(cppe5["train"])
148
+ >>> print(m_cppe5)
149
+ HFObjectDetection Dataset
150
+ -------------------------
151
+ Source: Dataset({
152
+ features: ['image_id', 'image', 'width', 'height', 'objects'],
153
+ num_rows: 1000
154
+ })
155
+ Metadata: {'id': 'cppe-5', 'index2label': {0: 'Coverall', 1: 'Face_Shield', 2: 'Gloves', 3: 'Goggles', 4: 'Mask'}, 'description': '', 'citation': '', 'homepage': '', 'license': '', 'features': {'image_id': Value('int64'), 'image': Image(mode=None, decode=True), 'width': Value('int32'), 'height': Value('int32'), 'objects': {'id': List(Value('int64')), 'area': List(Value('int64')), 'bbox': List(List(Value('float32'), length=4)), 'category': List(ClassLabel(names=['Coverall', 'Face_Shield', 'Gloves', 'Goggles', 'Mask']))}}, 'post_processed': None, 'supervised_keys': None, 'builder_name': 'parquet', 'dataset_name': 'cppe-5', 'config_name': 'default', 'version': 0.0.0, 'splits': {'train': SplitInfo(name='train', num_bytes=240478590, num_examples=1000, shard_lengths=None, dataset_name='cppe-5'), 'test': SplitInfo(name='test', num_bytes=4172706, num_examples=29, shard_lengths=None, dataset_name='cppe-5')}, 'download_checksums': {'hf://datasets/cppe-5@66f6a5efd474e35bd7cb94bf15dea27d4c6ad3f8/data/train-00000-of-00001.parquet': {'num_bytes': 237015519, 'checksum': None}, 'hf://datasets/cppe-5@66f6a5efd474e35bd7cb94bf15dea27d4c6ad3f8/data/test-00000-of-00001.parquet': {'num_bytes': 4137134, 'checksum': None}}, 'download_size': 241152653, 'post_processing_size': None, 'dataset_size': 244651296, 'size_in_bytes': 485803949}
156
+
157
+ >>> image = m_cppe5[0][0]
158
+ >>> print(f"type={image.__class__.__name__}, shape={image.shape}")
159
+ type=ndarray, shape=(3, 663, 943)
160
+
161
+ >>> target = m_cppe5[0][1]
162
+ >>> print(f"box={target.boxes[0]}, label={target.labels[0]}")
163
+ box=[302.0, 109.0, 73.0, 52.0], label=4
164
+
165
+ >>> print(m_cppe5[0][2])
166
+ {'id': [114, 115, 116, 117], 'image_id': 15, 'width': 943, 'height': 663, 'area': [3796, 1596, 152768, 81002]}
167
+ ```
168
+
132
169
  ## Additional Information
133
170
 
134
171
  For more information on the MAITE protocol, check out their [documentation](https://mit-ll-ai-technology.github.io/maite/).
@@ -54,6 +54,8 @@ tuple(<class 'numpy.ndarray'>, <class 'numpy.ndarray'>, <class 'dict'>)
54
54
 
55
55
  Wrappers provide a way to convert datasets to allow usage of tools within specific backend frameworks.
56
56
 
57
+ ### Torchvision
58
+
57
59
  `TorchvisionWrapper` is a convenience class that wraps any of the datasets and provides the capability to apply
58
60
  `torchvision` transforms to the dataset.
59
61
 
@@ -102,6 +104,41 @@ type=Image, shape=torch.Size([3, 224, 224])
102
104
  tensor([16.4062, 47.4688, 28.4375, 54.0312], dtype=torch.float64)
103
105
  ```
104
106
 
107
+ ## Dataset Adapters
108
+
109
+ Adapters provide a way to read in datasets from other popular formats.
110
+
111
+ ### Huggingface
112
+
113
+ Hugging face datasets can be adapted into MAITE compliant format using the `from_huggingface` adapter.
114
+
115
+ ```python
116
+ >>> from datasets import load_dataset
117
+ >>> from maite_datasets.adapters import from_huggingface
118
+
119
+ >>> cppe5 = load_dataset("cppe-5")
120
+ >>> m_cppe5 = from_huggingface(cppe5["train"])
121
+ >>> print(m_cppe5)
122
+ HFObjectDetection Dataset
123
+ -------------------------
124
+ Source: Dataset({
125
+ features: ['image_id', 'image', 'width', 'height', 'objects'],
126
+ num_rows: 1000
127
+ })
128
+ Metadata: {'id': 'cppe-5', 'index2label': {0: 'Coverall', 1: 'Face_Shield', 2: 'Gloves', 3: 'Goggles', 4: 'Mask'}, 'description': '', 'citation': '', 'homepage': '', 'license': '', 'features': {'image_id': Value('int64'), 'image': Image(mode=None, decode=True), 'width': Value('int32'), 'height': Value('int32'), 'objects': {'id': List(Value('int64')), 'area': List(Value('int64')), 'bbox': List(List(Value('float32'), length=4)), 'category': List(ClassLabel(names=['Coverall', 'Face_Shield', 'Gloves', 'Goggles', 'Mask']))}}, 'post_processed': None, 'supervised_keys': None, 'builder_name': 'parquet', 'dataset_name': 'cppe-5', 'config_name': 'default', 'version': 0.0.0, 'splits': {'train': SplitInfo(name='train', num_bytes=240478590, num_examples=1000, shard_lengths=None, dataset_name='cppe-5'), 'test': SplitInfo(name='test', num_bytes=4172706, num_examples=29, shard_lengths=None, dataset_name='cppe-5')}, 'download_checksums': {'hf://datasets/cppe-5@66f6a5efd474e35bd7cb94bf15dea27d4c6ad3f8/data/train-00000-of-00001.parquet': {'num_bytes': 237015519, 'checksum': None}, 'hf://datasets/cppe-5@66f6a5efd474e35bd7cb94bf15dea27d4c6ad3f8/data/test-00000-of-00001.parquet': {'num_bytes': 4137134, 'checksum': None}}, 'download_size': 241152653, 'post_processing_size': None, 'dataset_size': 244651296, 'size_in_bytes': 485803949}
129
+
130
+ >>> image = m_cppe5[0][0]
131
+ >>> print(f"type={image.__class__.__name__}, shape={image.shape}")
132
+ type=ndarray, shape=(3, 663, 943)
133
+
134
+ >>> target = m_cppe5[0][1]
135
+ >>> print(f"box={target.boxes[0]}, label={target.labels[0]}")
136
+ box=[302.0, 109.0, 73.0, 52.0], label=4
137
+
138
+ >>> print(m_cppe5[0][2])
139
+ {'id': [114, 115, 116, 117], 'image_id': 15, 'width': 943, 'height': 663, 'area': [3796, 1596, 152768, 81002]}
140
+ ```
141
+
105
142
  ## Additional Information
106
143
 
107
144
  For more information on the MAITE protocol, check out their [documentation](https://mit-ll-ai-technology.github.io/maite/).
@@ -2,7 +2,7 @@
2
2
  name = "maite-datasets"
3
3
  description = "A collection of Image Classification and Object Detection task datasets conforming to the MAITE protocol."
4
4
  readme = "README.md"
5
- requires-python = ">=3.9"
5
+ requires-python = ">=3.10"
6
6
  dynamic = ["version"]
7
7
  dependencies = [
8
8
  "defusedxml>=0.7.1",
@@ -24,10 +24,10 @@ classifiers = [
24
24
  "Operating System :: OS Independent",
25
25
  "License :: OSI Approved :: MIT License",
26
26
  "Programming Language :: Python :: 3 :: Only",
27
- "Programming Language :: Python :: 3.9",
28
27
  "Programming Language :: Python :: 3.10",
29
28
  "Programming Language :: Python :: 3.11",
30
29
  "Programming Language :: Python :: 3.12",
30
+ "Programming Language :: Python :: 3.13",
31
31
  ]
32
32
 
33
33
  [project.optional-dependencies]
@@ -37,28 +37,35 @@ tqdm = [
37
37
 
38
38
  [dependency-groups]
39
39
  base = [
40
- "nox[uv]>=2025.5.1",
40
+ "nox>=2025.5.1",
41
+ "nox-uv>=0.6.2",
42
+ "uv>=0.8.0",
43
+ ]
44
+ more = [
41
45
  "torch>=2.2.0",
42
46
  "torchvision>=0.17.0",
43
47
  "tqdm>=4.66",
44
- "uv>=0.8.0",
45
48
  ]
46
49
  lint = [
50
+ { include-group = "base" },
47
51
  "ruff>=0.11",
48
52
  "codespell[toml]>=2.3",
49
53
  ]
50
54
  test = [
51
55
  { include-group = "base" },
56
+ { include-group = "more" },
52
57
  "pytest>=8.3",
53
58
  "pytest-cov>=6.1",
54
59
  "coverage[toml]>=7.6",
55
60
  ]
56
61
  type = [
57
62
  { include-group = "base" },
63
+ { include-group = "more" },
58
64
  "pyright[nodejs]>=1.1.400",
59
65
  ]
60
66
  dev = [
61
67
  { include-group = "base" },
68
+ { include-group = "more" },
62
69
  { include-group = "lint" },
63
70
  { include-group = "test" },
64
71
  { include-group = "type" },
@@ -0,0 +1,3 @@
1
+ from ._huggingface import HFImageClassificationDataset, HFObjectDetectionDataset, from_huggingface
2
+
3
+ __all__ = ["HFImageClassificationDataset", "HFObjectDetectionDataset", "from_huggingface"]
@@ -0,0 +1,391 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Mapping
4
+ from dataclasses import dataclass
5
+ from functools import lru_cache
6
+ from typing import Any, Literal, TypeAlias, overload
7
+
8
+ import maite.protocols.image_classification as ic
9
+ import maite.protocols.object_detection as od
10
+ import numpy as np
11
+ from maite.protocols import DatasetMetadata, DatumMetadata
12
+
13
+ from maite_datasets._base import BaseDataset, NumpyArray, ObjectDetectionTarget
14
+ from maite_datasets.protocols import HFArray, HFClassLabel, HFDataset, HFImage, HFList, HFValue
15
+ from maite_datasets.wrappers._torch import TTarget
16
+
17
+ # Constants for image processing
18
+ MAX_VALID_CHANNELS = 10
19
+
20
+ FeatureDict: TypeAlias = Mapping[str, Any]
21
+
22
+
23
+ @dataclass
24
+ class HFDatasetInfo:
25
+ image_key: str
26
+
27
+
28
+ @dataclass
29
+ class HFImageClassificationDatasetInfo(HFDatasetInfo):
30
+ label_key: str
31
+
32
+
33
+ @dataclass
34
+ class HFObjectDetectionDatasetInfo(HFDatasetInfo):
35
+ objects_key: str
36
+ bbox_key: str
37
+ label_key: str
38
+
39
+
40
+ class HFBaseDataset(BaseDataset[NumpyArray, TTarget]):
41
+ """Base wrapper for Hugging Face datasets, handling common logic."""
42
+
43
+ def __init__(self, hf_dataset: HFDataset, image_key: str, known_keys: set[str]) -> None:
44
+ self.source = hf_dataset
45
+ self._image_key = image_key
46
+
47
+ # Add dataset metadata
48
+ dataset_info_dict = hf_dataset.info.__dict__
49
+ if "id" in dataset_info_dict:
50
+ dataset_info_dict["datasetinfo_id"] = dataset_info_dict.pop("id")
51
+ self._metadata_id = dataset_info_dict["dataset_name"]
52
+ self._metadata_dict = dataset_info_dict
53
+
54
+ # Pre-validate features and cache metadata keys
55
+ self._validate_features(hf_dataset.features)
56
+ self._scalar_meta_keys = self._extract_scalar_meta_keys(hf_dataset.features, known_keys)
57
+
58
+ # Cache for image conversions
59
+ self._image_cache: dict[int, np.ndarray] = {}
60
+
61
+ def _validate_features(self, features: FeatureDict) -> None:
62
+ """Pre-validate all features during initialization."""
63
+ if self._image_key not in features:
64
+ raise ValueError(f"Image key '{self._image_key}' not found in dataset features.")
65
+
66
+ if not isinstance(features[self._image_key], (HFImage, HFArray)):
67
+ raise TypeError(f"Image feature '{self._image_key}' must be HFImage or HFArray.")
68
+
69
+ def _extract_scalar_meta_keys(self, features: FeatureDict, known_keys: set[str]) -> list[str]:
70
+ """Extract scalar metadata keys during initialization."""
71
+ return [key for key, feature in features.items() if key not in known_keys and isinstance(feature, HFValue)]
72
+
73
+ def __len__(self) -> int:
74
+ return len(self.source)
75
+
76
+ def _get_base_metadata(self, index: int) -> DatumMetadata:
77
+ """Extract base metadata for a datum."""
78
+ item = self.source[index]
79
+ datum_metadata: DatumMetadata = {"id": index}
80
+ for key in self._scalar_meta_keys:
81
+ datum_metadata[key] = item[key]
82
+ return datum_metadata
83
+
84
+ @lru_cache(maxsize=64) # Cache image conversions
85
+ def _get_image(self, index: int) -> np.ndarray:
86
+ """Get and process image with caching and optimized conversions."""
87
+ # Convert to numpy array only once
88
+ raw_image = self.source[index][self._image_key]
89
+ image = np.asarray(raw_image)
90
+
91
+ # Handle different image formats efficiently
92
+ if image.ndim == 2:
93
+ # Grayscale: HW -> CHW
94
+ image = image[np.newaxis, ...] # More efficient than expand_dims
95
+ elif image.ndim == 3:
96
+ # Check if we need to transpose from HWC to CHW
97
+ if image.shape[-1] < image.shape[-3] and image.shape[-1] <= MAX_VALID_CHANNELS:
98
+ # HWC -> CHW using optimized transpose
99
+ image = np.transpose(image, (2, 0, 1))
100
+ elif image.shape[0] > MAX_VALID_CHANNELS:
101
+ raise ValueError(
102
+ f"Image at index {index} has invalid channel configuration. "
103
+ f"Expected channels to be less than {MAX_VALID_CHANNELS}, got shape {image.shape}"
104
+ )
105
+ else:
106
+ raise ValueError(
107
+ f"Image at index {index} has unsupported dimensionality. "
108
+ f"Expected 2D or 3D, got {image.ndim}D with shape {image.shape}"
109
+ )
110
+
111
+ if image.ndim != 3:
112
+ raise ValueError(f"Image processing failed for index {index}. Final shape: {image.shape}")
113
+
114
+ return image
115
+
116
+
117
+ class HFImageClassificationDataset(HFBaseDataset[NumpyArray], ic.Dataset):
118
+ """Wraps a Hugging Face dataset to comply with the ImageClassificationDataset protocol."""
119
+
120
+ def __init__(self, hf_dataset: HFDataset, image_key: str, label_key: str) -> None:
121
+ super().__init__(hf_dataset, image_key, known_keys={image_key, label_key})
122
+ self._label_key = label_key
123
+
124
+ # Pre-validate label feature
125
+ label_feature = hf_dataset.features[self._label_key]
126
+ if not isinstance(label_feature, HFClassLabel):
127
+ raise TypeError(
128
+ f"Label feature '{self._label_key}' must be a datasets.ClassLabel, got {type(label_feature).__name__}."
129
+ )
130
+
131
+ self._num_classes: int = label_feature.num_classes
132
+
133
+ # Pre-compute one-hot identity matrix for efficient encoding
134
+ self._one_hot_matrix = np.eye(self._num_classes, dtype=np.float32)
135
+
136
+ # Enhanced metadata with validation
137
+ self.metadata: DatasetMetadata = DatasetMetadata(
138
+ id=self._metadata_id, index2label=dict(enumerate(label_feature.names), **self._metadata_dict)
139
+ )
140
+
141
+ def __getitem__(self, index: int) -> tuple[NumpyArray, NumpyArray, DatumMetadata]:
142
+ if not 0 <= index < len(self.source):
143
+ raise IndexError(f"Index {index} out of range for dataset of size {len(self.source)}")
144
+
145
+ # Process image
146
+ image = self._get_image(index)
147
+ label_int = self.source[index][self._label_key]
148
+
149
+ # Process target
150
+ if not 0 <= label_int < self._num_classes:
151
+ raise ValueError(f"Label {label_int} at index {index} is out of range [0, {self._num_classes})")
152
+ one_hot_label = self._one_hot_matrix[label_int]
153
+
154
+ # Process metadata
155
+ datum_metadata = self._get_base_metadata(index)
156
+
157
+ return image, one_hot_label, datum_metadata
158
+
159
+
160
+ class HFObjectDetectionDataset(HFBaseDataset[ObjectDetectionTarget], od.Dataset):
161
+ """Wraps a Hugging Face dataset to comply with the ObjectDetectionDataset protocol."""
162
+
163
+ def __init__(self, hf_dataset: HFDataset, image_key: str, objects_key: str, bbox_key: str, label_key: str) -> None:
164
+ super().__init__(hf_dataset, image_key, known_keys={image_key, objects_key})
165
+ self._objects_key = objects_key
166
+ self._bbox_key = bbox_key
167
+ self._label_key = label_key
168
+
169
+ # Pre-validate and extract object features
170
+ self._object_meta_keys = self._validate_and_extract_object_features(hf_dataset.features)
171
+
172
+ # Validate and extract label information
173
+ label_feature = self._extract_label_feature(hf_dataset.features)
174
+ self.metadata: DatasetMetadata = DatasetMetadata(
175
+ id=self._metadata_id, index2label=dict(enumerate(label_feature.names)), **self._metadata_dict
176
+ )
177
+
178
+ def _validate_and_extract_object_features(self, features: FeatureDict) -> list[str]:
179
+ """Validate objects feature and extract metadata keys."""
180
+ objects_feature = features[self._objects_key]
181
+
182
+ # Determine the structure and get inner features
183
+ if isinstance(objects_feature, HFList): # list(dict) case
184
+ if not isinstance(objects_feature.feature, dict):
185
+ raise TypeError(f"Objects feature '{self._objects_key}' with list type must contain dict features.")
186
+ inner_feature_dict = objects_feature.feature
187
+ elif isinstance(objects_feature, dict): # dict(list) case
188
+ inner_feature_dict = objects_feature
189
+ else:
190
+ raise TypeError(
191
+ f"Objects feature '{self._objects_key}' must be a list or dict, got {type(objects_feature).__name__}."
192
+ )
193
+
194
+ # Validate required keys exist
195
+ required_keys = {self._bbox_key, self._label_key}
196
+ missing_keys = required_keys - set(inner_feature_dict.keys())
197
+ if missing_keys:
198
+ raise ValueError(f"Objects feature '{self._objects_key}' missing required keys: {missing_keys}")
199
+
200
+ # Extract object metadata keys
201
+ known_inner_keys = {self._bbox_key, self._label_key}
202
+ return [
203
+ key
204
+ for key, feature in inner_feature_dict.items()
205
+ if key not in known_inner_keys and isinstance(feature, (HFValue, HFList))
206
+ ]
207
+
208
+ def _extract_label_feature(self, features: FeatureDict) -> HFClassLabel:
209
+ """Extract and validate the label feature."""
210
+ objects_feature = features[self._objects_key]
211
+
212
+ inner_features = objects_feature.feature if isinstance(objects_feature, HFList) else objects_feature
213
+ label_feature_container = inner_features[self._label_key]
214
+ label_feature = (
215
+ label_feature_container.feature
216
+ if isinstance(label_feature_container.feature, HFClassLabel)
217
+ else label_feature_container
218
+ )
219
+
220
+ if not isinstance(label_feature, HFClassLabel):
221
+ raise TypeError(
222
+ f"Label '{self._label_key}' in '{self._objects_key}' must be a ClassLabel, "
223
+ f"got {type(label_feature).__name__}."
224
+ )
225
+
226
+ return label_feature
227
+
228
+ def __getitem__(self, index: int) -> tuple[NumpyArray, ObjectDetectionTarget, DatumMetadata]:
229
+ if not 0 <= index < len(self.source):
230
+ raise IndexError(f"Index {index} out of range for dataset of size {len(self.source)}")
231
+
232
+ # Process image
233
+ image = self._get_image(index)
234
+ objects = self.source[index][self._objects_key]
235
+
236
+ # Process target
237
+ boxes = objects[self._bbox_key]
238
+ labels = objects[self._label_key]
239
+ scores = np.zeros_like(labels, dtype=np.float32)
240
+ target = ObjectDetectionTarget(boxes, labels, scores)
241
+
242
+ # Process metadata
243
+ datum_metadata = self._get_base_metadata(index)
244
+ self._add_object_metadata(objects, datum_metadata)
245
+
246
+ return image, target, datum_metadata
247
+
248
+ def _add_object_metadata(self, objects: dict[str, Any], datum_metadata: DatumMetadata) -> None:
249
+ """Efficiently add object metadata to datum metadata."""
250
+ if not objects[self._bbox_key]: # No objects
251
+ return
252
+
253
+ num_objects = len(objects[self._bbox_key])
254
+
255
+ for key in self._object_meta_keys:
256
+ value = objects[key]
257
+ if isinstance(value, list):
258
+ if len(value) == num_objects:
259
+ datum_metadata[key] = value
260
+ else:
261
+ raise ValueError(
262
+ f"Object metadata '{key}' length {len(value)} doesn't match number of objects {num_objects}"
263
+ )
264
+ else:
265
+ datum_metadata[key] = [value] * num_objects
266
+
267
+
268
+ def is_bbox(feature: Any) -> bool:
269
+ """Check if feature represents bounding box data with proper type validation."""
270
+ if not isinstance(feature, HFList):
271
+ return False
272
+
273
+ # Handle nested list structure
274
+ bbox_candidate = feature.feature if isinstance(feature.feature, HFList) else feature
275
+
276
+ return (
277
+ isinstance(bbox_candidate, HFList)
278
+ and bbox_candidate.length == 4
279
+ and isinstance(bbox_candidate.feature, HFValue)
280
+ and any(dtype in bbox_candidate.feature.dtype for dtype in ["float", "int"])
281
+ )
282
+
283
+
284
+ def is_label(feature: Any) -> bool:
285
+ """Check if feature represents label data with proper type validation."""
286
+ target_feature = feature.feature if isinstance(feature, HFList) else feature
287
+ return isinstance(target_feature, HFClassLabel)
288
+
289
+
290
+ def find_od_keys(feature: Any) -> tuple[str | None, str | None]:
291
+ """Helper to find bbox and label keys for object detection with improved logic."""
292
+ if not ((isinstance(feature, HFList) and isinstance(feature.feature, dict)) or isinstance(feature, dict)):
293
+ return None, None
294
+
295
+ inner_features: FeatureDict = feature.feature if isinstance(feature, HFList) else feature
296
+
297
+ bbox_key = label_key = None
298
+
299
+ for inner_name, inner_feature in inner_features.items():
300
+ if bbox_key is None and is_bbox(inner_feature):
301
+ bbox_key = inner_name
302
+ if label_key is None and is_label(inner_feature):
303
+ label_key = inner_name
304
+
305
+ # Early exit if both found
306
+ if bbox_key and label_key:
307
+ break
308
+
309
+ return bbox_key, label_key
310
+
311
+
312
+ def get_dataset_info(dataset: HFDataset) -> HFDatasetInfo:
313
+ """Extract dataset information with improved validation and error messages."""
314
+ features = dataset.features
315
+ image_key = label_key = objects_key = bbox_key = None
316
+
317
+ # More efficient feature detection
318
+ for name, feature in features.items():
319
+ if image_key is None and isinstance(feature, (HFImage, HFArray)):
320
+ image_key = name
321
+ elif label_key is None and isinstance(feature, HFClassLabel):
322
+ label_key = name
323
+ elif objects_key is None:
324
+ temp_bbox, temp_label = find_od_keys(feature)
325
+ if temp_bbox and temp_label:
326
+ objects_key, bbox_key, label_key = name, temp_bbox, temp_label
327
+
328
+ if not image_key:
329
+ available_features = list(features.keys())
330
+ raise ValueError(
331
+ f"No image key found in dataset. Available features: {available_features}. "
332
+ f"Expected HFImage or HFArray type."
333
+ )
334
+
335
+ # Return appropriate dataset info based on detected features
336
+ if objects_key and bbox_key and label_key:
337
+ return HFObjectDetectionDatasetInfo(image_key, objects_key, bbox_key, label_key)
338
+ if label_key:
339
+ return HFImageClassificationDatasetInfo(image_key, label_key)
340
+ return HFDatasetInfo(image_key)
341
+
342
+
343
+ @overload
344
+ def from_huggingface(dataset: HFDataset, task: Literal["image_classification"]) -> HFImageClassificationDataset: ...
345
+
346
+
347
+ @overload
348
+ def from_huggingface(dataset: HFDataset, task: Literal["object_detection"]) -> HFObjectDetectionDataset: ...
349
+
350
+
351
+ @overload
352
+ def from_huggingface(
353
+ dataset: HFDataset, task: Literal["auto"] = "auto"
354
+ ) -> HFObjectDetectionDataset | HFImageClassificationDataset: ...
355
+
356
+
357
+ def from_huggingface(
358
+ dataset: HFDataset, task: Literal["image_classification", "object_detection", "auto"] = "auto"
359
+ ) -> HFObjectDetectionDataset | HFImageClassificationDataset:
360
+ """Create appropriate dataset wrapper with enhanced error handling."""
361
+ info = get_dataset_info(dataset)
362
+
363
+ if isinstance(info, HFImageClassificationDatasetInfo):
364
+ if task in ("image_classification", "auto"):
365
+ return HFImageClassificationDataset(dataset, info.image_key, info.label_key)
366
+ if task == "object_detection":
367
+ raise ValueError(
368
+ f"Task mismatch: requested 'object_detection' but dataset appears to be "
369
+ f"image classification. Detected features: image='{info.image_key}', "
370
+ f"label='{info.label_key}'"
371
+ )
372
+
373
+ elif isinstance(info, HFObjectDetectionDatasetInfo):
374
+ if task in ("object_detection", "auto"):
375
+ return HFObjectDetectionDataset(dataset, info.image_key, info.objects_key, info.bbox_key, info.label_key)
376
+ if task == "image_classification":
377
+ raise ValueError(
378
+ f"Task mismatch: requested 'image_classification' but dataset appears to be "
379
+ f"object detection. Detected features: image='{info.image_key}', "
380
+ f"objects='{info.objects_key}'"
381
+ )
382
+
383
+ # Enhanced error message for auto-detection failure
384
+ available_features = list(dataset.features.keys())
385
+ feature_types = {k: type(v).__name__ for k, v in dataset.features.items()}
386
+
387
+ raise ValueError(
388
+ f"Could not automatically determine task for requested type '{task}'. "
389
+ f"Detected info: {info}. Available features: {available_features}. "
390
+ f"Feature types: {feature_types}. Ensure dataset has proper image and label/objects features."
391
+ )
@@ -0,0 +1,94 @@
1
+ """
2
+ Common type protocols used for interoperability.
3
+ """
4
+
5
+ from collections.abc import Iterable, Iterator, Mapping, Sequence
6
+ from typing import Any, Protocol, overload, runtime_checkable
7
+
8
+
9
+ @runtime_checkable
10
+ class Array(Protocol):
11
+ """
12
+ Protocol for interoperable array objects.
13
+
14
+ Supports common array representations with popular libraries like
15
+ PyTorch, Tensorflow and JAX, as well as NumPy arrays.
16
+ """
17
+
18
+ @property
19
+ def shape(self) -> tuple[int, ...]: ...
20
+ def __array__(self) -> Any: ...
21
+ def __getitem__(self, key: Any, /) -> Any: ...
22
+ def __iter__(self) -> Iterator[Any]: ...
23
+ def __len__(self) -> int: ...
24
+
25
+
26
+ @runtime_checkable
27
+ class HFDatasetInfo(Protocol):
28
+ @property
29
+ def dataset_name(self) -> str: ...
30
+
31
+
32
+ @runtime_checkable
33
+ class HFDataset(Protocol):
34
+ @property
35
+ def features(self) -> Mapping[str, Any]: ...
36
+
37
+ @property
38
+ def builder_name(self) -> str | None: ...
39
+
40
+ @property
41
+ def info(self) -> HFDatasetInfo: ...
42
+
43
+ @overload
44
+ def __getitem__(self, key: int | slice | Iterable[int]) -> dict[str, Any]: ...
45
+ @overload
46
+ def __getitem__(self, key: str) -> Sequence[int]: ...
47
+ def __getitem__(self, key: str | int | slice | Iterable[int]) -> dict[str, Any] | Sequence[int]: ...
48
+
49
+ def __len__(self) -> int: ...
50
+
51
+
52
+ @runtime_checkable
53
+ class HFFeature(Protocol):
54
+ @property
55
+ def _type(self) -> str: ...
56
+
57
+
58
+ @runtime_checkable
59
+ class HFClassLabel(HFFeature, Protocol):
60
+ @property
61
+ def names(self) -> list[str]: ...
62
+
63
+ @property
64
+ def num_classes(self) -> int: ...
65
+
66
+
67
+ @runtime_checkable
68
+ class HFImage(HFFeature, Protocol):
69
+ @property
70
+ def decode(self) -> bool: ...
71
+
72
+
73
+ @runtime_checkable
74
+ class HFArray(HFFeature, Protocol):
75
+ @property
76
+ def shape(self) -> tuple[int, ...]: ...
77
+ @property
78
+ def dtype(self) -> str: ...
79
+
80
+
81
+ @runtime_checkable
82
+ class HFList(HFFeature, Protocol):
83
+ @property
84
+ def feature(self) -> Any: ...
85
+ @property
86
+ def length(self) -> int: ...
87
+
88
+
89
+ @runtime_checkable
90
+ class HFValue(HFFeature, Protocol):
91
+ @property
92
+ def pa_type(self) -> Any: ... # pyarrow type ... not documented
93
+ @property
94
+ def dtype(self) -> str: ...
@@ -2,7 +2,6 @@ from __future__ import annotations
2
2
 
3
3
  from typing import Any, Callable, Generic, TypeAlias, TypeVar, cast, overload
4
4
 
5
- import numpy as np
6
5
  import torch
7
6
  from maite.protocols import DatasetMetadata, DatumMetadata
8
7
  from maite.protocols.object_detection import ObjectDetectionTarget as _ObjectDetectionTarget
@@ -70,23 +69,22 @@ class TorchvisionWrapper(Generic[TArray, TTarget]):
70
69
  image, target, metadata = self._dataset[index]
71
70
 
72
71
  # Convert image to torch tensor
73
- torch_image = torch.from_numpy(image) if isinstance(image, np.ndarray) else torch.as_tensor(image)
74
- torch_image = Image(torch_image)
72
+ torch_image = Image(torch.tensor(image))
75
73
 
76
74
  # Handle different target types
77
75
  if isinstance(target, Array):
78
76
  # Image classification case
79
- torch_target = torch.as_tensor(target, dtype=torch.float32)
77
+ torch_target = torch.tensor(target, dtype=torch.float32)
80
78
  torch_datum = self._transform((torch_image, torch_target, metadata))
81
79
  return cast(TorchvisionImageClassificationDatum, torch_datum)
82
80
 
83
81
  if isinstance(target, _ObjectDetectionTarget):
84
82
  # Object detection case
85
83
  torch_boxes = BoundingBoxes(
86
- torch.as_tensor(target.boxes), format="XYXY", canvas_size=(torch_image.shape[-2], torch_image.shape[-1])
84
+ torch.tensor(target.boxes), format="XYXY", canvas_size=(torch_image.shape[-2], torch_image.shape[-1])
87
85
  ) # type: ignore
88
- torch_labels = torch.as_tensor(target.labels, dtype=torch.int64)
89
- torch_scores = torch.as_tensor(target.scores, dtype=torch.float32)
86
+ torch_labels = torch.tensor(target.labels, dtype=torch.int64)
87
+ torch_scores = torch.tensor(target.scores, dtype=torch.float32)
90
88
  torch_target = ObjectDetectionTarget(torch_boxes, torch_labels, torch_scores)
91
89
  torch_datum = self._transform((torch_image, torch_target, metadata))
92
90
  return cast(TorchvisionObjectDetectionDatum, torch_datum)
@@ -1,23 +0,0 @@
1
- """
2
- Common type protocols used for interoperability.
3
- """
4
-
5
- from collections.abc import Iterator
6
- from typing import Any, Protocol, runtime_checkable
7
-
8
-
9
- @runtime_checkable
10
- class Array(Protocol):
11
- """
12
- Protocol for interoperable array objects.
13
-
14
- Supports common array representations with popular libraries like
15
- PyTorch, Tensorflow and JAX, as well as NumPy arrays.
16
- """
17
-
18
- @property
19
- def shape(self) -> tuple[int, ...]: ...
20
- def __array__(self) -> Any: ...
21
- def __getitem__(self, key: Any, /) -> Any: ...
22
- def __iter__(self) -> Iterator[Any]: ...
23
- def __len__(self) -> int: ...
File without changes