hafnia 0.3.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. cli/__main__.py +3 -1
  2. cli/config.py +43 -3
  3. cli/keychain.py +88 -0
  4. cli/profile_cmds.py +5 -2
  5. hafnia/__init__.py +1 -1
  6. hafnia/dataset/dataset_helpers.py +9 -2
  7. hafnia/dataset/dataset_names.py +130 -16
  8. hafnia/dataset/dataset_recipe/dataset_recipe.py +49 -37
  9. hafnia/dataset/dataset_recipe/recipe_transforms.py +18 -2
  10. hafnia/dataset/dataset_upload_helper.py +83 -22
  11. hafnia/dataset/format_conversions/format_image_classification_folder.py +110 -0
  12. hafnia/dataset/format_conversions/format_yolo.py +164 -0
  13. hafnia/dataset/format_conversions/torchvision_datasets.py +287 -0
  14. hafnia/dataset/hafnia_dataset.py +396 -96
  15. hafnia/dataset/operations/dataset_stats.py +84 -73
  16. hafnia/dataset/operations/dataset_transformations.py +116 -47
  17. hafnia/dataset/operations/table_transformations.py +135 -17
  18. hafnia/dataset/primitives/bbox.py +25 -14
  19. hafnia/dataset/primitives/bitmask.py +22 -15
  20. hafnia/dataset/primitives/classification.py +16 -8
  21. hafnia/dataset/primitives/point.py +7 -3
  22. hafnia/dataset/primitives/polygon.py +15 -10
  23. hafnia/dataset/primitives/primitive.py +1 -1
  24. hafnia/dataset/primitives/segmentation.py +12 -9
  25. hafnia/experiment/hafnia_logger.py +0 -9
  26. hafnia/platform/dataset_recipe.py +7 -2
  27. hafnia/platform/datasets.py +5 -9
  28. hafnia/platform/download.py +24 -90
  29. hafnia/torch_helpers.py +12 -12
  30. hafnia/utils.py +17 -0
  31. hafnia/visualizations/image_visualizations.py +3 -1
  32. {hafnia-0.3.0.dist-info → hafnia-0.4.1.dist-info}/METADATA +11 -9
  33. hafnia-0.4.1.dist-info/RECORD +57 -0
  34. hafnia-0.3.0.dist-info/RECORD +0 -53
  35. {hafnia-0.3.0.dist-info → hafnia-0.4.1.dist-info}/WHEEL +0 -0
  36. {hafnia-0.3.0.dist-info → hafnia-0.4.1.dist-info}/entry_points.txt +0 -0
  37. {hafnia-0.3.0.dist-info → hafnia-0.4.1.dist-info}/licenses/LICENSE +0 -0
@@ -8,14 +8,16 @@ from dataclasses import dataclass
8
8
  from datetime import datetime
9
9
  from pathlib import Path
10
10
  from random import Random
11
- from typing import Any, Dict, List, Optional, Type, Union
11
+ from typing import Any, Dict, List, Optional, Tuple, Type, Union
12
12
 
13
+ import cv2
13
14
  import more_itertools
14
15
  import numpy as np
15
16
  import polars as pl
17
+ from packaging.version import Version
16
18
  from PIL import Image
17
19
  from pydantic import BaseModel, Field, field_serializer, field_validator
18
- from tqdm import tqdm
20
+ from rich.progress import track
19
21
 
20
22
  import hafnia
21
23
  from hafnia.dataset import dataset_helpers
@@ -26,13 +28,20 @@ from hafnia.dataset.dataset_names import (
26
28
  FILENAME_DATASET_INFO,
27
29
  FILENAME_RECIPE_JSON,
28
30
  TAG_IS_SAMPLE,
29
- ColumnName,
31
+ AwsCredentials,
32
+ PrimitiveField,
33
+ SampleField,
30
34
  SplitName,
35
+ StorageFormat,
31
36
  )
32
- from hafnia.dataset.operations import dataset_stats, dataset_transformations, table_transformations
33
- from hafnia.dataset.operations.table_transformations import (
34
- check_image_paths,
35
- read_table_from_path,
37
+ from hafnia.dataset.format_conversions import (
38
+ format_image_classification_folder,
39
+ format_yolo,
40
+ )
41
+ from hafnia.dataset.operations import (
42
+ dataset_stats,
43
+ dataset_transformations,
44
+ table_transformations,
36
45
  )
37
46
  from hafnia.dataset.primitives import PRIMITIVE_TYPES, get_primitive_type_from_string
38
47
  from hafnia.dataset.primitives.bbox import Bbox
@@ -44,14 +53,30 @@ from hafnia.log import user_logger
44
53
 
45
54
 
46
55
  class TaskInfo(BaseModel):
47
- primitive: Type[Primitive] # Primitive class or string name of the primitive, e.g. "Bbox" or "bitmask"
48
- class_names: Optional[List[str]] # Class names for the tasks. To get consistent class indices specify class_names.
49
- name: Optional[str] = None # Use 'None' to use default name Bbox ->"bboxes", Bitmask -> "bitmasks" etc.
56
+ primitive: Type[Primitive] = Field(
57
+ description="Primitive class or string name of the primitive, e.g. 'Bbox' or 'bitmask'"
58
+ )
59
+ class_names: Optional[List[str]] = Field(default=None, description="Optional list of class names for the primitive")
60
+ name: Optional[str] = Field(
61
+ default=None,
62
+ description=(
63
+ "Optional name for the task. 'None' will use default name of the provided primitive. "
64
+ "e.g. Bbox ->'bboxes', Bitmask -> 'bitmasks' etc."
65
+ ),
66
+ )
50
67
 
51
68
  def model_post_init(self, __context: Any) -> None:
52
69
  if self.name is None:
53
70
  self.name = self.primitive.default_task_name()
54
71
 
72
+ def get_class_index(self, class_name: str) -> int:
73
+ """Get class index for a given class name"""
74
+ if self.class_names is None:
75
+ raise ValueError(f"Task '{self.name}' has no class names defined.")
76
+ if class_name not in self.class_names:
77
+ raise ValueError(f"Class name '{class_name}' not found in task '{self.name}'.")
78
+ return self.class_names.index(class_name)
79
+
55
80
  # The 'primitive'-field of type 'Type[Primitive]' is not supported by pydantic out-of-the-box as
56
81
  # the 'Primitive' class is an abstract base class and for the actual primtives such as Bbox, Bitmask, Classification.
57
82
  # Below magic functions ('ensure_primitive' and 'serialize_primitive') ensures that the 'primitive' field can
@@ -87,6 +112,10 @@ class TaskInfo(BaseModel):
87
112
  )
88
113
  return class_names
89
114
 
115
+ def full_name(self) -> str:
116
+ """Get qualified name for the task: <primitive_name>:<task_name>"""
117
+ return f"{self.primitive.__name__}:{self.name}"
118
+
90
119
  # To get unique hash value for TaskInfo objects
91
120
  def __hash__(self) -> int:
92
121
  class_names = self.class_names or []
@@ -99,17 +128,36 @@ class TaskInfo(BaseModel):
99
128
 
100
129
 
101
130
  class DatasetInfo(BaseModel):
102
- dataset_name: str
103
- version: str # Dataset version. This is not the same as the Hafnia dataset format version.
104
- tasks: List[TaskInfo]
105
- distributions: Optional[List[TaskInfo]] = None # Distributions. TODO: FIX/REMOVE/CHANGE this
106
- meta: Optional[Dict[str, Any]] = None # Metadata about the dataset, e.g. description, etc.
107
- format_version: str = hafnia.__dataset_format_version__ # Version of the Hafnia dataset format
108
- updated_at: datetime = datetime.now()
131
+ dataset_name: str = Field(description="Name of the dataset, e.g. 'coco'")
132
+ version: Optional[str] = Field(default=None, description="Version of the dataset")
133
+ tasks: List[TaskInfo] = Field(default=None, description="List of tasks in the dataset")
134
+ reference_bibtex: Optional[str] = Field(
135
+ default=None,
136
+ description="Optional, BibTeX reference to dataset publication",
137
+ )
138
+ reference_paper_url: Optional[str] = Field(
139
+ default=None,
140
+ description="Optional, URL to dataset publication",
141
+ )
142
+ reference_dataset_page: Optional[str] = Field(
143
+ default=None,
144
+ description="Optional, URL to the dataset page",
145
+ )
146
+ meta: Optional[Dict[str, Any]] = Field(default=None, description="Optional metadata about the dataset")
147
+ format_version: str = Field(
148
+ default=hafnia.__dataset_format_version__,
149
+ description="Version of the Hafnia dataset format. You should not set this manually.",
150
+ )
151
+ updated_at: datetime = Field(
152
+ default_factory=datetime.now,
153
+ description="Timestamp of the last update to the dataset info. You should not set this manually.",
154
+ )
109
155
 
110
156
  @field_validator("tasks", mode="after")
111
157
  @classmethod
112
- def _validate_check_for_duplicate_tasks(cls, tasks: List[TaskInfo]) -> List[TaskInfo]:
158
+ def _validate_check_for_duplicate_tasks(cls, tasks: Optional[List[TaskInfo]]) -> List[TaskInfo]:
159
+ if tasks is None:
160
+ return []
113
161
  task_name_counts = collections.Counter(task.name for task in tasks)
114
162
  duplicate_task_names = [name for name, count in task_name_counts.items() if count > 1]
115
163
  if duplicate_task_names:
@@ -118,6 +166,35 @@ class DatasetInfo(BaseModel):
118
166
  )
119
167
  return tasks
120
168
 
169
+ @field_validator("format_version")
170
+ @classmethod
171
+ def _validate_format_version(cls, format_version: str) -> str:
172
+ try:
173
+ Version(format_version)
174
+ except Exception as e:
175
+ raise ValueError(f"Invalid format_version '{format_version}'. Must be a valid version string.") from e
176
+
177
+ if Version(format_version) > Version(hafnia.__dataset_format_version__):
178
+ user_logger.warning(
179
+ f"The loaded dataset format version '{format_version}' is newer than the format version "
180
+ f"'{hafnia.__dataset_format_version__}' used in your version of Hafnia. Please consider "
181
+ f"updating Hafnia package."
182
+ )
183
+ return format_version
184
+
185
+ @field_validator("version")
186
+ @classmethod
187
+ def _validate_version(cls, dataset_version: Optional[str]) -> Optional[str]:
188
+ if dataset_version is None:
189
+ return None
190
+
191
+ try:
192
+ Version(dataset_version)
193
+ except Exception as e:
194
+ raise ValueError(f"Invalid dataset_version '{dataset_version}'. Must be a valid version string.") from e
195
+
196
+ return dataset_version
197
+
121
198
  def check_for_duplicate_task_names(self) -> List[TaskInfo]:
122
199
  return self._validate_check_for_duplicate_tasks(self.tasks)
123
200
 
@@ -182,14 +259,12 @@ class DatasetInfo(BaseModel):
182
259
  f"Hafnia format version '{hafnia.__dataset_format_version__}'."
183
260
  )
184
261
  unique_tasks = set(info0.tasks + info1.tasks)
185
- distributions = set((info0.distributions or []) + (info1.distributions or []))
186
262
  meta = (info0.meta or {}).copy()
187
263
  meta.update(info1.meta or {})
188
264
  return DatasetInfo(
189
265
  dataset_name=info0.dataset_name + "+" + info1.dataset_name,
190
- version="merged",
266
+ version=None,
191
267
  tasks=list(unique_tasks),
192
- distributions=list(distributions),
193
268
  meta=meta,
194
269
  format_version=dataset_format_version,
195
270
  )
@@ -205,16 +280,24 @@ class DatasetInfo(BaseModel):
205
280
  raise ValueError(f"Multiple tasks found with name '{task_name}'. This should not happen!")
206
281
  return tasks_with_name[0]
207
282
 
208
- def get_task_by_primitive(self, primitive: Union[Type[Primitive], str]) -> TaskInfo:
283
+ def get_tasks_by_primitive(self, primitive: Union[Type[Primitive], str]) -> List[TaskInfo]:
209
284
  """
210
- Get task by its primitive type. Raises an error if the primitive type is not found or if multiple tasks
211
- have the same primitive type.
285
+ Get all tasks by their primitive type.
212
286
  """
213
287
  if isinstance(primitive, str):
214
288
  primitive = get_primitive_type_from_string(primitive)
215
289
 
216
290
  tasks_with_primitive = [task for task in self.tasks if task.primitive == primitive]
217
- if not tasks_with_primitive:
291
+ return tasks_with_primitive
292
+
293
+ def get_task_by_primitive(self, primitive: Union[Type[Primitive], str]) -> TaskInfo:
294
+ """
295
+ Get task by its primitive type. Raises an error if the primitive type is not found or if multiple tasks
296
+ have the same primitive type.
297
+ """
298
+
299
+ tasks_with_primitive = self.get_tasks_by_primitive(primitive)
300
+ if len(tasks_with_primitive) == 0:
218
301
  raise ValueError(f"Task with primitive {primitive} not found in dataset info.")
219
302
  if len(tasks_with_primitive) > 1:
220
303
  raise ValueError(
@@ -258,22 +341,44 @@ class DatasetInfo(BaseModel):
258
341
 
259
342
 
260
343
  class Sample(BaseModel):
261
- file_name: str
262
- height: int
263
- width: int
264
- split: str # Split name, e.g., "train", "val", "test"
265
- tags: List[str] = [] # tags for a given sample. Used for creating subsets of the dataset.
266
- collection_index: Optional[int] = None # Optional e.g. frame number for video datasets
267
- collection_id: Optional[str] = None # Optional e.g. video name for video datasets
268
- remote_path: Optional[str] = None # Optional remote path for the image, if applicable
269
- sample_index: Optional[int] = None # Don't manually set this, it is used for indexing samples in the dataset.
270
- classifications: Optional[List[Classification]] = None # Optional classification primitive
271
- objects: Optional[List[Bbox]] = None # List of coordinate primitives, e.g., Bbox, Bitmask, etc.
272
- bitmasks: Optional[List[Bitmask]] = None # List of bitmasks, if applicable
273
- polygons: Optional[List[Polygon]] = None # List of polygons, if applicable
274
-
275
- attribution: Optional[Attribution] = None # Attribution information for the image
276
- meta: Optional[Dict] = None # Additional metadata, e.g., camera settings, GPS data, etc.
344
+ file_path: Optional[str] = Field(description="Path to the image/video file.")
345
+ height: int = Field(description="Height of the image")
346
+ width: int = Field(description="Width of the image")
347
+ split: str = Field(description="Split name, e.g., 'train', 'val', 'test'")
348
+ tags: List[str] = Field(
349
+ default_factory=list,
350
+ description="Tags for a given sample. Used for creating subsets of the dataset.",
351
+ )
352
+ storage_format: str = Field(
353
+ default=StorageFormat.IMAGE,
354
+ description="Storage format. Sample data is stored as image or inside a video or zip file.",
355
+ )
356
+ collection_index: Optional[int] = Field(default=None, description="Optional e.g. frame number for video datasets")
357
+ collection_id: Optional[str] = Field(default=None, description="Optional e.g. video name for video datasets")
358
+ remote_path: Optional[str] = Field(default=None, description="Optional remote path for the image, if applicable")
359
+ sample_index: Optional[int] = Field(
360
+ default=None,
361
+ description="Don't manually set this, it is used for indexing samples in the dataset.",
362
+ )
363
+ classifications: Optional[List[Classification]] = Field(
364
+ default=None, description="Optional list of classifications"
365
+ )
366
+ bboxes: Optional[List[Bbox]] = Field(default=None, description="Optional list of bounding boxes")
367
+ bitmasks: Optional[List[Bitmask]] = Field(default=None, description="Optional list of bitmasks")
368
+ polygons: Optional[List[Polygon]] = Field(default=None, description="Optional list of polygons")
369
+
370
+ attribution: Optional[Attribution] = Field(default=None, description="Attribution information for the image")
371
+ dataset_name: Optional[str] = Field(
372
+ default=None,
373
+ description=(
374
+ "Don't manually set this, it will be automatically defined during initialization. "
375
+ "Name of the dataset the sample belongs to. E.g. 'coco-2017' or 'midwest-vehicle-detection'."
376
+ ),
377
+ )
378
+ meta: Optional[Dict] = Field(
379
+ default=None,
380
+ description="Additional metadata, e.g., camera settings, GPS data, etc.",
381
+ )
277
382
 
278
383
  def get_annotations(self, primitive_types: Optional[List[Type[Primitive]]] = None) -> List[Primitive]:
279
384
  """
@@ -294,7 +399,9 @@ class Sample(BaseModel):
294
399
  Reads the image from the file path and returns it as a PIL Image.
295
400
  Raises FileNotFoundError if the image file does not exist.
296
401
  """
297
- path_image = Path(self.file_name)
402
+ if self.file_path is None:
403
+ raise ValueError(f"Sample has no '{SampleField.FILE_PATH}' defined.")
404
+ path_image = Path(self.file_path)
298
405
  if not path_image.exists():
299
406
  raise FileNotFoundError(f"Image file {path_image} does not exist. Please check the file path.")
300
407
 
@@ -302,8 +409,22 @@ class Sample(BaseModel):
302
409
  return image
303
410
 
304
411
  def read_image(self) -> np.ndarray:
305
- image_pil = self.read_image_pillow()
306
- image = np.array(image_pil)
412
+ if self.storage_format == StorageFormat.VIDEO:
413
+ video = cv2.VideoCapture(str(self.file_path))
414
+ if self.collection_index is None:
415
+ raise ValueError("collection_index must be set for video storage format to read the correct frame.")
416
+ video.set(cv2.CAP_PROP_POS_FRAMES, self.collection_index)
417
+ success, image = video.read()
418
+ video.release()
419
+ if not success:
420
+ raise ValueError(f"Could not read frame {self.collection_index} from video file {self.file_path}.")
421
+ return image
422
+
423
+ elif self.storage_format == StorageFormat.IMAGE:
424
+ image_pil = self.read_image_pillow()
425
+ image = np.array(image_pil)
426
+ else:
427
+ raise ValueError(f"Unsupported storage format: {self.storage_format}")
307
428
  return image
308
429
 
309
430
  def draw_annotations(self, image: Optional[np.ndarray] = None) -> np.ndarray:
@@ -386,9 +507,11 @@ class HafniaDataset:
386
507
  samples: pl.DataFrame
387
508
 
388
509
  # Function mapping: Dataset stats
389
- split_counts = dataset_stats.split_counts
390
- class_counts_for_task = dataset_stats.class_counts_for_task
391
- class_counts_all = dataset_stats.class_counts_all
510
+ calculate_split_counts = dataset_stats.calculate_split_counts
511
+ calculate_split_counts_extended = dataset_stats.calculate_split_counts_extended
512
+ calculate_task_class_counts = dataset_stats.calculate_task_class_counts
513
+ calculate_class_counts = dataset_stats.calculate_class_counts
514
+ calculate_primitive_counts = dataset_stats.calculate_primitive_counts
392
515
 
393
516
  # Function mapping: Print stats
394
517
  print_stats = dataset_stats.print_stats
@@ -401,6 +524,13 @@ class HafniaDataset:
401
524
 
402
525
  # Function mapping: Dataset transformations
403
526
  transform_images = dataset_transformations.transform_images
527
+ convert_to_image_storage_format = dataset_transformations.convert_to_image_storage_format
528
+
529
+ # Import / export functions
530
+ from_yolo_format = format_yolo.from_yolo_format
531
+ to_yolo_format = format_yolo.to_yolo_format
532
+ to_image_classification_folder = format_image_classification_folder.to_image_classification_folder
533
+ from_image_classification_folder = format_image_classification_folder.from_image_classification_folder
404
534
 
405
535
  def __getitem__(self, item: int) -> Dict[str, Any]:
406
536
  return self.samples.row(index=item, named=True)
@@ -413,30 +543,23 @@ class HafniaDataset:
413
543
  yield row
414
544
 
415
545
  def __post_init__(self):
416
- samples = self.samples
417
- if ColumnName.SAMPLE_INDEX not in samples.columns:
418
- samples = samples.with_row_index(name=ColumnName.SAMPLE_INDEX)
419
-
420
- # Backwards compatibility: If tags-column doesn't exist, create it with empty lists
421
- if ColumnName.TAGS not in samples.columns:
422
- tags_column: List[List[str]] = [[] for _ in range(len(self))] # type: ignore[annotation-unchecked]
423
- samples = samples.with_columns(pl.Series(tags_column, dtype=pl.List(pl.String)).alias(ColumnName.TAGS))
424
-
425
- self.samples = samples
546
+ self.samples, self.info = _dataset_corrections(self.samples, self.info)
426
547
 
427
548
  @staticmethod
428
549
  def from_path(path_folder: Path, check_for_images: bool = True) -> "HafniaDataset":
550
+ path_folder = Path(path_folder)
429
551
  HafniaDataset.check_dataset_path(path_folder, raise_error=True)
430
552
 
431
553
  dataset_info = DatasetInfo.from_json_file(path_folder / FILENAME_DATASET_INFO)
432
- table = read_table_from_path(path_folder)
554
+ samples = table_transformations.read_samples_from_path(path_folder)
555
+ samples, dataset_info = _dataset_corrections(samples, dataset_info)
433
556
 
434
557
  # Convert from relative paths to absolute paths
435
558
  dataset_root = path_folder.absolute().as_posix() + "/"
436
- table = table.with_columns((dataset_root + pl.col("file_name")).alias("file_name"))
559
+ samples = samples.with_columns((dataset_root + pl.col(SampleField.FILE_PATH)).alias(SampleField.FILE_PATH))
437
560
  if check_for_images:
438
- check_image_paths(table)
439
- return HafniaDataset(samples=table, info=dataset_info)
561
+ table_transformations.check_image_paths(samples)
562
+ return HafniaDataset(samples=samples, info=dataset_info)
440
563
 
441
564
  @staticmethod
442
565
  def from_name(name: str, force_redownload: bool = False, download_files: bool = True) -> "HafniaDataset":
@@ -462,8 +585,12 @@ class HafniaDataset:
462
585
  else:
463
586
  raise TypeError(f"Unsupported sample type: {type(sample)}. Expected Sample or dict.")
464
587
 
465
- table = pl.from_records(json_samples)
466
- table = table.drop(ColumnName.SAMPLE_INDEX).with_row_index(name=ColumnName.SAMPLE_INDEX)
588
+ # To ensure that the 'file_path' column is of type string even if all samples have 'None' as file_path
589
+ schema_override = {SampleField.FILE_PATH: pl.String}
590
+ table = pl.from_records(json_samples, schema_overrides=schema_override)
591
+ table = table.drop(pl.selectors.by_dtype(pl.Null))
592
+ table = table_transformations.add_sample_index(table)
593
+ table = table_transformations.add_dataset_name_if_missing(table, dataset_name=info.dataset_name)
467
594
  return HafniaDataset(info=info, samples=table)
468
595
 
469
596
  @staticmethod
@@ -518,6 +645,28 @@ class HafniaDataset:
518
645
  merged_dataset = HafniaDataset.merge(merged_dataset, dataset)
519
646
  return merged_dataset
520
647
 
648
+ @staticmethod
649
+ def from_name_public_dataset(
650
+ name: str,
651
+ force_redownload: bool = False,
652
+ n_samples: Optional[int] = None,
653
+ ) -> HafniaDataset:
654
+ from hafnia.dataset.format_conversions.torchvision_datasets import (
655
+ torchvision_to_hafnia_converters,
656
+ )
657
+
658
+ name_to_torchvision_function = torchvision_to_hafnia_converters()
659
+
660
+ if name not in name_to_torchvision_function:
661
+ raise ValueError(
662
+ f"Unknown torchvision dataset name: {name}. Supported: {list(name_to_torchvision_function.keys())}"
663
+ )
664
+ vision_dataset = name_to_torchvision_function[name]
665
+ return vision_dataset(
666
+ force_redownload=force_redownload,
667
+ n_samples=n_samples,
668
+ )
669
+
521
670
  def shuffle(dataset: HafniaDataset, seed: int = 42) -> HafniaDataset:
522
671
  table = dataset.samples.sample(n=len(dataset), with_replacement=False, seed=seed, shuffle=True)
523
672
  return dataset.update_samples(table)
@@ -575,12 +724,12 @@ class HafniaDataset:
575
724
  """
576
725
  dataset_split_to_be_divided = dataset.create_split_dataset(split_name=split_name)
577
726
  if len(dataset_split_to_be_divided) == 0:
578
- split_counts = dict(dataset.samples.select(pl.col(ColumnName.SPLIT).value_counts()).iter_rows())
727
+ split_counts = dict(dataset.samples.select(pl.col(SampleField.SPLIT).value_counts()).iter_rows())
579
728
  raise ValueError(f"No samples in the '{split_name}' split to divide into multiple splits. {split_counts=}")
580
729
  assert len(dataset_split_to_be_divided) > 0, f"No samples in the '{split_name}' split!"
581
730
  dataset_split_to_be_divided = dataset_split_to_be_divided.splits_by_ratios(split_ratios=split_ratios, seed=42)
582
731
 
583
- remaining_data = dataset.samples.filter(pl.col(ColumnName.SPLIT).is_in([split_name]).not_())
732
+ remaining_data = dataset.samples.filter(pl.col(SampleField.SPLIT).is_in([split_name]).not_())
584
733
  new_table = pl.concat([remaining_data, dataset_split_to_be_divided.samples], how="vertical")
585
734
  dataset_new = dataset.update_samples(new_table)
586
735
  return dataset_new
@@ -593,21 +742,23 @@ class HafniaDataset:
593
742
 
594
743
  # Remove any pre-existing "sample"-tags
595
744
  samples = samples.with_columns(
596
- pl.col(ColumnName.TAGS).list.eval(pl.element().filter(pl.element() != TAG_IS_SAMPLE)).alias(ColumnName.TAGS)
745
+ pl.col(SampleField.TAGS)
746
+ .list.eval(pl.element().filter(pl.element() != TAG_IS_SAMPLE))
747
+ .alias(SampleField.TAGS)
597
748
  )
598
749
 
599
750
  # Add "sample" to tags column for the selected samples
600
751
  is_sample_indices = Random(seed).sample(range(len(dataset)), n_samples)
601
752
  samples = samples.with_columns(
602
753
  pl.when(pl.int_range(len(samples)).is_in(is_sample_indices))
603
- .then(pl.col(ColumnName.TAGS).list.concat(pl.lit([TAG_IS_SAMPLE])))
604
- .otherwise(pl.col(ColumnName.TAGS))
754
+ .then(pl.col(SampleField.TAGS).list.concat(pl.lit([TAG_IS_SAMPLE])))
755
+ .otherwise(pl.col(SampleField.TAGS))
605
756
  )
606
757
  return dataset.update_samples(samples)
607
758
 
608
759
  def class_mapper(
609
760
  dataset: "HafniaDataset",
610
- class_mapping: Dict[str, str],
761
+ class_mapping: Union[Dict[str, str], List[Tuple[str, str]]],
611
762
  method: str = "strict",
612
763
  primitive: Optional[Type[Primitive]] = None,
613
764
  task_name: Optional[str] = None,
@@ -659,6 +810,47 @@ class HafniaDataset:
659
810
  dataset=dataset, old_task_name=old_task_name, new_task_name=new_task_name
660
811
  )
661
812
 
813
+ def drop_task(
814
+ dataset: "HafniaDataset",
815
+ task_name: str,
816
+ ) -> "HafniaDataset":
817
+ """
818
+ Drop a task from the dataset.
819
+ If 'task_name' and 'primitive' are not provided, the function will attempt to infer the task.
820
+ """
821
+ dataset = copy.copy(dataset) # To avoid mutating the original dataset. Shallow copy is sufficient
822
+ drop_task = dataset.info.get_task_by_name(task_name=task_name)
823
+ tasks_with_same_primitive = dataset.info.get_tasks_by_primitive(drop_task.primitive)
824
+
825
+ no_other_tasks_with_same_primitive = len(tasks_with_same_primitive) == 1
826
+ if no_other_tasks_with_same_primitive:
827
+ return dataset.drop_primitive(primitive=drop_task.primitive)
828
+
829
+ dataset.info = dataset.info.replace_task(old_task=drop_task, new_task=None)
830
+ dataset.samples = dataset.samples.with_columns(
831
+ pl.col(drop_task.primitive.column_name())
832
+ .list.filter(pl.element().struct.field(PrimitiveField.TASK_NAME) != drop_task.name)
833
+ .alias(drop_task.primitive.column_name())
834
+ )
835
+
836
+ return dataset
837
+
838
+ def drop_primitive(
839
+ dataset: "HafniaDataset",
840
+ primitive: Type[Primitive],
841
+ ) -> "HafniaDataset":
842
+ """
843
+ Drop a primitive from the dataset.
844
+ """
845
+ dataset = copy.copy(dataset) # To avoid mutating the original dataset. Shallow copy is sufficient
846
+ tasks_to_drop = dataset.info.get_tasks_by_primitive(primitive=primitive)
847
+ for task in tasks_to_drop:
848
+ dataset.info = dataset.info.replace_task(old_task=task, new_task=None)
849
+
850
+ # Drop the primitive column from the samples table
851
+ dataset.samples = dataset.samples.drop(primitive.column_name())
852
+ return dataset
853
+
662
854
  def select_samples_by_class_name(
663
855
  dataset: HafniaDataset,
664
856
  name: Union[List[str], str],
@@ -695,13 +887,63 @@ class HafniaDataset:
695
887
 
696
888
  return HafniaDataset(info=merged_info, samples=merged_samples)
697
889
 
698
- def as_dict_dataset_splits(self) -> Dict[str, "HafniaDataset"]:
890
+ def download_files_aws(
891
+ dataset: HafniaDataset,
892
+ path_output_folder: Path,
893
+ aws_credentials: AwsCredentials,
894
+ force_redownload: bool = False,
895
+ ) -> HafniaDataset:
896
+ from hafnia.platform.datasets import fast_copy_files_s3
897
+
898
+ remote_src_paths = dataset.samples[SampleField.REMOTE_PATH].unique().to_list()
899
+ update_rows = []
900
+ local_dst_paths = []
901
+ for remote_src_path in remote_src_paths:
902
+ local_path_str = (path_output_folder / "data" / Path(remote_src_path).name).absolute().as_posix()
903
+ local_dst_paths.append(local_path_str)
904
+ update_rows.append(
905
+ {
906
+ SampleField.REMOTE_PATH: remote_src_path,
907
+ SampleField.FILE_PATH: local_path_str,
908
+ }
909
+ )
910
+ update_df = pl.DataFrame(update_rows)
911
+ samples = dataset.samples.update(update_df, on=[SampleField.REMOTE_PATH])
912
+ dataset = dataset.update_samples(samples)
913
+
914
+ if not force_redownload:
915
+ download_indices = [idx for idx, local_path in enumerate(local_dst_paths) if not Path(local_path).exists()]
916
+ n_files = len(local_dst_paths)
917
+ skip_files = n_files - len(download_indices)
918
+ if skip_files > 0:
919
+ user_logger.info(
920
+ f"Found {skip_files}/{n_files} files already exists. Downloading {len(download_indices)} files."
921
+ )
922
+ remote_src_paths = [remote_src_paths[idx] for idx in download_indices]
923
+ local_dst_paths = [local_dst_paths[idx] for idx in download_indices]
924
+
925
+ if len(remote_src_paths) == 0:
926
+ user_logger.info(
927
+ "All files already exist locally. Skipping download. Set 'force_redownload=True' to re-download."
928
+ )
929
+ return dataset
930
+
931
+ environment_vars = aws_credentials.aws_credentials()
932
+ fast_copy_files_s3(
933
+ src_paths=remote_src_paths,
934
+ dst_paths=local_dst_paths,
935
+ append_envs=environment_vars,
936
+ description="Downloading images",
937
+ )
938
+ return dataset
939
+
940
+ def to_dict_dataset_splits(self) -> Dict[str, "HafniaDataset"]:
699
941
  """
700
942
  Splits the dataset into multiple datasets based on the 'split' column.
701
943
  Returns a dictionary with split names as keys and HafniaDataset objects as values.
702
944
  """
703
- if ColumnName.SPLIT not in self.samples.columns:
704
- raise ValueError(f"Dataset must contain a '{ColumnName.SPLIT}' column.")
945
+ if SampleField.SPLIT not in self.samples.columns:
946
+ raise ValueError(f"Dataset must contain a '{SampleField.SPLIT}' column.")
705
947
 
706
948
  splits = {}
707
949
  for split_name in SplitName.valid_splits():
@@ -710,20 +952,11 @@ class HafniaDataset:
710
952
  return splits
711
953
 
712
954
  def create_sample_dataset(self) -> "HafniaDataset":
713
- # Backwards compatibility. Remove in future versions when dataset have been updated
714
- if "is_sample" in self.samples.columns:
715
- user_logger.warning(
716
- "'is_sample' column found in the dataset. This column is deprecated and will be removed in future versions. "
717
- "Please use the 'tags' column with the tag 'sample' instead."
718
- )
719
- table = self.samples.filter(pl.col("is_sample") == True) # noqa: E712
720
- return self.update_samples(table)
721
-
722
- if ColumnName.TAGS not in self.samples.columns:
723
- raise ValueError(f"Dataset must contain an '{ColumnName.TAGS}' column.")
955
+ if SampleField.TAGS not in self.samples.columns:
956
+ raise ValueError(f"Dataset must contain an '{SampleField.TAGS}' column.")
724
957
 
725
958
  table = self.samples.filter(
726
- pl.col(ColumnName.TAGS).list.eval(pl.element().filter(pl.element() == TAG_IS_SAMPLE)).list.len() > 0
959
+ pl.col(SampleField.TAGS).list.eval(pl.element().filter(pl.element() == TAG_IS_SAMPLE)).list.len() > 0
727
960
  )
728
961
  return self.update_samples(table)
729
962
 
@@ -734,10 +967,10 @@ class HafniaDataset:
734
967
  split_names = split_name
735
968
 
736
969
  for name in split_names:
737
- if name not in SplitName.valid_splits():
970
+ if name not in SplitName.all_split_names():
738
971
  raise ValueError(f"Invalid split name: {split_name}. Valid splits are: {SplitName.valid_splits()}")
739
972
 
740
- filtered_dataset = self.samples.filter(pl.col(ColumnName.SPLIT).is_in(split_names))
973
+ filtered_dataset = self.samples.filter(pl.col(SampleField.SPLIT).is_in(split_names))
741
974
  return self.update_samples(filtered_dataset)
742
975
 
743
976
  def update_samples(self, table: pl.DataFrame) -> "HafniaDataset":
@@ -772,29 +1005,69 @@ class HafniaDataset:
772
1005
  def copy(self) -> "HafniaDataset":
773
1006
  return HafniaDataset(info=self.info.model_copy(deep=True), samples=self.samples.clone())
774
1007
 
1008
+ def create_primitive_table(
1009
+ self,
1010
+ primitive: Type[Primitive],
1011
+ task_name: Optional[str] = None,
1012
+ keep_sample_data: bool = False,
1013
+ ) -> pl.DataFrame:
1014
+ return table_transformations.create_primitive_table(
1015
+ samples_table=self.samples,
1016
+ PrimitiveType=primitive,
1017
+ task_name=task_name,
1018
+ keep_sample_data=keep_sample_data,
1019
+ )
1020
+
775
1021
  def write(self, path_folder: Path, add_version: bool = False, drop_null_cols: bool = True) -> None:
776
1022
  user_logger.info(f"Writing dataset to {path_folder}...")
1023
+ path_folder = path_folder.absolute()
777
1024
  if not path_folder.exists():
778
1025
  path_folder.mkdir(parents=True)
779
-
780
- new_relative_paths = []
781
- for org_path in tqdm(self.samples["file_name"].to_list(), desc="- Copy images"):
1026
+ hafnia_dataset = self.copy() # To avoid inplace modifications
1027
+ new_paths = []
1028
+ org_paths = hafnia_dataset.samples[SampleField.FILE_PATH].to_list()
1029
+ for org_path in track(org_paths, description="- Copy images"):
782
1030
  new_path = dataset_helpers.copy_and_rename_file_to_hash_value(
783
1031
  path_source=Path(org_path),
784
1032
  path_dataset_root=path_folder,
785
1033
  )
786
- new_relative_paths.append(str(new_path.relative_to(path_folder)))
787
- table = self.samples.with_columns(pl.Series(new_relative_paths).alias("file_name"))
1034
+ new_paths.append(str(new_path))
1035
+ hafnia_dataset.samples = hafnia_dataset.samples.with_columns(pl.Series(new_paths).alias(SampleField.FILE_PATH))
1036
+ hafnia_dataset.write_annotations(
1037
+ path_folder=path_folder,
1038
+ drop_null_cols=drop_null_cols,
1039
+ add_version=add_version,
1040
+ )
788
1041
 
1042
+ def write_annotations(
1043
+ dataset: HafniaDataset,
1044
+ path_folder: Path,
1045
+ drop_null_cols: bool = True,
1046
+ add_version: bool = False,
1047
+ ) -> None:
1048
+ """
1049
+ Writes only the annotations files (JSONL and Parquet) to the specified folder.
1050
+ """
1051
+ user_logger.info(f"Writing dataset annotations to {path_folder}...")
1052
+ path_folder = path_folder.absolute()
1053
+ if not path_folder.exists():
1054
+ path_folder.mkdir(parents=True)
1055
+ dataset.info.write_json(path_folder / FILENAME_DATASET_INFO)
1056
+
1057
+ samples = dataset.samples
789
1058
  if drop_null_cols: # Drops all unused/Null columns
790
- table = table.drop(pl.selectors.by_dtype(pl.Null))
1059
+ samples = samples.drop(pl.selectors.by_dtype(pl.Null))
1060
+
1061
+ # Store only relative paths in the annotations files
1062
+ absolute_paths = samples[SampleField.FILE_PATH].to_list()
1063
+ relative_paths = [str(Path(path).relative_to(path_folder)) for path in absolute_paths]
1064
+ samples = samples.with_columns(pl.Series(relative_paths).alias(SampleField.FILE_PATH))
791
1065
 
792
- table.write_ndjson(path_folder / FILENAME_ANNOTATIONS_JSONL) # Json for readability
793
- table.write_parquet(path_folder / FILENAME_ANNOTATIONS_PARQUET) # Parquet for speed
794
- self.info.write_json(path_folder / FILENAME_DATASET_INFO)
1066
+ samples.write_ndjson(path_folder / FILENAME_ANNOTATIONS_JSONL) # Json for readability
1067
+ samples.write_parquet(path_folder / FILENAME_ANNOTATIONS_PARQUET) # Parquet for speed
795
1068
 
796
1069
  if add_version:
797
- path_version = path_folder / "versions" / f"{self.info.version}"
1070
+ path_version = path_folder / "versions" / f"{dataset.info.version}"
798
1071
  path_version.mkdir(parents=True, exist_ok=True)
799
1072
  for filename in DATASET_FILENAMES_REQUIRED:
800
1073
  shutil.copy2(path_folder / filename, path_version / filename)
@@ -846,3 +1119,30 @@ def get_or_create_dataset_path_from_recipe(
846
1119
  dataset.write(path_dataset)
847
1120
 
848
1121
  return path_dataset
1122
+
1123
+
1124
+ def _dataset_corrections(samples: pl.DataFrame, dataset_info: DatasetInfo) -> Tuple[pl.DataFrame, DatasetInfo]:
1125
+ format_version_of_dataset = Version(dataset_info.format_version)
1126
+
1127
+ ## Backwards compatibility fixes for older dataset versions
1128
+ if format_version_of_dataset < Version("0.2.0"):
1129
+ samples = table_transformations.add_dataset_name_if_missing(samples, dataset_info.dataset_name)
1130
+
1131
+ if "file_name" in samples.columns:
1132
+ samples = samples.rename({"file_name": SampleField.FILE_PATH})
1133
+
1134
+ if SampleField.SAMPLE_INDEX not in samples.columns:
1135
+ samples = table_transformations.add_sample_index(samples)
1136
+
1137
+ # Backwards compatibility: If tags-column doesn't exist, create it with empty lists
1138
+ if SampleField.TAGS not in samples.columns:
1139
+ tags_column: List[List[str]] = [[] for _ in range(len(samples))] # type: ignore[annotation-unchecked]
1140
+ samples = samples.with_columns(pl.Series(tags_column, dtype=pl.List(pl.String)).alias(SampleField.TAGS))
1141
+
1142
+ if SampleField.STORAGE_FORMAT not in samples.columns:
1143
+ samples = samples.with_columns(pl.lit(StorageFormat.IMAGE).alias(SampleField.STORAGE_FORMAT))
1144
+
1145
+ if SampleField.SAMPLE_INDEX in samples.columns and samples[SampleField.SAMPLE_INDEX].dtype != pl.UInt64:
1146
+ samples = samples.cast({SampleField.SAMPLE_INDEX: pl.UInt64})
1147
+
1148
+ return samples, dataset_info