hafnia 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/__main__.py +3 -1
- cli/config.py +43 -3
- cli/keychain.py +88 -0
- cli/profile_cmds.py +5 -2
- hafnia/__init__.py +1 -1
- hafnia/dataset/dataset_helpers.py +9 -2
- hafnia/dataset/dataset_names.py +2 -1
- hafnia/dataset/dataset_recipe/dataset_recipe.py +49 -37
- hafnia/dataset/dataset_recipe/recipe_transforms.py +18 -2
- hafnia/dataset/dataset_upload_helper.py +60 -4
- hafnia/dataset/format_conversions/image_classification_from_directory.py +106 -0
- hafnia/dataset/format_conversions/torchvision_datasets.py +281 -0
- hafnia/dataset/hafnia_dataset.py +176 -50
- hafnia/dataset/operations/dataset_stats.py +2 -3
- hafnia/dataset/operations/dataset_transformations.py +19 -15
- hafnia/dataset/operations/table_transformations.py +4 -3
- hafnia/dataset/primitives/bbox.py +25 -12
- hafnia/dataset/primitives/bitmask.py +26 -14
- hafnia/dataset/primitives/classification.py +16 -8
- hafnia/dataset/primitives/point.py +7 -3
- hafnia/dataset/primitives/polygon.py +16 -9
- hafnia/dataset/primitives/segmentation.py +10 -7
- hafnia/experiment/hafnia_logger.py +0 -9
- hafnia/platform/dataset_recipe.py +7 -2
- hafnia/platform/datasets.py +3 -3
- hafnia/platform/download.py +23 -18
- hafnia/utils.py +17 -0
- hafnia/visualizations/image_visualizations.py +1 -1
- {hafnia-0.3.0.dist-info → hafnia-0.4.0.dist-info}/METADATA +8 -6
- hafnia-0.4.0.dist-info/RECORD +56 -0
- hafnia-0.3.0.dist-info/RECORD +0 -53
- {hafnia-0.3.0.dist-info → hafnia-0.4.0.dist-info}/WHEEL +0 -0
- {hafnia-0.3.0.dist-info → hafnia-0.4.0.dist-info}/entry_points.txt +0 -0
- {hafnia-0.3.0.dist-info → hafnia-0.4.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,281 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import os
|
|
3
|
+
import shutil
|
|
4
|
+
import tempfile
|
|
5
|
+
import textwrap
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Callable, Dict, List, Optional, Tuple
|
|
8
|
+
|
|
9
|
+
from rich.progress import track
|
|
10
|
+
from torchvision import datasets as tv_datasets
|
|
11
|
+
from torchvision.datasets import VisionDataset
|
|
12
|
+
from torchvision.datasets.utils import download_and_extract_archive, extract_archive
|
|
13
|
+
|
|
14
|
+
from hafnia import utils
|
|
15
|
+
from hafnia.dataset.dataset_helpers import save_pil_image_with_hash_name
|
|
16
|
+
from hafnia.dataset.dataset_names import SplitName
|
|
17
|
+
from hafnia.dataset.format_conversions.image_classification_from_directory import (
|
|
18
|
+
import_image_classification_directory_tree,
|
|
19
|
+
)
|
|
20
|
+
from hafnia.dataset.hafnia_dataset import DatasetInfo, HafniaDataset, Sample, TaskInfo
|
|
21
|
+
from hafnia.dataset.primitives import Classification
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def torchvision_to_hafnia_converters() -> Dict[str, Callable]:
|
|
25
|
+
return {
|
|
26
|
+
"mnist": mnist_as_hafnia_dataset,
|
|
27
|
+
"cifar10": cifar10_as_hafnia_dataset,
|
|
28
|
+
"cifar100": cifar100_as_hafnia_dataset,
|
|
29
|
+
"caltech-101": caltech_101_as_hafnia_dataset,
|
|
30
|
+
"caltech-256": caltech_256_as_hafnia_dataset,
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def mnist_as_hafnia_dataset(force_redownload=False, n_samples: Optional[int] = None) -> HafniaDataset:
|
|
35
|
+
samples, tasks = torchvision_basic_image_classification_dataset_as_hafnia_dataset(
|
|
36
|
+
dataset_loader=tv_datasets.MNIST,
|
|
37
|
+
force_redownload=force_redownload,
|
|
38
|
+
n_samples=n_samples,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
dataset_info = DatasetInfo(
|
|
42
|
+
dataset_name="mnist",
|
|
43
|
+
version="1.1.0",
|
|
44
|
+
tasks=tasks,
|
|
45
|
+
reference_bibtex=textwrap.dedent("""\
|
|
46
|
+
@article{lecun2010mnist,
|
|
47
|
+
title={MNIST handwritten digit database},
|
|
48
|
+
author={LeCun, Yann and Cortes, Corinna and Burges, CJ},
|
|
49
|
+
journal={ATT Labs [Online]. Available: http://yann.lecun.com/exdb/mnist},
|
|
50
|
+
volume={2},
|
|
51
|
+
year={2010}
|
|
52
|
+
}"""),
|
|
53
|
+
reference_paper_url=None,
|
|
54
|
+
reference_dataset_page="http://yann.lecun.com/exdb/mnist",
|
|
55
|
+
)
|
|
56
|
+
return HafniaDataset.from_samples_list(samples_list=samples, info=dataset_info)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def cifar10_as_hafnia_dataset(force_redownload: bool = False, n_samples: Optional[int] = None) -> HafniaDataset:
|
|
60
|
+
return cifar_as_hafnia_dataset(dataset_name="cifar10", force_redownload=force_redownload, n_samples=n_samples)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def cifar100_as_hafnia_dataset(force_redownload: bool = False, n_samples: Optional[int] = None) -> HafniaDataset:
|
|
64
|
+
return cifar_as_hafnia_dataset(dataset_name="cifar100", force_redownload=force_redownload, n_samples=n_samples)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def caltech_101_as_hafnia_dataset(
|
|
68
|
+
force_redownload: bool = False,
|
|
69
|
+
n_samples: Optional[int] = None,
|
|
70
|
+
) -> HafniaDataset:
|
|
71
|
+
dataset_name = "caltech-101"
|
|
72
|
+
path_image_classification_folder = _download_and_extract_caltech_dataset(
|
|
73
|
+
dataset_name, force_redownload=force_redownload
|
|
74
|
+
)
|
|
75
|
+
hafnia_dataset = import_image_classification_directory_tree(
|
|
76
|
+
path_image_classification_folder,
|
|
77
|
+
split=SplitName.TRAIN,
|
|
78
|
+
n_samples=n_samples,
|
|
79
|
+
)
|
|
80
|
+
hafnia_dataset.info.dataset_name = dataset_name
|
|
81
|
+
hafnia_dataset.info.version = "1.1.0"
|
|
82
|
+
hafnia_dataset.info.reference_bibtex = textwrap.dedent("""\
|
|
83
|
+
@article{FeiFei2004LearningGV,
|
|
84
|
+
title={Learning Generative Visual Models from Few Training Examples: An Incremental Bayesian
|
|
85
|
+
Approach Tested on 101 Object Categories},
|
|
86
|
+
author={Li Fei-Fei and Rob Fergus and Pietro Perona},
|
|
87
|
+
journal={Computer Vision and Pattern Recognition Workshop},
|
|
88
|
+
year={2004},
|
|
89
|
+
}
|
|
90
|
+
""")
|
|
91
|
+
hafnia_dataset.info.reference_dataset_page = "https://data.caltech.edu/records/mzrjq-6wc02"
|
|
92
|
+
|
|
93
|
+
return hafnia_dataset
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def caltech_256_as_hafnia_dataset(
|
|
97
|
+
force_redownload: bool = False,
|
|
98
|
+
n_samples: Optional[int] = None,
|
|
99
|
+
) -> HafniaDataset:
|
|
100
|
+
dataset_name = "caltech-256"
|
|
101
|
+
|
|
102
|
+
path_image_classification_folder = _download_and_extract_caltech_dataset(
|
|
103
|
+
dataset_name, force_redownload=force_redownload
|
|
104
|
+
)
|
|
105
|
+
hafnia_dataset = import_image_classification_directory_tree(
|
|
106
|
+
path_image_classification_folder,
|
|
107
|
+
split=SplitName.TRAIN,
|
|
108
|
+
n_samples=n_samples,
|
|
109
|
+
)
|
|
110
|
+
hafnia_dataset.info.dataset_name = dataset_name
|
|
111
|
+
hafnia_dataset.info.version = "1.1.0"
|
|
112
|
+
hafnia_dataset.info.reference_bibtex = textwrap.dedent("""\
|
|
113
|
+
@misc{griffin_2023_5sv1j-ytw97,
|
|
114
|
+
author = {Griffin, Gregory and
|
|
115
|
+
Holub, Alex and
|
|
116
|
+
Perona, Pietro},
|
|
117
|
+
title = {Caltech-256 Object Category Dataset},
|
|
118
|
+
month = aug,
|
|
119
|
+
year = 2023,
|
|
120
|
+
publisher = {California Institute of Technology},
|
|
121
|
+
version = {public},
|
|
122
|
+
}""")
|
|
123
|
+
hafnia_dataset.info.reference_dataset_page = "https://data.caltech.edu/records/nyy15-4j048"
|
|
124
|
+
|
|
125
|
+
return hafnia_dataset
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def cifar_as_hafnia_dataset(
|
|
129
|
+
dataset_name: str,
|
|
130
|
+
force_redownload: bool = False,
|
|
131
|
+
n_samples: Optional[int] = None,
|
|
132
|
+
) -> HafniaDataset:
|
|
133
|
+
if dataset_name == "cifar10":
|
|
134
|
+
dataset_loader = tv_datasets.CIFAR10
|
|
135
|
+
elif dataset_name == "cifar100":
|
|
136
|
+
dataset_loader = tv_datasets.CIFAR100
|
|
137
|
+
else:
|
|
138
|
+
raise ValueError(f"Unknown dataset name: {dataset_name}. Supported: cifar10, cifar100")
|
|
139
|
+
samples, tasks = torchvision_basic_image_classification_dataset_as_hafnia_dataset(
|
|
140
|
+
dataset_loader=dataset_loader,
|
|
141
|
+
force_redownload=force_redownload,
|
|
142
|
+
n_samples=n_samples,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
dataset_info = DatasetInfo(
|
|
146
|
+
dataset_name=dataset_name,
|
|
147
|
+
version="1.1.0",
|
|
148
|
+
tasks=tasks,
|
|
149
|
+
reference_bibtex=textwrap.dedent("""\
|
|
150
|
+
@@TECHREPORT{Krizhevsky09learningmultiple,
|
|
151
|
+
author = {Alex Krizhevsky},
|
|
152
|
+
title = {Learning multiple layers of features from tiny images},
|
|
153
|
+
institution = {},
|
|
154
|
+
year = {2009}
|
|
155
|
+
}"""),
|
|
156
|
+
reference_paper_url="https://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf",
|
|
157
|
+
reference_dataset_page="https://www.cs.toronto.edu/~kriz/cifar.html",
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
return HafniaDataset.from_samples_list(samples_list=samples, info=dataset_info)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def torchvision_basic_image_classification_dataset_as_hafnia_dataset(
|
|
164
|
+
dataset_loader: VisionDataset,
|
|
165
|
+
force_redownload: bool = False,
|
|
166
|
+
n_samples: Optional[int] = None,
|
|
167
|
+
) -> Tuple[List[Sample], List[TaskInfo]]:
|
|
168
|
+
"""
|
|
169
|
+
Converts a certain group of torchvision-based image classification datasets to a Hafnia Dataset.
|
|
170
|
+
|
|
171
|
+
This conversion only works for certain group of image classification VisionDataset by torchvision.
|
|
172
|
+
Common for these datasets is:
|
|
173
|
+
1) They provide a 'class_to_idx' mapping,
|
|
174
|
+
2) A "train" boolean parameter in the init function to separate training and test data - thus no validation split
|
|
175
|
+
is available for these datasets,
|
|
176
|
+
3) Datasets are in-memory and not on disk
|
|
177
|
+
4) Samples consist of a PIL image and a class index.
|
|
178
|
+
|
|
179
|
+
"""
|
|
180
|
+
torchvision_dataset_name = dataset_loader.__name__
|
|
181
|
+
|
|
182
|
+
# Check if loader has train-parameter using inspect module
|
|
183
|
+
params = inspect.signature(dataset_loader).parameters
|
|
184
|
+
|
|
185
|
+
has_train_param = ("train" in params) and (params["train"].annotation is bool)
|
|
186
|
+
if not has_train_param:
|
|
187
|
+
raise ValueError(
|
|
188
|
+
f"The dataset loader '{dataset_loader.__name__}' does not have a 'train: bool' parameter in the init "
|
|
189
|
+
"function. This is a sign that the wrong dataset loader is being used. This conversion function only "
|
|
190
|
+
"works for certain image classification datasets provided by torchvision that are similar to e.g. "
|
|
191
|
+
"MNIST, CIFAR-10, CIFAR-100"
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
path_torchvision_dataset = utils.get_path_torchvision_downloads() / torchvision_dataset_name
|
|
195
|
+
path_hafnia_conversions = utils.get_path_hafnia_conversions() / torchvision_dataset_name
|
|
196
|
+
|
|
197
|
+
if force_redownload:
|
|
198
|
+
shutil.rmtree(path_torchvision_dataset, ignore_errors=True)
|
|
199
|
+
shutil.rmtree(path_hafnia_conversions, ignore_errors=True)
|
|
200
|
+
|
|
201
|
+
splits = {
|
|
202
|
+
SplitName.TRAIN: dataset_loader(root=path_torchvision_dataset, train=True, download=True),
|
|
203
|
+
SplitName.TEST: dataset_loader(root=path_torchvision_dataset, train=False, download=True),
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
samples = []
|
|
207
|
+
n_samples_per_split = n_samples // len(splits) if n_samples is not None else None
|
|
208
|
+
for split_name, torchvision_dataset in splits.items():
|
|
209
|
+
class_name_to_index = torchvision_dataset.class_to_idx
|
|
210
|
+
class_index_to_name = {v: k for k, v in class_name_to_index.items()}
|
|
211
|
+
description = f"Convert '{torchvision_dataset_name}' ({split_name} split) to Hafnia Dataset "
|
|
212
|
+
samples_in_split = []
|
|
213
|
+
for image, class_idx in track(torchvision_dataset, total=n_samples_per_split, description=description):
|
|
214
|
+
(width, height) = image.size
|
|
215
|
+
path_image = save_pil_image_with_hash_name(image, path_hafnia_conversions)
|
|
216
|
+
sample = Sample(
|
|
217
|
+
file_path=str(path_image),
|
|
218
|
+
height=height,
|
|
219
|
+
width=width,
|
|
220
|
+
split=split_name,
|
|
221
|
+
classifications=[
|
|
222
|
+
Classification(
|
|
223
|
+
class_name=class_index_to_name[class_idx],
|
|
224
|
+
class_idx=class_idx,
|
|
225
|
+
)
|
|
226
|
+
],
|
|
227
|
+
)
|
|
228
|
+
samples_in_split.append(sample)
|
|
229
|
+
|
|
230
|
+
if n_samples_per_split is not None and len(samples_in_split) >= n_samples_per_split:
|
|
231
|
+
break
|
|
232
|
+
|
|
233
|
+
samples.extend(samples_in_split)
|
|
234
|
+
class_names = list(class_name_to_index.keys())
|
|
235
|
+
tasks = [TaskInfo(primitive=Classification, class_names=class_names)]
|
|
236
|
+
|
|
237
|
+
return samples, tasks
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def _download_and_extract_caltech_dataset(dataset_name: str, force_redownload: bool) -> Path:
|
|
241
|
+
path_torchvision_dataset = utils.get_path_torchvision_downloads() / dataset_name
|
|
242
|
+
|
|
243
|
+
if force_redownload:
|
|
244
|
+
shutil.rmtree(path_torchvision_dataset, ignore_errors=True)
|
|
245
|
+
|
|
246
|
+
if path_torchvision_dataset.exists():
|
|
247
|
+
return path_torchvision_dataset
|
|
248
|
+
|
|
249
|
+
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
250
|
+
path_tmp_output = Path(tmpdirname)
|
|
251
|
+
path_tmp_output.mkdir(parents=True, exist_ok=True)
|
|
252
|
+
|
|
253
|
+
if dataset_name == "caltech-101":
|
|
254
|
+
download_and_extract_archive(
|
|
255
|
+
"https://data.caltech.edu/records/mzrjq-6wc02/files/caltech-101.zip",
|
|
256
|
+
download_root=path_tmp_output,
|
|
257
|
+
filename="caltech-101.zip",
|
|
258
|
+
md5="3138e1922a9193bfa496528edbbc45d0",
|
|
259
|
+
)
|
|
260
|
+
path_output_extracted = path_tmp_output / "caltech-101"
|
|
261
|
+
for gzip_file in os.listdir(path_output_extracted):
|
|
262
|
+
if gzip_file.endswith(".gz"):
|
|
263
|
+
extract_archive(os.path.join(path_output_extracted, gzip_file), path_output_extracted)
|
|
264
|
+
path_org = path_output_extracted / "101_ObjectCategories"
|
|
265
|
+
|
|
266
|
+
elif dataset_name == "caltech-256":
|
|
267
|
+
org_dataset_name = "256_ObjectCategories"
|
|
268
|
+
path_org = path_tmp_output / org_dataset_name
|
|
269
|
+
download_and_extract_archive(
|
|
270
|
+
url=f"https://data.caltech.edu/records/nyy15-4j048/files/{org_dataset_name}.tar",
|
|
271
|
+
download_root=path_tmp_output,
|
|
272
|
+
md5="67b4f42ca05d46448c6bb8ecd2220f6d",
|
|
273
|
+
remove_finished=True,
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
else:
|
|
277
|
+
raise ValueError(f"Unknown dataset name: {dataset_name}. Supported: caltech-101, caltech-256")
|
|
278
|
+
|
|
279
|
+
shutil.rmtree(path_torchvision_dataset, ignore_errors=True)
|
|
280
|
+
shutil.move(path_org, path_torchvision_dataset)
|
|
281
|
+
return path_torchvision_dataset
|
hafnia/dataset/hafnia_dataset.py
CHANGED
|
@@ -8,14 +8,15 @@ from dataclasses import dataclass
|
|
|
8
8
|
from datetime import datetime
|
|
9
9
|
from pathlib import Path
|
|
10
10
|
from random import Random
|
|
11
|
-
from typing import Any, Dict, List, Optional, Type, Union
|
|
11
|
+
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
|
12
12
|
|
|
13
13
|
import more_itertools
|
|
14
14
|
import numpy as np
|
|
15
15
|
import polars as pl
|
|
16
|
+
from packaging.version import Version
|
|
16
17
|
from PIL import Image
|
|
17
18
|
from pydantic import BaseModel, Field, field_serializer, field_validator
|
|
18
|
-
from
|
|
19
|
+
from rich.progress import track
|
|
19
20
|
|
|
20
21
|
import hafnia
|
|
21
22
|
from hafnia.dataset import dataset_helpers
|
|
@@ -29,10 +30,14 @@ from hafnia.dataset.dataset_names import (
|
|
|
29
30
|
ColumnName,
|
|
30
31
|
SplitName,
|
|
31
32
|
)
|
|
32
|
-
from hafnia.dataset.operations import
|
|
33
|
+
from hafnia.dataset.operations import (
|
|
34
|
+
dataset_stats,
|
|
35
|
+
dataset_transformations,
|
|
36
|
+
table_transformations,
|
|
37
|
+
)
|
|
33
38
|
from hafnia.dataset.operations.table_transformations import (
|
|
34
39
|
check_image_paths,
|
|
35
|
-
|
|
40
|
+
read_samples_from_path,
|
|
36
41
|
)
|
|
37
42
|
from hafnia.dataset.primitives import PRIMITIVE_TYPES, get_primitive_type_from_string
|
|
38
43
|
from hafnia.dataset.primitives.bbox import Bbox
|
|
@@ -44,9 +49,17 @@ from hafnia.log import user_logger
|
|
|
44
49
|
|
|
45
50
|
|
|
46
51
|
class TaskInfo(BaseModel):
|
|
47
|
-
primitive: Type[Primitive]
|
|
48
|
-
|
|
49
|
-
|
|
52
|
+
primitive: Type[Primitive] = Field(
|
|
53
|
+
description="Primitive class or string name of the primitive, e.g. 'Bbox' or 'bitmask'"
|
|
54
|
+
)
|
|
55
|
+
class_names: Optional[List[str]] = Field(default=None, description="Optional list of class names for the primitive")
|
|
56
|
+
name: Optional[str] = Field(
|
|
57
|
+
default=None,
|
|
58
|
+
description=(
|
|
59
|
+
"Optional name for the task. 'None' will use default name of the provided primitive. "
|
|
60
|
+
"e.g. Bbox ->'bboxes', Bitmask -> 'bitmasks' etc."
|
|
61
|
+
),
|
|
62
|
+
)
|
|
50
63
|
|
|
51
64
|
def model_post_init(self, __context: Any) -> None:
|
|
52
65
|
if self.name is None:
|
|
@@ -99,17 +112,37 @@ class TaskInfo(BaseModel):
|
|
|
99
112
|
|
|
100
113
|
|
|
101
114
|
class DatasetInfo(BaseModel):
|
|
102
|
-
dataset_name: str
|
|
103
|
-
version: str
|
|
104
|
-
tasks: List[TaskInfo]
|
|
105
|
-
distributions: Optional[List[TaskInfo]] = None
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
115
|
+
dataset_name: str = Field(description="Name of the dataset, e.g. 'coco'")
|
|
116
|
+
version: Optional[str] = Field(default=None, description="Version of the dataset")
|
|
117
|
+
tasks: List[TaskInfo] = Field(default=None, description="List of tasks in the dataset")
|
|
118
|
+
distributions: Optional[List[TaskInfo]] = Field(default=None, description="Optional list of task distributions")
|
|
119
|
+
reference_bibtex: Optional[str] = Field(
|
|
120
|
+
default=None,
|
|
121
|
+
description="Optional, BibTeX reference to dataset publication",
|
|
122
|
+
)
|
|
123
|
+
reference_paper_url: Optional[str] = Field(
|
|
124
|
+
default=None,
|
|
125
|
+
description="Optional, URL to dataset publication",
|
|
126
|
+
)
|
|
127
|
+
reference_dataset_page: Optional[str] = Field(
|
|
128
|
+
default=None,
|
|
129
|
+
description="Optional, URL to the dataset page",
|
|
130
|
+
)
|
|
131
|
+
meta: Optional[Dict[str, Any]] = Field(default=None, description="Optional metadata about the dataset")
|
|
132
|
+
format_version: str = Field(
|
|
133
|
+
default=hafnia.__dataset_format_version__,
|
|
134
|
+
description="Version of the Hafnia dataset format. You should not set this manually.",
|
|
135
|
+
)
|
|
136
|
+
updated_at: datetime = Field(
|
|
137
|
+
default_factory=datetime.now,
|
|
138
|
+
description="Timestamp of the last update to the dataset info. You should not set this manually.",
|
|
139
|
+
)
|
|
109
140
|
|
|
110
141
|
@field_validator("tasks", mode="after")
|
|
111
142
|
@classmethod
|
|
112
|
-
def _validate_check_for_duplicate_tasks(cls, tasks: List[TaskInfo]) -> List[TaskInfo]:
|
|
143
|
+
def _validate_check_for_duplicate_tasks(cls, tasks: Optional[List[TaskInfo]]) -> List[TaskInfo]:
|
|
144
|
+
if tasks is None:
|
|
145
|
+
return []
|
|
113
146
|
task_name_counts = collections.Counter(task.name for task in tasks)
|
|
114
147
|
duplicate_task_names = [name for name, count in task_name_counts.items() if count > 1]
|
|
115
148
|
if duplicate_task_names:
|
|
@@ -118,6 +151,35 @@ class DatasetInfo(BaseModel):
|
|
|
118
151
|
)
|
|
119
152
|
return tasks
|
|
120
153
|
|
|
154
|
+
@field_validator("format_version")
|
|
155
|
+
@classmethod
|
|
156
|
+
def _validate_format_version(cls, format_version: str) -> str:
|
|
157
|
+
try:
|
|
158
|
+
Version(format_version)
|
|
159
|
+
except Exception as e:
|
|
160
|
+
raise ValueError(f"Invalid format_version '{format_version}'. Must be a valid version string.") from e
|
|
161
|
+
|
|
162
|
+
if Version(format_version) > Version(hafnia.__dataset_format_version__):
|
|
163
|
+
user_logger.warning(
|
|
164
|
+
f"The loaded dataset format version '{format_version}' is newer than the format version "
|
|
165
|
+
f"'{hafnia.__dataset_format_version__}' used in your version of Hafnia. Please consider "
|
|
166
|
+
f"updating Hafnia package."
|
|
167
|
+
)
|
|
168
|
+
return format_version
|
|
169
|
+
|
|
170
|
+
@field_validator("version")
|
|
171
|
+
@classmethod
|
|
172
|
+
def _validate_version(cls, dataset_version: Optional[str]) -> Optional[str]:
|
|
173
|
+
if dataset_version is None:
|
|
174
|
+
return None
|
|
175
|
+
|
|
176
|
+
try:
|
|
177
|
+
Version(dataset_version)
|
|
178
|
+
except Exception as e:
|
|
179
|
+
raise ValueError(f"Invalid dataset_version '{dataset_version}'. Must be a valid version string.") from e
|
|
180
|
+
|
|
181
|
+
return dataset_version
|
|
182
|
+
|
|
121
183
|
def check_for_duplicate_task_names(self) -> List[TaskInfo]:
|
|
122
184
|
return self._validate_check_for_duplicate_tasks(self.tasks)
|
|
123
185
|
|
|
@@ -187,7 +249,7 @@ class DatasetInfo(BaseModel):
|
|
|
187
249
|
meta.update(info1.meta or {})
|
|
188
250
|
return DatasetInfo(
|
|
189
251
|
dataset_name=info0.dataset_name + "+" + info1.dataset_name,
|
|
190
|
-
version=
|
|
252
|
+
version=None,
|
|
191
253
|
tasks=list(unique_tasks),
|
|
192
254
|
distributions=list(distributions),
|
|
193
255
|
meta=meta,
|
|
@@ -258,22 +320,40 @@ class DatasetInfo(BaseModel):
|
|
|
258
320
|
|
|
259
321
|
|
|
260
322
|
class Sample(BaseModel):
|
|
261
|
-
|
|
262
|
-
height: int
|
|
263
|
-
width: int
|
|
264
|
-
split: str
|
|
265
|
-
tags: List[str] =
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
323
|
+
file_path: str = Field(description="Path to the image file")
|
|
324
|
+
height: int = Field(description="Height of the image")
|
|
325
|
+
width: int = Field(description="Width of the image")
|
|
326
|
+
split: str = Field(description="Split name, e.g., 'train', 'val', 'test'")
|
|
327
|
+
tags: List[str] = Field(
|
|
328
|
+
default_factory=list,
|
|
329
|
+
description="Tags for a given sample. Used for creating subsets of the dataset.",
|
|
330
|
+
)
|
|
331
|
+
collection_index: Optional[int] = Field(default=None, description="Optional e.g. frame number for video datasets")
|
|
332
|
+
collection_id: Optional[str] = Field(default=None, description="Optional e.g. video name for video datasets")
|
|
333
|
+
remote_path: Optional[str] = Field(default=None, description="Optional remote path for the image, if applicable")
|
|
334
|
+
sample_index: Optional[int] = Field(
|
|
335
|
+
default=None,
|
|
336
|
+
description="Don't manually set this, it is used for indexing samples in the dataset.",
|
|
337
|
+
)
|
|
338
|
+
classifications: Optional[List[Classification]] = Field(
|
|
339
|
+
default=None, description="Optional list of classifications"
|
|
340
|
+
)
|
|
341
|
+
objects: Optional[List[Bbox]] = Field(default=None, description="Optional list of objects (bounding boxes)")
|
|
342
|
+
bitmasks: Optional[List[Bitmask]] = Field(default=None, description="Optional list of bitmasks")
|
|
343
|
+
polygons: Optional[List[Polygon]] = Field(default=None, description="Optional list of polygons")
|
|
344
|
+
|
|
345
|
+
attribution: Optional[Attribution] = Field(default=None, description="Attribution information for the image")
|
|
346
|
+
dataset_name: Optional[str] = Field(
|
|
347
|
+
default=None,
|
|
348
|
+
description=(
|
|
349
|
+
"Don't manually set this, it will be automatically defined during initialization. "
|
|
350
|
+
"Name of the dataset the sample belongs to. E.g. 'coco-2017' or 'midwest-vehicle-detection'."
|
|
351
|
+
),
|
|
352
|
+
)
|
|
353
|
+
meta: Optional[Dict] = Field(
|
|
354
|
+
default=None,
|
|
355
|
+
description="Additional metadata, e.g., camera settings, GPS data, etc.",
|
|
356
|
+
)
|
|
277
357
|
|
|
278
358
|
def get_annotations(self, primitive_types: Optional[List[Type[Primitive]]] = None) -> List[Primitive]:
|
|
279
359
|
"""
|
|
@@ -294,7 +374,7 @@ class Sample(BaseModel):
|
|
|
294
374
|
Reads the image from the file path and returns it as a PIL Image.
|
|
295
375
|
Raises FileNotFoundError if the image file does not exist.
|
|
296
376
|
"""
|
|
297
|
-
path_image = Path(self.
|
|
377
|
+
path_image = Path(self.file_path)
|
|
298
378
|
if not path_image.exists():
|
|
299
379
|
raise FileNotFoundError(f"Image file {path_image} does not exist. Please check the file path.")
|
|
300
380
|
|
|
@@ -413,30 +493,23 @@ class HafniaDataset:
|
|
|
413
493
|
yield row
|
|
414
494
|
|
|
415
495
|
def __post_init__(self):
|
|
416
|
-
samples = self.samples
|
|
417
|
-
if ColumnName.SAMPLE_INDEX not in samples.columns:
|
|
418
|
-
samples = samples.with_row_index(name=ColumnName.SAMPLE_INDEX)
|
|
419
|
-
|
|
420
|
-
# Backwards compatibility: If tags-column doesn't exist, create it with empty lists
|
|
421
|
-
if ColumnName.TAGS not in samples.columns:
|
|
422
|
-
tags_column: List[List[str]] = [[] for _ in range(len(self))] # type: ignore[annotation-unchecked]
|
|
423
|
-
samples = samples.with_columns(pl.Series(tags_column, dtype=pl.List(pl.String)).alias(ColumnName.TAGS))
|
|
424
|
-
|
|
425
|
-
self.samples = samples
|
|
496
|
+
self.samples, self.info = _dataset_corrections(self.samples, self.info)
|
|
426
497
|
|
|
427
498
|
@staticmethod
|
|
428
499
|
def from_path(path_folder: Path, check_for_images: bool = True) -> "HafniaDataset":
|
|
500
|
+
path_folder = Path(path_folder)
|
|
429
501
|
HafniaDataset.check_dataset_path(path_folder, raise_error=True)
|
|
430
502
|
|
|
431
503
|
dataset_info = DatasetInfo.from_json_file(path_folder / FILENAME_DATASET_INFO)
|
|
432
|
-
|
|
504
|
+
samples = read_samples_from_path(path_folder)
|
|
505
|
+
samples, dataset_info = _dataset_corrections(samples, dataset_info)
|
|
433
506
|
|
|
434
507
|
# Convert from relative paths to absolute paths
|
|
435
508
|
dataset_root = path_folder.absolute().as_posix() + "/"
|
|
436
|
-
|
|
509
|
+
samples = samples.with_columns((dataset_root + pl.col(ColumnName.FILE_PATH)).alias(ColumnName.FILE_PATH))
|
|
437
510
|
if check_for_images:
|
|
438
|
-
check_image_paths(
|
|
439
|
-
return HafniaDataset(samples=
|
|
511
|
+
check_image_paths(samples)
|
|
512
|
+
return HafniaDataset(samples=samples, info=dataset_info)
|
|
440
513
|
|
|
441
514
|
@staticmethod
|
|
442
515
|
def from_name(name: str, force_redownload: bool = False, download_files: bool = True) -> "HafniaDataset":
|
|
@@ -464,6 +537,14 @@ class HafniaDataset:
|
|
|
464
537
|
|
|
465
538
|
table = pl.from_records(json_samples)
|
|
466
539
|
table = table.drop(ColumnName.SAMPLE_INDEX).with_row_index(name=ColumnName.SAMPLE_INDEX)
|
|
540
|
+
|
|
541
|
+
# Add 'dataset_name' to samples
|
|
542
|
+
table = table.with_columns(
|
|
543
|
+
pl.when(pl.col(ColumnName.DATASET_NAME).is_null())
|
|
544
|
+
.then(pl.lit(info.dataset_name))
|
|
545
|
+
.otherwise(pl.col(ColumnName.DATASET_NAME))
|
|
546
|
+
.alias(ColumnName.DATASET_NAME)
|
|
547
|
+
)
|
|
467
548
|
return HafniaDataset(info=info, samples=table)
|
|
468
549
|
|
|
469
550
|
@staticmethod
|
|
@@ -518,6 +599,28 @@ class HafniaDataset:
|
|
|
518
599
|
merged_dataset = HafniaDataset.merge(merged_dataset, dataset)
|
|
519
600
|
return merged_dataset
|
|
520
601
|
|
|
602
|
+
@staticmethod
|
|
603
|
+
def from_name_public_dataset(
|
|
604
|
+
name: str,
|
|
605
|
+
force_redownload: bool = False,
|
|
606
|
+
n_samples: Optional[int] = None,
|
|
607
|
+
) -> HafniaDataset:
|
|
608
|
+
from hafnia.dataset.format_conversions.torchvision_datasets import (
|
|
609
|
+
torchvision_to_hafnia_converters,
|
|
610
|
+
)
|
|
611
|
+
|
|
612
|
+
name_to_torchvision_function = torchvision_to_hafnia_converters()
|
|
613
|
+
|
|
614
|
+
if name not in name_to_torchvision_function:
|
|
615
|
+
raise ValueError(
|
|
616
|
+
f"Unknown torchvision dataset name: {name}. Supported: {list(name_to_torchvision_function.keys())}"
|
|
617
|
+
)
|
|
618
|
+
vision_dataset = name_to_torchvision_function[name]
|
|
619
|
+
return vision_dataset(
|
|
620
|
+
force_redownload=force_redownload,
|
|
621
|
+
n_samples=n_samples,
|
|
622
|
+
)
|
|
623
|
+
|
|
521
624
|
def shuffle(dataset: HafniaDataset, seed: int = 42) -> HafniaDataset:
|
|
522
625
|
table = dataset.samples.sample(n=len(dataset), with_replacement=False, seed=seed, shuffle=True)
|
|
523
626
|
return dataset.update_samples(table)
|
|
@@ -607,7 +710,7 @@ class HafniaDataset:
|
|
|
607
710
|
|
|
608
711
|
def class_mapper(
|
|
609
712
|
dataset: "HafniaDataset",
|
|
610
|
-
class_mapping: Dict[str, str],
|
|
713
|
+
class_mapping: Union[Dict[str, str], List[Tuple[str, str]]],
|
|
611
714
|
method: str = "strict",
|
|
612
715
|
primitive: Optional[Type[Primitive]] = None,
|
|
613
716
|
task_name: Optional[str] = None,
|
|
@@ -778,13 +881,14 @@ class HafniaDataset:
|
|
|
778
881
|
path_folder.mkdir(parents=True)
|
|
779
882
|
|
|
780
883
|
new_relative_paths = []
|
|
781
|
-
|
|
884
|
+
org_paths = self.samples[ColumnName.FILE_PATH].to_list()
|
|
885
|
+
for org_path in track(org_paths, description="- Copy images"):
|
|
782
886
|
new_path = dataset_helpers.copy_and_rename_file_to_hash_value(
|
|
783
887
|
path_source=Path(org_path),
|
|
784
888
|
path_dataset_root=path_folder,
|
|
785
889
|
)
|
|
786
890
|
new_relative_paths.append(str(new_path.relative_to(path_folder)))
|
|
787
|
-
table = self.samples.with_columns(pl.Series(new_relative_paths).alias(
|
|
891
|
+
table = self.samples.with_columns(pl.Series(new_relative_paths).alias(ColumnName.FILE_PATH))
|
|
788
892
|
|
|
789
893
|
if drop_null_cols: # Drops all unused/Null columns
|
|
790
894
|
table = table.drop(pl.selectors.by_dtype(pl.Null))
|
|
@@ -846,3 +950,25 @@ def get_or_create_dataset_path_from_recipe(
|
|
|
846
950
|
dataset.write(path_dataset)
|
|
847
951
|
|
|
848
952
|
return path_dataset
|
|
953
|
+
|
|
954
|
+
|
|
955
|
+
def _dataset_corrections(samples: pl.DataFrame, dataset_info: DatasetInfo) -> Tuple[pl.DataFrame, DatasetInfo]:
|
|
956
|
+
format_version_of_dataset = Version(dataset_info.format_version)
|
|
957
|
+
|
|
958
|
+
## Backwards compatibility fixes for older dataset versions
|
|
959
|
+
if format_version_of_dataset <= Version("0.3.0"):
|
|
960
|
+
if ColumnName.DATASET_NAME not in samples.columns:
|
|
961
|
+
samples = samples.with_columns(pl.lit(dataset_info.dataset_name).alias(ColumnName.DATASET_NAME))
|
|
962
|
+
|
|
963
|
+
if "file_name" in samples.columns:
|
|
964
|
+
samples = samples.rename({"file_name": ColumnName.FILE_PATH})
|
|
965
|
+
|
|
966
|
+
if ColumnName.SAMPLE_INDEX not in samples.columns:
|
|
967
|
+
samples = samples.with_row_index(name=ColumnName.SAMPLE_INDEX)
|
|
968
|
+
|
|
969
|
+
# Backwards compatibility: If tags-column doesn't exist, create it with empty lists
|
|
970
|
+
if ColumnName.TAGS not in samples.columns:
|
|
971
|
+
tags_column: List[List[str]] = [[] for _ in range(len(samples))] # type: ignore[annotation-unchecked]
|
|
972
|
+
samples = samples.with_columns(pl.Series(tags_column, dtype=pl.List(pl.String)).alias(ColumnName.TAGS))
|
|
973
|
+
|
|
974
|
+
return samples, dataset_info
|
|
@@ -5,8 +5,8 @@ from typing import TYPE_CHECKING, Dict, Optional, Type
|
|
|
5
5
|
import polars as pl
|
|
6
6
|
import rich
|
|
7
7
|
from rich import print as rprint
|
|
8
|
+
from rich.progress import track
|
|
8
9
|
from rich.table import Table
|
|
9
|
-
from tqdm import tqdm
|
|
10
10
|
|
|
11
11
|
from hafnia.dataset.dataset_names import ColumnName, FieldName, SplitName
|
|
12
12
|
from hafnia.dataset.operations.table_transformations import create_primitive_table
|
|
@@ -179,7 +179,6 @@ def check_dataset(dataset: HafniaDataset):
|
|
|
179
179
|
from hafnia.dataset.hafnia_dataset import Sample
|
|
180
180
|
|
|
181
181
|
user_logger.info("Checking Hafnia dataset...")
|
|
182
|
-
assert isinstance(dataset.info.version, str) and len(dataset.info.version) > 0
|
|
183
182
|
assert isinstance(dataset.info.dataset_name, str) and len(dataset.info.dataset_name) > 0
|
|
184
183
|
|
|
185
184
|
sample_dataset = dataset.create_sample_dataset()
|
|
@@ -215,7 +214,7 @@ def check_dataset(dataset: HafniaDataset):
|
|
|
215
214
|
f"classes: {class_names}. "
|
|
216
215
|
)
|
|
217
216
|
|
|
218
|
-
for sample_dict in
|
|
217
|
+
for sample_dict in track(dataset, description="Checking samples in dataset"):
|
|
219
218
|
sample = Sample(**sample_dict) # noqa: F841
|
|
220
219
|
|
|
221
220
|
|