hafnia 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hafnia/__init__.py +1 -1
- hafnia/dataset/dataset_names.py +128 -15
- hafnia/dataset/dataset_upload_helper.py +30 -25
- hafnia/dataset/format_conversions/{image_classification_from_directory.py → format_image_classification_folder.py} +14 -10
- hafnia/dataset/format_conversions/format_yolo.py +164 -0
- hafnia/dataset/format_conversions/torchvision_datasets.py +10 -4
- hafnia/dataset/hafnia_dataset.py +246 -72
- hafnia/dataset/operations/dataset_stats.py +82 -70
- hafnia/dataset/operations/dataset_transformations.py +102 -37
- hafnia/dataset/operations/table_transformations.py +132 -15
- hafnia/dataset/primitives/bbox.py +3 -5
- hafnia/dataset/primitives/bitmask.py +2 -7
- hafnia/dataset/primitives/classification.py +3 -3
- hafnia/dataset/primitives/polygon.py +2 -4
- hafnia/dataset/primitives/primitive.py +1 -1
- hafnia/dataset/primitives/segmentation.py +2 -2
- hafnia/platform/datasets.py +3 -7
- hafnia/platform/download.py +1 -72
- hafnia/torch_helpers.py +12 -12
- hafnia/visualizations/image_visualizations.py +2 -0
- {hafnia-0.4.0.dist-info → hafnia-0.4.1.dist-info}/METADATA +4 -4
- {hafnia-0.4.0.dist-info → hafnia-0.4.1.dist-info}/RECORD +25 -24
- {hafnia-0.4.0.dist-info → hafnia-0.4.1.dist-info}/WHEEL +0 -0
- {hafnia-0.4.0.dist-info → hafnia-0.4.1.dist-info}/entry_points.txt +0 -0
- {hafnia-0.4.0.dist-info → hafnia-0.4.1.dist-info}/licenses/LICENSE +0 -0
hafnia/dataset/hafnia_dataset.py
CHANGED
|
@@ -10,6 +10,7 @@ from pathlib import Path
|
|
|
10
10
|
from random import Random
|
|
11
11
|
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
|
12
12
|
|
|
13
|
+
import cv2
|
|
13
14
|
import more_itertools
|
|
14
15
|
import numpy as np
|
|
15
16
|
import polars as pl
|
|
@@ -27,18 +28,21 @@ from hafnia.dataset.dataset_names import (
|
|
|
27
28
|
FILENAME_DATASET_INFO,
|
|
28
29
|
FILENAME_RECIPE_JSON,
|
|
29
30
|
TAG_IS_SAMPLE,
|
|
30
|
-
|
|
31
|
+
AwsCredentials,
|
|
32
|
+
PrimitiveField,
|
|
33
|
+
SampleField,
|
|
31
34
|
SplitName,
|
|
35
|
+
StorageFormat,
|
|
36
|
+
)
|
|
37
|
+
from hafnia.dataset.format_conversions import (
|
|
38
|
+
format_image_classification_folder,
|
|
39
|
+
format_yolo,
|
|
32
40
|
)
|
|
33
41
|
from hafnia.dataset.operations import (
|
|
34
42
|
dataset_stats,
|
|
35
43
|
dataset_transformations,
|
|
36
44
|
table_transformations,
|
|
37
45
|
)
|
|
38
|
-
from hafnia.dataset.operations.table_transformations import (
|
|
39
|
-
check_image_paths,
|
|
40
|
-
read_samples_from_path,
|
|
41
|
-
)
|
|
42
46
|
from hafnia.dataset.primitives import PRIMITIVE_TYPES, get_primitive_type_from_string
|
|
43
47
|
from hafnia.dataset.primitives.bbox import Bbox
|
|
44
48
|
from hafnia.dataset.primitives.bitmask import Bitmask
|
|
@@ -65,6 +69,14 @@ class TaskInfo(BaseModel):
|
|
|
65
69
|
if self.name is None:
|
|
66
70
|
self.name = self.primitive.default_task_name()
|
|
67
71
|
|
|
72
|
+
def get_class_index(self, class_name: str) -> int:
|
|
73
|
+
"""Get class index for a given class name"""
|
|
74
|
+
if self.class_names is None:
|
|
75
|
+
raise ValueError(f"Task '{self.name}' has no class names defined.")
|
|
76
|
+
if class_name not in self.class_names:
|
|
77
|
+
raise ValueError(f"Class name '{class_name}' not found in task '{self.name}'.")
|
|
78
|
+
return self.class_names.index(class_name)
|
|
79
|
+
|
|
68
80
|
# The 'primitive'-field of type 'Type[Primitive]' is not supported by pydantic out-of-the-box as
|
|
69
81
|
# the 'Primitive' class is an abstract base class and for the actual primtives such as Bbox, Bitmask, Classification.
|
|
70
82
|
# Below magic functions ('ensure_primitive' and 'serialize_primitive') ensures that the 'primitive' field can
|
|
@@ -100,6 +112,10 @@ class TaskInfo(BaseModel):
|
|
|
100
112
|
)
|
|
101
113
|
return class_names
|
|
102
114
|
|
|
115
|
+
def full_name(self) -> str:
|
|
116
|
+
"""Get qualified name for the task: <primitive_name>:<task_name>"""
|
|
117
|
+
return f"{self.primitive.__name__}:{self.name}"
|
|
118
|
+
|
|
103
119
|
# To get unique hash value for TaskInfo objects
|
|
104
120
|
def __hash__(self) -> int:
|
|
105
121
|
class_names = self.class_names or []
|
|
@@ -115,7 +131,6 @@ class DatasetInfo(BaseModel):
|
|
|
115
131
|
dataset_name: str = Field(description="Name of the dataset, e.g. 'coco'")
|
|
116
132
|
version: Optional[str] = Field(default=None, description="Version of the dataset")
|
|
117
133
|
tasks: List[TaskInfo] = Field(default=None, description="List of tasks in the dataset")
|
|
118
|
-
distributions: Optional[List[TaskInfo]] = Field(default=None, description="Optional list of task distributions")
|
|
119
134
|
reference_bibtex: Optional[str] = Field(
|
|
120
135
|
default=None,
|
|
121
136
|
description="Optional, BibTeX reference to dataset publication",
|
|
@@ -244,14 +259,12 @@ class DatasetInfo(BaseModel):
|
|
|
244
259
|
f"Hafnia format version '{hafnia.__dataset_format_version__}'."
|
|
245
260
|
)
|
|
246
261
|
unique_tasks = set(info0.tasks + info1.tasks)
|
|
247
|
-
distributions = set((info0.distributions or []) + (info1.distributions or []))
|
|
248
262
|
meta = (info0.meta or {}).copy()
|
|
249
263
|
meta.update(info1.meta or {})
|
|
250
264
|
return DatasetInfo(
|
|
251
265
|
dataset_name=info0.dataset_name + "+" + info1.dataset_name,
|
|
252
266
|
version=None,
|
|
253
267
|
tasks=list(unique_tasks),
|
|
254
|
-
distributions=list(distributions),
|
|
255
268
|
meta=meta,
|
|
256
269
|
format_version=dataset_format_version,
|
|
257
270
|
)
|
|
@@ -267,16 +280,24 @@ class DatasetInfo(BaseModel):
|
|
|
267
280
|
raise ValueError(f"Multiple tasks found with name '{task_name}'. This should not happen!")
|
|
268
281
|
return tasks_with_name[0]
|
|
269
282
|
|
|
270
|
-
def
|
|
283
|
+
def get_tasks_by_primitive(self, primitive: Union[Type[Primitive], str]) -> List[TaskInfo]:
|
|
271
284
|
"""
|
|
272
|
-
Get
|
|
273
|
-
have the same primitive type.
|
|
285
|
+
Get all tasks by their primitive type.
|
|
274
286
|
"""
|
|
275
287
|
if isinstance(primitive, str):
|
|
276
288
|
primitive = get_primitive_type_from_string(primitive)
|
|
277
289
|
|
|
278
290
|
tasks_with_primitive = [task for task in self.tasks if task.primitive == primitive]
|
|
279
|
-
|
|
291
|
+
return tasks_with_primitive
|
|
292
|
+
|
|
293
|
+
def get_task_by_primitive(self, primitive: Union[Type[Primitive], str]) -> TaskInfo:
|
|
294
|
+
"""
|
|
295
|
+
Get task by its primitive type. Raises an error if the primitive type is not found or if multiple tasks
|
|
296
|
+
have the same primitive type.
|
|
297
|
+
"""
|
|
298
|
+
|
|
299
|
+
tasks_with_primitive = self.get_tasks_by_primitive(primitive)
|
|
300
|
+
if len(tasks_with_primitive) == 0:
|
|
280
301
|
raise ValueError(f"Task with primitive {primitive} not found in dataset info.")
|
|
281
302
|
if len(tasks_with_primitive) > 1:
|
|
282
303
|
raise ValueError(
|
|
@@ -320,7 +341,7 @@ class DatasetInfo(BaseModel):
|
|
|
320
341
|
|
|
321
342
|
|
|
322
343
|
class Sample(BaseModel):
|
|
323
|
-
file_path: str = Field(description="Path to the image file")
|
|
344
|
+
file_path: Optional[str] = Field(description="Path to the image/video file.")
|
|
324
345
|
height: int = Field(description="Height of the image")
|
|
325
346
|
width: int = Field(description="Width of the image")
|
|
326
347
|
split: str = Field(description="Split name, e.g., 'train', 'val', 'test'")
|
|
@@ -328,6 +349,10 @@ class Sample(BaseModel):
|
|
|
328
349
|
default_factory=list,
|
|
329
350
|
description="Tags for a given sample. Used for creating subsets of the dataset.",
|
|
330
351
|
)
|
|
352
|
+
storage_format: str = Field(
|
|
353
|
+
default=StorageFormat.IMAGE,
|
|
354
|
+
description="Storage format. Sample data is stored as image or inside a video or zip file.",
|
|
355
|
+
)
|
|
331
356
|
collection_index: Optional[int] = Field(default=None, description="Optional e.g. frame number for video datasets")
|
|
332
357
|
collection_id: Optional[str] = Field(default=None, description="Optional e.g. video name for video datasets")
|
|
333
358
|
remote_path: Optional[str] = Field(default=None, description="Optional remote path for the image, if applicable")
|
|
@@ -338,7 +363,7 @@ class Sample(BaseModel):
|
|
|
338
363
|
classifications: Optional[List[Classification]] = Field(
|
|
339
364
|
default=None, description="Optional list of classifications"
|
|
340
365
|
)
|
|
341
|
-
|
|
366
|
+
bboxes: Optional[List[Bbox]] = Field(default=None, description="Optional list of bounding boxes")
|
|
342
367
|
bitmasks: Optional[List[Bitmask]] = Field(default=None, description="Optional list of bitmasks")
|
|
343
368
|
polygons: Optional[List[Polygon]] = Field(default=None, description="Optional list of polygons")
|
|
344
369
|
|
|
@@ -374,6 +399,8 @@ class Sample(BaseModel):
|
|
|
374
399
|
Reads the image from the file path and returns it as a PIL Image.
|
|
375
400
|
Raises FileNotFoundError if the image file does not exist.
|
|
376
401
|
"""
|
|
402
|
+
if self.file_path is None:
|
|
403
|
+
raise ValueError(f"Sample has no '{SampleField.FILE_PATH}' defined.")
|
|
377
404
|
path_image = Path(self.file_path)
|
|
378
405
|
if not path_image.exists():
|
|
379
406
|
raise FileNotFoundError(f"Image file {path_image} does not exist. Please check the file path.")
|
|
@@ -382,8 +409,22 @@ class Sample(BaseModel):
|
|
|
382
409
|
return image
|
|
383
410
|
|
|
384
411
|
def read_image(self) -> np.ndarray:
|
|
385
|
-
|
|
386
|
-
|
|
412
|
+
if self.storage_format == StorageFormat.VIDEO:
|
|
413
|
+
video = cv2.VideoCapture(str(self.file_path))
|
|
414
|
+
if self.collection_index is None:
|
|
415
|
+
raise ValueError("collection_index must be set for video storage format to read the correct frame.")
|
|
416
|
+
video.set(cv2.CAP_PROP_POS_FRAMES, self.collection_index)
|
|
417
|
+
success, image = video.read()
|
|
418
|
+
video.release()
|
|
419
|
+
if not success:
|
|
420
|
+
raise ValueError(f"Could not read frame {self.collection_index} from video file {self.file_path}.")
|
|
421
|
+
return image
|
|
422
|
+
|
|
423
|
+
elif self.storage_format == StorageFormat.IMAGE:
|
|
424
|
+
image_pil = self.read_image_pillow()
|
|
425
|
+
image = np.array(image_pil)
|
|
426
|
+
else:
|
|
427
|
+
raise ValueError(f"Unsupported storage format: {self.storage_format}")
|
|
387
428
|
return image
|
|
388
429
|
|
|
389
430
|
def draw_annotations(self, image: Optional[np.ndarray] = None) -> np.ndarray:
|
|
@@ -466,9 +507,11 @@ class HafniaDataset:
|
|
|
466
507
|
samples: pl.DataFrame
|
|
467
508
|
|
|
468
509
|
# Function mapping: Dataset stats
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
510
|
+
calculate_split_counts = dataset_stats.calculate_split_counts
|
|
511
|
+
calculate_split_counts_extended = dataset_stats.calculate_split_counts_extended
|
|
512
|
+
calculate_task_class_counts = dataset_stats.calculate_task_class_counts
|
|
513
|
+
calculate_class_counts = dataset_stats.calculate_class_counts
|
|
514
|
+
calculate_primitive_counts = dataset_stats.calculate_primitive_counts
|
|
472
515
|
|
|
473
516
|
# Function mapping: Print stats
|
|
474
517
|
print_stats = dataset_stats.print_stats
|
|
@@ -481,6 +524,13 @@ class HafniaDataset:
|
|
|
481
524
|
|
|
482
525
|
# Function mapping: Dataset transformations
|
|
483
526
|
transform_images = dataset_transformations.transform_images
|
|
527
|
+
convert_to_image_storage_format = dataset_transformations.convert_to_image_storage_format
|
|
528
|
+
|
|
529
|
+
# Import / export functions
|
|
530
|
+
from_yolo_format = format_yolo.from_yolo_format
|
|
531
|
+
to_yolo_format = format_yolo.to_yolo_format
|
|
532
|
+
to_image_classification_folder = format_image_classification_folder.to_image_classification_folder
|
|
533
|
+
from_image_classification_folder = format_image_classification_folder.from_image_classification_folder
|
|
484
534
|
|
|
485
535
|
def __getitem__(self, item: int) -> Dict[str, Any]:
|
|
486
536
|
return self.samples.row(index=item, named=True)
|
|
@@ -501,14 +551,14 @@ class HafniaDataset:
|
|
|
501
551
|
HafniaDataset.check_dataset_path(path_folder, raise_error=True)
|
|
502
552
|
|
|
503
553
|
dataset_info = DatasetInfo.from_json_file(path_folder / FILENAME_DATASET_INFO)
|
|
504
|
-
samples = read_samples_from_path(path_folder)
|
|
554
|
+
samples = table_transformations.read_samples_from_path(path_folder)
|
|
505
555
|
samples, dataset_info = _dataset_corrections(samples, dataset_info)
|
|
506
556
|
|
|
507
557
|
# Convert from relative paths to absolute paths
|
|
508
558
|
dataset_root = path_folder.absolute().as_posix() + "/"
|
|
509
|
-
samples = samples.with_columns((dataset_root + pl.col(
|
|
559
|
+
samples = samples.with_columns((dataset_root + pl.col(SampleField.FILE_PATH)).alias(SampleField.FILE_PATH))
|
|
510
560
|
if check_for_images:
|
|
511
|
-
check_image_paths(samples)
|
|
561
|
+
table_transformations.check_image_paths(samples)
|
|
512
562
|
return HafniaDataset(samples=samples, info=dataset_info)
|
|
513
563
|
|
|
514
564
|
@staticmethod
|
|
@@ -535,16 +585,12 @@ class HafniaDataset:
|
|
|
535
585
|
else:
|
|
536
586
|
raise TypeError(f"Unsupported sample type: {type(sample)}. Expected Sample or dict.")
|
|
537
587
|
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
table =
|
|
543
|
-
|
|
544
|
-
.then(pl.lit(info.dataset_name))
|
|
545
|
-
.otherwise(pl.col(ColumnName.DATASET_NAME))
|
|
546
|
-
.alias(ColumnName.DATASET_NAME)
|
|
547
|
-
)
|
|
588
|
+
# To ensure that the 'file_path' column is of type string even if all samples have 'None' as file_path
|
|
589
|
+
schema_override = {SampleField.FILE_PATH: pl.String}
|
|
590
|
+
table = pl.from_records(json_samples, schema_overrides=schema_override)
|
|
591
|
+
table = table.drop(pl.selectors.by_dtype(pl.Null))
|
|
592
|
+
table = table_transformations.add_sample_index(table)
|
|
593
|
+
table = table_transformations.add_dataset_name_if_missing(table, dataset_name=info.dataset_name)
|
|
548
594
|
return HafniaDataset(info=info, samples=table)
|
|
549
595
|
|
|
550
596
|
@staticmethod
|
|
@@ -678,12 +724,12 @@ class HafniaDataset:
|
|
|
678
724
|
"""
|
|
679
725
|
dataset_split_to_be_divided = dataset.create_split_dataset(split_name=split_name)
|
|
680
726
|
if len(dataset_split_to_be_divided) == 0:
|
|
681
|
-
split_counts = dict(dataset.samples.select(pl.col(
|
|
727
|
+
split_counts = dict(dataset.samples.select(pl.col(SampleField.SPLIT).value_counts()).iter_rows())
|
|
682
728
|
raise ValueError(f"No samples in the '{split_name}' split to divide into multiple splits. {split_counts=}")
|
|
683
729
|
assert len(dataset_split_to_be_divided) > 0, f"No samples in the '{split_name}' split!"
|
|
684
730
|
dataset_split_to_be_divided = dataset_split_to_be_divided.splits_by_ratios(split_ratios=split_ratios, seed=42)
|
|
685
731
|
|
|
686
|
-
remaining_data = dataset.samples.filter(pl.col(
|
|
732
|
+
remaining_data = dataset.samples.filter(pl.col(SampleField.SPLIT).is_in([split_name]).not_())
|
|
687
733
|
new_table = pl.concat([remaining_data, dataset_split_to_be_divided.samples], how="vertical")
|
|
688
734
|
dataset_new = dataset.update_samples(new_table)
|
|
689
735
|
return dataset_new
|
|
@@ -696,15 +742,17 @@ class HafniaDataset:
|
|
|
696
742
|
|
|
697
743
|
# Remove any pre-existing "sample"-tags
|
|
698
744
|
samples = samples.with_columns(
|
|
699
|
-
pl.col(
|
|
745
|
+
pl.col(SampleField.TAGS)
|
|
746
|
+
.list.eval(pl.element().filter(pl.element() != TAG_IS_SAMPLE))
|
|
747
|
+
.alias(SampleField.TAGS)
|
|
700
748
|
)
|
|
701
749
|
|
|
702
750
|
# Add "sample" to tags column for the selected samples
|
|
703
751
|
is_sample_indices = Random(seed).sample(range(len(dataset)), n_samples)
|
|
704
752
|
samples = samples.with_columns(
|
|
705
753
|
pl.when(pl.int_range(len(samples)).is_in(is_sample_indices))
|
|
706
|
-
.then(pl.col(
|
|
707
|
-
.otherwise(pl.col(
|
|
754
|
+
.then(pl.col(SampleField.TAGS).list.concat(pl.lit([TAG_IS_SAMPLE])))
|
|
755
|
+
.otherwise(pl.col(SampleField.TAGS))
|
|
708
756
|
)
|
|
709
757
|
return dataset.update_samples(samples)
|
|
710
758
|
|
|
@@ -762,6 +810,47 @@ class HafniaDataset:
|
|
|
762
810
|
dataset=dataset, old_task_name=old_task_name, new_task_name=new_task_name
|
|
763
811
|
)
|
|
764
812
|
|
|
813
|
+
def drop_task(
|
|
814
|
+
dataset: "HafniaDataset",
|
|
815
|
+
task_name: str,
|
|
816
|
+
) -> "HafniaDataset":
|
|
817
|
+
"""
|
|
818
|
+
Drop a task from the dataset.
|
|
819
|
+
If 'task_name' and 'primitive' are not provided, the function will attempt to infer the task.
|
|
820
|
+
"""
|
|
821
|
+
dataset = copy.copy(dataset) # To avoid mutating the original dataset. Shallow copy is sufficient
|
|
822
|
+
drop_task = dataset.info.get_task_by_name(task_name=task_name)
|
|
823
|
+
tasks_with_same_primitive = dataset.info.get_tasks_by_primitive(drop_task.primitive)
|
|
824
|
+
|
|
825
|
+
no_other_tasks_with_same_primitive = len(tasks_with_same_primitive) == 1
|
|
826
|
+
if no_other_tasks_with_same_primitive:
|
|
827
|
+
return dataset.drop_primitive(primitive=drop_task.primitive)
|
|
828
|
+
|
|
829
|
+
dataset.info = dataset.info.replace_task(old_task=drop_task, new_task=None)
|
|
830
|
+
dataset.samples = dataset.samples.with_columns(
|
|
831
|
+
pl.col(drop_task.primitive.column_name())
|
|
832
|
+
.list.filter(pl.element().struct.field(PrimitiveField.TASK_NAME) != drop_task.name)
|
|
833
|
+
.alias(drop_task.primitive.column_name())
|
|
834
|
+
)
|
|
835
|
+
|
|
836
|
+
return dataset
|
|
837
|
+
|
|
838
|
+
def drop_primitive(
|
|
839
|
+
dataset: "HafniaDataset",
|
|
840
|
+
primitive: Type[Primitive],
|
|
841
|
+
) -> "HafniaDataset":
|
|
842
|
+
"""
|
|
843
|
+
Drop a primitive from the dataset.
|
|
844
|
+
"""
|
|
845
|
+
dataset = copy.copy(dataset) # To avoid mutating the original dataset. Shallow copy is sufficient
|
|
846
|
+
tasks_to_drop = dataset.info.get_tasks_by_primitive(primitive=primitive)
|
|
847
|
+
for task in tasks_to_drop:
|
|
848
|
+
dataset.info = dataset.info.replace_task(old_task=task, new_task=None)
|
|
849
|
+
|
|
850
|
+
# Drop the primitive column from the samples table
|
|
851
|
+
dataset.samples = dataset.samples.drop(primitive.column_name())
|
|
852
|
+
return dataset
|
|
853
|
+
|
|
765
854
|
def select_samples_by_class_name(
|
|
766
855
|
dataset: HafniaDataset,
|
|
767
856
|
name: Union[List[str], str],
|
|
@@ -798,13 +887,63 @@ class HafniaDataset:
|
|
|
798
887
|
|
|
799
888
|
return HafniaDataset(info=merged_info, samples=merged_samples)
|
|
800
889
|
|
|
801
|
-
def
|
|
890
|
+
def download_files_aws(
|
|
891
|
+
dataset: HafniaDataset,
|
|
892
|
+
path_output_folder: Path,
|
|
893
|
+
aws_credentials: AwsCredentials,
|
|
894
|
+
force_redownload: bool = False,
|
|
895
|
+
) -> HafniaDataset:
|
|
896
|
+
from hafnia.platform.datasets import fast_copy_files_s3
|
|
897
|
+
|
|
898
|
+
remote_src_paths = dataset.samples[SampleField.REMOTE_PATH].unique().to_list()
|
|
899
|
+
update_rows = []
|
|
900
|
+
local_dst_paths = []
|
|
901
|
+
for remote_src_path in remote_src_paths:
|
|
902
|
+
local_path_str = (path_output_folder / "data" / Path(remote_src_path).name).absolute().as_posix()
|
|
903
|
+
local_dst_paths.append(local_path_str)
|
|
904
|
+
update_rows.append(
|
|
905
|
+
{
|
|
906
|
+
SampleField.REMOTE_PATH: remote_src_path,
|
|
907
|
+
SampleField.FILE_PATH: local_path_str,
|
|
908
|
+
}
|
|
909
|
+
)
|
|
910
|
+
update_df = pl.DataFrame(update_rows)
|
|
911
|
+
samples = dataset.samples.update(update_df, on=[SampleField.REMOTE_PATH])
|
|
912
|
+
dataset = dataset.update_samples(samples)
|
|
913
|
+
|
|
914
|
+
if not force_redownload:
|
|
915
|
+
download_indices = [idx for idx, local_path in enumerate(local_dst_paths) if not Path(local_path).exists()]
|
|
916
|
+
n_files = len(local_dst_paths)
|
|
917
|
+
skip_files = n_files - len(download_indices)
|
|
918
|
+
if skip_files > 0:
|
|
919
|
+
user_logger.info(
|
|
920
|
+
f"Found {skip_files}/{n_files} files already exists. Downloading {len(download_indices)} files."
|
|
921
|
+
)
|
|
922
|
+
remote_src_paths = [remote_src_paths[idx] for idx in download_indices]
|
|
923
|
+
local_dst_paths = [local_dst_paths[idx] for idx in download_indices]
|
|
924
|
+
|
|
925
|
+
if len(remote_src_paths) == 0:
|
|
926
|
+
user_logger.info(
|
|
927
|
+
"All files already exist locally. Skipping download. Set 'force_redownload=True' to re-download."
|
|
928
|
+
)
|
|
929
|
+
return dataset
|
|
930
|
+
|
|
931
|
+
environment_vars = aws_credentials.aws_credentials()
|
|
932
|
+
fast_copy_files_s3(
|
|
933
|
+
src_paths=remote_src_paths,
|
|
934
|
+
dst_paths=local_dst_paths,
|
|
935
|
+
append_envs=environment_vars,
|
|
936
|
+
description="Downloading images",
|
|
937
|
+
)
|
|
938
|
+
return dataset
|
|
939
|
+
|
|
940
|
+
def to_dict_dataset_splits(self) -> Dict[str, "HafniaDataset"]:
|
|
802
941
|
"""
|
|
803
942
|
Splits the dataset into multiple datasets based on the 'split' column.
|
|
804
943
|
Returns a dictionary with split names as keys and HafniaDataset objects as values.
|
|
805
944
|
"""
|
|
806
|
-
if
|
|
807
|
-
raise ValueError(f"Dataset must contain a '{
|
|
945
|
+
if SampleField.SPLIT not in self.samples.columns:
|
|
946
|
+
raise ValueError(f"Dataset must contain a '{SampleField.SPLIT}' column.")
|
|
808
947
|
|
|
809
948
|
splits = {}
|
|
810
949
|
for split_name in SplitName.valid_splits():
|
|
@@ -813,20 +952,11 @@ class HafniaDataset:
|
|
|
813
952
|
return splits
|
|
814
953
|
|
|
815
954
|
def create_sample_dataset(self) -> "HafniaDataset":
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
user_logger.warning(
|
|
819
|
-
"'is_sample' column found in the dataset. This column is deprecated and will be removed in future versions. "
|
|
820
|
-
"Please use the 'tags' column with the tag 'sample' instead."
|
|
821
|
-
)
|
|
822
|
-
table = self.samples.filter(pl.col("is_sample") == True) # noqa: E712
|
|
823
|
-
return self.update_samples(table)
|
|
824
|
-
|
|
825
|
-
if ColumnName.TAGS not in self.samples.columns:
|
|
826
|
-
raise ValueError(f"Dataset must contain an '{ColumnName.TAGS}' column.")
|
|
955
|
+
if SampleField.TAGS not in self.samples.columns:
|
|
956
|
+
raise ValueError(f"Dataset must contain an '{SampleField.TAGS}' column.")
|
|
827
957
|
|
|
828
958
|
table = self.samples.filter(
|
|
829
|
-
pl.col(
|
|
959
|
+
pl.col(SampleField.TAGS).list.eval(pl.element().filter(pl.element() == TAG_IS_SAMPLE)).list.len() > 0
|
|
830
960
|
)
|
|
831
961
|
return self.update_samples(table)
|
|
832
962
|
|
|
@@ -837,10 +967,10 @@ class HafniaDataset:
|
|
|
837
967
|
split_names = split_name
|
|
838
968
|
|
|
839
969
|
for name in split_names:
|
|
840
|
-
if name not in SplitName.
|
|
970
|
+
if name not in SplitName.all_split_names():
|
|
841
971
|
raise ValueError(f"Invalid split name: {split_name}. Valid splits are: {SplitName.valid_splits()}")
|
|
842
972
|
|
|
843
|
-
filtered_dataset = self.samples.filter(pl.col(
|
|
973
|
+
filtered_dataset = self.samples.filter(pl.col(SampleField.SPLIT).is_in(split_names))
|
|
844
974
|
return self.update_samples(filtered_dataset)
|
|
845
975
|
|
|
846
976
|
def update_samples(self, table: pl.DataFrame) -> "HafniaDataset":
|
|
@@ -875,30 +1005,69 @@ class HafniaDataset:
|
|
|
875
1005
|
def copy(self) -> "HafniaDataset":
|
|
876
1006
|
return HafniaDataset(info=self.info.model_copy(deep=True), samples=self.samples.clone())
|
|
877
1007
|
|
|
1008
|
+
def create_primitive_table(
|
|
1009
|
+
self,
|
|
1010
|
+
primitive: Type[Primitive],
|
|
1011
|
+
task_name: Optional[str] = None,
|
|
1012
|
+
keep_sample_data: bool = False,
|
|
1013
|
+
) -> pl.DataFrame:
|
|
1014
|
+
return table_transformations.create_primitive_table(
|
|
1015
|
+
samples_table=self.samples,
|
|
1016
|
+
PrimitiveType=primitive,
|
|
1017
|
+
task_name=task_name,
|
|
1018
|
+
keep_sample_data=keep_sample_data,
|
|
1019
|
+
)
|
|
1020
|
+
|
|
878
1021
|
def write(self, path_folder: Path, add_version: bool = False, drop_null_cols: bool = True) -> None:
|
|
879
1022
|
user_logger.info(f"Writing dataset to {path_folder}...")
|
|
1023
|
+
path_folder = path_folder.absolute()
|
|
880
1024
|
if not path_folder.exists():
|
|
881
1025
|
path_folder.mkdir(parents=True)
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
org_paths =
|
|
1026
|
+
hafnia_dataset = self.copy() # To avoid inplace modifications
|
|
1027
|
+
new_paths = []
|
|
1028
|
+
org_paths = hafnia_dataset.samples[SampleField.FILE_PATH].to_list()
|
|
885
1029
|
for org_path in track(org_paths, description="- Copy images"):
|
|
886
1030
|
new_path = dataset_helpers.copy_and_rename_file_to_hash_value(
|
|
887
1031
|
path_source=Path(org_path),
|
|
888
1032
|
path_dataset_root=path_folder,
|
|
889
1033
|
)
|
|
890
|
-
|
|
891
|
-
|
|
1034
|
+
new_paths.append(str(new_path))
|
|
1035
|
+
hafnia_dataset.samples = hafnia_dataset.samples.with_columns(pl.Series(new_paths).alias(SampleField.FILE_PATH))
|
|
1036
|
+
hafnia_dataset.write_annotations(
|
|
1037
|
+
path_folder=path_folder,
|
|
1038
|
+
drop_null_cols=drop_null_cols,
|
|
1039
|
+
add_version=add_version,
|
|
1040
|
+
)
|
|
1041
|
+
|
|
1042
|
+
def write_annotations(
|
|
1043
|
+
dataset: HafniaDataset,
|
|
1044
|
+
path_folder: Path,
|
|
1045
|
+
drop_null_cols: bool = True,
|
|
1046
|
+
add_version: bool = False,
|
|
1047
|
+
) -> None:
|
|
1048
|
+
"""
|
|
1049
|
+
Writes only the annotations files (JSONL and Parquet) to the specified folder.
|
|
1050
|
+
"""
|
|
1051
|
+
user_logger.info(f"Writing dataset annotations to {path_folder}...")
|
|
1052
|
+
path_folder = path_folder.absolute()
|
|
1053
|
+
if not path_folder.exists():
|
|
1054
|
+
path_folder.mkdir(parents=True)
|
|
1055
|
+
dataset.info.write_json(path_folder / FILENAME_DATASET_INFO)
|
|
892
1056
|
|
|
1057
|
+
samples = dataset.samples
|
|
893
1058
|
if drop_null_cols: # Drops all unused/Null columns
|
|
894
|
-
|
|
1059
|
+
samples = samples.drop(pl.selectors.by_dtype(pl.Null))
|
|
1060
|
+
|
|
1061
|
+
# Store only relative paths in the annotations files
|
|
1062
|
+
absolute_paths = samples[SampleField.FILE_PATH].to_list()
|
|
1063
|
+
relative_paths = [str(Path(path).relative_to(path_folder)) for path in absolute_paths]
|
|
1064
|
+
samples = samples.with_columns(pl.Series(relative_paths).alias(SampleField.FILE_PATH))
|
|
895
1065
|
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
self.info.write_json(path_folder / FILENAME_DATASET_INFO)
|
|
1066
|
+
samples.write_ndjson(path_folder / FILENAME_ANNOTATIONS_JSONL) # Json for readability
|
|
1067
|
+
samples.write_parquet(path_folder / FILENAME_ANNOTATIONS_PARQUET) # Parquet for speed
|
|
899
1068
|
|
|
900
1069
|
if add_version:
|
|
901
|
-
path_version = path_folder / "versions" / f"{
|
|
1070
|
+
path_version = path_folder / "versions" / f"{dataset.info.version}"
|
|
902
1071
|
path_version.mkdir(parents=True, exist_ok=True)
|
|
903
1072
|
for filename in DATASET_FILENAMES_REQUIRED:
|
|
904
1073
|
shutil.copy2(path_folder / filename, path_version / filename)
|
|
@@ -956,19 +1125,24 @@ def _dataset_corrections(samples: pl.DataFrame, dataset_info: DatasetInfo) -> Tu
|
|
|
956
1125
|
format_version_of_dataset = Version(dataset_info.format_version)
|
|
957
1126
|
|
|
958
1127
|
## Backwards compatibility fixes for older dataset versions
|
|
959
|
-
if format_version_of_dataset
|
|
960
|
-
|
|
961
|
-
samples = samples.with_columns(pl.lit(dataset_info.dataset_name).alias(ColumnName.DATASET_NAME))
|
|
1128
|
+
if format_version_of_dataset < Version("0.2.0"):
|
|
1129
|
+
samples = table_transformations.add_dataset_name_if_missing(samples, dataset_info.dataset_name)
|
|
962
1130
|
|
|
963
1131
|
if "file_name" in samples.columns:
|
|
964
|
-
samples = samples.rename({"file_name":
|
|
1132
|
+
samples = samples.rename({"file_name": SampleField.FILE_PATH})
|
|
965
1133
|
|
|
966
|
-
if
|
|
967
|
-
samples =
|
|
1134
|
+
if SampleField.SAMPLE_INDEX not in samples.columns:
|
|
1135
|
+
samples = table_transformations.add_sample_index(samples)
|
|
968
1136
|
|
|
969
1137
|
# Backwards compatibility: If tags-column doesn't exist, create it with empty lists
|
|
970
|
-
if
|
|
1138
|
+
if SampleField.TAGS not in samples.columns:
|
|
971
1139
|
tags_column: List[List[str]] = [[] for _ in range(len(samples))] # type: ignore[annotation-unchecked]
|
|
972
|
-
samples = samples.with_columns(pl.Series(tags_column, dtype=pl.List(pl.String)).alias(
|
|
1140
|
+
samples = samples.with_columns(pl.Series(tags_column, dtype=pl.List(pl.String)).alias(SampleField.TAGS))
|
|
1141
|
+
|
|
1142
|
+
if SampleField.STORAGE_FORMAT not in samples.columns:
|
|
1143
|
+
samples = samples.with_columns(pl.lit(StorageFormat.IMAGE).alias(SampleField.STORAGE_FORMAT))
|
|
1144
|
+
|
|
1145
|
+
if SampleField.SAMPLE_INDEX in samples.columns and samples[SampleField.SAMPLE_INDEX].dtype != pl.UInt64:
|
|
1146
|
+
samples = samples.cast({SampleField.SAMPLE_INDEX: pl.UInt64})
|
|
973
1147
|
|
|
974
1148
|
return samples, dataset_info
|