hafnia 0.3.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. cli/__main__.py +3 -1
  2. cli/config.py +43 -3
  3. cli/keychain.py +88 -0
  4. cli/profile_cmds.py +5 -2
  5. hafnia/__init__.py +1 -1
  6. hafnia/dataset/dataset_helpers.py +9 -2
  7. hafnia/dataset/dataset_names.py +130 -16
  8. hafnia/dataset/dataset_recipe/dataset_recipe.py +49 -37
  9. hafnia/dataset/dataset_recipe/recipe_transforms.py +18 -2
  10. hafnia/dataset/dataset_upload_helper.py +83 -22
  11. hafnia/dataset/format_conversions/format_image_classification_folder.py +110 -0
  12. hafnia/dataset/format_conversions/format_yolo.py +164 -0
  13. hafnia/dataset/format_conversions/torchvision_datasets.py +287 -0
  14. hafnia/dataset/hafnia_dataset.py +396 -96
  15. hafnia/dataset/operations/dataset_stats.py +84 -73
  16. hafnia/dataset/operations/dataset_transformations.py +116 -47
  17. hafnia/dataset/operations/table_transformations.py +135 -17
  18. hafnia/dataset/primitives/bbox.py +25 -14
  19. hafnia/dataset/primitives/bitmask.py +22 -15
  20. hafnia/dataset/primitives/classification.py +16 -8
  21. hafnia/dataset/primitives/point.py +7 -3
  22. hafnia/dataset/primitives/polygon.py +15 -10
  23. hafnia/dataset/primitives/primitive.py +1 -1
  24. hafnia/dataset/primitives/segmentation.py +12 -9
  25. hafnia/experiment/hafnia_logger.py +0 -9
  26. hafnia/platform/dataset_recipe.py +7 -2
  27. hafnia/platform/datasets.py +5 -9
  28. hafnia/platform/download.py +24 -90
  29. hafnia/torch_helpers.py +12 -12
  30. hafnia/utils.py +17 -0
  31. hafnia/visualizations/image_visualizations.py +3 -1
  32. {hafnia-0.3.0.dist-info → hafnia-0.4.1.dist-info}/METADATA +11 -9
  33. hafnia-0.4.1.dist-info/RECORD +57 -0
  34. hafnia-0.3.0.dist-info/RECORD +0 -53
  35. {hafnia-0.3.0.dist-info → hafnia-0.4.1.dist-info}/WHEEL +0 -0
  36. {hafnia-0.3.0.dist-info → hafnia-0.4.1.dist-info}/entry_points.txt +0 -0
  37. {hafnia-0.3.0.dist-info → hafnia-0.4.1.dist-info}/licenses/LICENSE +0 -0
@@ -4,7 +4,7 @@ import base64
4
4
  from datetime import datetime
5
5
  from enum import Enum
6
6
  from pathlib import Path
7
- from typing import Dict, List, Optional, Tuple, Type, Union
7
+ from typing import Any, Dict, List, Optional, Tuple, Type, Union
8
8
 
9
9
  import boto3
10
10
  import polars as pl
@@ -14,10 +14,10 @@ from pydantic import BaseModel, ConfigDict, field_validator
14
14
  from cli.config import Config
15
15
  from hafnia.dataset import primitives
16
16
  from hafnia.dataset.dataset_names import (
17
- ColumnName,
18
17
  DatasetVariant,
19
18
  DeploymentStage,
20
- FieldName,
19
+ PrimitiveField,
20
+ SampleField,
21
21
  SplitName,
22
22
  )
23
23
  from hafnia.dataset.hafnia_dataset import Attribution, HafniaDataset, Sample, TaskInfo
@@ -52,6 +52,7 @@ class DbDataset(BaseModel, validate_assignment=True): # type: ignore[call-arg]
52
52
  license_citation: Optional[str] = None
53
53
  version: Optional[str] = None
54
54
  s3_bucket_name: Optional[str] = None
55
+ dataset_format_version: Optional[str] = None
55
56
  annotation_date: Optional[datetime] = None
56
57
  annotation_project_id: Optional[str] = None
57
58
  annotation_dataset_id: Optional[str] = None
@@ -186,9 +187,60 @@ class EntityTypeChoices(str, Enum): # Should match `EntityTypeChoices` in `dipd
186
187
  EVENT = "EVENT"
187
188
 
188
189
 
190
+ class Annotations(BaseModel):
191
+ """
192
+ Used in 'DatasetImageMetadata' for visualizing image annotations
193
+ in gallery images on the dataset detail page.
194
+ """
195
+
196
+ bboxes: Optional[List[Bbox]] = None
197
+ classifications: Optional[List[Classification]] = None
198
+ polygons: Optional[List[Polygon]] = None
199
+ bitmasks: Optional[List[Bitmask]] = None
200
+
201
+
202
+ class DatasetImageMetadata(BaseModel):
203
+ """
204
+ Metadata for gallery images on the dataset detail page on portal.
205
+ """
206
+
207
+ annotations: Optional[Annotations] = None
208
+ meta: Optional[Dict[str, Any]] = None
209
+
210
+ @classmethod
211
+ def from_sample(cls, sample: Sample) -> "DatasetImageMetadata":
212
+ sample = sample.model_copy(deep=True)
213
+ if sample.file_path is None:
214
+ raise ValueError("Sample has no file_path defined.")
215
+ sample.file_path = "/".join(Path(sample.file_path).parts[-3:])
216
+ metadata = {}
217
+ metadata_field_names = [
218
+ SampleField.FILE_PATH,
219
+ SampleField.HEIGHT,
220
+ SampleField.WIDTH,
221
+ SampleField.SPLIT,
222
+ ]
223
+ for field_name in metadata_field_names:
224
+ if hasattr(sample, field_name) and getattr(sample, field_name) is not None:
225
+ metadata[field_name] = getattr(sample, field_name)
226
+
227
+ obj = DatasetImageMetadata(
228
+ annotations=Annotations(
229
+ bboxes=sample.bboxes,
230
+ classifications=sample.classifications,
231
+ polygons=sample.polygons,
232
+ bitmasks=sample.bitmasks,
233
+ ),
234
+ meta=metadata,
235
+ )
236
+
237
+ return obj
238
+
239
+
189
240
  class DatasetImage(Attribution, validate_assignment=True): # type: ignore[call-arg]
190
241
  img: str # Base64-encoded image string
191
242
  order: Optional[int] = None
243
+ metadata: Optional[DatasetImageMetadata] = None
192
244
 
193
245
  @field_validator("img", mode="before")
194
246
  def validate_image_path(cls, v: Union[str, Path]) -> str:
@@ -254,7 +306,7 @@ def upload_dataset_details(cfg: Config, data: str, dataset_name: str) -> dict:
254
306
  import_endpoint = f"{dataset_endpoint}/{dataset_id}/import"
255
307
  headers = {"Authorization": cfg.api_key}
256
308
 
257
- user_logger.info("Importing dataset details. This may take up to 30 seconds...")
309
+ user_logger.info("Exporting dataset details to platform. This may take up to 30 seconds...")
258
310
  response = post(endpoint=import_endpoint, headers=headers, data=data) # type: ignore[assignment]
259
311
  return response # type: ignore[return-value]
260
312
 
@@ -293,13 +345,13 @@ def calculate_distribution_values(
293
345
  classifications = dataset_split.select(pl.col(classification_column).explode())
294
346
  classifications = classifications.filter(pl.col(classification_column).is_not_null()).unnest(classification_column)
295
347
  classifications = classifications.filter(
296
- pl.col(FieldName.TASK_NAME).is_in([task.name for task in distribution_tasks])
348
+ pl.col(PrimitiveField.TASK_NAME).is_in([task.name for task in distribution_tasks])
297
349
  )
298
350
  dist_values = []
299
- for (task_name,), task_group in classifications.group_by(FieldName.TASK_NAME):
351
+ for (task_name,), task_group in classifications.group_by(PrimitiveField.TASK_NAME):
300
352
  distribution_type = DbDistributionType(name=task_name)
301
353
  n_annotated_total = len(task_group)
302
- for (class_name,), class_group in task_group.group_by(FieldName.CLASS_NAME):
354
+ for (class_name,), class_group in task_group.group_by(PrimitiveField.CLASS_NAME):
303
355
  class_count = len(class_group)
304
356
 
305
357
  dist_values.append(
@@ -333,6 +385,7 @@ def dataset_info_from_dataset(
333
385
  path_hidden: Optional[Path],
334
386
  path_gallery_images: Optional[Path] = None,
335
387
  gallery_image_names: Optional[List[str]] = None,
388
+ distribution_task_names: Optional[List[TaskInfo]] = None,
336
389
  ) -> DbDataset:
337
390
  dataset_variants = []
338
391
  dataset_reports = []
@@ -377,13 +430,15 @@ def dataset_info_from_dataset(
377
430
  )
378
431
  )
379
432
 
433
+ distribution_task_names = distribution_task_names or []
434
+ distribution_tasks = [t for t in dataset.info.tasks if t.name in distribution_task_names]
380
435
  for split_name in SplitChoices:
381
436
  split_names = SPLIT_CHOICE_MAPPING[split_name]
382
- dataset_split = dataset_variant.samples.filter(pl.col(ColumnName.SPLIT).is_in(split_names))
437
+ dataset_split = dataset_variant.samples.filter(pl.col(SampleField.SPLIT).is_in(split_names))
383
438
 
384
439
  distribution_values = calculate_distribution_values(
385
440
  dataset_split=dataset_split,
386
- distribution_tasks=dataset.info.distributions,
441
+ distribution_tasks=distribution_tasks,
387
442
  )
388
443
  report = DbSplitAnnotationsReport(
389
444
  variant_type=VARIANT_TYPE_MAPPING[variant_type], # type: ignore[index]
@@ -411,7 +466,7 @@ def dataset_info_from_dataset(
411
466
 
412
467
  annotation_type = DbAnnotationType(name=AnnotationType.ObjectDetection.value)
413
468
  for (class_name, task_name), class_group in df_per_instance.group_by(
414
- FieldName.CLASS_NAME, FieldName.TASK_NAME
469
+ PrimitiveField.CLASS_NAME, PrimitiveField.TASK_NAME
415
470
  ):
416
471
  if class_name is None:
417
472
  continue
@@ -423,10 +478,10 @@ def dataset_info_from_dataset(
423
478
  annotation_type=annotation_type,
424
479
  task_name=task_name,
425
480
  ),
426
- unique_obj_ids=class_group[FieldName.OBJECT_ID].n_unique(),
481
+ unique_obj_ids=class_group[PrimitiveField.OBJECT_ID].n_unique(),
427
482
  obj_instances=len(class_group),
428
483
  annotation_type=[annotation_type],
429
- images_with_obj=class_group[ColumnName.SAMPLE_INDEX].n_unique(),
484
+ images_with_obj=class_group[SampleField.SAMPLE_INDEX].n_unique(),
430
485
  area_avg_ratio=class_group["area"].mean(),
431
486
  area_min_ratio=class_group["area"].min(),
432
487
  area_max_ratio=class_group["area"].max(),
@@ -445,7 +500,7 @@ def dataset_info_from_dataset(
445
500
  width_avg_px=class_group["width_px"].mean(),
446
501
  width_min_px=int(class_group["width_px"].min()),
447
502
  width_max_px=int(class_group["width_px"].max()),
448
- average_count_per_image=len(class_group) / class_group[ColumnName.SAMPLE_INDEX].n_unique(),
503
+ average_count_per_image=len(class_group) / class_group[SampleField.SAMPLE_INDEX].n_unique(),
449
504
  )
450
505
  )
451
506
 
@@ -459,13 +514,13 @@ def dataset_info_from_dataset(
459
514
 
460
515
  # Include only classification tasks that are defined in the dataset info
461
516
  classification_df = classification_df.filter(
462
- pl.col(FieldName.TASK_NAME).is_in(classification_tasks)
517
+ pl.col(PrimitiveField.TASK_NAME).is_in(classification_tasks)
463
518
  )
464
519
 
465
520
  for (
466
521
  task_name,
467
522
  class_name,
468
- ), class_group in classification_df.group_by(FieldName.TASK_NAME, FieldName.CLASS_NAME):
523
+ ), class_group in classification_df.group_by(PrimitiveField.TASK_NAME, PrimitiveField.CLASS_NAME):
469
524
  if class_name is None:
470
525
  continue
471
526
  if task_name == Classification.default_task_name():
@@ -494,7 +549,7 @@ def dataset_info_from_dataset(
494
549
  if has_primitive(dataset_split, PrimitiveType=Bitmask):
495
550
  col_name = Bitmask.column_name()
496
551
  drop_columns = [col for col in primitive_columns if col != col_name]
497
- drop_columns.append(FieldName.META)
552
+ drop_columns.append(PrimitiveField.META)
498
553
 
499
554
  df_per_instance = table_transformations.create_primitive_table(
500
555
  dataset_split, PrimitiveType=Bitmask, keep_sample_data=True
@@ -512,7 +567,7 @@ def dataset_info_from_dataset(
512
567
 
513
568
  annotation_type = DbAnnotationType(name=AnnotationType.InstanceSegmentation)
514
569
  for (class_name, task_name), class_group in df_per_instance.group_by(
515
- FieldName.CLASS_NAME, FieldName.TASK_NAME
570
+ PrimitiveField.CLASS_NAME, PrimitiveField.TASK_NAME
516
571
  ):
517
572
  if class_name is None:
518
573
  continue
@@ -524,11 +579,11 @@ def dataset_info_from_dataset(
524
579
  annotation_type=annotation_type,
525
580
  task_name=task_name,
526
581
  ),
527
- unique_obj_ids=class_group[FieldName.OBJECT_ID].n_unique(),
582
+ unique_obj_ids=class_group[PrimitiveField.OBJECT_ID].n_unique(),
528
583
  obj_instances=len(class_group),
529
584
  annotation_type=[annotation_type],
530
- average_count_per_image=len(class_group) / class_group[ColumnName.SAMPLE_INDEX].n_unique(),
531
- images_with_obj=class_group[ColumnName.SAMPLE_INDEX].n_unique(),
585
+ average_count_per_image=len(class_group) / class_group[SampleField.SAMPLE_INDEX].n_unique(),
586
+ images_with_obj=class_group[SampleField.SAMPLE_INDEX].n_unique(),
532
587
  area_avg_ratio=class_group["area"].mean(),
533
588
  area_min_ratio=class_group["area"].min(),
534
589
  area_max_ratio=class_group["area"].max(),
@@ -569,7 +624,9 @@ def dataset_info_from_dataset(
569
624
  s3_bucket_name=bucket_sample,
570
625
  dataset_variants=dataset_variants,
571
626
  split_annotations_reports=dataset_reports,
572
- license_citation=dataset_meta_info.get("license_citation", None),
627
+ latest_update=dataset.info.updated_at,
628
+ dataset_format_version=dataset.info.format_version,
629
+ license_citation=dataset.info.reference_bibtex,
573
630
  data_captured_start=dataset_meta_info.get("data_captured_start", None),
574
631
  data_captured_end=dataset_meta_info.get("data_captured_end", None),
575
632
  data_received_start=dataset_meta_info.get("data_received_start", None),
@@ -594,7 +651,7 @@ def create_gallery_images(
594
651
  path_gallery_images.mkdir(parents=True, exist_ok=True)
595
652
  COL_IMAGE_NAME = "image_name"
596
653
  samples = dataset.samples.with_columns(
597
- dataset.samples[ColumnName.FILE_NAME].str.split("/").list.last().alias(COL_IMAGE_NAME)
654
+ dataset.samples[SampleField.FILE_PATH].str.split("/").list.last().alias(COL_IMAGE_NAME)
598
655
  )
599
656
  gallery_samples = samples.filter(pl.col(COL_IMAGE_NAME).is_in(gallery_image_names))
600
657
 
@@ -604,6 +661,9 @@ def create_gallery_images(
604
661
  gallery_images = []
605
662
  for gallery_sample in gallery_samples.iter_rows(named=True):
606
663
  sample = Sample(**gallery_sample)
664
+
665
+ metadata = DatasetImageMetadata.from_sample(sample=sample)
666
+ sample.classifications = None # To not draw classifications in gallery images
607
667
  image = sample.draw_annotations()
608
668
 
609
669
  path_gallery_image = path_gallery_images / gallery_sample[COL_IMAGE_NAME]
@@ -611,6 +671,7 @@ def create_gallery_images(
611
671
 
612
672
  dataset_image_dict = {
613
673
  "img": path_gallery_image,
674
+ "metadata": metadata,
614
675
  }
615
676
  if sample.attribution is not None:
616
677
  sample.attribution.changes = "Annotations have been visualized"
@@ -0,0 +1,110 @@
1
+ import shutil
2
+ from pathlib import Path
3
+ from typing import TYPE_CHECKING, List, Optional
4
+
5
+ import more_itertools
6
+ import polars as pl
7
+ from PIL import Image
8
+ from rich.progress import track
9
+
10
+ from hafnia.dataset.dataset_names import PrimitiveField, SampleField
11
+ from hafnia.dataset.primitives import Classification
12
+ from hafnia.utils import is_image_file
13
+
14
+ if TYPE_CHECKING:
15
+ from hafnia.dataset.hafnia_dataset import HafniaDataset
16
+
17
+
18
+ def from_image_classification_folder(
19
+ path_folder: Path,
20
+ split: str,
21
+ n_samples: Optional[int] = None,
22
+ ) -> "HafniaDataset":
23
+ from hafnia.dataset.hafnia_dataset import DatasetInfo, HafniaDataset, Sample, TaskInfo
24
+
25
+ class_folder_paths = [path for path in path_folder.iterdir() if path.is_dir()]
26
+ class_names = sorted([folder.name for folder in class_folder_paths]) # Sort for determinism
27
+
28
+ # Gather all image paths per class
29
+ path_images_per_class: List[List[Path]] = []
30
+ for path_class_folder in class_folder_paths:
31
+ per_class_images = []
32
+ for path_image in list(path_class_folder.rglob("*.*")):
33
+ if is_image_file(path_image):
34
+ per_class_images.append(path_image)
35
+ path_images_per_class.append(sorted(per_class_images))
36
+
37
+ # Interleave to ensure classes are balanced in the output dataset for n_samples < total
38
+ path_images = list(more_itertools.interleave_longest(*path_images_per_class))
39
+
40
+ if n_samples is not None:
41
+ path_images = path_images[:n_samples]
42
+
43
+ samples = []
44
+ for path_image_org in track(path_images, description="Convert 'image classification' dataset to Hafnia Dataset"):
45
+ class_name = path_image_org.parent.name
46
+
47
+ read_image = Image.open(path_image_org)
48
+ width, height = read_image.size
49
+
50
+ classifications = [Classification(class_name=class_name, class_idx=class_names.index(class_name))]
51
+ sample = Sample(
52
+ file_path=str(path_image_org.absolute()),
53
+ width=width,
54
+ height=height,
55
+ split=split,
56
+ classifications=classifications,
57
+ )
58
+ samples.append(sample)
59
+
60
+ dataset_info = DatasetInfo(
61
+ dataset_name="ImageClassificationFromDirectoryTree",
62
+ tasks=[TaskInfo(primitive=Classification, class_names=class_names)],
63
+ )
64
+
65
+ hafnia_dataset = HafniaDataset.from_samples_list(samples_list=samples, info=dataset_info)
66
+ return hafnia_dataset
67
+
68
+
69
+ def to_image_classification_folder(
70
+ dataset: "HafniaDataset",
71
+ path_output: Path,
72
+ task_name: Optional[str] = None,
73
+ clean_folder: bool = False,
74
+ ) -> Path:
75
+ task = dataset.info.get_task_by_task_name_and_primitive(task_name=task_name, primitive=Classification)
76
+
77
+ samples = dataset.samples.with_columns(
78
+ pl.col(task.primitive.column_name())
79
+ .list.filter(pl.element().struct.field(PrimitiveField.TASK_NAME) == task.name)
80
+ .alias(task.primitive.column_name())
81
+ )
82
+
83
+ classification_counts = samples[task.primitive.column_name()].list.len()
84
+ has_no_classification_samples = (classification_counts == 0).sum()
85
+ if has_no_classification_samples > 0:
86
+ raise ValueError(f"Some samples do not have a classification for task '{task.name}'.")
87
+
88
+ has_multi_classification_samples = (classification_counts > 1).sum()
89
+ if has_multi_classification_samples > 0:
90
+ raise ValueError(f"Some samples have multiple classifications for task '{task.name}'.")
91
+
92
+ if clean_folder:
93
+ shutil.rmtree(path_output, ignore_errors=True)
94
+ path_output.mkdir(parents=True, exist_ok=True)
95
+
96
+ description = "Export Hafnia Dataset to directory tree"
97
+ for sample_dict in track(samples.iter_rows(named=True), total=len(samples), description=description):
98
+ classifications = sample_dict[task.primitive.column_name()]
99
+ if len(classifications) != 1:
100
+ raise ValueError("Each sample should have exactly one classification.")
101
+ classification = classifications[0]
102
+ class_name = classification[PrimitiveField.CLASS_NAME].replace("/", "_") # Avoid issues with subfolders
103
+ path_class_folder = path_output / class_name
104
+ path_class_folder.mkdir(parents=True, exist_ok=True)
105
+
106
+ path_image_org = Path(sample_dict[SampleField.FILE_PATH])
107
+ path_image_new = path_class_folder / path_image_org.name
108
+ shutil.copy2(path_image_org, path_image_new)
109
+
110
+ return path_output
@@ -0,0 +1,164 @@
1
+ import shutil
2
+ from pathlib import Path
3
+ from typing import TYPE_CHECKING, List, Optional
4
+
5
+ from PIL import Image
6
+ from rich.progress import track
7
+
8
+ from hafnia.dataset import primitives
9
+ from hafnia.dataset.dataset_names import SplitName
10
+
11
+ if TYPE_CHECKING:
12
+ from hafnia.dataset.hafnia_dataset import HafniaDataset
13
+
14
+ FILENAME_YOLO_CLASS_NAMES = "obj.names"
15
+ FILENAME_YOLO_IMAGES_TXT = "images.txt"
16
+
17
+
18
+ def get_image_size(path: Path) -> tuple[int, int]:
19
+ with Image.open(path) as img:
20
+ return img.size # (width, height)
21
+
22
+
23
+ def from_yolo_format(
24
+ path_yolo_dataset: Path,
25
+ split_name: str = SplitName.UNDEFINED,
26
+ dataset_name: str = "yolo-dataset",
27
+ filename_class_names: str = FILENAME_YOLO_CLASS_NAMES,
28
+ filename_images_txt: str = FILENAME_YOLO_IMAGES_TXT,
29
+ ) -> "HafniaDataset":
30
+ """
31
+ Imports a YOLO (Darknet) formatted dataset as a HafniaDataset.
32
+ """
33
+ from hafnia.dataset.hafnia_dataset import DatasetInfo, HafniaDataset, Sample, TaskInfo
34
+
35
+ path_class_names = path_yolo_dataset / filename_class_names
36
+
37
+ if split_name not in SplitName.all_split_names():
38
+ raise ValueError(f"Invalid split name: {split_name}. Must be one of {SplitName.all_split_names()}")
39
+
40
+ if not path_class_names.exists():
41
+ raise FileNotFoundError(f"File with class names not found at '{path_class_names.resolve()}'.")
42
+
43
+ class_names_text = path_class_names.read_text()
44
+ if class_names_text.strip() == "":
45
+ raise ValueError(f"File with class names not found at '{path_class_names.resolve()}' is empty")
46
+
47
+ class_names = [class_name for class_name in class_names_text.splitlines() if class_name.strip() != ""]
48
+
49
+ if len(class_names) == 0:
50
+ raise ValueError(f"File with class names not found at '{path_class_names.resolve()}' has no class names")
51
+
52
+ path_images_txt = path_yolo_dataset / filename_images_txt
53
+
54
+ if not path_images_txt.exists():
55
+ raise FileNotFoundError(f"File with images not found at '{path_images_txt.resolve()}'")
56
+
57
+ images_txt_text = path_images_txt.read_text()
58
+ if len(images_txt_text.strip()) == 0:
59
+ raise ValueError(f"File is empty at '{path_images_txt.resolve()}'")
60
+
61
+ image_paths_raw = [line.strip() for line in images_txt_text.splitlines()]
62
+
63
+ samples: List[Sample] = []
64
+ for image_path_raw in track(image_paths_raw):
65
+ path_image = path_yolo_dataset / image_path_raw
66
+ if not path_image.exists():
67
+ raise FileNotFoundError(f"File with image not found at '{path_image.resolve()}'")
68
+ width, height = get_image_size(path_image)
69
+
70
+ path_label = path_image.with_suffix(".txt")
71
+ if not path_label.exists():
72
+ raise FileNotFoundError(f"File with labels not found at '{path_label.resolve()}'")
73
+
74
+ boxes: List[primitives.Bbox] = []
75
+ bbox_strings = path_label.read_text().splitlines()
76
+ for bbox_string in bbox_strings:
77
+ parts = bbox_string.strip().split()
78
+ if len(parts) != 5:
79
+ raise ValueError(f"Invalid bbox format in file {path_label.resolve()}: {bbox_string}")
80
+
81
+ class_idx = int(parts[0])
82
+ x_center, y_center, bbox_width, bbox_height = (float(value) for value in parts[1:5])
83
+
84
+ top_left_x = x_center - bbox_width / 2
85
+ top_left_y = y_center - bbox_height / 2
86
+
87
+ bbox = primitives.Bbox(
88
+ top_left_x=top_left_x,
89
+ top_left_y=top_left_y,
90
+ width=bbox_width,
91
+ height=bbox_height,
92
+ class_idx=class_idx,
93
+ class_name=class_names[class_idx] if 0 <= class_idx < len(class_names) else None,
94
+ )
95
+ boxes.append(bbox)
96
+
97
+ sample = Sample(
98
+ file_path=path_image.absolute().as_posix(),
99
+ height=height,
100
+ width=width,
101
+ split=split_name,
102
+ bboxes=boxes,
103
+ )
104
+ samples.append(sample)
105
+
106
+ tasks = [TaskInfo(primitive=primitives.Bbox, class_names=class_names)]
107
+ info = DatasetInfo(dataset_name=dataset_name, tasks=tasks)
108
+ hafnia_dataset = HafniaDataset.from_samples_list(samples, info=info)
109
+ return hafnia_dataset
110
+
111
+
112
+ def to_yolo_format(
113
+ dataset: "HafniaDataset",
114
+ path_export_yolo_dataset: Path,
115
+ task_name: Optional[str] = None,
116
+ ):
117
+ """Exports a HafniaDataset as YOLO (Darknet) format."""
118
+ from hafnia.dataset.hafnia_dataset import Sample
119
+
120
+ bbox_task = dataset.info.get_task_by_task_name_and_primitive(task_name=task_name, primitive=primitives.Bbox)
121
+
122
+ class_names = bbox_task.class_names or []
123
+ if len(class_names) == 0:
124
+ raise ValueError(
125
+ f"Hafnia dataset task '{bbox_task.name}' has no class names defined. This is required for YOLO export."
126
+ )
127
+ path_export_yolo_dataset.mkdir(parents=True, exist_ok=True)
128
+ path_class_names = path_export_yolo_dataset / FILENAME_YOLO_CLASS_NAMES
129
+ path_class_names.write_text("\n".join(class_names))
130
+
131
+ path_data_folder = path_export_yolo_dataset / "data"
132
+ path_data_folder.mkdir(parents=True, exist_ok=True)
133
+ image_paths: List[str] = []
134
+ for sample_dict in dataset:
135
+ sample = Sample(**sample_dict)
136
+ if sample.file_path is None:
137
+ raise ValueError("Sample has no file_path defined.")
138
+ path_image_src = Path(sample.file_path)
139
+ path_image_dst = path_data_folder / path_image_src.name
140
+ shutil.copy2(path_image_src, path_image_dst)
141
+ image_paths.append(path_image_dst.relative_to(path_export_yolo_dataset).as_posix())
142
+ path_label = path_image_dst.with_suffix(".txt")
143
+ bboxes = sample.bboxes or []
144
+ bbox_strings = [bbox_to_yolo_format(bbox) for bbox in bboxes]
145
+ path_label.write_text("\n".join(bbox_strings))
146
+
147
+ path_images_txt = path_export_yolo_dataset / FILENAME_YOLO_IMAGES_TXT
148
+ path_images_txt.write_text("\n".join(image_paths))
149
+
150
+
151
+ def bbox_to_yolo_format(bbox: primitives.Bbox) -> str:
152
+ """
153
+ From hafnia bbox to yolo bbox string conversion
154
+ Both yolo and hafnia use normalized coordinates [0, 1]
155
+ Hafnia: top_left_x, top_left_y, width, height
156
+ Yolo (darknet): "<object-class> <x_center> <y_center> <width> <height>"
157
+ Example (3 bounding boxes):
158
+ 1 0.716797 0.395833 0.216406 0.147222
159
+ 0 0.687109 0.379167 0.255469 0.158333
160
+ 1 0.420312 0.395833 0.140625 0.166667
161
+ """
162
+ x_center = bbox.top_left_x + bbox.width / 2
163
+ y_center = bbox.top_left_y + bbox.height / 2
164
+ return f"{bbox.class_idx} {x_center} {y_center} {bbox.width} {bbox.height}"