hafnia 0.3.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/__main__.py +3 -1
- cli/config.py +43 -3
- cli/keychain.py +88 -0
- cli/profile_cmds.py +5 -2
- hafnia/__init__.py +1 -1
- hafnia/dataset/dataset_helpers.py +9 -2
- hafnia/dataset/dataset_names.py +130 -16
- hafnia/dataset/dataset_recipe/dataset_recipe.py +49 -37
- hafnia/dataset/dataset_recipe/recipe_transforms.py +18 -2
- hafnia/dataset/dataset_upload_helper.py +83 -22
- hafnia/dataset/format_conversions/format_image_classification_folder.py +110 -0
- hafnia/dataset/format_conversions/format_yolo.py +164 -0
- hafnia/dataset/format_conversions/torchvision_datasets.py +287 -0
- hafnia/dataset/hafnia_dataset.py +396 -96
- hafnia/dataset/operations/dataset_stats.py +84 -73
- hafnia/dataset/operations/dataset_transformations.py +116 -47
- hafnia/dataset/operations/table_transformations.py +135 -17
- hafnia/dataset/primitives/bbox.py +25 -14
- hafnia/dataset/primitives/bitmask.py +22 -15
- hafnia/dataset/primitives/classification.py +16 -8
- hafnia/dataset/primitives/point.py +7 -3
- hafnia/dataset/primitives/polygon.py +15 -10
- hafnia/dataset/primitives/primitive.py +1 -1
- hafnia/dataset/primitives/segmentation.py +12 -9
- hafnia/experiment/hafnia_logger.py +0 -9
- hafnia/platform/dataset_recipe.py +7 -2
- hafnia/platform/datasets.py +5 -9
- hafnia/platform/download.py +24 -90
- hafnia/torch_helpers.py +12 -12
- hafnia/utils.py +17 -0
- hafnia/visualizations/image_visualizations.py +3 -1
- {hafnia-0.3.0.dist-info → hafnia-0.4.1.dist-info}/METADATA +11 -9
- hafnia-0.4.1.dist-info/RECORD +57 -0
- hafnia-0.3.0.dist-info/RECORD +0 -53
- {hafnia-0.3.0.dist-info → hafnia-0.4.1.dist-info}/WHEEL +0 -0
- {hafnia-0.3.0.dist-info → hafnia-0.4.1.dist-info}/entry_points.txt +0 -0
- {hafnia-0.3.0.dist-info → hafnia-0.4.1.dist-info}/licenses/LICENSE +0 -0
hafnia/dataset/hafnia_dataset.py
CHANGED
|
@@ -8,14 +8,16 @@ from dataclasses import dataclass
|
|
|
8
8
|
from datetime import datetime
|
|
9
9
|
from pathlib import Path
|
|
10
10
|
from random import Random
|
|
11
|
-
from typing import Any, Dict, List, Optional, Type, Union
|
|
11
|
+
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
|
12
12
|
|
|
13
|
+
import cv2
|
|
13
14
|
import more_itertools
|
|
14
15
|
import numpy as np
|
|
15
16
|
import polars as pl
|
|
17
|
+
from packaging.version import Version
|
|
16
18
|
from PIL import Image
|
|
17
19
|
from pydantic import BaseModel, Field, field_serializer, field_validator
|
|
18
|
-
from
|
|
20
|
+
from rich.progress import track
|
|
19
21
|
|
|
20
22
|
import hafnia
|
|
21
23
|
from hafnia.dataset import dataset_helpers
|
|
@@ -26,13 +28,20 @@ from hafnia.dataset.dataset_names import (
|
|
|
26
28
|
FILENAME_DATASET_INFO,
|
|
27
29
|
FILENAME_RECIPE_JSON,
|
|
28
30
|
TAG_IS_SAMPLE,
|
|
29
|
-
|
|
31
|
+
AwsCredentials,
|
|
32
|
+
PrimitiveField,
|
|
33
|
+
SampleField,
|
|
30
34
|
SplitName,
|
|
35
|
+
StorageFormat,
|
|
31
36
|
)
|
|
32
|
-
from hafnia.dataset.
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
37
|
+
from hafnia.dataset.format_conversions import (
|
|
38
|
+
format_image_classification_folder,
|
|
39
|
+
format_yolo,
|
|
40
|
+
)
|
|
41
|
+
from hafnia.dataset.operations import (
|
|
42
|
+
dataset_stats,
|
|
43
|
+
dataset_transformations,
|
|
44
|
+
table_transformations,
|
|
36
45
|
)
|
|
37
46
|
from hafnia.dataset.primitives import PRIMITIVE_TYPES, get_primitive_type_from_string
|
|
38
47
|
from hafnia.dataset.primitives.bbox import Bbox
|
|
@@ -44,14 +53,30 @@ from hafnia.log import user_logger
|
|
|
44
53
|
|
|
45
54
|
|
|
46
55
|
class TaskInfo(BaseModel):
|
|
47
|
-
primitive: Type[Primitive]
|
|
48
|
-
|
|
49
|
-
|
|
56
|
+
primitive: Type[Primitive] = Field(
|
|
57
|
+
description="Primitive class or string name of the primitive, e.g. 'Bbox' or 'bitmask'"
|
|
58
|
+
)
|
|
59
|
+
class_names: Optional[List[str]] = Field(default=None, description="Optional list of class names for the primitive")
|
|
60
|
+
name: Optional[str] = Field(
|
|
61
|
+
default=None,
|
|
62
|
+
description=(
|
|
63
|
+
"Optional name for the task. 'None' will use default name of the provided primitive. "
|
|
64
|
+
"e.g. Bbox ->'bboxes', Bitmask -> 'bitmasks' etc."
|
|
65
|
+
),
|
|
66
|
+
)
|
|
50
67
|
|
|
51
68
|
def model_post_init(self, __context: Any) -> None:
|
|
52
69
|
if self.name is None:
|
|
53
70
|
self.name = self.primitive.default_task_name()
|
|
54
71
|
|
|
72
|
+
def get_class_index(self, class_name: str) -> int:
|
|
73
|
+
"""Get class index for a given class name"""
|
|
74
|
+
if self.class_names is None:
|
|
75
|
+
raise ValueError(f"Task '{self.name}' has no class names defined.")
|
|
76
|
+
if class_name not in self.class_names:
|
|
77
|
+
raise ValueError(f"Class name '{class_name}' not found in task '{self.name}'.")
|
|
78
|
+
return self.class_names.index(class_name)
|
|
79
|
+
|
|
55
80
|
# The 'primitive'-field of type 'Type[Primitive]' is not supported by pydantic out-of-the-box as
|
|
56
81
|
# the 'Primitive' class is an abstract base class and for the actual primtives such as Bbox, Bitmask, Classification.
|
|
57
82
|
# Below magic functions ('ensure_primitive' and 'serialize_primitive') ensures that the 'primitive' field can
|
|
@@ -87,6 +112,10 @@ class TaskInfo(BaseModel):
|
|
|
87
112
|
)
|
|
88
113
|
return class_names
|
|
89
114
|
|
|
115
|
+
def full_name(self) -> str:
|
|
116
|
+
"""Get qualified name for the task: <primitive_name>:<task_name>"""
|
|
117
|
+
return f"{self.primitive.__name__}:{self.name}"
|
|
118
|
+
|
|
90
119
|
# To get unique hash value for TaskInfo objects
|
|
91
120
|
def __hash__(self) -> int:
|
|
92
121
|
class_names = self.class_names or []
|
|
@@ -99,17 +128,36 @@ class TaskInfo(BaseModel):
|
|
|
99
128
|
|
|
100
129
|
|
|
101
130
|
class DatasetInfo(BaseModel):
|
|
102
|
-
dataset_name: str
|
|
103
|
-
version: str
|
|
104
|
-
tasks: List[TaskInfo]
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
131
|
+
dataset_name: str = Field(description="Name of the dataset, e.g. 'coco'")
|
|
132
|
+
version: Optional[str] = Field(default=None, description="Version of the dataset")
|
|
133
|
+
tasks: List[TaskInfo] = Field(default=None, description="List of tasks in the dataset")
|
|
134
|
+
reference_bibtex: Optional[str] = Field(
|
|
135
|
+
default=None,
|
|
136
|
+
description="Optional, BibTeX reference to dataset publication",
|
|
137
|
+
)
|
|
138
|
+
reference_paper_url: Optional[str] = Field(
|
|
139
|
+
default=None,
|
|
140
|
+
description="Optional, URL to dataset publication",
|
|
141
|
+
)
|
|
142
|
+
reference_dataset_page: Optional[str] = Field(
|
|
143
|
+
default=None,
|
|
144
|
+
description="Optional, URL to the dataset page",
|
|
145
|
+
)
|
|
146
|
+
meta: Optional[Dict[str, Any]] = Field(default=None, description="Optional metadata about the dataset")
|
|
147
|
+
format_version: str = Field(
|
|
148
|
+
default=hafnia.__dataset_format_version__,
|
|
149
|
+
description="Version of the Hafnia dataset format. You should not set this manually.",
|
|
150
|
+
)
|
|
151
|
+
updated_at: datetime = Field(
|
|
152
|
+
default_factory=datetime.now,
|
|
153
|
+
description="Timestamp of the last update to the dataset info. You should not set this manually.",
|
|
154
|
+
)
|
|
109
155
|
|
|
110
156
|
@field_validator("tasks", mode="after")
|
|
111
157
|
@classmethod
|
|
112
|
-
def _validate_check_for_duplicate_tasks(cls, tasks: List[TaskInfo]) -> List[TaskInfo]:
|
|
158
|
+
def _validate_check_for_duplicate_tasks(cls, tasks: Optional[List[TaskInfo]]) -> List[TaskInfo]:
|
|
159
|
+
if tasks is None:
|
|
160
|
+
return []
|
|
113
161
|
task_name_counts = collections.Counter(task.name for task in tasks)
|
|
114
162
|
duplicate_task_names = [name for name, count in task_name_counts.items() if count > 1]
|
|
115
163
|
if duplicate_task_names:
|
|
@@ -118,6 +166,35 @@ class DatasetInfo(BaseModel):
|
|
|
118
166
|
)
|
|
119
167
|
return tasks
|
|
120
168
|
|
|
169
|
+
@field_validator("format_version")
|
|
170
|
+
@classmethod
|
|
171
|
+
def _validate_format_version(cls, format_version: str) -> str:
|
|
172
|
+
try:
|
|
173
|
+
Version(format_version)
|
|
174
|
+
except Exception as e:
|
|
175
|
+
raise ValueError(f"Invalid format_version '{format_version}'. Must be a valid version string.") from e
|
|
176
|
+
|
|
177
|
+
if Version(format_version) > Version(hafnia.__dataset_format_version__):
|
|
178
|
+
user_logger.warning(
|
|
179
|
+
f"The loaded dataset format version '{format_version}' is newer than the format version "
|
|
180
|
+
f"'{hafnia.__dataset_format_version__}' used in your version of Hafnia. Please consider "
|
|
181
|
+
f"updating Hafnia package."
|
|
182
|
+
)
|
|
183
|
+
return format_version
|
|
184
|
+
|
|
185
|
+
@field_validator("version")
|
|
186
|
+
@classmethod
|
|
187
|
+
def _validate_version(cls, dataset_version: Optional[str]) -> Optional[str]:
|
|
188
|
+
if dataset_version is None:
|
|
189
|
+
return None
|
|
190
|
+
|
|
191
|
+
try:
|
|
192
|
+
Version(dataset_version)
|
|
193
|
+
except Exception as e:
|
|
194
|
+
raise ValueError(f"Invalid dataset_version '{dataset_version}'. Must be a valid version string.") from e
|
|
195
|
+
|
|
196
|
+
return dataset_version
|
|
197
|
+
|
|
121
198
|
def check_for_duplicate_task_names(self) -> List[TaskInfo]:
|
|
122
199
|
return self._validate_check_for_duplicate_tasks(self.tasks)
|
|
123
200
|
|
|
@@ -182,14 +259,12 @@ class DatasetInfo(BaseModel):
|
|
|
182
259
|
f"Hafnia format version '{hafnia.__dataset_format_version__}'."
|
|
183
260
|
)
|
|
184
261
|
unique_tasks = set(info0.tasks + info1.tasks)
|
|
185
|
-
distributions = set((info0.distributions or []) + (info1.distributions or []))
|
|
186
262
|
meta = (info0.meta or {}).copy()
|
|
187
263
|
meta.update(info1.meta or {})
|
|
188
264
|
return DatasetInfo(
|
|
189
265
|
dataset_name=info0.dataset_name + "+" + info1.dataset_name,
|
|
190
|
-
version=
|
|
266
|
+
version=None,
|
|
191
267
|
tasks=list(unique_tasks),
|
|
192
|
-
distributions=list(distributions),
|
|
193
268
|
meta=meta,
|
|
194
269
|
format_version=dataset_format_version,
|
|
195
270
|
)
|
|
@@ -205,16 +280,24 @@ class DatasetInfo(BaseModel):
|
|
|
205
280
|
raise ValueError(f"Multiple tasks found with name '{task_name}'. This should not happen!")
|
|
206
281
|
return tasks_with_name[0]
|
|
207
282
|
|
|
208
|
-
def
|
|
283
|
+
def get_tasks_by_primitive(self, primitive: Union[Type[Primitive], str]) -> List[TaskInfo]:
|
|
209
284
|
"""
|
|
210
|
-
Get
|
|
211
|
-
have the same primitive type.
|
|
285
|
+
Get all tasks by their primitive type.
|
|
212
286
|
"""
|
|
213
287
|
if isinstance(primitive, str):
|
|
214
288
|
primitive = get_primitive_type_from_string(primitive)
|
|
215
289
|
|
|
216
290
|
tasks_with_primitive = [task for task in self.tasks if task.primitive == primitive]
|
|
217
|
-
|
|
291
|
+
return tasks_with_primitive
|
|
292
|
+
|
|
293
|
+
def get_task_by_primitive(self, primitive: Union[Type[Primitive], str]) -> TaskInfo:
|
|
294
|
+
"""
|
|
295
|
+
Get task by its primitive type. Raises an error if the primitive type is not found or if multiple tasks
|
|
296
|
+
have the same primitive type.
|
|
297
|
+
"""
|
|
298
|
+
|
|
299
|
+
tasks_with_primitive = self.get_tasks_by_primitive(primitive)
|
|
300
|
+
if len(tasks_with_primitive) == 0:
|
|
218
301
|
raise ValueError(f"Task with primitive {primitive} not found in dataset info.")
|
|
219
302
|
if len(tasks_with_primitive) > 1:
|
|
220
303
|
raise ValueError(
|
|
@@ -258,22 +341,44 @@ class DatasetInfo(BaseModel):
|
|
|
258
341
|
|
|
259
342
|
|
|
260
343
|
class Sample(BaseModel):
|
|
261
|
-
|
|
262
|
-
height: int
|
|
263
|
-
width: int
|
|
264
|
-
split: str
|
|
265
|
-
tags: List[str] =
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
344
|
+
file_path: Optional[str] = Field(description="Path to the image/video file.")
|
|
345
|
+
height: int = Field(description="Height of the image")
|
|
346
|
+
width: int = Field(description="Width of the image")
|
|
347
|
+
split: str = Field(description="Split name, e.g., 'train', 'val', 'test'")
|
|
348
|
+
tags: List[str] = Field(
|
|
349
|
+
default_factory=list,
|
|
350
|
+
description="Tags for a given sample. Used for creating subsets of the dataset.",
|
|
351
|
+
)
|
|
352
|
+
storage_format: str = Field(
|
|
353
|
+
default=StorageFormat.IMAGE,
|
|
354
|
+
description="Storage format. Sample data is stored as image or inside a video or zip file.",
|
|
355
|
+
)
|
|
356
|
+
collection_index: Optional[int] = Field(default=None, description="Optional e.g. frame number for video datasets")
|
|
357
|
+
collection_id: Optional[str] = Field(default=None, description="Optional e.g. video name for video datasets")
|
|
358
|
+
remote_path: Optional[str] = Field(default=None, description="Optional remote path for the image, if applicable")
|
|
359
|
+
sample_index: Optional[int] = Field(
|
|
360
|
+
default=None,
|
|
361
|
+
description="Don't manually set this, it is used for indexing samples in the dataset.",
|
|
362
|
+
)
|
|
363
|
+
classifications: Optional[List[Classification]] = Field(
|
|
364
|
+
default=None, description="Optional list of classifications"
|
|
365
|
+
)
|
|
366
|
+
bboxes: Optional[List[Bbox]] = Field(default=None, description="Optional list of bounding boxes")
|
|
367
|
+
bitmasks: Optional[List[Bitmask]] = Field(default=None, description="Optional list of bitmasks")
|
|
368
|
+
polygons: Optional[List[Polygon]] = Field(default=None, description="Optional list of polygons")
|
|
369
|
+
|
|
370
|
+
attribution: Optional[Attribution] = Field(default=None, description="Attribution information for the image")
|
|
371
|
+
dataset_name: Optional[str] = Field(
|
|
372
|
+
default=None,
|
|
373
|
+
description=(
|
|
374
|
+
"Don't manually set this, it will be automatically defined during initialization. "
|
|
375
|
+
"Name of the dataset the sample belongs to. E.g. 'coco-2017' or 'midwest-vehicle-detection'."
|
|
376
|
+
),
|
|
377
|
+
)
|
|
378
|
+
meta: Optional[Dict] = Field(
|
|
379
|
+
default=None,
|
|
380
|
+
description="Additional metadata, e.g., camera settings, GPS data, etc.",
|
|
381
|
+
)
|
|
277
382
|
|
|
278
383
|
def get_annotations(self, primitive_types: Optional[List[Type[Primitive]]] = None) -> List[Primitive]:
|
|
279
384
|
"""
|
|
@@ -294,7 +399,9 @@ class Sample(BaseModel):
|
|
|
294
399
|
Reads the image from the file path and returns it as a PIL Image.
|
|
295
400
|
Raises FileNotFoundError if the image file does not exist.
|
|
296
401
|
"""
|
|
297
|
-
|
|
402
|
+
if self.file_path is None:
|
|
403
|
+
raise ValueError(f"Sample has no '{SampleField.FILE_PATH}' defined.")
|
|
404
|
+
path_image = Path(self.file_path)
|
|
298
405
|
if not path_image.exists():
|
|
299
406
|
raise FileNotFoundError(f"Image file {path_image} does not exist. Please check the file path.")
|
|
300
407
|
|
|
@@ -302,8 +409,22 @@ class Sample(BaseModel):
|
|
|
302
409
|
return image
|
|
303
410
|
|
|
304
411
|
def read_image(self) -> np.ndarray:
|
|
305
|
-
|
|
306
|
-
|
|
412
|
+
if self.storage_format == StorageFormat.VIDEO:
|
|
413
|
+
video = cv2.VideoCapture(str(self.file_path))
|
|
414
|
+
if self.collection_index is None:
|
|
415
|
+
raise ValueError("collection_index must be set for video storage format to read the correct frame.")
|
|
416
|
+
video.set(cv2.CAP_PROP_POS_FRAMES, self.collection_index)
|
|
417
|
+
success, image = video.read()
|
|
418
|
+
video.release()
|
|
419
|
+
if not success:
|
|
420
|
+
raise ValueError(f"Could not read frame {self.collection_index} from video file {self.file_path}.")
|
|
421
|
+
return image
|
|
422
|
+
|
|
423
|
+
elif self.storage_format == StorageFormat.IMAGE:
|
|
424
|
+
image_pil = self.read_image_pillow()
|
|
425
|
+
image = np.array(image_pil)
|
|
426
|
+
else:
|
|
427
|
+
raise ValueError(f"Unsupported storage format: {self.storage_format}")
|
|
307
428
|
return image
|
|
308
429
|
|
|
309
430
|
def draw_annotations(self, image: Optional[np.ndarray] = None) -> np.ndarray:
|
|
@@ -386,9 +507,11 @@ class HafniaDataset:
|
|
|
386
507
|
samples: pl.DataFrame
|
|
387
508
|
|
|
388
509
|
# Function mapping: Dataset stats
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
510
|
+
calculate_split_counts = dataset_stats.calculate_split_counts
|
|
511
|
+
calculate_split_counts_extended = dataset_stats.calculate_split_counts_extended
|
|
512
|
+
calculate_task_class_counts = dataset_stats.calculate_task_class_counts
|
|
513
|
+
calculate_class_counts = dataset_stats.calculate_class_counts
|
|
514
|
+
calculate_primitive_counts = dataset_stats.calculate_primitive_counts
|
|
392
515
|
|
|
393
516
|
# Function mapping: Print stats
|
|
394
517
|
print_stats = dataset_stats.print_stats
|
|
@@ -401,6 +524,13 @@ class HafniaDataset:
|
|
|
401
524
|
|
|
402
525
|
# Function mapping: Dataset transformations
|
|
403
526
|
transform_images = dataset_transformations.transform_images
|
|
527
|
+
convert_to_image_storage_format = dataset_transformations.convert_to_image_storage_format
|
|
528
|
+
|
|
529
|
+
# Import / export functions
|
|
530
|
+
from_yolo_format = format_yolo.from_yolo_format
|
|
531
|
+
to_yolo_format = format_yolo.to_yolo_format
|
|
532
|
+
to_image_classification_folder = format_image_classification_folder.to_image_classification_folder
|
|
533
|
+
from_image_classification_folder = format_image_classification_folder.from_image_classification_folder
|
|
404
534
|
|
|
405
535
|
def __getitem__(self, item: int) -> Dict[str, Any]:
|
|
406
536
|
return self.samples.row(index=item, named=True)
|
|
@@ -413,30 +543,23 @@ class HafniaDataset:
|
|
|
413
543
|
yield row
|
|
414
544
|
|
|
415
545
|
def __post_init__(self):
|
|
416
|
-
samples = self.samples
|
|
417
|
-
if ColumnName.SAMPLE_INDEX not in samples.columns:
|
|
418
|
-
samples = samples.with_row_index(name=ColumnName.SAMPLE_INDEX)
|
|
419
|
-
|
|
420
|
-
# Backwards compatibility: If tags-column doesn't exist, create it with empty lists
|
|
421
|
-
if ColumnName.TAGS not in samples.columns:
|
|
422
|
-
tags_column: List[List[str]] = [[] for _ in range(len(self))] # type: ignore[annotation-unchecked]
|
|
423
|
-
samples = samples.with_columns(pl.Series(tags_column, dtype=pl.List(pl.String)).alias(ColumnName.TAGS))
|
|
424
|
-
|
|
425
|
-
self.samples = samples
|
|
546
|
+
self.samples, self.info = _dataset_corrections(self.samples, self.info)
|
|
426
547
|
|
|
427
548
|
@staticmethod
|
|
428
549
|
def from_path(path_folder: Path, check_for_images: bool = True) -> "HafniaDataset":
|
|
550
|
+
path_folder = Path(path_folder)
|
|
429
551
|
HafniaDataset.check_dataset_path(path_folder, raise_error=True)
|
|
430
552
|
|
|
431
553
|
dataset_info = DatasetInfo.from_json_file(path_folder / FILENAME_DATASET_INFO)
|
|
432
|
-
|
|
554
|
+
samples = table_transformations.read_samples_from_path(path_folder)
|
|
555
|
+
samples, dataset_info = _dataset_corrections(samples, dataset_info)
|
|
433
556
|
|
|
434
557
|
# Convert from relative paths to absolute paths
|
|
435
558
|
dataset_root = path_folder.absolute().as_posix() + "/"
|
|
436
|
-
|
|
559
|
+
samples = samples.with_columns((dataset_root + pl.col(SampleField.FILE_PATH)).alias(SampleField.FILE_PATH))
|
|
437
560
|
if check_for_images:
|
|
438
|
-
check_image_paths(
|
|
439
|
-
return HafniaDataset(samples=
|
|
561
|
+
table_transformations.check_image_paths(samples)
|
|
562
|
+
return HafniaDataset(samples=samples, info=dataset_info)
|
|
440
563
|
|
|
441
564
|
@staticmethod
|
|
442
565
|
def from_name(name: str, force_redownload: bool = False, download_files: bool = True) -> "HafniaDataset":
|
|
@@ -462,8 +585,12 @@ class HafniaDataset:
|
|
|
462
585
|
else:
|
|
463
586
|
raise TypeError(f"Unsupported sample type: {type(sample)}. Expected Sample or dict.")
|
|
464
587
|
|
|
465
|
-
|
|
466
|
-
|
|
588
|
+
# To ensure that the 'file_path' column is of type string even if all samples have 'None' as file_path
|
|
589
|
+
schema_override = {SampleField.FILE_PATH: pl.String}
|
|
590
|
+
table = pl.from_records(json_samples, schema_overrides=schema_override)
|
|
591
|
+
table = table.drop(pl.selectors.by_dtype(pl.Null))
|
|
592
|
+
table = table_transformations.add_sample_index(table)
|
|
593
|
+
table = table_transformations.add_dataset_name_if_missing(table, dataset_name=info.dataset_name)
|
|
467
594
|
return HafniaDataset(info=info, samples=table)
|
|
468
595
|
|
|
469
596
|
@staticmethod
|
|
@@ -518,6 +645,28 @@ class HafniaDataset:
|
|
|
518
645
|
merged_dataset = HafniaDataset.merge(merged_dataset, dataset)
|
|
519
646
|
return merged_dataset
|
|
520
647
|
|
|
648
|
+
@staticmethod
|
|
649
|
+
def from_name_public_dataset(
|
|
650
|
+
name: str,
|
|
651
|
+
force_redownload: bool = False,
|
|
652
|
+
n_samples: Optional[int] = None,
|
|
653
|
+
) -> HafniaDataset:
|
|
654
|
+
from hafnia.dataset.format_conversions.torchvision_datasets import (
|
|
655
|
+
torchvision_to_hafnia_converters,
|
|
656
|
+
)
|
|
657
|
+
|
|
658
|
+
name_to_torchvision_function = torchvision_to_hafnia_converters()
|
|
659
|
+
|
|
660
|
+
if name not in name_to_torchvision_function:
|
|
661
|
+
raise ValueError(
|
|
662
|
+
f"Unknown torchvision dataset name: {name}. Supported: {list(name_to_torchvision_function.keys())}"
|
|
663
|
+
)
|
|
664
|
+
vision_dataset = name_to_torchvision_function[name]
|
|
665
|
+
return vision_dataset(
|
|
666
|
+
force_redownload=force_redownload,
|
|
667
|
+
n_samples=n_samples,
|
|
668
|
+
)
|
|
669
|
+
|
|
521
670
|
def shuffle(dataset: HafniaDataset, seed: int = 42) -> HafniaDataset:
|
|
522
671
|
table = dataset.samples.sample(n=len(dataset), with_replacement=False, seed=seed, shuffle=True)
|
|
523
672
|
return dataset.update_samples(table)
|
|
@@ -575,12 +724,12 @@ class HafniaDataset:
|
|
|
575
724
|
"""
|
|
576
725
|
dataset_split_to_be_divided = dataset.create_split_dataset(split_name=split_name)
|
|
577
726
|
if len(dataset_split_to_be_divided) == 0:
|
|
578
|
-
split_counts = dict(dataset.samples.select(pl.col(
|
|
727
|
+
split_counts = dict(dataset.samples.select(pl.col(SampleField.SPLIT).value_counts()).iter_rows())
|
|
579
728
|
raise ValueError(f"No samples in the '{split_name}' split to divide into multiple splits. {split_counts=}")
|
|
580
729
|
assert len(dataset_split_to_be_divided) > 0, f"No samples in the '{split_name}' split!"
|
|
581
730
|
dataset_split_to_be_divided = dataset_split_to_be_divided.splits_by_ratios(split_ratios=split_ratios, seed=42)
|
|
582
731
|
|
|
583
|
-
remaining_data = dataset.samples.filter(pl.col(
|
|
732
|
+
remaining_data = dataset.samples.filter(pl.col(SampleField.SPLIT).is_in([split_name]).not_())
|
|
584
733
|
new_table = pl.concat([remaining_data, dataset_split_to_be_divided.samples], how="vertical")
|
|
585
734
|
dataset_new = dataset.update_samples(new_table)
|
|
586
735
|
return dataset_new
|
|
@@ -593,21 +742,23 @@ class HafniaDataset:
|
|
|
593
742
|
|
|
594
743
|
# Remove any pre-existing "sample"-tags
|
|
595
744
|
samples = samples.with_columns(
|
|
596
|
-
pl.col(
|
|
745
|
+
pl.col(SampleField.TAGS)
|
|
746
|
+
.list.eval(pl.element().filter(pl.element() != TAG_IS_SAMPLE))
|
|
747
|
+
.alias(SampleField.TAGS)
|
|
597
748
|
)
|
|
598
749
|
|
|
599
750
|
# Add "sample" to tags column for the selected samples
|
|
600
751
|
is_sample_indices = Random(seed).sample(range(len(dataset)), n_samples)
|
|
601
752
|
samples = samples.with_columns(
|
|
602
753
|
pl.when(pl.int_range(len(samples)).is_in(is_sample_indices))
|
|
603
|
-
.then(pl.col(
|
|
604
|
-
.otherwise(pl.col(
|
|
754
|
+
.then(pl.col(SampleField.TAGS).list.concat(pl.lit([TAG_IS_SAMPLE])))
|
|
755
|
+
.otherwise(pl.col(SampleField.TAGS))
|
|
605
756
|
)
|
|
606
757
|
return dataset.update_samples(samples)
|
|
607
758
|
|
|
608
759
|
def class_mapper(
|
|
609
760
|
dataset: "HafniaDataset",
|
|
610
|
-
class_mapping: Dict[str, str],
|
|
761
|
+
class_mapping: Union[Dict[str, str], List[Tuple[str, str]]],
|
|
611
762
|
method: str = "strict",
|
|
612
763
|
primitive: Optional[Type[Primitive]] = None,
|
|
613
764
|
task_name: Optional[str] = None,
|
|
@@ -659,6 +810,47 @@ class HafniaDataset:
|
|
|
659
810
|
dataset=dataset, old_task_name=old_task_name, new_task_name=new_task_name
|
|
660
811
|
)
|
|
661
812
|
|
|
813
|
+
def drop_task(
|
|
814
|
+
dataset: "HafniaDataset",
|
|
815
|
+
task_name: str,
|
|
816
|
+
) -> "HafniaDataset":
|
|
817
|
+
"""
|
|
818
|
+
Drop a task from the dataset.
|
|
819
|
+
If 'task_name' and 'primitive' are not provided, the function will attempt to infer the task.
|
|
820
|
+
"""
|
|
821
|
+
dataset = copy.copy(dataset) # To avoid mutating the original dataset. Shallow copy is sufficient
|
|
822
|
+
drop_task = dataset.info.get_task_by_name(task_name=task_name)
|
|
823
|
+
tasks_with_same_primitive = dataset.info.get_tasks_by_primitive(drop_task.primitive)
|
|
824
|
+
|
|
825
|
+
no_other_tasks_with_same_primitive = len(tasks_with_same_primitive) == 1
|
|
826
|
+
if no_other_tasks_with_same_primitive:
|
|
827
|
+
return dataset.drop_primitive(primitive=drop_task.primitive)
|
|
828
|
+
|
|
829
|
+
dataset.info = dataset.info.replace_task(old_task=drop_task, new_task=None)
|
|
830
|
+
dataset.samples = dataset.samples.with_columns(
|
|
831
|
+
pl.col(drop_task.primitive.column_name())
|
|
832
|
+
.list.filter(pl.element().struct.field(PrimitiveField.TASK_NAME) != drop_task.name)
|
|
833
|
+
.alias(drop_task.primitive.column_name())
|
|
834
|
+
)
|
|
835
|
+
|
|
836
|
+
return dataset
|
|
837
|
+
|
|
838
|
+
def drop_primitive(
|
|
839
|
+
dataset: "HafniaDataset",
|
|
840
|
+
primitive: Type[Primitive],
|
|
841
|
+
) -> "HafniaDataset":
|
|
842
|
+
"""
|
|
843
|
+
Drop a primitive from the dataset.
|
|
844
|
+
"""
|
|
845
|
+
dataset = copy.copy(dataset) # To avoid mutating the original dataset. Shallow copy is sufficient
|
|
846
|
+
tasks_to_drop = dataset.info.get_tasks_by_primitive(primitive=primitive)
|
|
847
|
+
for task in tasks_to_drop:
|
|
848
|
+
dataset.info = dataset.info.replace_task(old_task=task, new_task=None)
|
|
849
|
+
|
|
850
|
+
# Drop the primitive column from the samples table
|
|
851
|
+
dataset.samples = dataset.samples.drop(primitive.column_name())
|
|
852
|
+
return dataset
|
|
853
|
+
|
|
662
854
|
def select_samples_by_class_name(
|
|
663
855
|
dataset: HafniaDataset,
|
|
664
856
|
name: Union[List[str], str],
|
|
@@ -695,13 +887,63 @@ class HafniaDataset:
|
|
|
695
887
|
|
|
696
888
|
return HafniaDataset(info=merged_info, samples=merged_samples)
|
|
697
889
|
|
|
698
|
-
def
|
|
890
|
+
def download_files_aws(
|
|
891
|
+
dataset: HafniaDataset,
|
|
892
|
+
path_output_folder: Path,
|
|
893
|
+
aws_credentials: AwsCredentials,
|
|
894
|
+
force_redownload: bool = False,
|
|
895
|
+
) -> HafniaDataset:
|
|
896
|
+
from hafnia.platform.datasets import fast_copy_files_s3
|
|
897
|
+
|
|
898
|
+
remote_src_paths = dataset.samples[SampleField.REMOTE_PATH].unique().to_list()
|
|
899
|
+
update_rows = []
|
|
900
|
+
local_dst_paths = []
|
|
901
|
+
for remote_src_path in remote_src_paths:
|
|
902
|
+
local_path_str = (path_output_folder / "data" / Path(remote_src_path).name).absolute().as_posix()
|
|
903
|
+
local_dst_paths.append(local_path_str)
|
|
904
|
+
update_rows.append(
|
|
905
|
+
{
|
|
906
|
+
SampleField.REMOTE_PATH: remote_src_path,
|
|
907
|
+
SampleField.FILE_PATH: local_path_str,
|
|
908
|
+
}
|
|
909
|
+
)
|
|
910
|
+
update_df = pl.DataFrame(update_rows)
|
|
911
|
+
samples = dataset.samples.update(update_df, on=[SampleField.REMOTE_PATH])
|
|
912
|
+
dataset = dataset.update_samples(samples)
|
|
913
|
+
|
|
914
|
+
if not force_redownload:
|
|
915
|
+
download_indices = [idx for idx, local_path in enumerate(local_dst_paths) if not Path(local_path).exists()]
|
|
916
|
+
n_files = len(local_dst_paths)
|
|
917
|
+
skip_files = n_files - len(download_indices)
|
|
918
|
+
if skip_files > 0:
|
|
919
|
+
user_logger.info(
|
|
920
|
+
f"Found {skip_files}/{n_files} files already exists. Downloading {len(download_indices)} files."
|
|
921
|
+
)
|
|
922
|
+
remote_src_paths = [remote_src_paths[idx] for idx in download_indices]
|
|
923
|
+
local_dst_paths = [local_dst_paths[idx] for idx in download_indices]
|
|
924
|
+
|
|
925
|
+
if len(remote_src_paths) == 0:
|
|
926
|
+
user_logger.info(
|
|
927
|
+
"All files already exist locally. Skipping download. Set 'force_redownload=True' to re-download."
|
|
928
|
+
)
|
|
929
|
+
return dataset
|
|
930
|
+
|
|
931
|
+
environment_vars = aws_credentials.aws_credentials()
|
|
932
|
+
fast_copy_files_s3(
|
|
933
|
+
src_paths=remote_src_paths,
|
|
934
|
+
dst_paths=local_dst_paths,
|
|
935
|
+
append_envs=environment_vars,
|
|
936
|
+
description="Downloading images",
|
|
937
|
+
)
|
|
938
|
+
return dataset
|
|
939
|
+
|
|
940
|
+
def to_dict_dataset_splits(self) -> Dict[str, "HafniaDataset"]:
|
|
699
941
|
"""
|
|
700
942
|
Splits the dataset into multiple datasets based on the 'split' column.
|
|
701
943
|
Returns a dictionary with split names as keys and HafniaDataset objects as values.
|
|
702
944
|
"""
|
|
703
|
-
if
|
|
704
|
-
raise ValueError(f"Dataset must contain a '{
|
|
945
|
+
if SampleField.SPLIT not in self.samples.columns:
|
|
946
|
+
raise ValueError(f"Dataset must contain a '{SampleField.SPLIT}' column.")
|
|
705
947
|
|
|
706
948
|
splits = {}
|
|
707
949
|
for split_name in SplitName.valid_splits():
|
|
@@ -710,20 +952,11 @@ class HafniaDataset:
|
|
|
710
952
|
return splits
|
|
711
953
|
|
|
712
954
|
def create_sample_dataset(self) -> "HafniaDataset":
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
user_logger.warning(
|
|
716
|
-
"'is_sample' column found in the dataset. This column is deprecated and will be removed in future versions. "
|
|
717
|
-
"Please use the 'tags' column with the tag 'sample' instead."
|
|
718
|
-
)
|
|
719
|
-
table = self.samples.filter(pl.col("is_sample") == True) # noqa: E712
|
|
720
|
-
return self.update_samples(table)
|
|
721
|
-
|
|
722
|
-
if ColumnName.TAGS not in self.samples.columns:
|
|
723
|
-
raise ValueError(f"Dataset must contain an '{ColumnName.TAGS}' column.")
|
|
955
|
+
if SampleField.TAGS not in self.samples.columns:
|
|
956
|
+
raise ValueError(f"Dataset must contain an '{SampleField.TAGS}' column.")
|
|
724
957
|
|
|
725
958
|
table = self.samples.filter(
|
|
726
|
-
pl.col(
|
|
959
|
+
pl.col(SampleField.TAGS).list.eval(pl.element().filter(pl.element() == TAG_IS_SAMPLE)).list.len() > 0
|
|
727
960
|
)
|
|
728
961
|
return self.update_samples(table)
|
|
729
962
|
|
|
@@ -734,10 +967,10 @@ class HafniaDataset:
|
|
|
734
967
|
split_names = split_name
|
|
735
968
|
|
|
736
969
|
for name in split_names:
|
|
737
|
-
if name not in SplitName.
|
|
970
|
+
if name not in SplitName.all_split_names():
|
|
738
971
|
raise ValueError(f"Invalid split name: {split_name}. Valid splits are: {SplitName.valid_splits()}")
|
|
739
972
|
|
|
740
|
-
filtered_dataset = self.samples.filter(pl.col(
|
|
973
|
+
filtered_dataset = self.samples.filter(pl.col(SampleField.SPLIT).is_in(split_names))
|
|
741
974
|
return self.update_samples(filtered_dataset)
|
|
742
975
|
|
|
743
976
|
def update_samples(self, table: pl.DataFrame) -> "HafniaDataset":
|
|
@@ -772,29 +1005,69 @@ class HafniaDataset:
|
|
|
772
1005
|
def copy(self) -> "HafniaDataset":
|
|
773
1006
|
return HafniaDataset(info=self.info.model_copy(deep=True), samples=self.samples.clone())
|
|
774
1007
|
|
|
1008
|
+
def create_primitive_table(
|
|
1009
|
+
self,
|
|
1010
|
+
primitive: Type[Primitive],
|
|
1011
|
+
task_name: Optional[str] = None,
|
|
1012
|
+
keep_sample_data: bool = False,
|
|
1013
|
+
) -> pl.DataFrame:
|
|
1014
|
+
return table_transformations.create_primitive_table(
|
|
1015
|
+
samples_table=self.samples,
|
|
1016
|
+
PrimitiveType=primitive,
|
|
1017
|
+
task_name=task_name,
|
|
1018
|
+
keep_sample_data=keep_sample_data,
|
|
1019
|
+
)
|
|
1020
|
+
|
|
775
1021
|
def write(self, path_folder: Path, add_version: bool = False, drop_null_cols: bool = True) -> None:
|
|
776
1022
|
user_logger.info(f"Writing dataset to {path_folder}...")
|
|
1023
|
+
path_folder = path_folder.absolute()
|
|
777
1024
|
if not path_folder.exists():
|
|
778
1025
|
path_folder.mkdir(parents=True)
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
1026
|
+
hafnia_dataset = self.copy() # To avoid inplace modifications
|
|
1027
|
+
new_paths = []
|
|
1028
|
+
org_paths = hafnia_dataset.samples[SampleField.FILE_PATH].to_list()
|
|
1029
|
+
for org_path in track(org_paths, description="- Copy images"):
|
|
782
1030
|
new_path = dataset_helpers.copy_and_rename_file_to_hash_value(
|
|
783
1031
|
path_source=Path(org_path),
|
|
784
1032
|
path_dataset_root=path_folder,
|
|
785
1033
|
)
|
|
786
|
-
|
|
787
|
-
|
|
1034
|
+
new_paths.append(str(new_path))
|
|
1035
|
+
hafnia_dataset.samples = hafnia_dataset.samples.with_columns(pl.Series(new_paths).alias(SampleField.FILE_PATH))
|
|
1036
|
+
hafnia_dataset.write_annotations(
|
|
1037
|
+
path_folder=path_folder,
|
|
1038
|
+
drop_null_cols=drop_null_cols,
|
|
1039
|
+
add_version=add_version,
|
|
1040
|
+
)
|
|
788
1041
|
|
|
1042
|
+
def write_annotations(
|
|
1043
|
+
dataset: HafniaDataset,
|
|
1044
|
+
path_folder: Path,
|
|
1045
|
+
drop_null_cols: bool = True,
|
|
1046
|
+
add_version: bool = False,
|
|
1047
|
+
) -> None:
|
|
1048
|
+
"""
|
|
1049
|
+
Writes only the annotations files (JSONL and Parquet) to the specified folder.
|
|
1050
|
+
"""
|
|
1051
|
+
user_logger.info(f"Writing dataset annotations to {path_folder}...")
|
|
1052
|
+
path_folder = path_folder.absolute()
|
|
1053
|
+
if not path_folder.exists():
|
|
1054
|
+
path_folder.mkdir(parents=True)
|
|
1055
|
+
dataset.info.write_json(path_folder / FILENAME_DATASET_INFO)
|
|
1056
|
+
|
|
1057
|
+
samples = dataset.samples
|
|
789
1058
|
if drop_null_cols: # Drops all unused/Null columns
|
|
790
|
-
|
|
1059
|
+
samples = samples.drop(pl.selectors.by_dtype(pl.Null))
|
|
1060
|
+
|
|
1061
|
+
# Store only relative paths in the annotations files
|
|
1062
|
+
absolute_paths = samples[SampleField.FILE_PATH].to_list()
|
|
1063
|
+
relative_paths = [str(Path(path).relative_to(path_folder)) for path in absolute_paths]
|
|
1064
|
+
samples = samples.with_columns(pl.Series(relative_paths).alias(SampleField.FILE_PATH))
|
|
791
1065
|
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
self.info.write_json(path_folder / FILENAME_DATASET_INFO)
|
|
1066
|
+
samples.write_ndjson(path_folder / FILENAME_ANNOTATIONS_JSONL) # Json for readability
|
|
1067
|
+
samples.write_parquet(path_folder / FILENAME_ANNOTATIONS_PARQUET) # Parquet for speed
|
|
795
1068
|
|
|
796
1069
|
if add_version:
|
|
797
|
-
path_version = path_folder / "versions" / f"{
|
|
1070
|
+
path_version = path_folder / "versions" / f"{dataset.info.version}"
|
|
798
1071
|
path_version.mkdir(parents=True, exist_ok=True)
|
|
799
1072
|
for filename in DATASET_FILENAMES_REQUIRED:
|
|
800
1073
|
shutil.copy2(path_folder / filename, path_version / filename)
|
|
@@ -846,3 +1119,30 @@ def get_or_create_dataset_path_from_recipe(
|
|
|
846
1119
|
dataset.write(path_dataset)
|
|
847
1120
|
|
|
848
1121
|
return path_dataset
|
|
1122
|
+
|
|
1123
|
+
|
|
1124
|
+
def _dataset_corrections(samples: pl.DataFrame, dataset_info: DatasetInfo) -> Tuple[pl.DataFrame, DatasetInfo]:
|
|
1125
|
+
format_version_of_dataset = Version(dataset_info.format_version)
|
|
1126
|
+
|
|
1127
|
+
## Backwards compatibility fixes for older dataset versions
|
|
1128
|
+
if format_version_of_dataset < Version("0.2.0"):
|
|
1129
|
+
samples = table_transformations.add_dataset_name_if_missing(samples, dataset_info.dataset_name)
|
|
1130
|
+
|
|
1131
|
+
if "file_name" in samples.columns:
|
|
1132
|
+
samples = samples.rename({"file_name": SampleField.FILE_PATH})
|
|
1133
|
+
|
|
1134
|
+
if SampleField.SAMPLE_INDEX not in samples.columns:
|
|
1135
|
+
samples = table_transformations.add_sample_index(samples)
|
|
1136
|
+
|
|
1137
|
+
# Backwards compatibility: If tags-column doesn't exist, create it with empty lists
|
|
1138
|
+
if SampleField.TAGS not in samples.columns:
|
|
1139
|
+
tags_column: List[List[str]] = [[] for _ in range(len(samples))] # type: ignore[annotation-unchecked]
|
|
1140
|
+
samples = samples.with_columns(pl.Series(tags_column, dtype=pl.List(pl.String)).alias(SampleField.TAGS))
|
|
1141
|
+
|
|
1142
|
+
if SampleField.STORAGE_FORMAT not in samples.columns:
|
|
1143
|
+
samples = samples.with_columns(pl.lit(StorageFormat.IMAGE).alias(SampleField.STORAGE_FORMAT))
|
|
1144
|
+
|
|
1145
|
+
if SampleField.SAMPLE_INDEX in samples.columns and samples[SampleField.SAMPLE_INDEX].dtype != pl.UInt64:
|
|
1146
|
+
samples = samples.cast({SampleField.SAMPLE_INDEX: pl.UInt64})
|
|
1147
|
+
|
|
1148
|
+
return samples, dataset_info
|