maite-datasets 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- maite_datasets/__init__.py +16 -1
- maite_datasets/_collate.py +112 -0
- maite_datasets/_protocols.py +20 -25
- maite_datasets/_reader/__init__.py +6 -0
- maite_datasets/_reader/_base.py +135 -0
- maite_datasets/_reader/_coco.py +287 -0
- maite_datasets/_reader/_factory.py +64 -0
- maite_datasets/_reader/_yolo.py +312 -0
- maite_datasets/_validate.py +169 -0
- {maite_datasets-0.0.2.dist-info → maite_datasets-0.0.4.dist-info}/METADATA +1 -1
- {maite_datasets-0.0.2.dist-info → maite_datasets-0.0.4.dist-info}/RECORD +13 -6
- {maite_datasets-0.0.2.dist-info → maite_datasets-0.0.4.dist-info}/WHEEL +0 -0
- {maite_datasets-0.0.2.dist-info → maite_datasets-0.0.4.dist-info}/licenses/LICENSE +0 -0
maite_datasets/__init__.py
CHANGED
@@ -1,5 +1,20 @@
|
|
1
1
|
"""Module for MAITE compliant Computer Vision datasets."""
|
2
2
|
|
3
3
|
from maite_datasets._builder import to_image_classification_dataset, to_object_detection_dataset
|
4
|
+
from maite_datasets._collate import collate_as_torch, collate_as_numpy, collate_as_list
|
5
|
+
from maite_datasets._validate import validate_dataset
|
6
|
+
from maite_datasets._reader._factory import create_dataset_reader
|
7
|
+
from maite_datasets._reader._coco import COCODatasetReader
|
8
|
+
from maite_datasets._reader._yolo import YOLODatasetReader
|
4
9
|
|
5
|
-
__all__ = [
|
10
|
+
__all__ = [
|
11
|
+
"collate_as_list",
|
12
|
+
"collate_as_numpy",
|
13
|
+
"collate_as_torch",
|
14
|
+
"create_dataset_reader",
|
15
|
+
"to_image_classification_dataset",
|
16
|
+
"to_object_detection_dataset",
|
17
|
+
"validate_dataset",
|
18
|
+
"COCODatasetReader",
|
19
|
+
"YOLODatasetReader",
|
20
|
+
]
|
@@ -0,0 +1,112 @@
|
|
1
|
+
"""
|
2
|
+
Collate functions used with a PyTorch DataLoader to load data from MAITE compliant datasets.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from __future__ import annotations
|
6
|
+
|
7
|
+
__all__ = []
|
8
|
+
|
9
|
+
from collections.abc import Iterable, Sequence
|
10
|
+
from typing import Any, TypeVar, TYPE_CHECKING
|
11
|
+
|
12
|
+
import numpy as np
|
13
|
+
from numpy.typing import NDArray
|
14
|
+
|
15
|
+
if TYPE_CHECKING:
|
16
|
+
import torch
|
17
|
+
|
18
|
+
from maite_datasets._protocols import ArrayLike
|
19
|
+
|
20
|
+
T_in = TypeVar("T_in")
|
21
|
+
T_tgt = TypeVar("T_tgt")
|
22
|
+
T_md = TypeVar("T_md")
|
23
|
+
|
24
|
+
|
25
|
+
def collate_as_list(
|
26
|
+
batch_data_as_singles: Iterable[tuple[T_in, T_tgt, T_md]],
|
27
|
+
) -> tuple[Sequence[T_in], Sequence[T_tgt], Sequence[T_md]]:
|
28
|
+
"""
|
29
|
+
A collate function that takes a batch of individual data points in the format
|
30
|
+
(input, target, metadata) and returns three lists: the input batch, the target batch,
|
31
|
+
and the metadata batch. This is useful for loading data with torch.utils.data.DataLoader
|
32
|
+
when the target and metadata are not tensors.
|
33
|
+
|
34
|
+
Parameters
|
35
|
+
----------
|
36
|
+
batch_data_as_singles : An iterable of (input, target, metadata) tuples.
|
37
|
+
|
38
|
+
Returns
|
39
|
+
-------
|
40
|
+
tuple[Sequence[T_in], Sequence[T_tgt], Sequence[T_md]]
|
41
|
+
A tuple of three lists: the input batch, the target batch, and the metadata batch.
|
42
|
+
"""
|
43
|
+
input_batch: list[T_in] = []
|
44
|
+
target_batch: list[T_tgt] = []
|
45
|
+
metadata_batch: list[T_md] = []
|
46
|
+
for input_datum, target_datum, metadata_datum in batch_data_as_singles:
|
47
|
+
input_batch.append(input_datum)
|
48
|
+
target_batch.append(target_datum)
|
49
|
+
metadata_batch.append(metadata_datum)
|
50
|
+
|
51
|
+
return input_batch, target_batch, metadata_batch
|
52
|
+
|
53
|
+
|
54
|
+
def collate_as_numpy(
|
55
|
+
batch_data_as_singles: Iterable[tuple[ArrayLike, T_tgt, T_md]],
|
56
|
+
) -> tuple[NDArray[Any], Sequence[T_tgt], Sequence[T_md]]:
|
57
|
+
"""
|
58
|
+
A collate function that takes a batch of individual data points in the format
|
59
|
+
(input, target, metadata) and returns the batched input as a single NumPy array with two
|
60
|
+
lists: the target batch, and the metadata batch. The inputs must be homogeneous arrays.
|
61
|
+
|
62
|
+
Parameters
|
63
|
+
----------
|
64
|
+
batch_data_as_singles : An iterable of (ArrayLike, target, metadata) tuples.
|
65
|
+
|
66
|
+
Returns
|
67
|
+
-------
|
68
|
+
tuple[NDArray[Any], Sequence[T_tgt], Sequence[T_md]]
|
69
|
+
A tuple of a NumPy array and two lists: the input batch, the target batch, and the metadata batch.
|
70
|
+
"""
|
71
|
+
input_batch: list[NDArray[Any]] = []
|
72
|
+
target_batch: list[T_tgt] = []
|
73
|
+
metadata_batch: list[T_md] = []
|
74
|
+
for input_datum, target_datum, metadata_datum in batch_data_as_singles:
|
75
|
+
input_batch.append(np.asarray(input_datum))
|
76
|
+
target_batch.append(target_datum)
|
77
|
+
metadata_batch.append(metadata_datum)
|
78
|
+
|
79
|
+
return np.stack(input_batch) if input_batch else np.array([]), target_batch, metadata_batch
|
80
|
+
|
81
|
+
|
82
|
+
def collate_as_torch(
|
83
|
+
batch_data_as_singles: Iterable[tuple[ArrayLike, T_tgt, T_md]],
|
84
|
+
) -> tuple[torch.Tensor, Sequence[T_tgt], Sequence[T_md]]:
|
85
|
+
"""
|
86
|
+
A collate function that takes a batch of individual data points in the format
|
87
|
+
(input, target, metadata) and returns the batched input as a single torch Tensor with two
|
88
|
+
lists: the target batch, and the metadata batch. The inputs must be homogeneous arrays.
|
89
|
+
|
90
|
+
Parameters
|
91
|
+
----------
|
92
|
+
batch_data_as_singles : An iterable of (ArrayLike, target, metadata) tuples.
|
93
|
+
|
94
|
+
Returns
|
95
|
+
-------
|
96
|
+
tuple[torch.Tensor, Sequence[T_tgt], Sequence[T_md]]
|
97
|
+
A tuple of a torch Tensor and two lists: the input batch, the target batch, and the metadata batch.
|
98
|
+
"""
|
99
|
+
try:
|
100
|
+
import torch
|
101
|
+
except ImportError:
|
102
|
+
raise ImportError("PyTorch is not installed. Please install it to use this function.")
|
103
|
+
|
104
|
+
input_batch: list[torch.Tensor] = []
|
105
|
+
target_batch: list[T_tgt] = []
|
106
|
+
metadata_batch: list[T_md] = []
|
107
|
+
for input_datum, target_datum, metadata_datum in batch_data_as_singles:
|
108
|
+
input_batch.append(torch.as_tensor(input_datum))
|
109
|
+
target_batch.append(target_datum)
|
110
|
+
metadata_batch.append(metadata_datum)
|
111
|
+
|
112
|
+
return torch.stack(input_batch) if input_batch else torch.tensor([]), target_batch, metadata_batch
|
maite_datasets/_protocols.py
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
"""
|
2
|
-
Common type protocols used for interoperability
|
2
|
+
Common type protocols used for interoperability.
|
3
3
|
"""
|
4
4
|
|
5
|
+
from collections.abc import Iterator
|
5
6
|
import sys
|
6
7
|
from typing import (
|
7
8
|
Any,
|
8
9
|
Generic,
|
9
|
-
Iterator,
|
10
|
-
Mapping,
|
11
10
|
Protocol,
|
11
|
+
TypeAlias,
|
12
12
|
TypedDict,
|
13
13
|
TypeVar,
|
14
14
|
runtime_checkable,
|
@@ -36,29 +36,10 @@ See Also
|
|
36
36
|
@runtime_checkable
|
37
37
|
class Array(Protocol):
|
38
38
|
"""
|
39
|
-
Protocol for array objects
|
39
|
+
Protocol for interoperable array objects.
|
40
40
|
|
41
41
|
Supports common array representations with popular libraries like
|
42
42
|
PyTorch, Tensorflow and JAX, as well as NumPy arrays.
|
43
|
-
|
44
|
-
Example
|
45
|
-
-------
|
46
|
-
>>> import numpy as np
|
47
|
-
>>> import torch
|
48
|
-
>>> from maite_datasets._typing import Array
|
49
|
-
|
50
|
-
Create array objects
|
51
|
-
|
52
|
-
>>> ndarray = np.random.random((10, 10))
|
53
|
-
>>> tensor = torch.tensor([1, 2, 3])
|
54
|
-
|
55
|
-
Check type at runtime
|
56
|
-
|
57
|
-
>>> isinstance(ndarray, Array)
|
58
|
-
True
|
59
|
-
|
60
|
-
>>> isinstance(tensor, Array)
|
61
|
-
True
|
62
43
|
"""
|
63
44
|
|
64
45
|
@property
|
@@ -71,6 +52,7 @@ class Array(Protocol):
|
|
71
52
|
|
72
53
|
_T = TypeVar("_T")
|
73
54
|
_T_co = TypeVar("_T_co", covariant=True)
|
55
|
+
_T_cn = TypeVar("_T_cn", contravariant=True)
|
74
56
|
|
75
57
|
|
76
58
|
class DatasetMetadata(TypedDict, total=False):
|
@@ -89,6 +71,19 @@ class DatasetMetadata(TypedDict, total=False):
|
|
89
71
|
index2label: NotRequired[ReadOnly[dict[int, str]]]
|
90
72
|
|
91
73
|
|
74
|
+
class DatumMetadata(TypedDict, total=False):
|
75
|
+
"""
|
76
|
+
Datum level metadata required for all `AnnotatedDataset` classes.
|
77
|
+
|
78
|
+
Attributes
|
79
|
+
----------
|
80
|
+
id : Required[str]
|
81
|
+
A unique identifier for the datum
|
82
|
+
"""
|
83
|
+
|
84
|
+
id: Required[ReadOnly[str]]
|
85
|
+
|
86
|
+
|
92
87
|
@runtime_checkable
|
93
88
|
class Dataset(Generic[_T_co], Protocol):
|
94
89
|
"""
|
@@ -134,7 +129,7 @@ class AnnotatedDataset(Dataset[_T_co], Generic[_T_co], Protocol):
|
|
134
129
|
# ========== IMAGE CLASSIFICATION DATASETS ==========
|
135
130
|
|
136
131
|
|
137
|
-
ImageClassificationDatum: TypeAlias = tuple[ArrayLike, ArrayLike,
|
132
|
+
ImageClassificationDatum: TypeAlias = tuple[ArrayLike, ArrayLike, DatumMetadata]
|
138
133
|
"""
|
139
134
|
Type alias for an image classification datum tuple.
|
140
135
|
|
@@ -174,7 +169,7 @@ class ObjectDetectionTarget(Protocol):
|
|
174
169
|
def scores(self) -> ArrayLike: ...
|
175
170
|
|
176
171
|
|
177
|
-
ObjectDetectionDatum: TypeAlias = tuple[ArrayLike, ObjectDetectionTarget,
|
172
|
+
ObjectDetectionDatum: TypeAlias = tuple[ArrayLike, ObjectDetectionTarget, DatumMetadata]
|
178
173
|
"""
|
179
174
|
Type alias for an object detection datum tuple.
|
180
175
|
|
@@ -0,0 +1,135 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from abc import ABC, abstractmethod
|
4
|
+
import logging
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import Any
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
|
10
|
+
from maite_datasets._protocols import ArrayLike, ObjectDetectionDataset
|
11
|
+
|
12
|
+
_logger = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
|
15
|
+
class _ObjectDetectionTarget:
|
16
|
+
"""Internal implementation of ObjectDetectionTarget protocol."""
|
17
|
+
|
18
|
+
def __init__(self, boxes: ArrayLike, labels: ArrayLike, scores: ArrayLike) -> None:
|
19
|
+
self._boxes = np.asarray(boxes)
|
20
|
+
self._labels = np.asarray(labels)
|
21
|
+
self._scores = np.asarray(scores)
|
22
|
+
|
23
|
+
@property
|
24
|
+
def boxes(self) -> ArrayLike:
|
25
|
+
return self._boxes
|
26
|
+
|
27
|
+
@property
|
28
|
+
def labels(self) -> ArrayLike:
|
29
|
+
return self._labels
|
30
|
+
|
31
|
+
@property
|
32
|
+
def scores(self) -> ArrayLike:
|
33
|
+
return self._scores
|
34
|
+
|
35
|
+
|
36
|
+
class BaseDatasetReader(ABC):
|
37
|
+
"""
|
38
|
+
Abstract base class for object detection dataset readers.
|
39
|
+
|
40
|
+
Provides common functionality for dataset path handling, validation,
|
41
|
+
and dataset creation while allowing format-specific implementations.
|
42
|
+
|
43
|
+
Parameters
|
44
|
+
----------
|
45
|
+
dataset_path : str or Path
|
46
|
+
Root directory containing dataset files
|
47
|
+
dataset_id : str or None, default None
|
48
|
+
Dataset identifier. If None, uses dataset_path name
|
49
|
+
"""
|
50
|
+
|
51
|
+
def __init__(self, dataset_path: str | Path, dataset_id: str | None = None) -> None:
|
52
|
+
self.dataset_path = Path(dataset_path)
|
53
|
+
self._dataset_id = dataset_id or self.dataset_path.name
|
54
|
+
|
55
|
+
# Basic path validation
|
56
|
+
if not self.dataset_path.exists():
|
57
|
+
raise FileNotFoundError(f"Dataset path not found: {self.dataset_path}")
|
58
|
+
|
59
|
+
# Format-specific initialization
|
60
|
+
self._initialize_format_specific()
|
61
|
+
|
62
|
+
@abstractmethod
|
63
|
+
def _initialize_format_specific(self) -> None:
|
64
|
+
"""Initialize format-specific components (annotations, classes, etc.)."""
|
65
|
+
pass
|
66
|
+
|
67
|
+
@abstractmethod
|
68
|
+
def _create_dataset_implementation(self) -> ObjectDetectionDataset:
|
69
|
+
"""Create the format-specific dataset implementation."""
|
70
|
+
pass
|
71
|
+
|
72
|
+
@abstractmethod
|
73
|
+
def _validate_format_specific(self) -> tuple[list[str], dict[str, Any]]:
|
74
|
+
"""Validate format-specific structure and return issues and stats."""
|
75
|
+
pass
|
76
|
+
|
77
|
+
@property
|
78
|
+
@abstractmethod
|
79
|
+
def index2label(self) -> dict[int, str]:
|
80
|
+
"""Mapping from class index to class name."""
|
81
|
+
pass
|
82
|
+
|
83
|
+
def _validate_images_directory(self) -> tuple[list[str], dict[str, Any]]:
|
84
|
+
"""Validate images directory and return issues and stats."""
|
85
|
+
issues = []
|
86
|
+
stats = {}
|
87
|
+
|
88
|
+
images_path = self.dataset_path / "images"
|
89
|
+
if not images_path.exists():
|
90
|
+
issues.append("Missing images/ directory")
|
91
|
+
return issues, stats
|
92
|
+
|
93
|
+
image_files = []
|
94
|
+
for ext in [".jpg", ".jpeg", ".png", ".bmp"]:
|
95
|
+
image_files.extend(images_path.glob(f"*{ext}"))
|
96
|
+
image_files.extend(images_path.glob(f"*{ext.upper()}"))
|
97
|
+
|
98
|
+
stats["num_images"] = len(image_files)
|
99
|
+
if len(image_files) == 0:
|
100
|
+
issues.append("No image files found in images/ directory")
|
101
|
+
|
102
|
+
return issues, stats
|
103
|
+
|
104
|
+
def validate_structure(self) -> dict[str, Any]:
|
105
|
+
"""
|
106
|
+
Validate dataset directory structure and return diagnostic information.
|
107
|
+
|
108
|
+
Returns
|
109
|
+
-------
|
110
|
+
dict[str, Any]
|
111
|
+
Validation results containing:
|
112
|
+
- is_valid: bool indicating if structure is valid
|
113
|
+
- issues: list of validation issues found
|
114
|
+
- stats: dict with dataset statistics
|
115
|
+
"""
|
116
|
+
# Validate images directory (common to all formats)
|
117
|
+
issues, stats = self._validate_images_directory()
|
118
|
+
|
119
|
+
# Format-specific validation
|
120
|
+
format_issues, format_stats = self._validate_format_specific()
|
121
|
+
issues.extend(format_issues)
|
122
|
+
stats.update(format_stats)
|
123
|
+
|
124
|
+
return {"is_valid": len(issues) == 0, "issues": issues, "stats": stats}
|
125
|
+
|
126
|
+
def get_dataset(self) -> ObjectDetectionDataset:
|
127
|
+
"""
|
128
|
+
Get dataset conforming to MAITE ObjectDetectionDataset protocol.
|
129
|
+
|
130
|
+
Returns
|
131
|
+
-------
|
132
|
+
ObjectDetectionDataset
|
133
|
+
Dataset instance with MAITE-compatible interface
|
134
|
+
"""
|
135
|
+
return self._create_dataset_implementation()
|
@@ -0,0 +1,287 @@
|
|
1
|
+
"""Dataset reader for COCO detection format."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
import json
|
6
|
+
import logging
|
7
|
+
from pathlib import Path
|
8
|
+
from typing import Any
|
9
|
+
|
10
|
+
import numpy as np
|
11
|
+
from PIL import Image
|
12
|
+
|
13
|
+
from maite_datasets._protocols import DatasetMetadata, DatumMetadata, ObjectDetectionDataset, ObjectDetectionDatum
|
14
|
+
from maite_datasets._reader._base import _ObjectDetectionTarget, BaseDatasetReader
|
15
|
+
|
16
|
+
_logger = logging.getLogger(__name__)
|
17
|
+
|
18
|
+
|
19
|
+
class COCODatasetReader(BaseDatasetReader):
|
20
|
+
"""
|
21
|
+
COCO format dataset reader conforming to MAITE protocols.
|
22
|
+
|
23
|
+
Reads COCO format object detection datasets from disk and provides
|
24
|
+
MAITE-compatible interface.
|
25
|
+
|
26
|
+
Directory Structure Requirements
|
27
|
+
--------------------------------
|
28
|
+
```
|
29
|
+
dataset_root/
|
30
|
+
├── images/
|
31
|
+
│ ├── image1.jpg
|
32
|
+
│ ├── image2.jpg
|
33
|
+
│ └── ...
|
34
|
+
├── annotations.json # COCO format annotation file
|
35
|
+
└── classes.txt # Optional: one class name per line
|
36
|
+
```
|
37
|
+
|
38
|
+
COCO Format Specifications
|
39
|
+
--------------------------
|
40
|
+
annotations.json structure:
|
41
|
+
```json
|
42
|
+
{
|
43
|
+
"images": [
|
44
|
+
{
|
45
|
+
"id": 1,
|
46
|
+
"file_name": "image1.jpg",
|
47
|
+
"width": 640,
|
48
|
+
"height": 480
|
49
|
+
}
|
50
|
+
],
|
51
|
+
"annotations": [
|
52
|
+
{
|
53
|
+
"id": 1,
|
54
|
+
"image_id": 1,
|
55
|
+
"category_id": 1,
|
56
|
+
"bbox": [100, 50, 200, 150], // [x, y, width, height]
|
57
|
+
"area": 30000
|
58
|
+
}
|
59
|
+
],
|
60
|
+
"categories": [
|
61
|
+
{
|
62
|
+
"id": 1,
|
63
|
+
"name": "person"
|
64
|
+
}
|
65
|
+
]
|
66
|
+
}
|
67
|
+
```
|
68
|
+
|
69
|
+
classes.txt format (optional, one class per line, ordered by index):
|
70
|
+
```
|
71
|
+
person
|
72
|
+
bicycle
|
73
|
+
car
|
74
|
+
motorcycle
|
75
|
+
```
|
76
|
+
|
77
|
+
Parameters
|
78
|
+
----------
|
79
|
+
dataset_path : str or Path
|
80
|
+
Root directory containing COCO dataset files
|
81
|
+
annotation_file : str, default "annotations.json"
|
82
|
+
Name of COCO annotation JSON file
|
83
|
+
images_dir : str, default "images"
|
84
|
+
Name of directory containing images
|
85
|
+
classes_file : str or None, default "classes.txt"
|
86
|
+
Optional file containing class names (one per line)
|
87
|
+
If None, uses category names from COCO annotations
|
88
|
+
dataset_id : str or None, default None
|
89
|
+
Dataset identifier. If None, uses dataset_path name
|
90
|
+
|
91
|
+
Notes
|
92
|
+
-----
|
93
|
+
COCO annotations should follow standard COCO format with:
|
94
|
+
- "images": list of image metadata
|
95
|
+
- "annotations": list of bounding box annotations
|
96
|
+
- "categories": list of category definitions
|
97
|
+
|
98
|
+
Bounding boxes are converted from COCO format (x, y, width, height)
|
99
|
+
to MAITE format (x1, y1, x2, y2).
|
100
|
+
"""
|
101
|
+
|
102
|
+
def __init__(
|
103
|
+
self,
|
104
|
+
dataset_path: str | Path,
|
105
|
+
annotation_file: str = "annotations.json",
|
106
|
+
images_dir: str = "images",
|
107
|
+
classes_file: str | None = "classes.txt",
|
108
|
+
dataset_id: str | None = None,
|
109
|
+
) -> None:
|
110
|
+
self.annotation_file = annotation_file
|
111
|
+
self.images_dir = images_dir
|
112
|
+
self.classes_file = classes_file
|
113
|
+
|
114
|
+
# Initialize base class
|
115
|
+
super().__init__(dataset_path, dataset_id)
|
116
|
+
|
117
|
+
def _initialize_format_specific(self) -> None:
|
118
|
+
"""Initialize COCO-specific components."""
|
119
|
+
self.images_path = self.dataset_path / self.images_dir
|
120
|
+
self.annotation_path = self.dataset_path / self.annotation_file
|
121
|
+
self.classes_path = self.dataset_path / self.classes_file if self.classes_file else None
|
122
|
+
|
123
|
+
if not self.annotation_path.exists():
|
124
|
+
raise FileNotFoundError(f"Annotation file not found: {self.annotation_path}")
|
125
|
+
if not self.images_path.exists():
|
126
|
+
raise FileNotFoundError(f"Images directory not found: {self.images_path}")
|
127
|
+
|
128
|
+
self._load_annotations()
|
129
|
+
|
130
|
+
@property
|
131
|
+
def index2label(self) -> dict[int, str]:
|
132
|
+
"""Mapping from class index to class name."""
|
133
|
+
return self._index2label
|
134
|
+
|
135
|
+
def _create_dataset_implementation(self) -> ObjectDetectionDataset:
|
136
|
+
"""Create COCO dataset implementation."""
|
137
|
+
return _COCODataset(self)
|
138
|
+
|
139
|
+
def _validate_format_specific(self) -> tuple[list[str], dict[str, Any]]:
|
140
|
+
"""Validate COCO format specific files and structure."""
|
141
|
+
issues = []
|
142
|
+
stats = {}
|
143
|
+
|
144
|
+
annotation_path = self.dataset_path / self.annotation_file
|
145
|
+
if not annotation_path.exists():
|
146
|
+
issues.append(f"Missing {self.annotation_file} file")
|
147
|
+
return issues, stats
|
148
|
+
|
149
|
+
try:
|
150
|
+
with open(annotation_path) as f:
|
151
|
+
coco_data = json.load(f)
|
152
|
+
except json.JSONDecodeError as e:
|
153
|
+
issues.append(f"Invalid JSON in {self.annotation_file}: {e}")
|
154
|
+
return issues, stats
|
155
|
+
|
156
|
+
# Check required keys
|
157
|
+
required_keys = ["images", "annotations", "categories"]
|
158
|
+
for key in required_keys:
|
159
|
+
if key not in coco_data:
|
160
|
+
issues.append(f"Missing required key '{key}' in {self.annotation_file}")
|
161
|
+
else:
|
162
|
+
stats[f"num_{key}"] = len(coco_data[key])
|
163
|
+
|
164
|
+
# Check optional classes.txt
|
165
|
+
if self.classes_file:
|
166
|
+
classes_path = self.dataset_path / self.classes_file
|
167
|
+
if classes_path.exists():
|
168
|
+
try:
|
169
|
+
with open(classes_path) as f:
|
170
|
+
class_lines = [line.strip() for line in f if line.strip()]
|
171
|
+
stats["num_class_names"] = len(class_lines)
|
172
|
+
except Exception as e:
|
173
|
+
issues.append(f"Error reading {self.classes_file}: {e}")
|
174
|
+
|
175
|
+
return issues, stats
|
176
|
+
|
177
|
+
def _load_annotations(self) -> None:
|
178
|
+
"""Load and parse COCO annotations."""
|
179
|
+
with open(self.annotation_path) as f:
|
180
|
+
self.coco_data = json.load(f)
|
181
|
+
|
182
|
+
# Build mappings
|
183
|
+
self.image_id_to_info = {img["id"]: img for img in self.coco_data["images"]}
|
184
|
+
self.category_id_to_idx = {cat["id"]: idx for idx, cat in enumerate(self.coco_data["categories"])}
|
185
|
+
|
186
|
+
# Group annotations by image
|
187
|
+
self.image_id_to_annotations: dict[int, list[dict[str, Any]]] = {}
|
188
|
+
for ann in self.coco_data["annotations"]:
|
189
|
+
img_id = ann["image_id"]
|
190
|
+
if img_id not in self.image_id_to_annotations:
|
191
|
+
self.image_id_to_annotations[img_id] = []
|
192
|
+
self.image_id_to_annotations[img_id].append(ann)
|
193
|
+
|
194
|
+
# Load class names
|
195
|
+
if self.classes_path and self.classes_path.exists():
|
196
|
+
with open(self.classes_path) as f:
|
197
|
+
class_names = [line.strip() for line in f if line.strip()]
|
198
|
+
else:
|
199
|
+
class_names = [cat["name"] for cat in self.coco_data["categories"]]
|
200
|
+
|
201
|
+
self._index2label = {idx: name for idx, name in enumerate(class_names)}
|
202
|
+
|
203
|
+
|
204
|
+
class _COCODataset:
|
205
|
+
"""Internal COCO dataset implementation."""
|
206
|
+
|
207
|
+
def __init__(self, reader: COCODatasetReader) -> None:
|
208
|
+
self.reader = reader
|
209
|
+
self.image_ids = list(reader.image_id_to_info.keys())
|
210
|
+
|
211
|
+
@property
|
212
|
+
def metadata(self) -> DatasetMetadata:
|
213
|
+
return DatasetMetadata(
|
214
|
+
id=self.reader._dataset_id,
|
215
|
+
index2label=self.reader.index2label,
|
216
|
+
)
|
217
|
+
|
218
|
+
def __len__(self) -> int:
|
219
|
+
return len(self.image_ids)
|
220
|
+
|
221
|
+
def __getitem__(self, index: int) -> ObjectDetectionDatum:
|
222
|
+
image_id = self.image_ids[index]
|
223
|
+
image_info = self.reader.image_id_to_info[image_id]
|
224
|
+
|
225
|
+
# Load image
|
226
|
+
image_path = self.reader.images_path / image_info["file_name"]
|
227
|
+
image = np.array(Image.open(image_path).convert("RGB"))
|
228
|
+
image = np.transpose(image, (2, 0, 1)) # Convert to CHW format
|
229
|
+
|
230
|
+
# Get annotations for this image
|
231
|
+
annotations = self.reader.image_id_to_annotations.get(image_id, [])
|
232
|
+
|
233
|
+
if annotations:
|
234
|
+
boxes = []
|
235
|
+
labels = []
|
236
|
+
annotation_metadata = []
|
237
|
+
|
238
|
+
for ann in annotations:
|
239
|
+
# Convert COCO bbox (x, y, w, h) to (x1, y1, x2, y2)
|
240
|
+
x, y, w, h = ann["bbox"]
|
241
|
+
boxes.append([x, y, x + w, y + h])
|
242
|
+
|
243
|
+
# Map category_id to class index
|
244
|
+
cat_idx = self.reader.category_id_to_idx[ann["category_id"]]
|
245
|
+
labels.append(cat_idx)
|
246
|
+
|
247
|
+
# Collect annotation metadata
|
248
|
+
ann_meta = {
|
249
|
+
"annotation_id": ann["id"],
|
250
|
+
"category_id": ann["category_id"],
|
251
|
+
"area": ann.get("area", 0),
|
252
|
+
"iscrowd": ann.get("iscrowd", 0),
|
253
|
+
}
|
254
|
+
# Add any additional fields from annotation
|
255
|
+
for key, value in ann.items():
|
256
|
+
if key not in ["id", "image_id", "category_id", "bbox", "area", "iscrowd"]:
|
257
|
+
ann_meta[f"ann_{key}"] = value
|
258
|
+
annotation_metadata.append(ann_meta)
|
259
|
+
|
260
|
+
boxes = np.array(boxes, dtype=np.float32)
|
261
|
+
labels = np.array(labels, dtype=np.int64)
|
262
|
+
scores = np.ones(len(labels), dtype=np.float32) # Ground truth scores
|
263
|
+
else:
|
264
|
+
# Empty annotations
|
265
|
+
boxes = np.empty((0, 4), dtype=np.float32)
|
266
|
+
labels = np.empty(0, dtype=np.int64)
|
267
|
+
scores = np.empty(0, dtype=np.float32)
|
268
|
+
annotation_metadata = []
|
269
|
+
|
270
|
+
target = _ObjectDetectionTarget(boxes, labels, scores)
|
271
|
+
|
272
|
+
# Create comprehensive datum metadata
|
273
|
+
datum_metadata = DatumMetadata(
|
274
|
+
id=f"{self.reader._dataset_id}_{image_id}",
|
275
|
+
# Image-level metadata
|
276
|
+
coco_image_id=image_id,
|
277
|
+
file_name=image_info["file_name"],
|
278
|
+
width=image_info["width"],
|
279
|
+
height=image_info["height"],
|
280
|
+
# Optional COCO image fields
|
281
|
+
**{key: value for key, value in image_info.items() if key not in ["id", "file_name", "width", "height"]},
|
282
|
+
# Annotation metadata
|
283
|
+
annotations=annotation_metadata,
|
284
|
+
num_annotations=len(annotations),
|
285
|
+
)
|
286
|
+
|
287
|
+
return image, target, datum_metadata
|
@@ -0,0 +1,64 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from pathlib import Path
|
5
|
+
|
6
|
+
from maite_datasets._reader._base import BaseDatasetReader
|
7
|
+
from maite_datasets._reader._yolo import YOLODatasetReader
|
8
|
+
from maite_datasets._reader._coco import COCODatasetReader
|
9
|
+
|
10
|
+
_logger = logging.getLogger(__name__)
|
11
|
+
|
12
|
+
|
13
|
+
def create_dataset_reader(dataset_path: str | Path, format_hint: str | None = None) -> BaseDatasetReader:
|
14
|
+
"""
|
15
|
+
Factory function to create appropriate dataset reader based on directory structure.
|
16
|
+
|
17
|
+
Parameters
|
18
|
+
----------
|
19
|
+
dataset_path : str or Path
|
20
|
+
Root directory containing dataset files
|
21
|
+
format_hint : str or None, default None
|
22
|
+
Format hint ("coco" or "yolo"). If None, auto-detects based on file structure
|
23
|
+
|
24
|
+
Returns
|
25
|
+
-------
|
26
|
+
BaseDatasetReader
|
27
|
+
Appropriate reader instance for the detected format
|
28
|
+
|
29
|
+
Raises
|
30
|
+
------
|
31
|
+
ValueError
|
32
|
+
If format cannot be determined or is unsupported
|
33
|
+
"""
|
34
|
+
dataset_path = Path(dataset_path)
|
35
|
+
|
36
|
+
if format_hint:
|
37
|
+
format_hint = format_hint.lower()
|
38
|
+
if format_hint == "coco":
|
39
|
+
return COCODatasetReader(dataset_path)
|
40
|
+
elif format_hint == "yolo":
|
41
|
+
return YOLODatasetReader(dataset_path)
|
42
|
+
else:
|
43
|
+
raise ValueError(f"Unsupported format hint: {format_hint}")
|
44
|
+
|
45
|
+
# Auto-detect format
|
46
|
+
has_annotations_json = (dataset_path / "annotations.json").exists()
|
47
|
+
has_labels_dir = (dataset_path / "labels").exists()
|
48
|
+
|
49
|
+
if has_annotations_json and not has_labels_dir:
|
50
|
+
_logger.info(f"Detected COCO format for {dataset_path}")
|
51
|
+
return COCODatasetReader(dataset_path)
|
52
|
+
elif has_labels_dir and not has_annotations_json:
|
53
|
+
_logger.info(f"Detected YOLO format for {dataset_path}")
|
54
|
+
return YOLODatasetReader(dataset_path)
|
55
|
+
elif has_annotations_json and has_labels_dir:
|
56
|
+
raise ValueError(
|
57
|
+
f"Ambiguous format in {dataset_path}: both annotations.json and labels/ exist. "
|
58
|
+
"Use format_hint parameter to specify format."
|
59
|
+
)
|
60
|
+
else:
|
61
|
+
raise ValueError(
|
62
|
+
f"Cannot detect dataset format in {dataset_path}. "
|
63
|
+
"Expected either annotations.json (COCO) or labels/ directory (YOLO)."
|
64
|
+
)
|
@@ -0,0 +1,312 @@
|
|
1
|
+
"""Dataset reader for YOLO detection format."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
__all__ = []
|
6
|
+
|
7
|
+
from pathlib import Path
|
8
|
+
from typing import Any
|
9
|
+
|
10
|
+
import numpy as np
|
11
|
+
from PIL import Image
|
12
|
+
|
13
|
+
from maite_datasets._protocols import DatasetMetadata, DatumMetadata, ObjectDetectionDataset, ObjectDetectionDatum
|
14
|
+
from maite_datasets._reader._base import _ObjectDetectionTarget, BaseDatasetReader
|
15
|
+
|
16
|
+
|
17
|
+
class YOLODatasetReader(BaseDatasetReader):
|
18
|
+
"""
|
19
|
+
YOLO format dataset reader conforming to MAITE protocols.
|
20
|
+
|
21
|
+
Reads YOLO format object detection datasets from disk and provides
|
22
|
+
MAITE-compatible interface.
|
23
|
+
|
24
|
+
Directory Structure Requirements
|
25
|
+
--------------------------------
|
26
|
+
```
|
27
|
+
dataset_root/
|
28
|
+
├── images/
|
29
|
+
│ ├── image1.jpg
|
30
|
+
│ ├── image2.jpg
|
31
|
+
│ └── ...
|
32
|
+
├── labels/
|
33
|
+
│ ├── image1.txt # YOLO format annotations
|
34
|
+
│ ├── image2.txt
|
35
|
+
│ └── ...
|
36
|
+
├── classes.txt # Required: one class name per line
|
37
|
+
└── data.yaml # Optional: dataset metadata
|
38
|
+
```
|
39
|
+
|
40
|
+
YOLO Format Specifications
|
41
|
+
--------------------------
|
42
|
+
Label file format (one line per object):
|
43
|
+
```
|
44
|
+
class_id center_x center_y width height
|
45
|
+
0 0.5 0.3 0.2 0.4
|
46
|
+
1 0.7 0.8 0.1 0.2
|
47
|
+
```
|
48
|
+
All YOLO coordinates are normalized to [0, 1] relative to image dimensions.
|
49
|
+
|
50
|
+
classes.txt format (required, one class per line, ordered by index):
|
51
|
+
```
|
52
|
+
person
|
53
|
+
bicycle
|
54
|
+
car
|
55
|
+
motorcycle
|
56
|
+
```
|
57
|
+
|
58
|
+
Parameters
|
59
|
+
----------
|
60
|
+
dataset_path : str or Path
|
61
|
+
Root directory containing YOLO dataset files
|
62
|
+
images_dir : str, default "images"
|
63
|
+
Name of directory containing images
|
64
|
+
labels_dir : str, default "labels"
|
65
|
+
Name of directory containing YOLO label files
|
66
|
+
classes_file : str, default "classes.txt"
|
67
|
+
File containing class names (one per line)
|
68
|
+
dataset_id : str or None, default None
|
69
|
+
Dataset identifier. If None, uses dataset_path name
|
70
|
+
image_extensions : list[str], default [".jpg", ".jpeg", ".png", ".bmp"]
|
71
|
+
Supported image file extensions
|
72
|
+
|
73
|
+
Notes
|
74
|
+
-----
|
75
|
+
YOLO label files should contain one line per object:
|
76
|
+
`class_id center_x center_y width height`
|
77
|
+
|
78
|
+
All coordinates should be normalized to [0, 1] relative to image dimensions.
|
79
|
+
Coordinates are converted to absolute pixel values and MAITE format (x1, y1, x2, y2).
|
80
|
+
"""
|
81
|
+
|
82
|
+
def __init__(
|
83
|
+
self,
|
84
|
+
dataset_path: str | Path,
|
85
|
+
images_dir: str = "images",
|
86
|
+
labels_dir: str = "labels",
|
87
|
+
classes_file: str = "classes.txt",
|
88
|
+
dataset_id: str | None = None,
|
89
|
+
image_extensions: list[str] | None = None,
|
90
|
+
) -> None:
|
91
|
+
self.images_dir = images_dir
|
92
|
+
self.labels_dir = labels_dir
|
93
|
+
self.classes_file = classes_file
|
94
|
+
|
95
|
+
if image_extensions is None:
|
96
|
+
image_extensions = [".jpg", ".jpeg", ".png", ".bmp"]
|
97
|
+
self.image_extensions = [ext.lower() for ext in image_extensions]
|
98
|
+
|
99
|
+
# Initialize base class
|
100
|
+
super().__init__(dataset_path, dataset_id)
|
101
|
+
|
102
|
+
def _initialize_format_specific(self) -> None:
|
103
|
+
"""Initialize YOLO-specific components."""
|
104
|
+
self.images_path = self.dataset_path / self.images_dir
|
105
|
+
self.labels_path = self.dataset_path / self.labels_dir
|
106
|
+
self.classes_path = self.dataset_path / self.classes_file
|
107
|
+
|
108
|
+
if not self.images_path.exists():
|
109
|
+
raise FileNotFoundError(f"Images directory not found: {self.images_path}")
|
110
|
+
if not self.labels_path.exists():
|
111
|
+
raise FileNotFoundError(f"Labels directory not found: {self.labels_path}")
|
112
|
+
if not self.classes_path.exists():
|
113
|
+
raise FileNotFoundError(f"Classes file not found: {self.classes_path}")
|
114
|
+
|
115
|
+
self._load_class_names()
|
116
|
+
self._find_image_files()
|
117
|
+
|
118
|
+
@property
|
119
|
+
def index2label(self) -> dict[int, str]:
|
120
|
+
"""Mapping from class index to class name."""
|
121
|
+
return self._index2label
|
122
|
+
|
123
|
+
def _create_dataset_implementation(self) -> ObjectDetectionDataset:
|
124
|
+
"""Create YOLO dataset implementation."""
|
125
|
+
return _YOLODataset(self)
|
126
|
+
|
127
|
+
def _validate_format_specific(self) -> tuple[list[str], dict[str, Any]]:
|
128
|
+
"""Validate YOLO format specific files and structure."""
|
129
|
+
issues = []
|
130
|
+
stats = {}
|
131
|
+
|
132
|
+
# Check labels directory
|
133
|
+
labels_path = self.dataset_path / self.labels_dir
|
134
|
+
if not labels_path.exists():
|
135
|
+
issues.append(f"Missing {self.labels_dir}/ directory")
|
136
|
+
else:
|
137
|
+
label_files = list(labels_path.glob("*.txt"))
|
138
|
+
stats["num_label_files"] = len(label_files)
|
139
|
+
if len(label_files) == 0:
|
140
|
+
issues.append(f"No label files found in {self.labels_dir}/ directory")
|
141
|
+
else:
|
142
|
+
# Validate label file format (sample check)
|
143
|
+
label_issues = self._validate_yolo_label_format(labels_path)
|
144
|
+
issues.extend(label_issues)
|
145
|
+
|
146
|
+
# Check required classes.txt
|
147
|
+
classes_path = self.dataset_path / self.classes_file
|
148
|
+
if not classes_path.exists():
|
149
|
+
issues.append(f"Missing required {self.classes_file} file")
|
150
|
+
else:
|
151
|
+
try:
|
152
|
+
with open(classes_path) as f:
|
153
|
+
class_lines = [line.strip() for line in f if line.strip()]
|
154
|
+
stats["num_classes"] = len(class_lines)
|
155
|
+
if len(class_lines) == 0:
|
156
|
+
issues.append(f"{self.classes_file} is empty")
|
157
|
+
except Exception as e:
|
158
|
+
issues.append(f"Error reading {self.classes_file}: {e}")
|
159
|
+
|
160
|
+
return issues, stats
|
161
|
+
|
162
|
+
def _validate_yolo_label_format(self, labels_path: Path) -> list[str]:
|
163
|
+
"""Validate YOLO label file format (sample check)."""
|
164
|
+
issues = []
|
165
|
+
label_files = list(labels_path.glob("*.txt"))
|
166
|
+
|
167
|
+
if not label_files:
|
168
|
+
return issues
|
169
|
+
|
170
|
+
sample_label = label_files[0]
|
171
|
+
try:
|
172
|
+
with open(sample_label) as f:
|
173
|
+
for line_num, line in enumerate(f, 1):
|
174
|
+
if not line.strip():
|
175
|
+
continue
|
176
|
+
|
177
|
+
parts = line.strip().split()
|
178
|
+
if len(parts) != 5:
|
179
|
+
issues.append(
|
180
|
+
f"Invalid YOLO format in {sample_label.name} line {line_num}: "
|
181
|
+
f"expected 5 values, got {len(parts)}"
|
182
|
+
)
|
183
|
+
break
|
184
|
+
|
185
|
+
try:
|
186
|
+
coords = [float(x) for x in parts[1:]]
|
187
|
+
if not all(0 <= coord <= 1 for coord in coords):
|
188
|
+
issues.append(f"Coordinates out of range [0,1] in {sample_label.name} line {line_num}")
|
189
|
+
break
|
190
|
+
except ValueError:
|
191
|
+
issues.append(f"Invalid numeric values in {sample_label.name} line {line_num}")
|
192
|
+
break
|
193
|
+
except Exception as e:
|
194
|
+
issues.append(f"Error validating label file {sample_label.name}: {e}")
|
195
|
+
|
196
|
+
return issues
|
197
|
+
|
198
|
+
def _load_class_names(self) -> None:
|
199
|
+
"""Load class names from classes file."""
|
200
|
+
with open(self.classes_path) as f:
|
201
|
+
class_names = [line.strip() for line in f if line.strip()]
|
202
|
+
self._index2label = {idx: name for idx, name in enumerate(class_names)}
|
203
|
+
|
204
|
+
def _find_image_files(self) -> None:
|
205
|
+
"""Find all valid image files."""
|
206
|
+
self.image_files = []
|
207
|
+
for ext in self.image_extensions:
|
208
|
+
self.image_files.extend(self.images_path.glob(f"*{ext}"))
|
209
|
+
self.image_files.sort()
|
210
|
+
|
211
|
+
if not self.image_files:
|
212
|
+
raise ValueError(f"No image files found in {self.images_path}")
|
213
|
+
|
214
|
+
|
215
|
+
class _YOLODataset:
|
216
|
+
"""Internal YOLO dataset implementation."""
|
217
|
+
|
218
|
+
def __init__(self, reader: YOLODatasetReader) -> None:
|
219
|
+
self.reader = reader
|
220
|
+
|
221
|
+
@property
|
222
|
+
def metadata(self) -> DatasetMetadata:
|
223
|
+
return DatasetMetadata(
|
224
|
+
id=self.reader._dataset_id,
|
225
|
+
index2label=self.reader.index2label,
|
226
|
+
)
|
227
|
+
|
228
|
+
def __len__(self) -> int:
|
229
|
+
return len(self.reader.image_files)
|
230
|
+
|
231
|
+
def __getitem__(self, index: int) -> ObjectDetectionDatum:
|
232
|
+
image_path = self.reader.image_files[index]
|
233
|
+
|
234
|
+
# Load image
|
235
|
+
image = np.array(Image.open(image_path).convert("RGB"))
|
236
|
+
img_height, img_width = image.shape[:2]
|
237
|
+
image = np.transpose(image, (2, 0, 1)) # Convert to CHW format
|
238
|
+
|
239
|
+
# Load corresponding label file
|
240
|
+
label_path = self.reader.labels_path / f"{image_path.stem}.txt"
|
241
|
+
|
242
|
+
annotation_metadata = []
|
243
|
+
if label_path.exists():
|
244
|
+
boxes = []
|
245
|
+
labels = []
|
246
|
+
|
247
|
+
with open(label_path) as f:
|
248
|
+
for line_num, line in enumerate(f):
|
249
|
+
if not line.strip():
|
250
|
+
continue
|
251
|
+
|
252
|
+
parts = line.strip().split()
|
253
|
+
if len(parts) != 5:
|
254
|
+
continue
|
255
|
+
|
256
|
+
class_id = int(parts[0])
|
257
|
+
center_x, center_y, width, height = map(float, parts[1:])
|
258
|
+
|
259
|
+
# Convert normalized YOLO format to absolute pixel coordinates
|
260
|
+
x1 = (center_x - width / 2) * img_width
|
261
|
+
y1 = (center_y - height / 2) * img_height
|
262
|
+
x2 = (center_x + width / 2) * img_width
|
263
|
+
y2 = (center_y + height / 2) * img_height
|
264
|
+
|
265
|
+
boxes.append([x1, y1, x2, y2])
|
266
|
+
labels.append(class_id)
|
267
|
+
|
268
|
+
# Store original YOLO format coordinates in metadata
|
269
|
+
ann_meta = {
|
270
|
+
"line_number": line_num + 1,
|
271
|
+
"class_id": class_id,
|
272
|
+
"yolo_center_x": center_x,
|
273
|
+
"yolo_center_y": center_y,
|
274
|
+
"yolo_width": width,
|
275
|
+
"yolo_height": height,
|
276
|
+
"absolute_bbox": [x1, y1, x2, y2],
|
277
|
+
}
|
278
|
+
annotation_metadata.append(ann_meta)
|
279
|
+
|
280
|
+
if boxes:
|
281
|
+
boxes = np.array(boxes, dtype=np.float32)
|
282
|
+
labels = np.array(labels, dtype=np.int64)
|
283
|
+
scores = np.ones(len(labels), dtype=np.float32) # Ground truth scores
|
284
|
+
else:
|
285
|
+
boxes = np.empty((0, 4), dtype=np.float32)
|
286
|
+
labels = np.empty(0, dtype=np.int64)
|
287
|
+
scores = np.empty(0, dtype=np.float32)
|
288
|
+
else:
|
289
|
+
# No label file - empty annotations
|
290
|
+
boxes = np.empty((0, 4), dtype=np.float32)
|
291
|
+
labels = np.empty(0, dtype=np.int64)
|
292
|
+
scores = np.empty(0, dtype=np.float32)
|
293
|
+
|
294
|
+
target = _ObjectDetectionTarget(boxes, labels, scores)
|
295
|
+
|
296
|
+
# Create comprehensive datum metadata
|
297
|
+
datum_metadata = DatumMetadata(
|
298
|
+
id=f"{self.reader._dataset_id}_{image_path.stem}",
|
299
|
+
# Image-level metadata
|
300
|
+
file_name=image_path.name,
|
301
|
+
file_path=str(image_path),
|
302
|
+
width=img_width,
|
303
|
+
height=img_height,
|
304
|
+
# Label file metadata
|
305
|
+
label_file=label_path.name if label_path.exists() else None,
|
306
|
+
label_file_exists=label_path.exists(),
|
307
|
+
# Annotation metadata
|
308
|
+
annotations=annotation_metadata,
|
309
|
+
num_annotations=len(annotation_metadata),
|
310
|
+
)
|
311
|
+
|
312
|
+
return image, target, datum_metadata
|
@@ -0,0 +1,169 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
from collections.abc import Sequence, Sized
|
7
|
+
from typing import Any, Literal
|
8
|
+
|
9
|
+
from maite_datasets._protocols import Array, ObjectDetectionTarget
|
10
|
+
|
11
|
+
|
12
|
+
class ValidationMessages:
|
13
|
+
DATASET_SIZED = "Dataset must be sized."
|
14
|
+
DATASET_INDEXABLE = "Dataset must be indexable."
|
15
|
+
DATASET_NONEMPTY = "Dataset must be non-empty."
|
16
|
+
DATASET_METADATA = "Dataset must have a 'metadata' attribute."
|
17
|
+
DATASET_METADATA_TYPE = "Dataset metadata must be a dictionary."
|
18
|
+
DATASET_METADATA_FORMAT = "Dataset metadata must contain an 'id' key."
|
19
|
+
DATUM_TYPE = "Dataset datum must be a tuple."
|
20
|
+
DATUM_FORMAT = "Dataset datum must contain 3 elements: image, target, metadata."
|
21
|
+
DATUM_IMAGE_TYPE = "Images must be 3-dimensional arrays."
|
22
|
+
DATUM_IMAGE_FORMAT = "Images must be in CHW format."
|
23
|
+
DATUM_TARGET_IC_TYPE = "ImageClassificationDataset targets must be one-dimensional arrays."
|
24
|
+
DATUM_TARGET_IC_FORMAT = "ImageClassificationDataset targets must be one-hot encoded or pseudo-probabilities."
|
25
|
+
DATUM_TARGET_OD_TYPE = "ObjectDetectionDataset targets must be have 'boxes', 'labels' and 'scores'."
|
26
|
+
DATUM_TARGET_OD_LABELS_TYPE = "ObjectDetectionTarget labels must be one-dimensional (N,) arrays."
|
27
|
+
DATUM_TARGET_OD_BOXES_TYPE = "ObjectDetectionTarget boxes must be two-dimensional (N, 4) arrays in xxyy format."
|
28
|
+
DATUM_TARGET_OD_SCORES_TYPE = "ObjectDetectionTarget scores must be one (N,) or two-dimensional (N, M) arrays."
|
29
|
+
DATUM_TARGET_TYPE = "Target is not a valid ImageClassification or ObjectDetection target type."
|
30
|
+
DATUM_METADATA_TYPE = "Datum metadata must be a dictionary."
|
31
|
+
DATUM_METADATA_FORMAT = "Datum metadata must contain an 'id' key."
|
32
|
+
|
33
|
+
|
34
|
+
def _validate_dataset_type(dataset: Any) -> list[str]:
|
35
|
+
issues = []
|
36
|
+
is_sized = isinstance(dataset, Sized)
|
37
|
+
is_indexable = hasattr(dataset, "__getitem__")
|
38
|
+
if not is_sized:
|
39
|
+
issues.append(ValidationMessages.DATASET_SIZED)
|
40
|
+
if not is_indexable:
|
41
|
+
issues.append(ValidationMessages.DATASET_INDEXABLE)
|
42
|
+
if is_sized and len(dataset) == 0:
|
43
|
+
issues.append(ValidationMessages.DATASET_NONEMPTY)
|
44
|
+
return issues
|
45
|
+
|
46
|
+
|
47
|
+
def _validate_dataset_metadata(dataset: Any) -> list[str]:
|
48
|
+
issues = []
|
49
|
+
if not hasattr(dataset, "metadata"):
|
50
|
+
issues.append(ValidationMessages.DATASET_METADATA)
|
51
|
+
metadata = getattr(dataset, "metadata", None)
|
52
|
+
if not isinstance(metadata, dict):
|
53
|
+
issues.append(ValidationMessages.DATASET_METADATA_TYPE)
|
54
|
+
if not isinstance(metadata, dict) or "id" not in metadata:
|
55
|
+
issues.append(ValidationMessages.DATASET_METADATA_FORMAT)
|
56
|
+
return issues
|
57
|
+
|
58
|
+
|
59
|
+
def _validate_datum_type(datum: Any) -> list[str]:
|
60
|
+
issues = []
|
61
|
+
if not isinstance(datum, tuple):
|
62
|
+
issues.append(ValidationMessages.DATUM_TYPE)
|
63
|
+
if datum is None or isinstance(datum, Sized) and len(datum) != 3:
|
64
|
+
issues.append(ValidationMessages.DATUM_FORMAT)
|
65
|
+
return issues
|
66
|
+
|
67
|
+
|
68
|
+
def _validate_datum_image(image: Any) -> list[str]:
|
69
|
+
issues = []
|
70
|
+
if not isinstance(image, Array) or len(image.shape) != 3:
|
71
|
+
issues.append(ValidationMessages.DATUM_IMAGE_TYPE)
|
72
|
+
if (
|
73
|
+
not isinstance(image, Array)
|
74
|
+
or len(image.shape) == 3
|
75
|
+
and (image.shape[0] > image.shape[1] or image.shape[0] > image.shape[2])
|
76
|
+
):
|
77
|
+
issues.append(ValidationMessages.DATUM_IMAGE_FORMAT)
|
78
|
+
return issues
|
79
|
+
|
80
|
+
|
81
|
+
def _validate_datum_target_ic(target: Any) -> list[str]:
|
82
|
+
issues = []
|
83
|
+
if not isinstance(target, Array) or len(target.shape) != 1:
|
84
|
+
issues.append(ValidationMessages.DATUM_TARGET_IC_TYPE)
|
85
|
+
if target is None or sum(target) > 1 + 1e-6 or sum(target) < 1 - 1e-6:
|
86
|
+
issues.append(ValidationMessages.DATUM_TARGET_IC_FORMAT)
|
87
|
+
return issues
|
88
|
+
|
89
|
+
|
90
|
+
def _validate_datum_target_od(target: Any) -> list[str]:
|
91
|
+
issues = []
|
92
|
+
if not isinstance(target, ObjectDetectionTarget):
|
93
|
+
issues.append(ValidationMessages.DATUM_TARGET_OD_TYPE)
|
94
|
+
od_target: ObjectDetectionTarget | None = target if isinstance(target, ObjectDetectionTarget) else None
|
95
|
+
if od_target is None or len(np.asarray(od_target.labels).shape) != 1:
|
96
|
+
issues.append(ValidationMessages.DATUM_TARGET_OD_LABELS_TYPE)
|
97
|
+
if (
|
98
|
+
od_target is None
|
99
|
+
or len(np.asarray(od_target.boxes).shape) != 2
|
100
|
+
or (len(np.asarray(od_target.boxes).shape) == 2 and np.asarray(od_target.boxes).shape[1] != 4)
|
101
|
+
):
|
102
|
+
issues.append(ValidationMessages.DATUM_TARGET_OD_BOXES_TYPE)
|
103
|
+
if od_target is None or len(np.asarray(od_target.scores).shape) not in (1, 2):
|
104
|
+
issues.append(ValidationMessages.DATUM_TARGET_OD_SCORES_TYPE)
|
105
|
+
return issues
|
106
|
+
|
107
|
+
|
108
|
+
def _detect_target_type(target: Any) -> Literal["ic", "od", "auto"]:
|
109
|
+
if isinstance(target, Array):
|
110
|
+
return "ic"
|
111
|
+
if isinstance(target, ObjectDetectionTarget):
|
112
|
+
return "od"
|
113
|
+
return "auto"
|
114
|
+
|
115
|
+
|
116
|
+
def _validate_datum_target(target: Any, target_type: Literal["ic", "od", "auto"]) -> list[str]:
|
117
|
+
issues = []
|
118
|
+
target_type = _detect_target_type(target) if target_type == "auto" else target_type
|
119
|
+
if target_type == "ic":
|
120
|
+
issues.extend(_validate_datum_target_ic(target))
|
121
|
+
elif target_type == "od":
|
122
|
+
issues.extend(_validate_datum_target_od(target))
|
123
|
+
else:
|
124
|
+
issues.append(ValidationMessages.DATUM_TARGET_TYPE)
|
125
|
+
return issues
|
126
|
+
|
127
|
+
|
128
|
+
def _validate_datum_metadata(metadata: Any) -> list[str]:
|
129
|
+
issues = []
|
130
|
+
if metadata is None or not isinstance(metadata, dict):
|
131
|
+
issues.append(ValidationMessages.DATUM_METADATA_TYPE)
|
132
|
+
if metadata is None or isinstance(metadata, dict) and "id" not in metadata:
|
133
|
+
issues.append(ValidationMessages.DATUM_METADATA_FORMAT)
|
134
|
+
return issues
|
135
|
+
|
136
|
+
|
137
|
+
def validate_dataset(dataset: Any, dataset_type: Literal["ic", "od", "auto"] = "auto") -> None:
|
138
|
+
"""
|
139
|
+
Validate a dataset for compliance with MAITE protocol.
|
140
|
+
|
141
|
+
Parameters
|
142
|
+
----------
|
143
|
+
dataset: Any
|
144
|
+
Dataset to validate.
|
145
|
+
dataset_type: "ic", "od", or "auto", default "auto"
|
146
|
+
Dataset type, if known.
|
147
|
+
|
148
|
+
Raises
|
149
|
+
------
|
150
|
+
ValueError
|
151
|
+
Raises exception if dataset is invalid with a list of validation issues.
|
152
|
+
"""
|
153
|
+
issues = []
|
154
|
+
issues.extend(_validate_dataset_type(dataset))
|
155
|
+
datum = None if issues else dataset[0] # type: ignore
|
156
|
+
issues.extend(_validate_dataset_metadata(dataset))
|
157
|
+
issues.extend(_validate_datum_type(datum))
|
158
|
+
|
159
|
+
is_seq = isinstance(datum, Sequence)
|
160
|
+
datum_len = len(datum) if is_seq else 0
|
161
|
+
image = datum[0] if is_seq and datum_len > 0 else None
|
162
|
+
target = datum[1] if is_seq and datum_len > 1 else None
|
163
|
+
metadata = datum[2] if is_seq and datum_len > 2 else None
|
164
|
+
issues.extend(_validate_datum_image(image))
|
165
|
+
issues.extend(_validate_datum_target(target, dataset_type))
|
166
|
+
issues.extend(_validate_datum_metadata(metadata))
|
167
|
+
|
168
|
+
if issues:
|
169
|
+
raise ValueError("Dataset validation issues found:\n - " + "\n - ".join(issues))
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: maite-datasets
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.4
|
4
4
|
Summary: A collection of Image Classification and Object Detection task datasets conforming to the MAITE protocol.
|
5
5
|
Author-email: Andrew Weng <andrew.weng@ariacoustics.com>, Ryan Wood <ryan.wood@ariacoustics.com>, Shaun Jullens <shaun.jullens@ariacoustics.com>
|
6
6
|
License-Expression: MIT
|
@@ -1,13 +1,20 @@
|
|
1
|
-
maite_datasets/__init__.py,sha256=
|
1
|
+
maite_datasets/__init__.py,sha256=53LW5bHMAr4uD6w2bvrPxgtROUIzaE-3LR6TR0dDucs,746
|
2
2
|
maite_datasets/_base.py,sha256=BiWB_xvL4AtV0jxVjzpcZHuRTb52dTD0CQtu08DzoXA,8195
|
3
3
|
maite_datasets/_builder.py,sha256=URhRCedvuqsy88N4lzQrwI-uL1kS1_kavP9fS402sPw,10036
|
4
|
+
maite_datasets/_collate.py,sha256=-XuKeeMmOnSB0RgQbz8BjsoqQar9Tsf_qALZxijQ498,4063
|
4
5
|
maite_datasets/_fileio.py,sha256=7S-hF3xU60AdcsPsfYR7rjbeGZUlv3JjGEZhGJOxGYU,5622
|
5
|
-
maite_datasets/_protocols.py,sha256=
|
6
|
+
maite_datasets/_protocols.py,sha256=aWrnUM1stZ9VInkBEynod_OdYq2ORSpew7yoF-Zeuig,5247
|
6
7
|
maite_datasets/_types.py,sha256=S5DMyiUrkUjV9uM0ysKqxVoi7z5P7B3EPiLI4Fyq9Jc,1147
|
8
|
+
maite_datasets/_validate.py,sha256=sP-5lYXkmkiTadJcy_LtEMiZ0m82xR0yELoxWORrZDQ,6904
|
7
9
|
maite_datasets/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
10
|
maite_datasets/_mixin/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
9
11
|
maite_datasets/_mixin/_numpy.py,sha256=GEuRyeprH-STh-_zktAp0Tg6NNyMdh1ThyhjW558NOo,860
|
10
12
|
maite_datasets/_mixin/_torch.py,sha256=pkN2vMNsDk_h5wnD5899zIHsPtEADbGfmRyI5CdGonI,827
|
13
|
+
maite_datasets/_reader/__init__.py,sha256=VzrVOsmztPJV83um8tY5qdqU-HEPP15RlLClGbxTFlQ,164
|
14
|
+
maite_datasets/_reader/_base.py,sha256=3_425HLcvfEU9bqQjy9S9gvqXlkPDR471IVXFxBozl0,4289
|
15
|
+
maite_datasets/_reader/_coco.py,sha256=YyDrgdXZog_EHViWary5k8bkQd-jNaPm-G1wiN_V5ks,9960
|
16
|
+
maite_datasets/_reader/_factory.py,sha256=cI3Cw1yWj4hK2gn6N5bugXzGMcNwcCEkJ4AoynwOZvI,2222
|
17
|
+
maite_datasets/_reader/_yolo.py,sha256=abWAXrFFGE00NlIMUb_lAoiXFykGYOAGKGHekhG30Q8,11462
|
11
18
|
maite_datasets/image_classification/__init__.py,sha256=pcZojkdsiMoLgY4mKjoQY6WyEwiGYHxNrAGpnvn3zsY,308
|
12
19
|
maite_datasets/image_classification/_cifar10.py,sha256=w7BPGZzUV1gXFoYRgxa6VOqKn1EgQi3x1rrA4nEUbeI,8470
|
13
20
|
maite_datasets/image_classification/_mnist.py,sha256=6xDWY4qbY1hlcUZKvVZeQMvYbF0vLtaVzOuQUKJkcJU,8248
|
@@ -18,7 +25,7 @@ maite_datasets/object_detection/_milco.py,sha256=KEU4JFvCxfyMAb4RFMnxTMk_MggdEAV
|
|
18
25
|
maite_datasets/object_detection/_seadrone.py,sha256=w_pSojLzgwdKrUSxaz8r7dPJVKGND6JSYl0S_BKOLH0,271282
|
19
26
|
maite_datasets/object_detection/_voc.py,sha256=VuokKaOzI1wSfgG5DC7ufMbRDlG-b6Se3hg4eQzNQbE,19731
|
20
27
|
maite_datasets/object_detection/_voc_torch.py,sha256=bjeawnNit7Llcf_cZY_9lcJYoUoAU-Wen6MMT-7QX3k,2917
|
21
|
-
maite_datasets-0.0.
|
22
|
-
maite_datasets-0.0.
|
23
|
-
maite_datasets-0.0.
|
24
|
-
maite_datasets-0.0.
|
28
|
+
maite_datasets-0.0.4.dist-info/METADATA,sha256=8-83ACnQAjf9LJgZY25GvIPGL5o5Wi0RA-SEog7jcvU,3747
|
29
|
+
maite_datasets-0.0.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
30
|
+
maite_datasets-0.0.4.dist-info/licenses/LICENSE,sha256=6h3J3R-ajGHh_isDSftzS5_jJjB9HH4TaI0vU-VscaY,1082
|
31
|
+
maite_datasets-0.0.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|