hafnia 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
hafnia/__init__.py CHANGED
@@ -3,4 +3,4 @@ from importlib.metadata import version
3
3
  __package_name__ = "hafnia"
4
4
  __version__ = version(__package_name__)
5
5
 
6
- __dataset_format_version__ = "0.1.0" # Hafnia dataset format version
6
+ __dataset_format_version__ = "0.2.0" # Hafnia dataset format version
@@ -1,5 +1,8 @@
1
1
  from enum import Enum
2
- from typing import List
2
+ from typing import Dict, List, Optional
3
+
4
+ import boto3
5
+ from pydantic import BaseModel, field_validator
3
6
 
4
7
  FILENAME_RECIPE_JSON = "recipe.json"
5
8
  FILENAME_DATASET_INFO = "dataset_info.json"
@@ -23,7 +26,7 @@ TAG_IS_SAMPLE = "sample"
23
26
  OPS_REMOVE_CLASS = "__REMOVE__"
24
27
 
25
28
 
26
- class FieldName:
29
+ class PrimitiveField:
27
30
  CLASS_NAME: str = "class_name" # Name of the class this primitive is associated with, e.g. "car" for Bbox
28
31
  CLASS_IDX: str = "class_idx" # Index of the class this primitive is associated with, e.g. 0 for "car" if it is the first class # noqa: E501
29
32
  OBJECT_ID: str = "object_id" # Unique identifier for the object, e.g. "12345123"
@@ -38,40 +41,150 @@ class FieldName:
38
41
  Returns a list of expected field names for primitives.
39
42
  """
40
43
  return [
41
- FieldName.CLASS_NAME,
42
- FieldName.CLASS_IDX,
43
- FieldName.OBJECT_ID,
44
- FieldName.CONFIDENCE,
45
- FieldName.META,
46
- FieldName.TASK_NAME,
44
+ PrimitiveField.CLASS_NAME,
45
+ PrimitiveField.CLASS_IDX,
46
+ PrimitiveField.OBJECT_ID,
47
+ PrimitiveField.CONFIDENCE,
48
+ PrimitiveField.META,
49
+ PrimitiveField.TASK_NAME,
47
50
  ]
48
51
 
49
52
 
50
- class ColumnName:
51
- SAMPLE_INDEX: str = "sample_index"
53
+ class SampleField:
52
54
  FILE_PATH: str = "file_path"
53
55
  HEIGHT: str = "height"
54
56
  WIDTH: str = "width"
55
57
  SPLIT: str = "split"
58
+ TAGS: str = "tags"
59
+
60
+ CLASSIFICATIONS: str = "classifications"
61
+ BBOXES: str = "bboxes"
62
+ BITMASKS: str = "bitmasks"
63
+ POLYGONS: str = "polygons"
64
+
65
+ STORAGE_FORMAT: str = "storage_format" # E.g. "image", "video", "zip"
66
+ COLLECTION_INDEX: str = "collection_index"
67
+ COLLECTION_ID: str = "collection_id"
56
68
  REMOTE_PATH: str = "remote_path" # Path to the file in remote storage, e.g. S3
69
+ SAMPLE_INDEX: str = "sample_index"
70
+
57
71
  ATTRIBUTION: str = "attribution" # Attribution for the sample (image/video), e.g. creator, license, source, etc.
58
- TAGS: str = "tags"
59
72
  META: str = "meta"
60
73
  DATASET_NAME: str = "dataset_name"
61
74
 
62
75
 
76
+ class StorageFormat:
77
+ IMAGE: str = "image"
78
+ VIDEO: str = "video"
79
+ ZIP: str = "zip"
80
+
81
+
63
82
  class SplitName:
64
- TRAIN = "train"
65
- VAL = "validation"
66
- TEST = "test"
67
- UNDEFINED = "UNDEFINED"
83
+ TRAIN: str = "train"
84
+ VAL: str = "validation"
85
+ TEST: str = "test"
86
+ UNDEFINED: str = "UNDEFINED"
68
87
 
69
88
  @staticmethod
70
89
  def valid_splits() -> List[str]:
71
90
  return [SplitName.TRAIN, SplitName.VAL, SplitName.TEST]
72
91
 
92
+ @staticmethod
93
+ def all_split_names() -> List[str]:
94
+ return [*SplitName.valid_splits(), SplitName.UNDEFINED]
95
+
73
96
 
74
97
  class DatasetVariant(Enum):
75
98
  DUMP = "dump"
76
99
  SAMPLE = "sample"
77
100
  HIDDEN = "hidden"
101
+
102
+
103
+ class AwsCredentials(BaseModel):
104
+ access_key: str
105
+ secret_key: str
106
+ session_token: str
107
+ region: Optional[str]
108
+
109
+ def aws_credentials(self) -> Dict[str, str]:
110
+ """
111
+ Returns the AWS credentials as a dictionary.
112
+ """
113
+ environment_vars = {
114
+ "AWS_ACCESS_KEY_ID": self.access_key,
115
+ "AWS_SECRET_ACCESS_KEY": self.secret_key,
116
+ "AWS_SESSION_TOKEN": self.session_token,
117
+ }
118
+ if self.region:
119
+ environment_vars["AWS_REGION"] = self.region
120
+
121
+ return environment_vars
122
+
123
+ @staticmethod
124
+ def from_session(session: boto3.Session) -> "AwsCredentials":
125
+ """
126
+ Creates AwsCredentials from a Boto3 session.
127
+ """
128
+ frozen_credentials = session.get_credentials().get_frozen_credentials()
129
+ return AwsCredentials(
130
+ access_key=frozen_credentials.access_key,
131
+ secret_key=frozen_credentials.secret_key,
132
+ session_token=frozen_credentials.token,
133
+ region=session.region_name,
134
+ )
135
+
136
+
137
+ ARN_PREFIX = "arn:aws:s3:::"
138
+
139
+
140
+ class ResourceCredentials(AwsCredentials):
141
+ s3_arn: str
142
+
143
+ @staticmethod
144
+ def fix_naming(payload: Dict[str, str]) -> "ResourceCredentials":
145
+ """
146
+ The endpoint returns a payload with a key called 's3_path', but it
147
+ is actually an ARN path (starts with arn:aws:s3::). This method renames it to 's3_arn' for consistency.
148
+ """
149
+ if "s3_path" in payload and payload["s3_path"].startswith(ARN_PREFIX):
150
+ payload["s3_arn"] = payload.pop("s3_path")
151
+
152
+ if "region" not in payload:
153
+ payload["region"] = "eu-west-1"
154
+ return ResourceCredentials(**payload)
155
+
156
+ @field_validator("s3_arn")
157
+ @classmethod
158
+ def validate_s3_arn(cls, value: str) -> str:
159
+ """Validate s3_arn to ensure it starts with 'arn:aws:s3:::'"""
160
+ if not value.startswith("arn:aws:s3:::"):
161
+ raise ValueError(f"Invalid S3 ARN: {value}. It should start with 'arn:aws:s3:::'")
162
+ return value
163
+
164
+ def s3_path(self) -> str:
165
+ """
166
+ Extracts the S3 path from the ARN.
167
+ Example: arn:aws:s3:::my-bucket/my-prefix -> my-bucket/my-prefix
168
+ """
169
+ return self.s3_arn[len(ARN_PREFIX) :]
170
+
171
+ def s3_uri(self) -> str:
172
+ """
173
+ Converts the S3 ARN to a URI format.
174
+ Example: arn:aws:s3:::my-bucket/my-prefix -> s3://my-bucket/my-prefix
175
+ """
176
+ return f"s3://{self.s3_path()}"
177
+
178
+ def bucket_name(self) -> str:
179
+ """
180
+ Extracts the bucket name from the S3 ARN.
181
+ Example: arn:aws:s3:::my-bucket/my-prefix -> my-bucket
182
+ """
183
+ return self.s3_path().split("/")[0]
184
+
185
+ def object_key(self) -> str:
186
+ """
187
+ Extracts the object key from the S3 ARN.
188
+ Example: arn:aws:s3:::my-bucket/my-prefix -> my-prefix
189
+ """
190
+ return "/".join(self.s3_path().split("/")[1:])
@@ -14,10 +14,10 @@ from pydantic import BaseModel, ConfigDict, field_validator
14
14
  from cli.config import Config
15
15
  from hafnia.dataset import primitives
16
16
  from hafnia.dataset.dataset_names import (
17
- ColumnName,
18
17
  DatasetVariant,
19
18
  DeploymentStage,
20
- FieldName,
19
+ PrimitiveField,
20
+ SampleField,
21
21
  SplitName,
22
22
  )
23
23
  from hafnia.dataset.hafnia_dataset import Attribution, HafniaDataset, Sample, TaskInfo
@@ -193,7 +193,7 @@ class Annotations(BaseModel):
193
193
  in gallery images on the dataset detail page.
194
194
  """
195
195
 
196
- objects: Optional[List[Bbox]] = None
196
+ bboxes: Optional[List[Bbox]] = None
197
197
  classifications: Optional[List[Classification]] = None
198
198
  polygons: Optional[List[Polygon]] = None
199
199
  bitmasks: Optional[List[Bitmask]] = None
@@ -210,13 +210,15 @@ class DatasetImageMetadata(BaseModel):
210
210
  @classmethod
211
211
  def from_sample(cls, sample: Sample) -> "DatasetImageMetadata":
212
212
  sample = sample.model_copy(deep=True)
213
+ if sample.file_path is None:
214
+ raise ValueError("Sample has no file_path defined.")
213
215
  sample.file_path = "/".join(Path(sample.file_path).parts[-3:])
214
216
  metadata = {}
215
217
  metadata_field_names = [
216
- ColumnName.FILE_PATH,
217
- ColumnName.HEIGHT,
218
- ColumnName.WIDTH,
219
- ColumnName.SPLIT,
218
+ SampleField.FILE_PATH,
219
+ SampleField.HEIGHT,
220
+ SampleField.WIDTH,
221
+ SampleField.SPLIT,
220
222
  ]
221
223
  for field_name in metadata_field_names:
222
224
  if hasattr(sample, field_name) and getattr(sample, field_name) is not None:
@@ -224,7 +226,7 @@ class DatasetImageMetadata(BaseModel):
224
226
 
225
227
  obj = DatasetImageMetadata(
226
228
  annotations=Annotations(
227
- objects=sample.objects,
229
+ bboxes=sample.bboxes,
228
230
  classifications=sample.classifications,
229
231
  polygons=sample.polygons,
230
232
  bitmasks=sample.bitmasks,
@@ -343,13 +345,13 @@ def calculate_distribution_values(
343
345
  classifications = dataset_split.select(pl.col(classification_column).explode())
344
346
  classifications = classifications.filter(pl.col(classification_column).is_not_null()).unnest(classification_column)
345
347
  classifications = classifications.filter(
346
- pl.col(FieldName.TASK_NAME).is_in([task.name for task in distribution_tasks])
348
+ pl.col(PrimitiveField.TASK_NAME).is_in([task.name for task in distribution_tasks])
347
349
  )
348
350
  dist_values = []
349
- for (task_name,), task_group in classifications.group_by(FieldName.TASK_NAME):
351
+ for (task_name,), task_group in classifications.group_by(PrimitiveField.TASK_NAME):
350
352
  distribution_type = DbDistributionType(name=task_name)
351
353
  n_annotated_total = len(task_group)
352
- for (class_name,), class_group in task_group.group_by(FieldName.CLASS_NAME):
354
+ for (class_name,), class_group in task_group.group_by(PrimitiveField.CLASS_NAME):
353
355
  class_count = len(class_group)
354
356
 
355
357
  dist_values.append(
@@ -383,6 +385,7 @@ def dataset_info_from_dataset(
383
385
  path_hidden: Optional[Path],
384
386
  path_gallery_images: Optional[Path] = None,
385
387
  gallery_image_names: Optional[List[str]] = None,
388
+ distribution_task_names: Optional[List[TaskInfo]] = None,
386
389
  ) -> DbDataset:
387
390
  dataset_variants = []
388
391
  dataset_reports = []
@@ -427,13 +430,15 @@ def dataset_info_from_dataset(
427
430
  )
428
431
  )
429
432
 
433
+ distribution_task_names = distribution_task_names or []
434
+ distribution_tasks = [t for t in dataset.info.tasks if t.name in distribution_task_names]
430
435
  for split_name in SplitChoices:
431
436
  split_names = SPLIT_CHOICE_MAPPING[split_name]
432
- dataset_split = dataset_variant.samples.filter(pl.col(ColumnName.SPLIT).is_in(split_names))
437
+ dataset_split = dataset_variant.samples.filter(pl.col(SampleField.SPLIT).is_in(split_names))
433
438
 
434
439
  distribution_values = calculate_distribution_values(
435
440
  dataset_split=dataset_split,
436
- distribution_tasks=dataset.info.distributions,
441
+ distribution_tasks=distribution_tasks,
437
442
  )
438
443
  report = DbSplitAnnotationsReport(
439
444
  variant_type=VARIANT_TYPE_MAPPING[variant_type], # type: ignore[index]
@@ -461,7 +466,7 @@ def dataset_info_from_dataset(
461
466
 
462
467
  annotation_type = DbAnnotationType(name=AnnotationType.ObjectDetection.value)
463
468
  for (class_name, task_name), class_group in df_per_instance.group_by(
464
- FieldName.CLASS_NAME, FieldName.TASK_NAME
469
+ PrimitiveField.CLASS_NAME, PrimitiveField.TASK_NAME
465
470
  ):
466
471
  if class_name is None:
467
472
  continue
@@ -473,10 +478,10 @@ def dataset_info_from_dataset(
473
478
  annotation_type=annotation_type,
474
479
  task_name=task_name,
475
480
  ),
476
- unique_obj_ids=class_group[FieldName.OBJECT_ID].n_unique(),
481
+ unique_obj_ids=class_group[PrimitiveField.OBJECT_ID].n_unique(),
477
482
  obj_instances=len(class_group),
478
483
  annotation_type=[annotation_type],
479
- images_with_obj=class_group[ColumnName.SAMPLE_INDEX].n_unique(),
484
+ images_with_obj=class_group[SampleField.SAMPLE_INDEX].n_unique(),
480
485
  area_avg_ratio=class_group["area"].mean(),
481
486
  area_min_ratio=class_group["area"].min(),
482
487
  area_max_ratio=class_group["area"].max(),
@@ -495,7 +500,7 @@ def dataset_info_from_dataset(
495
500
  width_avg_px=class_group["width_px"].mean(),
496
501
  width_min_px=int(class_group["width_px"].min()),
497
502
  width_max_px=int(class_group["width_px"].max()),
498
- average_count_per_image=len(class_group) / class_group[ColumnName.SAMPLE_INDEX].n_unique(),
503
+ average_count_per_image=len(class_group) / class_group[SampleField.SAMPLE_INDEX].n_unique(),
499
504
  )
500
505
  )
501
506
 
@@ -509,13 +514,13 @@ def dataset_info_from_dataset(
509
514
 
510
515
  # Include only classification tasks that are defined in the dataset info
511
516
  classification_df = classification_df.filter(
512
- pl.col(FieldName.TASK_NAME).is_in(classification_tasks)
517
+ pl.col(PrimitiveField.TASK_NAME).is_in(classification_tasks)
513
518
  )
514
519
 
515
520
  for (
516
521
  task_name,
517
522
  class_name,
518
- ), class_group in classification_df.group_by(FieldName.TASK_NAME, FieldName.CLASS_NAME):
523
+ ), class_group in classification_df.group_by(PrimitiveField.TASK_NAME, PrimitiveField.CLASS_NAME):
519
524
  if class_name is None:
520
525
  continue
521
526
  if task_name == Classification.default_task_name():
@@ -544,7 +549,7 @@ def dataset_info_from_dataset(
544
549
  if has_primitive(dataset_split, PrimitiveType=Bitmask):
545
550
  col_name = Bitmask.column_name()
546
551
  drop_columns = [col for col in primitive_columns if col != col_name]
547
- drop_columns.append(FieldName.META)
552
+ drop_columns.append(PrimitiveField.META)
548
553
 
549
554
  df_per_instance = table_transformations.create_primitive_table(
550
555
  dataset_split, PrimitiveType=Bitmask, keep_sample_data=True
@@ -562,7 +567,7 @@ def dataset_info_from_dataset(
562
567
 
563
568
  annotation_type = DbAnnotationType(name=AnnotationType.InstanceSegmentation)
564
569
  for (class_name, task_name), class_group in df_per_instance.group_by(
565
- FieldName.CLASS_NAME, FieldName.TASK_NAME
570
+ PrimitiveField.CLASS_NAME, PrimitiveField.TASK_NAME
566
571
  ):
567
572
  if class_name is None:
568
573
  continue
@@ -574,11 +579,11 @@ def dataset_info_from_dataset(
574
579
  annotation_type=annotation_type,
575
580
  task_name=task_name,
576
581
  ),
577
- unique_obj_ids=class_group[FieldName.OBJECT_ID].n_unique(),
582
+ unique_obj_ids=class_group[PrimitiveField.OBJECT_ID].n_unique(),
578
583
  obj_instances=len(class_group),
579
584
  annotation_type=[annotation_type],
580
- average_count_per_image=len(class_group) / class_group[ColumnName.SAMPLE_INDEX].n_unique(),
581
- images_with_obj=class_group[ColumnName.SAMPLE_INDEX].n_unique(),
585
+ average_count_per_image=len(class_group) / class_group[SampleField.SAMPLE_INDEX].n_unique(),
586
+ images_with_obj=class_group[SampleField.SAMPLE_INDEX].n_unique(),
582
587
  area_avg_ratio=class_group["area"].mean(),
583
588
  area_min_ratio=class_group["area"].min(),
584
589
  area_max_ratio=class_group["area"].max(),
@@ -646,7 +651,7 @@ def create_gallery_images(
646
651
  path_gallery_images.mkdir(parents=True, exist_ok=True)
647
652
  COL_IMAGE_NAME = "image_name"
648
653
  samples = dataset.samples.with_columns(
649
- dataset.samples[ColumnName.FILE_PATH].str.split("/").list.last().alias(COL_IMAGE_NAME)
654
+ dataset.samples[SampleField.FILE_PATH].str.split("/").list.last().alias(COL_IMAGE_NAME)
650
655
  )
651
656
  gallery_samples = samples.filter(pl.col(COL_IMAGE_NAME).is_in(gallery_image_names))
652
657
 
@@ -1,23 +1,27 @@
1
1
  import shutil
2
2
  from pathlib import Path
3
- from typing import List, Optional
3
+ from typing import TYPE_CHECKING, List, Optional
4
4
 
5
5
  import more_itertools
6
6
  import polars as pl
7
7
  from PIL import Image
8
8
  from rich.progress import track
9
9
 
10
- from hafnia.dataset.dataset_names import ColumnName, FieldName
11
- from hafnia.dataset.hafnia_dataset import DatasetInfo, HafniaDataset, Sample, TaskInfo
10
+ from hafnia.dataset.dataset_names import PrimitiveField, SampleField
12
11
  from hafnia.dataset.primitives import Classification
13
12
  from hafnia.utils import is_image_file
14
13
 
14
+ if TYPE_CHECKING:
15
+ from hafnia.dataset.hafnia_dataset import HafniaDataset
15
16
 
16
- def import_image_classification_directory_tree(
17
+
18
+ def from_image_classification_folder(
17
19
  path_folder: Path,
18
20
  split: str,
19
21
  n_samples: Optional[int] = None,
20
- ) -> HafniaDataset:
22
+ ) -> "HafniaDataset":
23
+ from hafnia.dataset.hafnia_dataset import DatasetInfo, HafniaDataset, Sample, TaskInfo
24
+
21
25
  class_folder_paths = [path for path in path_folder.iterdir() if path.is_dir()]
22
26
  class_names = sorted([folder.name for folder in class_folder_paths]) # Sort for determinism
23
27
 
@@ -62,8 +66,8 @@ def import_image_classification_directory_tree(
62
66
  return hafnia_dataset
63
67
 
64
68
 
65
- def export_image_classification_directory_tree(
66
- dataset: HafniaDataset,
69
+ def to_image_classification_folder(
70
+ dataset: "HafniaDataset",
67
71
  path_output: Path,
68
72
  task_name: Optional[str] = None,
69
73
  clean_folder: bool = False,
@@ -72,7 +76,7 @@ def export_image_classification_directory_tree(
72
76
 
73
77
  samples = dataset.samples.with_columns(
74
78
  pl.col(task.primitive.column_name())
75
- .list.filter(pl.element().struct.field(FieldName.TASK_NAME) == task.name)
79
+ .list.filter(pl.element().struct.field(PrimitiveField.TASK_NAME) == task.name)
76
80
  .alias(task.primitive.column_name())
77
81
  )
78
82
 
@@ -95,11 +99,11 @@ def export_image_classification_directory_tree(
95
99
  if len(classifications) != 1:
96
100
  raise ValueError("Each sample should have exactly one classification.")
97
101
  classification = classifications[0]
98
- class_name = classification[FieldName.CLASS_NAME].replace("/", "_") # Avoid issues with subfolders
102
+ class_name = classification[PrimitiveField.CLASS_NAME].replace("/", "_") # Avoid issues with subfolders
99
103
  path_class_folder = path_output / class_name
100
104
  path_class_folder.mkdir(parents=True, exist_ok=True)
101
105
 
102
- path_image_org = Path(sample_dict[ColumnName.FILE_PATH])
106
+ path_image_org = Path(sample_dict[SampleField.FILE_PATH])
103
107
  path_image_new = path_class_folder / path_image_org.name
104
108
  shutil.copy2(path_image_org, path_image_new)
105
109
 
@@ -0,0 +1,164 @@
1
+ import shutil
2
+ from pathlib import Path
3
+ from typing import TYPE_CHECKING, List, Optional
4
+
5
+ from PIL import Image
6
+ from rich.progress import track
7
+
8
+ from hafnia.dataset import primitives
9
+ from hafnia.dataset.dataset_names import SplitName
10
+
11
+ if TYPE_CHECKING:
12
+ from hafnia.dataset.hafnia_dataset import HafniaDataset
13
+
14
+ FILENAME_YOLO_CLASS_NAMES = "obj.names"
15
+ FILENAME_YOLO_IMAGES_TXT = "images.txt"
16
+
17
+
18
+ def get_image_size(path: Path) -> tuple[int, int]:
19
+ with Image.open(path) as img:
20
+ return img.size # (width, height)
21
+
22
+
23
+ def from_yolo_format(
24
+ path_yolo_dataset: Path,
25
+ split_name: str = SplitName.UNDEFINED,
26
+ dataset_name: str = "yolo-dataset",
27
+ filename_class_names: str = FILENAME_YOLO_CLASS_NAMES,
28
+ filename_images_txt: str = FILENAME_YOLO_IMAGES_TXT,
29
+ ) -> "HafniaDataset":
30
+ """
31
+ Imports a YOLO (Darknet) formatted dataset as a HafniaDataset.
32
+ """
33
+ from hafnia.dataset.hafnia_dataset import DatasetInfo, HafniaDataset, Sample, TaskInfo
34
+
35
+ path_class_names = path_yolo_dataset / filename_class_names
36
+
37
+ if split_name not in SplitName.all_split_names():
38
+ raise ValueError(f"Invalid split name: {split_name}. Must be one of {SplitName.all_split_names()}")
39
+
40
+ if not path_class_names.exists():
41
+ raise FileNotFoundError(f"File with class names not found at '{path_class_names.resolve()}'.")
42
+
43
+ class_names_text = path_class_names.read_text()
44
+ if class_names_text.strip() == "":
45
+ raise ValueError(f"File with class names not found at '{path_class_names.resolve()}' is empty")
46
+
47
+ class_names = [class_name for class_name in class_names_text.splitlines() if class_name.strip() != ""]
48
+
49
+ if len(class_names) == 0:
50
+ raise ValueError(f"File with class names not found at '{path_class_names.resolve()}' has no class names")
51
+
52
+ path_images_txt = path_yolo_dataset / filename_images_txt
53
+
54
+ if not path_images_txt.exists():
55
+ raise FileNotFoundError(f"File with images not found at '{path_images_txt.resolve()}'")
56
+
57
+ images_txt_text = path_images_txt.read_text()
58
+ if len(images_txt_text.strip()) == 0:
59
+ raise ValueError(f"File is empty at '{path_images_txt.resolve()}'")
60
+
61
+ image_paths_raw = [line.strip() for line in images_txt_text.splitlines()]
62
+
63
+ samples: List[Sample] = []
64
+ for image_path_raw in track(image_paths_raw):
65
+ path_image = path_yolo_dataset / image_path_raw
66
+ if not path_image.exists():
67
+ raise FileNotFoundError(f"File with image not found at '{path_image.resolve()}'")
68
+ width, height = get_image_size(path_image)
69
+
70
+ path_label = path_image.with_suffix(".txt")
71
+ if not path_label.exists():
72
+ raise FileNotFoundError(f"File with labels not found at '{path_label.resolve()}'")
73
+
74
+ boxes: List[primitives.Bbox] = []
75
+ bbox_strings = path_label.read_text().splitlines()
76
+ for bbox_string in bbox_strings:
77
+ parts = bbox_string.strip().split()
78
+ if len(parts) != 5:
79
+ raise ValueError(f"Invalid bbox format in file {path_label.resolve()}: {bbox_string}")
80
+
81
+ class_idx = int(parts[0])
82
+ x_center, y_center, bbox_width, bbox_height = (float(value) for value in parts[1:5])
83
+
84
+ top_left_x = x_center - bbox_width / 2
85
+ top_left_y = y_center - bbox_height / 2
86
+
87
+ bbox = primitives.Bbox(
88
+ top_left_x=top_left_x,
89
+ top_left_y=top_left_y,
90
+ width=bbox_width,
91
+ height=bbox_height,
92
+ class_idx=class_idx,
93
+ class_name=class_names[class_idx] if 0 <= class_idx < len(class_names) else None,
94
+ )
95
+ boxes.append(bbox)
96
+
97
+ sample = Sample(
98
+ file_path=path_image.absolute().as_posix(),
99
+ height=height,
100
+ width=width,
101
+ split=split_name,
102
+ bboxes=boxes,
103
+ )
104
+ samples.append(sample)
105
+
106
+ tasks = [TaskInfo(primitive=primitives.Bbox, class_names=class_names)]
107
+ info = DatasetInfo(dataset_name=dataset_name, tasks=tasks)
108
+ hafnia_dataset = HafniaDataset.from_samples_list(samples, info=info)
109
+ return hafnia_dataset
110
+
111
+
112
+ def to_yolo_format(
113
+ dataset: "HafniaDataset",
114
+ path_export_yolo_dataset: Path,
115
+ task_name: Optional[str] = None,
116
+ ):
117
+ """Exports a HafniaDataset as YOLO (Darknet) format."""
118
+ from hafnia.dataset.hafnia_dataset import Sample
119
+
120
+ bbox_task = dataset.info.get_task_by_task_name_and_primitive(task_name=task_name, primitive=primitives.Bbox)
121
+
122
+ class_names = bbox_task.class_names or []
123
+ if len(class_names) == 0:
124
+ raise ValueError(
125
+ f"Hafnia dataset task '{bbox_task.name}' has no class names defined. This is required for YOLO export."
126
+ )
127
+ path_export_yolo_dataset.mkdir(parents=True, exist_ok=True)
128
+ path_class_names = path_export_yolo_dataset / FILENAME_YOLO_CLASS_NAMES
129
+ path_class_names.write_text("\n".join(class_names))
130
+
131
+ path_data_folder = path_export_yolo_dataset / "data"
132
+ path_data_folder.mkdir(parents=True, exist_ok=True)
133
+ image_paths: List[str] = []
134
+ for sample_dict in dataset:
135
+ sample = Sample(**sample_dict)
136
+ if sample.file_path is None:
137
+ raise ValueError("Sample has no file_path defined.")
138
+ path_image_src = Path(sample.file_path)
139
+ path_image_dst = path_data_folder / path_image_src.name
140
+ shutil.copy2(path_image_src, path_image_dst)
141
+ image_paths.append(path_image_dst.relative_to(path_export_yolo_dataset).as_posix())
142
+ path_label = path_image_dst.with_suffix(".txt")
143
+ bboxes = sample.bboxes or []
144
+ bbox_strings = [bbox_to_yolo_format(bbox) for bbox in bboxes]
145
+ path_label.write_text("\n".join(bbox_strings))
146
+
147
+ path_images_txt = path_export_yolo_dataset / FILENAME_YOLO_IMAGES_TXT
148
+ path_images_txt.write_text("\n".join(image_paths))
149
+
150
+
151
+ def bbox_to_yolo_format(bbox: primitives.Bbox) -> str:
152
+ """
153
+ From hafnia bbox to yolo bbox string conversion
154
+ Both yolo and hafnia use normalized coordinates [0, 1]
155
+ Hafnia: top_left_x, top_left_y, width, height
156
+ Yolo (darknet): "<object-class> <x_center> <y_center> <width> <height>"
157
+ Example (3 bounding boxes):
158
+ 1 0.716797 0.395833 0.216406 0.147222
159
+ 0 0.687109 0.379167 0.255469 0.158333
160
+ 1 0.420312 0.395833 0.140625 0.166667
161
+ """
162
+ x_center = bbox.top_left_x + bbox.width / 2
163
+ y_center = bbox.top_left_y + bbox.height / 2
164
+ return f"{bbox.class_idx} {x_center} {y_center} {bbox.width} {bbox.height}"
@@ -14,8 +14,8 @@ from torchvision.datasets.utils import download_and_extract_archive, extract_arc
14
14
  from hafnia import utils
15
15
  from hafnia.dataset.dataset_helpers import save_pil_image_with_hash_name
16
16
  from hafnia.dataset.dataset_names import SplitName
17
- from hafnia.dataset.format_conversions.image_classification_from_directory import (
18
- import_image_classification_directory_tree,
17
+ from hafnia.dataset.format_conversions.format_image_classification_folder import (
18
+ from_image_classification_folder,
19
19
  )
20
20
  from hafnia.dataset.hafnia_dataset import DatasetInfo, HafniaDataset, Sample, TaskInfo
21
21
  from hafnia.dataset.primitives import Classification
@@ -72,7 +72,7 @@ def caltech_101_as_hafnia_dataset(
72
72
  path_image_classification_folder = _download_and_extract_caltech_dataset(
73
73
  dataset_name, force_redownload=force_redownload
74
74
  )
75
- hafnia_dataset = import_image_classification_directory_tree(
75
+ hafnia_dataset = from_image_classification_folder(
76
76
  path_image_classification_folder,
77
77
  split=SplitName.TRAIN,
78
78
  n_samples=n_samples,
@@ -102,7 +102,7 @@ def caltech_256_as_hafnia_dataset(
102
102
  path_image_classification_folder = _download_and_extract_caltech_dataset(
103
103
  dataset_name, force_redownload=force_redownload
104
104
  )
105
- hafnia_dataset = import_image_classification_directory_tree(
105
+ hafnia_dataset = from_image_classification_folder(
106
106
  path_image_classification_folder,
107
107
  split=SplitName.TRAIN,
108
108
  n_samples=n_samples,
@@ -122,6 +122,12 @@ def caltech_256_as_hafnia_dataset(
122
122
  }""")
123
123
  hafnia_dataset.info.reference_dataset_page = "https://data.caltech.edu/records/nyy15-4j048"
124
124
 
125
+ task = hafnia_dataset.info.get_task_by_primitive(Classification)
126
+
127
+ # Class Mapping: To remove numeric prefixes from class names
128
+ # E.g. "001.ak47 --> ak47", "002.american-flag --> american-flag", ...
129
+ class_mapping = {name: name.split(".")[-1] for name in task.class_names or []}
130
+ hafnia_dataset = hafnia_dataset.class_mapper(class_mapping=class_mapping, task_name=task.name)
125
131
  return hafnia_dataset
126
132
 
127
133