hafnia 0.4.0__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. hafnia/__init__.py +1 -1
  2. hafnia/dataset/dataset_names.py +128 -15
  3. hafnia/dataset/dataset_recipe/dataset_recipe.py +3 -3
  4. hafnia/dataset/dataset_upload_helper.py +31 -26
  5. hafnia/dataset/format_conversions/{image_classification_from_directory.py → format_image_classification_folder.py} +14 -10
  6. hafnia/dataset/format_conversions/format_yolo.py +164 -0
  7. hafnia/dataset/format_conversions/torchvision_datasets.py +10 -4
  8. hafnia/dataset/hafnia_dataset.py +246 -72
  9. hafnia/dataset/operations/dataset_stats.py +82 -70
  10. hafnia/dataset/operations/dataset_transformations.py +102 -37
  11. hafnia/dataset/operations/table_transformations.py +132 -15
  12. hafnia/dataset/primitives/bbox.py +3 -5
  13. hafnia/dataset/primitives/bitmask.py +2 -7
  14. hafnia/dataset/primitives/classification.py +3 -3
  15. hafnia/dataset/primitives/polygon.py +2 -4
  16. hafnia/dataset/primitives/primitive.py +1 -1
  17. hafnia/dataset/primitives/segmentation.py +2 -2
  18. hafnia/platform/datasets.py +4 -8
  19. hafnia/platform/download.py +1 -72
  20. hafnia/torch_helpers.py +12 -12
  21. hafnia/utils.py +1 -1
  22. hafnia/visualizations/image_visualizations.py +2 -0
  23. {hafnia-0.4.0.dist-info → hafnia-0.4.2.dist-info}/METADATA +4 -4
  24. hafnia-0.4.2.dist-info/RECORD +57 -0
  25. hafnia-0.4.2.dist-info/entry_points.txt +2 -0
  26. {cli → hafnia_cli}/__main__.py +2 -2
  27. {cli → hafnia_cli}/config.py +2 -2
  28. {cli → hafnia_cli}/dataset_cmds.py +2 -2
  29. {cli → hafnia_cli}/dataset_recipe_cmds.py +1 -1
  30. {cli → hafnia_cli}/experiment_cmds.py +1 -1
  31. {cli → hafnia_cli}/profile_cmds.py +2 -2
  32. {cli → hafnia_cli}/runc_cmds.py +1 -1
  33. {cli → hafnia_cli}/trainer_package_cmds.py +2 -2
  34. hafnia-0.4.0.dist-info/RECORD +0 -56
  35. hafnia-0.4.0.dist-info/entry_points.txt +0 -2
  36. {hafnia-0.4.0.dist-info → hafnia-0.4.2.dist-info}/WHEEL +0 -0
  37. {hafnia-0.4.0.dist-info → hafnia-0.4.2.dist-info}/licenses/LICENSE +0 -0
  38. {cli → hafnia_cli}/__init__.py +0 -0
  39. {cli → hafnia_cli}/consts.py +0 -0
  40. {cli → hafnia_cli}/keychain.py +0 -0
@@ -14,8 +14,8 @@ from torchvision.datasets.utils import download_and_extract_archive, extract_arc
14
14
  from hafnia import utils
15
15
  from hafnia.dataset.dataset_helpers import save_pil_image_with_hash_name
16
16
  from hafnia.dataset.dataset_names import SplitName
17
- from hafnia.dataset.format_conversions.image_classification_from_directory import (
18
- import_image_classification_directory_tree,
17
+ from hafnia.dataset.format_conversions.format_image_classification_folder import (
18
+ from_image_classification_folder,
19
19
  )
20
20
  from hafnia.dataset.hafnia_dataset import DatasetInfo, HafniaDataset, Sample, TaskInfo
21
21
  from hafnia.dataset.primitives import Classification
@@ -72,7 +72,7 @@ def caltech_101_as_hafnia_dataset(
72
72
  path_image_classification_folder = _download_and_extract_caltech_dataset(
73
73
  dataset_name, force_redownload=force_redownload
74
74
  )
75
- hafnia_dataset = import_image_classification_directory_tree(
75
+ hafnia_dataset = from_image_classification_folder(
76
76
  path_image_classification_folder,
77
77
  split=SplitName.TRAIN,
78
78
  n_samples=n_samples,
@@ -102,7 +102,7 @@ def caltech_256_as_hafnia_dataset(
102
102
  path_image_classification_folder = _download_and_extract_caltech_dataset(
103
103
  dataset_name, force_redownload=force_redownload
104
104
  )
105
- hafnia_dataset = import_image_classification_directory_tree(
105
+ hafnia_dataset = from_image_classification_folder(
106
106
  path_image_classification_folder,
107
107
  split=SplitName.TRAIN,
108
108
  n_samples=n_samples,
@@ -122,6 +122,12 @@ def caltech_256_as_hafnia_dataset(
122
122
  }""")
123
123
  hafnia_dataset.info.reference_dataset_page = "https://data.caltech.edu/records/nyy15-4j048"
124
124
 
125
+ task = hafnia_dataset.info.get_task_by_primitive(Classification)
126
+
127
+ # Class Mapping: To remove numeric prefixes from class names
128
+ # E.g. "001.ak47 --> ak47", "002.american-flag --> american-flag", ...
129
+ class_mapping = {name: name.split(".")[-1] for name in task.class_names or []}
130
+ hafnia_dataset = hafnia_dataset.class_mapper(class_mapping=class_mapping, task_name=task.name)
125
131
  return hafnia_dataset
126
132
 
127
133
 
@@ -10,6 +10,7 @@ from pathlib import Path
10
10
  from random import Random
11
11
  from typing import Any, Dict, List, Optional, Tuple, Type, Union
12
12
 
13
+ import cv2
13
14
  import more_itertools
14
15
  import numpy as np
15
16
  import polars as pl
@@ -27,18 +28,21 @@ from hafnia.dataset.dataset_names import (
27
28
  FILENAME_DATASET_INFO,
28
29
  FILENAME_RECIPE_JSON,
29
30
  TAG_IS_SAMPLE,
30
- ColumnName,
31
+ AwsCredentials,
32
+ PrimitiveField,
33
+ SampleField,
31
34
  SplitName,
35
+ StorageFormat,
36
+ )
37
+ from hafnia.dataset.format_conversions import (
38
+ format_image_classification_folder,
39
+ format_yolo,
32
40
  )
33
41
  from hafnia.dataset.operations import (
34
42
  dataset_stats,
35
43
  dataset_transformations,
36
44
  table_transformations,
37
45
  )
38
- from hafnia.dataset.operations.table_transformations import (
39
- check_image_paths,
40
- read_samples_from_path,
41
- )
42
46
  from hafnia.dataset.primitives import PRIMITIVE_TYPES, get_primitive_type_from_string
43
47
  from hafnia.dataset.primitives.bbox import Bbox
44
48
  from hafnia.dataset.primitives.bitmask import Bitmask
@@ -65,6 +69,14 @@ class TaskInfo(BaseModel):
65
69
  if self.name is None:
66
70
  self.name = self.primitive.default_task_name()
67
71
 
72
+ def get_class_index(self, class_name: str) -> int:
73
+ """Get class index for a given class name"""
74
+ if self.class_names is None:
75
+ raise ValueError(f"Task '{self.name}' has no class names defined.")
76
+ if class_name not in self.class_names:
77
+ raise ValueError(f"Class name '{class_name}' not found in task '{self.name}'.")
78
+ return self.class_names.index(class_name)
79
+
68
80
  # The 'primitive'-field of type 'Type[Primitive]' is not supported by pydantic out-of-the-box as
69
81
  # the 'Primitive' class is an abstract base class and for the actual primtives such as Bbox, Bitmask, Classification.
70
82
  # Below magic functions ('ensure_primitive' and 'serialize_primitive') ensures that the 'primitive' field can
@@ -100,6 +112,10 @@ class TaskInfo(BaseModel):
100
112
  )
101
113
  return class_names
102
114
 
115
+ def full_name(self) -> str:
116
+ """Get qualified name for the task: <primitive_name>:<task_name>"""
117
+ return f"{self.primitive.__name__}:{self.name}"
118
+
103
119
  # To get unique hash value for TaskInfo objects
104
120
  def __hash__(self) -> int:
105
121
  class_names = self.class_names or []
@@ -115,7 +131,6 @@ class DatasetInfo(BaseModel):
115
131
  dataset_name: str = Field(description="Name of the dataset, e.g. 'coco'")
116
132
  version: Optional[str] = Field(default=None, description="Version of the dataset")
117
133
  tasks: List[TaskInfo] = Field(default=None, description="List of tasks in the dataset")
118
- distributions: Optional[List[TaskInfo]] = Field(default=None, description="Optional list of task distributions")
119
134
  reference_bibtex: Optional[str] = Field(
120
135
  default=None,
121
136
  description="Optional, BibTeX reference to dataset publication",
@@ -244,14 +259,12 @@ class DatasetInfo(BaseModel):
244
259
  f"Hafnia format version '{hafnia.__dataset_format_version__}'."
245
260
  )
246
261
  unique_tasks = set(info0.tasks + info1.tasks)
247
- distributions = set((info0.distributions or []) + (info1.distributions or []))
248
262
  meta = (info0.meta or {}).copy()
249
263
  meta.update(info1.meta or {})
250
264
  return DatasetInfo(
251
265
  dataset_name=info0.dataset_name + "+" + info1.dataset_name,
252
266
  version=None,
253
267
  tasks=list(unique_tasks),
254
- distributions=list(distributions),
255
268
  meta=meta,
256
269
  format_version=dataset_format_version,
257
270
  )
@@ -267,16 +280,24 @@ class DatasetInfo(BaseModel):
267
280
  raise ValueError(f"Multiple tasks found with name '{task_name}'. This should not happen!")
268
281
  return tasks_with_name[0]
269
282
 
270
- def get_task_by_primitive(self, primitive: Union[Type[Primitive], str]) -> TaskInfo:
283
+ def get_tasks_by_primitive(self, primitive: Union[Type[Primitive], str]) -> List[TaskInfo]:
271
284
  """
272
- Get task by its primitive type. Raises an error if the primitive type is not found or if multiple tasks
273
- have the same primitive type.
285
+ Get all tasks by their primitive type.
274
286
  """
275
287
  if isinstance(primitive, str):
276
288
  primitive = get_primitive_type_from_string(primitive)
277
289
 
278
290
  tasks_with_primitive = [task for task in self.tasks if task.primitive == primitive]
279
- if not tasks_with_primitive:
291
+ return tasks_with_primitive
292
+
293
+ def get_task_by_primitive(self, primitive: Union[Type[Primitive], str]) -> TaskInfo:
294
+ """
295
+ Get task by its primitive type. Raises an error if the primitive type is not found or if multiple tasks
296
+ have the same primitive type.
297
+ """
298
+
299
+ tasks_with_primitive = self.get_tasks_by_primitive(primitive)
300
+ if len(tasks_with_primitive) == 0:
280
301
  raise ValueError(f"Task with primitive {primitive} not found in dataset info.")
281
302
  if len(tasks_with_primitive) > 1:
282
303
  raise ValueError(
@@ -320,7 +341,7 @@ class DatasetInfo(BaseModel):
320
341
 
321
342
 
322
343
  class Sample(BaseModel):
323
- file_path: str = Field(description="Path to the image file")
344
+ file_path: Optional[str] = Field(description="Path to the image/video file.")
324
345
  height: int = Field(description="Height of the image")
325
346
  width: int = Field(description="Width of the image")
326
347
  split: str = Field(description="Split name, e.g., 'train', 'val', 'test'")
@@ -328,6 +349,10 @@ class Sample(BaseModel):
328
349
  default_factory=list,
329
350
  description="Tags for a given sample. Used for creating subsets of the dataset.",
330
351
  )
352
+ storage_format: str = Field(
353
+ default=StorageFormat.IMAGE,
354
+ description="Storage format. Sample data is stored as image or inside a video or zip file.",
355
+ )
331
356
  collection_index: Optional[int] = Field(default=None, description="Optional e.g. frame number for video datasets")
332
357
  collection_id: Optional[str] = Field(default=None, description="Optional e.g. video name for video datasets")
333
358
  remote_path: Optional[str] = Field(default=None, description="Optional remote path for the image, if applicable")
@@ -338,7 +363,7 @@ class Sample(BaseModel):
338
363
  classifications: Optional[List[Classification]] = Field(
339
364
  default=None, description="Optional list of classifications"
340
365
  )
341
- objects: Optional[List[Bbox]] = Field(default=None, description="Optional list of objects (bounding boxes)")
366
+ bboxes: Optional[List[Bbox]] = Field(default=None, description="Optional list of bounding boxes")
342
367
  bitmasks: Optional[List[Bitmask]] = Field(default=None, description="Optional list of bitmasks")
343
368
  polygons: Optional[List[Polygon]] = Field(default=None, description="Optional list of polygons")
344
369
 
@@ -374,6 +399,8 @@ class Sample(BaseModel):
374
399
  Reads the image from the file path and returns it as a PIL Image.
375
400
  Raises FileNotFoundError if the image file does not exist.
376
401
  """
402
+ if self.file_path is None:
403
+ raise ValueError(f"Sample has no '{SampleField.FILE_PATH}' defined.")
377
404
  path_image = Path(self.file_path)
378
405
  if not path_image.exists():
379
406
  raise FileNotFoundError(f"Image file {path_image} does not exist. Please check the file path.")
@@ -382,8 +409,22 @@ class Sample(BaseModel):
382
409
  return image
383
410
 
384
411
  def read_image(self) -> np.ndarray:
385
- image_pil = self.read_image_pillow()
386
- image = np.array(image_pil)
412
+ if self.storage_format == StorageFormat.VIDEO:
413
+ video = cv2.VideoCapture(str(self.file_path))
414
+ if self.collection_index is None:
415
+ raise ValueError("collection_index must be set for video storage format to read the correct frame.")
416
+ video.set(cv2.CAP_PROP_POS_FRAMES, self.collection_index)
417
+ success, image = video.read()
418
+ video.release()
419
+ if not success:
420
+ raise ValueError(f"Could not read frame {self.collection_index} from video file {self.file_path}.")
421
+ return image
422
+
423
+ elif self.storage_format == StorageFormat.IMAGE:
424
+ image_pil = self.read_image_pillow()
425
+ image = np.array(image_pil)
426
+ else:
427
+ raise ValueError(f"Unsupported storage format: {self.storage_format}")
387
428
  return image
388
429
 
389
430
  def draw_annotations(self, image: Optional[np.ndarray] = None) -> np.ndarray:
@@ -466,9 +507,11 @@ class HafniaDataset:
466
507
  samples: pl.DataFrame
467
508
 
468
509
  # Function mapping: Dataset stats
469
- split_counts = dataset_stats.split_counts
470
- class_counts_for_task = dataset_stats.class_counts_for_task
471
- class_counts_all = dataset_stats.class_counts_all
510
+ calculate_split_counts = dataset_stats.calculate_split_counts
511
+ calculate_split_counts_extended = dataset_stats.calculate_split_counts_extended
512
+ calculate_task_class_counts = dataset_stats.calculate_task_class_counts
513
+ calculate_class_counts = dataset_stats.calculate_class_counts
514
+ calculate_primitive_counts = dataset_stats.calculate_primitive_counts
472
515
 
473
516
  # Function mapping: Print stats
474
517
  print_stats = dataset_stats.print_stats
@@ -481,6 +524,13 @@ class HafniaDataset:
481
524
 
482
525
  # Function mapping: Dataset transformations
483
526
  transform_images = dataset_transformations.transform_images
527
+ convert_to_image_storage_format = dataset_transformations.convert_to_image_storage_format
528
+
529
+ # Import / export functions
530
+ from_yolo_format = format_yolo.from_yolo_format
531
+ to_yolo_format = format_yolo.to_yolo_format
532
+ to_image_classification_folder = format_image_classification_folder.to_image_classification_folder
533
+ from_image_classification_folder = format_image_classification_folder.from_image_classification_folder
484
534
 
485
535
  def __getitem__(self, item: int) -> Dict[str, Any]:
486
536
  return self.samples.row(index=item, named=True)
@@ -501,14 +551,14 @@ class HafniaDataset:
501
551
  HafniaDataset.check_dataset_path(path_folder, raise_error=True)
502
552
 
503
553
  dataset_info = DatasetInfo.from_json_file(path_folder / FILENAME_DATASET_INFO)
504
- samples = read_samples_from_path(path_folder)
554
+ samples = table_transformations.read_samples_from_path(path_folder)
505
555
  samples, dataset_info = _dataset_corrections(samples, dataset_info)
506
556
 
507
557
  # Convert from relative paths to absolute paths
508
558
  dataset_root = path_folder.absolute().as_posix() + "/"
509
- samples = samples.with_columns((dataset_root + pl.col(ColumnName.FILE_PATH)).alias(ColumnName.FILE_PATH))
559
+ samples = samples.with_columns((dataset_root + pl.col(SampleField.FILE_PATH)).alias(SampleField.FILE_PATH))
510
560
  if check_for_images:
511
- check_image_paths(samples)
561
+ table_transformations.check_image_paths(samples)
512
562
  return HafniaDataset(samples=samples, info=dataset_info)
513
563
 
514
564
  @staticmethod
@@ -535,16 +585,12 @@ class HafniaDataset:
535
585
  else:
536
586
  raise TypeError(f"Unsupported sample type: {type(sample)}. Expected Sample or dict.")
537
587
 
538
- table = pl.from_records(json_samples)
539
- table = table.drop(ColumnName.SAMPLE_INDEX).with_row_index(name=ColumnName.SAMPLE_INDEX)
540
-
541
- # Add 'dataset_name' to samples
542
- table = table.with_columns(
543
- pl.when(pl.col(ColumnName.DATASET_NAME).is_null())
544
- .then(pl.lit(info.dataset_name))
545
- .otherwise(pl.col(ColumnName.DATASET_NAME))
546
- .alias(ColumnName.DATASET_NAME)
547
- )
588
+ # To ensure that the 'file_path' column is of type string even if all samples have 'None' as file_path
589
+ schema_override = {SampleField.FILE_PATH: pl.String}
590
+ table = pl.from_records(json_samples, schema_overrides=schema_override)
591
+ table = table.drop(pl.selectors.by_dtype(pl.Null))
592
+ table = table_transformations.add_sample_index(table)
593
+ table = table_transformations.add_dataset_name_if_missing(table, dataset_name=info.dataset_name)
548
594
  return HafniaDataset(info=info, samples=table)
549
595
 
550
596
  @staticmethod
@@ -678,12 +724,12 @@ class HafniaDataset:
678
724
  """
679
725
  dataset_split_to_be_divided = dataset.create_split_dataset(split_name=split_name)
680
726
  if len(dataset_split_to_be_divided) == 0:
681
- split_counts = dict(dataset.samples.select(pl.col(ColumnName.SPLIT).value_counts()).iter_rows())
727
+ split_counts = dict(dataset.samples.select(pl.col(SampleField.SPLIT).value_counts()).iter_rows())
682
728
  raise ValueError(f"No samples in the '{split_name}' split to divide into multiple splits. {split_counts=}")
683
729
  assert len(dataset_split_to_be_divided) > 0, f"No samples in the '{split_name}' split!"
684
730
  dataset_split_to_be_divided = dataset_split_to_be_divided.splits_by_ratios(split_ratios=split_ratios, seed=42)
685
731
 
686
- remaining_data = dataset.samples.filter(pl.col(ColumnName.SPLIT).is_in([split_name]).not_())
732
+ remaining_data = dataset.samples.filter(pl.col(SampleField.SPLIT).is_in([split_name]).not_())
687
733
  new_table = pl.concat([remaining_data, dataset_split_to_be_divided.samples], how="vertical")
688
734
  dataset_new = dataset.update_samples(new_table)
689
735
  return dataset_new
@@ -696,15 +742,17 @@ class HafniaDataset:
696
742
 
697
743
  # Remove any pre-existing "sample"-tags
698
744
  samples = samples.with_columns(
699
- pl.col(ColumnName.TAGS).list.eval(pl.element().filter(pl.element() != TAG_IS_SAMPLE)).alias(ColumnName.TAGS)
745
+ pl.col(SampleField.TAGS)
746
+ .list.eval(pl.element().filter(pl.element() != TAG_IS_SAMPLE))
747
+ .alias(SampleField.TAGS)
700
748
  )
701
749
 
702
750
  # Add "sample" to tags column for the selected samples
703
751
  is_sample_indices = Random(seed).sample(range(len(dataset)), n_samples)
704
752
  samples = samples.with_columns(
705
753
  pl.when(pl.int_range(len(samples)).is_in(is_sample_indices))
706
- .then(pl.col(ColumnName.TAGS).list.concat(pl.lit([TAG_IS_SAMPLE])))
707
- .otherwise(pl.col(ColumnName.TAGS))
754
+ .then(pl.col(SampleField.TAGS).list.concat(pl.lit([TAG_IS_SAMPLE])))
755
+ .otherwise(pl.col(SampleField.TAGS))
708
756
  )
709
757
  return dataset.update_samples(samples)
710
758
 
@@ -762,6 +810,47 @@ class HafniaDataset:
762
810
  dataset=dataset, old_task_name=old_task_name, new_task_name=new_task_name
763
811
  )
764
812
 
813
+ def drop_task(
814
+ dataset: "HafniaDataset",
815
+ task_name: str,
816
+ ) -> "HafniaDataset":
817
+ """
818
+ Drop a task from the dataset.
819
+ If 'task_name' and 'primitive' are not provided, the function will attempt to infer the task.
820
+ """
821
+ dataset = copy.copy(dataset) # To avoid mutating the original dataset. Shallow copy is sufficient
822
+ drop_task = dataset.info.get_task_by_name(task_name=task_name)
823
+ tasks_with_same_primitive = dataset.info.get_tasks_by_primitive(drop_task.primitive)
824
+
825
+ no_other_tasks_with_same_primitive = len(tasks_with_same_primitive) == 1
826
+ if no_other_tasks_with_same_primitive:
827
+ return dataset.drop_primitive(primitive=drop_task.primitive)
828
+
829
+ dataset.info = dataset.info.replace_task(old_task=drop_task, new_task=None)
830
+ dataset.samples = dataset.samples.with_columns(
831
+ pl.col(drop_task.primitive.column_name())
832
+ .list.filter(pl.element().struct.field(PrimitiveField.TASK_NAME) != drop_task.name)
833
+ .alias(drop_task.primitive.column_name())
834
+ )
835
+
836
+ return dataset
837
+
838
+ def drop_primitive(
839
+ dataset: "HafniaDataset",
840
+ primitive: Type[Primitive],
841
+ ) -> "HafniaDataset":
842
+ """
843
+ Drop a primitive from the dataset.
844
+ """
845
+ dataset = copy.copy(dataset) # To avoid mutating the original dataset. Shallow copy is sufficient
846
+ tasks_to_drop = dataset.info.get_tasks_by_primitive(primitive=primitive)
847
+ for task in tasks_to_drop:
848
+ dataset.info = dataset.info.replace_task(old_task=task, new_task=None)
849
+
850
+ # Drop the primitive column from the samples table
851
+ dataset.samples = dataset.samples.drop(primitive.column_name())
852
+ return dataset
853
+
765
854
  def select_samples_by_class_name(
766
855
  dataset: HafniaDataset,
767
856
  name: Union[List[str], str],
@@ -798,13 +887,63 @@ class HafniaDataset:
798
887
 
799
888
  return HafniaDataset(info=merged_info, samples=merged_samples)
800
889
 
801
- def as_dict_dataset_splits(self) -> Dict[str, "HafniaDataset"]:
890
+ def download_files_aws(
891
+ dataset: HafniaDataset,
892
+ path_output_folder: Path,
893
+ aws_credentials: AwsCredentials,
894
+ force_redownload: bool = False,
895
+ ) -> HafniaDataset:
896
+ from hafnia.platform.datasets import fast_copy_files_s3
897
+
898
+ remote_src_paths = dataset.samples[SampleField.REMOTE_PATH].unique().to_list()
899
+ update_rows = []
900
+ local_dst_paths = []
901
+ for remote_src_path in remote_src_paths:
902
+ local_path_str = (path_output_folder / "data" / Path(remote_src_path).name).absolute().as_posix()
903
+ local_dst_paths.append(local_path_str)
904
+ update_rows.append(
905
+ {
906
+ SampleField.REMOTE_PATH: remote_src_path,
907
+ SampleField.FILE_PATH: local_path_str,
908
+ }
909
+ )
910
+ update_df = pl.DataFrame(update_rows)
911
+ samples = dataset.samples.update(update_df, on=[SampleField.REMOTE_PATH])
912
+ dataset = dataset.update_samples(samples)
913
+
914
+ if not force_redownload:
915
+ download_indices = [idx for idx, local_path in enumerate(local_dst_paths) if not Path(local_path).exists()]
916
+ n_files = len(local_dst_paths)
917
+ skip_files = n_files - len(download_indices)
918
+ if skip_files > 0:
919
+ user_logger.info(
920
+ f"Found {skip_files}/{n_files} files already exists. Downloading {len(download_indices)} files."
921
+ )
922
+ remote_src_paths = [remote_src_paths[idx] for idx in download_indices]
923
+ local_dst_paths = [local_dst_paths[idx] for idx in download_indices]
924
+
925
+ if len(remote_src_paths) == 0:
926
+ user_logger.info(
927
+ "All files already exist locally. Skipping download. Set 'force_redownload=True' to re-download."
928
+ )
929
+ return dataset
930
+
931
+ environment_vars = aws_credentials.aws_credentials()
932
+ fast_copy_files_s3(
933
+ src_paths=remote_src_paths,
934
+ dst_paths=local_dst_paths,
935
+ append_envs=environment_vars,
936
+ description="Downloading images",
937
+ )
938
+ return dataset
939
+
940
+ def to_dict_dataset_splits(self) -> Dict[str, "HafniaDataset"]:
802
941
  """
803
942
  Splits the dataset into multiple datasets based on the 'split' column.
804
943
  Returns a dictionary with split names as keys and HafniaDataset objects as values.
805
944
  """
806
- if ColumnName.SPLIT not in self.samples.columns:
807
- raise ValueError(f"Dataset must contain a '{ColumnName.SPLIT}' column.")
945
+ if SampleField.SPLIT not in self.samples.columns:
946
+ raise ValueError(f"Dataset must contain a '{SampleField.SPLIT}' column.")
808
947
 
809
948
  splits = {}
810
949
  for split_name in SplitName.valid_splits():
@@ -813,20 +952,11 @@ class HafniaDataset:
813
952
  return splits
814
953
 
815
954
  def create_sample_dataset(self) -> "HafniaDataset":
816
- # Backwards compatibility. Remove in future versions when dataset have been updated
817
- if "is_sample" in self.samples.columns:
818
- user_logger.warning(
819
- "'is_sample' column found in the dataset. This column is deprecated and will be removed in future versions. "
820
- "Please use the 'tags' column with the tag 'sample' instead."
821
- )
822
- table = self.samples.filter(pl.col("is_sample") == True) # noqa: E712
823
- return self.update_samples(table)
824
-
825
- if ColumnName.TAGS not in self.samples.columns:
826
- raise ValueError(f"Dataset must contain an '{ColumnName.TAGS}' column.")
955
+ if SampleField.TAGS not in self.samples.columns:
956
+ raise ValueError(f"Dataset must contain an '{SampleField.TAGS}' column.")
827
957
 
828
958
  table = self.samples.filter(
829
- pl.col(ColumnName.TAGS).list.eval(pl.element().filter(pl.element() == TAG_IS_SAMPLE)).list.len() > 0
959
+ pl.col(SampleField.TAGS).list.eval(pl.element().filter(pl.element() == TAG_IS_SAMPLE)).list.len() > 0
830
960
  )
831
961
  return self.update_samples(table)
832
962
 
@@ -837,10 +967,10 @@ class HafniaDataset:
837
967
  split_names = split_name
838
968
 
839
969
  for name in split_names:
840
- if name not in SplitName.valid_splits():
970
+ if name not in SplitName.all_split_names():
841
971
  raise ValueError(f"Invalid split name: {split_name}. Valid splits are: {SplitName.valid_splits()}")
842
972
 
843
- filtered_dataset = self.samples.filter(pl.col(ColumnName.SPLIT).is_in(split_names))
973
+ filtered_dataset = self.samples.filter(pl.col(SampleField.SPLIT).is_in(split_names))
844
974
  return self.update_samples(filtered_dataset)
845
975
 
846
976
  def update_samples(self, table: pl.DataFrame) -> "HafniaDataset":
@@ -875,30 +1005,69 @@ class HafniaDataset:
875
1005
  def copy(self) -> "HafniaDataset":
876
1006
  return HafniaDataset(info=self.info.model_copy(deep=True), samples=self.samples.clone())
877
1007
 
1008
+ def create_primitive_table(
1009
+ self,
1010
+ primitive: Type[Primitive],
1011
+ task_name: Optional[str] = None,
1012
+ keep_sample_data: bool = False,
1013
+ ) -> pl.DataFrame:
1014
+ return table_transformations.create_primitive_table(
1015
+ samples_table=self.samples,
1016
+ PrimitiveType=primitive,
1017
+ task_name=task_name,
1018
+ keep_sample_data=keep_sample_data,
1019
+ )
1020
+
878
1021
  def write(self, path_folder: Path, add_version: bool = False, drop_null_cols: bool = True) -> None:
879
1022
  user_logger.info(f"Writing dataset to {path_folder}...")
1023
+ path_folder = path_folder.absolute()
880
1024
  if not path_folder.exists():
881
1025
  path_folder.mkdir(parents=True)
882
-
883
- new_relative_paths = []
884
- org_paths = self.samples[ColumnName.FILE_PATH].to_list()
1026
+ hafnia_dataset = self.copy() # To avoid inplace modifications
1027
+ new_paths = []
1028
+ org_paths = hafnia_dataset.samples[SampleField.FILE_PATH].to_list()
885
1029
  for org_path in track(org_paths, description="- Copy images"):
886
1030
  new_path = dataset_helpers.copy_and_rename_file_to_hash_value(
887
1031
  path_source=Path(org_path),
888
1032
  path_dataset_root=path_folder,
889
1033
  )
890
- new_relative_paths.append(str(new_path.relative_to(path_folder)))
891
- table = self.samples.with_columns(pl.Series(new_relative_paths).alias(ColumnName.FILE_PATH))
1034
+ new_paths.append(str(new_path))
1035
+ hafnia_dataset.samples = hafnia_dataset.samples.with_columns(pl.Series(new_paths).alias(SampleField.FILE_PATH))
1036
+ hafnia_dataset.write_annotations(
1037
+ path_folder=path_folder,
1038
+ drop_null_cols=drop_null_cols,
1039
+ add_version=add_version,
1040
+ )
1041
+
1042
+ def write_annotations(
1043
+ dataset: HafniaDataset,
1044
+ path_folder: Path,
1045
+ drop_null_cols: bool = True,
1046
+ add_version: bool = False,
1047
+ ) -> None:
1048
+ """
1049
+ Writes only the annotations files (JSONL and Parquet) to the specified folder.
1050
+ """
1051
+ user_logger.info(f"Writing dataset annotations to {path_folder}...")
1052
+ path_folder = path_folder.absolute()
1053
+ if not path_folder.exists():
1054
+ path_folder.mkdir(parents=True)
1055
+ dataset.info.write_json(path_folder / FILENAME_DATASET_INFO)
892
1056
 
1057
+ samples = dataset.samples
893
1058
  if drop_null_cols: # Drops all unused/Null columns
894
- table = table.drop(pl.selectors.by_dtype(pl.Null))
1059
+ samples = samples.drop(pl.selectors.by_dtype(pl.Null))
1060
+
1061
+ # Store only relative paths in the annotations files
1062
+ absolute_paths = samples[SampleField.FILE_PATH].to_list()
1063
+ relative_paths = [str(Path(path).relative_to(path_folder)) for path in absolute_paths]
1064
+ samples = samples.with_columns(pl.Series(relative_paths).alias(SampleField.FILE_PATH))
895
1065
 
896
- table.write_ndjson(path_folder / FILENAME_ANNOTATIONS_JSONL) # Json for readability
897
- table.write_parquet(path_folder / FILENAME_ANNOTATIONS_PARQUET) # Parquet for speed
898
- self.info.write_json(path_folder / FILENAME_DATASET_INFO)
1066
+ samples.write_ndjson(path_folder / FILENAME_ANNOTATIONS_JSONL) # Json for readability
1067
+ samples.write_parquet(path_folder / FILENAME_ANNOTATIONS_PARQUET) # Parquet for speed
899
1068
 
900
1069
  if add_version:
901
- path_version = path_folder / "versions" / f"{self.info.version}"
1070
+ path_version = path_folder / "versions" / f"{dataset.info.version}"
902
1071
  path_version.mkdir(parents=True, exist_ok=True)
903
1072
  for filename in DATASET_FILENAMES_REQUIRED:
904
1073
  shutil.copy2(path_folder / filename, path_version / filename)
@@ -956,19 +1125,24 @@ def _dataset_corrections(samples: pl.DataFrame, dataset_info: DatasetInfo) -> Tu
956
1125
  format_version_of_dataset = Version(dataset_info.format_version)
957
1126
 
958
1127
  ## Backwards compatibility fixes for older dataset versions
959
- if format_version_of_dataset <= Version("0.3.0"):
960
- if ColumnName.DATASET_NAME not in samples.columns:
961
- samples = samples.with_columns(pl.lit(dataset_info.dataset_name).alias(ColumnName.DATASET_NAME))
1128
+ if format_version_of_dataset < Version("0.2.0"):
1129
+ samples = table_transformations.add_dataset_name_if_missing(samples, dataset_info.dataset_name)
962
1130
 
963
1131
  if "file_name" in samples.columns:
964
- samples = samples.rename({"file_name": ColumnName.FILE_PATH})
1132
+ samples = samples.rename({"file_name": SampleField.FILE_PATH})
965
1133
 
966
- if ColumnName.SAMPLE_INDEX not in samples.columns:
967
- samples = samples.with_row_index(name=ColumnName.SAMPLE_INDEX)
1134
+ if SampleField.SAMPLE_INDEX not in samples.columns:
1135
+ samples = table_transformations.add_sample_index(samples)
968
1136
 
969
1137
  # Backwards compatibility: If tags-column doesn't exist, create it with empty lists
970
- if ColumnName.TAGS not in samples.columns:
1138
+ if SampleField.TAGS not in samples.columns:
971
1139
  tags_column: List[List[str]] = [[] for _ in range(len(samples))] # type: ignore[annotation-unchecked]
972
- samples = samples.with_columns(pl.Series(tags_column, dtype=pl.List(pl.String)).alias(ColumnName.TAGS))
1140
+ samples = samples.with_columns(pl.Series(tags_column, dtype=pl.List(pl.String)).alias(SampleField.TAGS))
1141
+
1142
+ if SampleField.STORAGE_FORMAT not in samples.columns:
1143
+ samples = samples.with_columns(pl.lit(StorageFormat.IMAGE).alias(SampleField.STORAGE_FORMAT))
1144
+
1145
+ if SampleField.SAMPLE_INDEX in samples.columns and samples[SampleField.SAMPLE_INDEX].dtype != pl.UInt64:
1146
+ samples = samples.cast({SampleField.SAMPLE_INDEX: pl.UInt64})
973
1147
 
974
1148
  return samples, dataset_info