hafnia 0.4.0__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. hafnia/__init__.py +1 -1
  2. hafnia/dataset/dataset_names.py +128 -15
  3. hafnia/dataset/dataset_recipe/dataset_recipe.py +3 -3
  4. hafnia/dataset/dataset_upload_helper.py +31 -26
  5. hafnia/dataset/format_conversions/{image_classification_from_directory.py → format_image_classification_folder.py} +14 -10
  6. hafnia/dataset/format_conversions/format_yolo.py +164 -0
  7. hafnia/dataset/format_conversions/torchvision_datasets.py +10 -4
  8. hafnia/dataset/hafnia_dataset.py +246 -72
  9. hafnia/dataset/operations/dataset_stats.py +82 -70
  10. hafnia/dataset/operations/dataset_transformations.py +102 -37
  11. hafnia/dataset/operations/table_transformations.py +132 -15
  12. hafnia/dataset/primitives/bbox.py +3 -5
  13. hafnia/dataset/primitives/bitmask.py +2 -7
  14. hafnia/dataset/primitives/classification.py +3 -3
  15. hafnia/dataset/primitives/polygon.py +2 -4
  16. hafnia/dataset/primitives/primitive.py +1 -1
  17. hafnia/dataset/primitives/segmentation.py +2 -2
  18. hafnia/platform/datasets.py +4 -8
  19. hafnia/platform/download.py +1 -72
  20. hafnia/torch_helpers.py +12 -12
  21. hafnia/utils.py +1 -1
  22. hafnia/visualizations/image_visualizations.py +2 -0
  23. {hafnia-0.4.0.dist-info → hafnia-0.4.2.dist-info}/METADATA +4 -4
  24. hafnia-0.4.2.dist-info/RECORD +57 -0
  25. hafnia-0.4.2.dist-info/entry_points.txt +2 -0
  26. {cli → hafnia_cli}/__main__.py +2 -2
  27. {cli → hafnia_cli}/config.py +2 -2
  28. {cli → hafnia_cli}/dataset_cmds.py +2 -2
  29. {cli → hafnia_cli}/dataset_recipe_cmds.py +1 -1
  30. {cli → hafnia_cli}/experiment_cmds.py +1 -1
  31. {cli → hafnia_cli}/profile_cmds.py +2 -2
  32. {cli → hafnia_cli}/runc_cmds.py +1 -1
  33. {cli → hafnia_cli}/trainer_package_cmds.py +2 -2
  34. hafnia-0.4.0.dist-info/RECORD +0 -56
  35. hafnia-0.4.0.dist-info/entry_points.txt +0 -2
  36. {hafnia-0.4.0.dist-info → hafnia-0.4.2.dist-info}/WHEEL +0 -0
  37. {hafnia-0.4.0.dist-info → hafnia-0.4.2.dist-info}/licenses/LICENSE +0 -0
  38. {cli → hafnia_cli}/__init__.py +0 -0
  39. {cli → hafnia_cli}/consts.py +0 -0
  40. {cli → hafnia_cli}/keychain.py +0 -0
hafnia/__init__.py CHANGED
@@ -3,4 +3,4 @@ from importlib.metadata import version
3
3
  __package_name__ = "hafnia"
4
4
  __version__ = version(__package_name__)
5
5
 
6
- __dataset_format_version__ = "0.1.0" # Hafnia dataset format version
6
+ __dataset_format_version__ = "0.2.0" # Hafnia dataset format version
@@ -1,5 +1,8 @@
1
1
  from enum import Enum
2
- from typing import List
2
+ from typing import Dict, List, Optional
3
+
4
+ import boto3
5
+ from pydantic import BaseModel, field_validator
3
6
 
4
7
  FILENAME_RECIPE_JSON = "recipe.json"
5
8
  FILENAME_DATASET_INFO = "dataset_info.json"
@@ -23,7 +26,7 @@ TAG_IS_SAMPLE = "sample"
23
26
  OPS_REMOVE_CLASS = "__REMOVE__"
24
27
 
25
28
 
26
- class FieldName:
29
+ class PrimitiveField:
27
30
  CLASS_NAME: str = "class_name" # Name of the class this primitive is associated with, e.g. "car" for Bbox
28
31
  CLASS_IDX: str = "class_idx" # Index of the class this primitive is associated with, e.g. 0 for "car" if it is the first class # noqa: E501
29
32
  OBJECT_ID: str = "object_id" # Unique identifier for the object, e.g. "12345123"
@@ -38,40 +41,150 @@ class FieldName:
38
41
  Returns a list of expected field names for primitives.
39
42
  """
40
43
  return [
41
- FieldName.CLASS_NAME,
42
- FieldName.CLASS_IDX,
43
- FieldName.OBJECT_ID,
44
- FieldName.CONFIDENCE,
45
- FieldName.META,
46
- FieldName.TASK_NAME,
44
+ PrimitiveField.CLASS_NAME,
45
+ PrimitiveField.CLASS_IDX,
46
+ PrimitiveField.OBJECT_ID,
47
+ PrimitiveField.CONFIDENCE,
48
+ PrimitiveField.META,
49
+ PrimitiveField.TASK_NAME,
47
50
  ]
48
51
 
49
52
 
50
- class ColumnName:
51
- SAMPLE_INDEX: str = "sample_index"
53
+ class SampleField:
52
54
  FILE_PATH: str = "file_path"
53
55
  HEIGHT: str = "height"
54
56
  WIDTH: str = "width"
55
57
  SPLIT: str = "split"
58
+ TAGS: str = "tags"
59
+
60
+ CLASSIFICATIONS: str = "classifications"
61
+ BBOXES: str = "bboxes"
62
+ BITMASKS: str = "bitmasks"
63
+ POLYGONS: str = "polygons"
64
+
65
+ STORAGE_FORMAT: str = "storage_format" # E.g. "image", "video", "zip"
66
+ COLLECTION_INDEX: str = "collection_index"
67
+ COLLECTION_ID: str = "collection_id"
56
68
  REMOTE_PATH: str = "remote_path" # Path to the file in remote storage, e.g. S3
69
+ SAMPLE_INDEX: str = "sample_index"
70
+
57
71
  ATTRIBUTION: str = "attribution" # Attribution for the sample (image/video), e.g. creator, license, source, etc.
58
- TAGS: str = "tags"
59
72
  META: str = "meta"
60
73
  DATASET_NAME: str = "dataset_name"
61
74
 
62
75
 
76
+ class StorageFormat:
77
+ IMAGE: str = "image"
78
+ VIDEO: str = "video"
79
+ ZIP: str = "zip"
80
+
81
+
63
82
  class SplitName:
64
- TRAIN = "train"
65
- VAL = "validation"
66
- TEST = "test"
67
- UNDEFINED = "UNDEFINED"
83
+ TRAIN: str = "train"
84
+ VAL: str = "validation"
85
+ TEST: str = "test"
86
+ UNDEFINED: str = "UNDEFINED"
68
87
 
69
88
  @staticmethod
70
89
  def valid_splits() -> List[str]:
71
90
  return [SplitName.TRAIN, SplitName.VAL, SplitName.TEST]
72
91
 
92
+ @staticmethod
93
+ def all_split_names() -> List[str]:
94
+ return [*SplitName.valid_splits(), SplitName.UNDEFINED]
95
+
73
96
 
74
97
  class DatasetVariant(Enum):
75
98
  DUMP = "dump"
76
99
  SAMPLE = "sample"
77
100
  HIDDEN = "hidden"
101
+
102
+
103
+ class AwsCredentials(BaseModel):
104
+ access_key: str
105
+ secret_key: str
106
+ session_token: str
107
+ region: Optional[str]
108
+
109
+ def aws_credentials(self) -> Dict[str, str]:
110
+ """
111
+ Returns the AWS credentials as a dictionary.
112
+ """
113
+ environment_vars = {
114
+ "AWS_ACCESS_KEY_ID": self.access_key,
115
+ "AWS_SECRET_ACCESS_KEY": self.secret_key,
116
+ "AWS_SESSION_TOKEN": self.session_token,
117
+ }
118
+ if self.region:
119
+ environment_vars["AWS_REGION"] = self.region
120
+
121
+ return environment_vars
122
+
123
+ @staticmethod
124
+ def from_session(session: boto3.Session) -> "AwsCredentials":
125
+ """
126
+ Creates AwsCredentials from a Boto3 session.
127
+ """
128
+ frozen_credentials = session.get_credentials().get_frozen_credentials()
129
+ return AwsCredentials(
130
+ access_key=frozen_credentials.access_key,
131
+ secret_key=frozen_credentials.secret_key,
132
+ session_token=frozen_credentials.token,
133
+ region=session.region_name,
134
+ )
135
+
136
+
137
+ ARN_PREFIX = "arn:aws:s3:::"
138
+
139
+
140
+ class ResourceCredentials(AwsCredentials):
141
+ s3_arn: str
142
+
143
+ @staticmethod
144
+ def fix_naming(payload: Dict[str, str]) -> "ResourceCredentials":
145
+ """
146
+ The endpoint returns a payload with a key called 's3_path', but it
147
+ is actually an ARN path (starts with arn:aws:s3::). This method renames it to 's3_arn' for consistency.
148
+ """
149
+ if "s3_path" in payload and payload["s3_path"].startswith(ARN_PREFIX):
150
+ payload["s3_arn"] = payload.pop("s3_path")
151
+
152
+ if "region" not in payload:
153
+ payload["region"] = "eu-west-1"
154
+ return ResourceCredentials(**payload)
155
+
156
+ @field_validator("s3_arn")
157
+ @classmethod
158
+ def validate_s3_arn(cls, value: str) -> str:
159
+ """Validate s3_arn to ensure it starts with 'arn:aws:s3:::'"""
160
+ if not value.startswith("arn:aws:s3:::"):
161
+ raise ValueError(f"Invalid S3 ARN: {value}. It should start with 'arn:aws:s3:::'")
162
+ return value
163
+
164
+ def s3_path(self) -> str:
165
+ """
166
+ Extracts the S3 path from the ARN.
167
+ Example: arn:aws:s3:::my-bucket/my-prefix -> my-bucket/my-prefix
168
+ """
169
+ return self.s3_arn[len(ARN_PREFIX) :]
170
+
171
+ def s3_uri(self) -> str:
172
+ """
173
+ Converts the S3 ARN to a URI format.
174
+ Example: arn:aws:s3:::my-bucket/my-prefix -> s3://my-bucket/my-prefix
175
+ """
176
+ return f"s3://{self.s3_path()}"
177
+
178
+ def bucket_name(self) -> str:
179
+ """
180
+ Extracts the bucket name from the S3 ARN.
181
+ Example: arn:aws:s3:::my-bucket/my-prefix -> my-bucket
182
+ """
183
+ return self.s3_path().split("/")[0]
184
+
185
+ def object_key(self) -> str:
186
+ """
187
+ Extracts the object key from the S3 ARN.
188
+ Example: arn:aws:s3:::my-bucket/my-prefix -> my-prefix
189
+ """
190
+ return "/".join(self.s3_path().split("/")[1:])
@@ -98,8 +98,8 @@ class DatasetRecipe(Serializable):
98
98
  @staticmethod
99
99
  def from_recipe_id(recipe_id: str) -> "DatasetRecipe":
100
100
  """Loads a dataset recipe by id from the hafnia platform."""
101
- from cli.config import Config
102
101
  from hafnia.platform.dataset_recipe import get_dataset_recipe_by_id
102
+ from hafnia_cli.config import Config
103
103
 
104
104
  cfg = Config()
105
105
  endpoint_dataset = cfg.get_platform_endpoint("dataset_recipes")
@@ -114,8 +114,8 @@ class DatasetRecipe(Serializable):
114
114
  @staticmethod
115
115
  def from_recipe_name(name: str) -> "DatasetRecipe":
116
116
  """Loads a dataset recipe by name from the hafnia platform"""
117
- from cli.config import Config
118
117
  from hafnia.platform.dataset_recipe import get_dataset_recipe_by_name
118
+ from hafnia_cli.config import Config
119
119
 
120
120
  cfg = Config()
121
121
  endpoint_dataset = cfg.get_platform_endpoint("dataset_recipes")
@@ -239,8 +239,8 @@ class DatasetRecipe(Serializable):
239
239
 
240
240
  def as_platform_recipe(self, recipe_name: Optional[str], overwrite: bool = False) -> Dict:
241
241
  """Uploads dataset recipe to the hafnia platform."""
242
- from cli.config import Config
243
242
  from hafnia.platform.dataset_recipe import get_or_create_dataset_recipe
243
+ from hafnia_cli.config import Config
244
244
 
245
245
  recipe = self.as_dict()
246
246
  cfg = Config()
@@ -11,13 +11,12 @@ import polars as pl
11
11
  from PIL import Image
12
12
  from pydantic import BaseModel, ConfigDict, field_validator
13
13
 
14
- from cli.config import Config
15
14
  from hafnia.dataset import primitives
16
15
  from hafnia.dataset.dataset_names import (
17
- ColumnName,
18
16
  DatasetVariant,
19
17
  DeploymentStage,
20
- FieldName,
18
+ PrimitiveField,
19
+ SampleField,
21
20
  SplitName,
22
21
  )
23
22
  from hafnia.dataset.hafnia_dataset import Attribution, HafniaDataset, Sample, TaskInfo
@@ -33,6 +32,7 @@ from hafnia.dataset.primitives.primitive import Primitive
33
32
  from hafnia.http import post
34
33
  from hafnia.log import user_logger
35
34
  from hafnia.platform.datasets import get_dataset_id
35
+ from hafnia_cli.config import Config
36
36
 
37
37
 
38
38
  def generate_bucket_name(dataset_name: str, deployment_stage: DeploymentStage) -> str:
@@ -193,7 +193,7 @@ class Annotations(BaseModel):
193
193
  in gallery images on the dataset detail page.
194
194
  """
195
195
 
196
- objects: Optional[List[Bbox]] = None
196
+ bboxes: Optional[List[Bbox]] = None
197
197
  classifications: Optional[List[Classification]] = None
198
198
  polygons: Optional[List[Polygon]] = None
199
199
  bitmasks: Optional[List[Bitmask]] = None
@@ -210,13 +210,15 @@ class DatasetImageMetadata(BaseModel):
210
210
  @classmethod
211
211
  def from_sample(cls, sample: Sample) -> "DatasetImageMetadata":
212
212
  sample = sample.model_copy(deep=True)
213
+ if sample.file_path is None:
214
+ raise ValueError("Sample has no file_path defined.")
213
215
  sample.file_path = "/".join(Path(sample.file_path).parts[-3:])
214
216
  metadata = {}
215
217
  metadata_field_names = [
216
- ColumnName.FILE_PATH,
217
- ColumnName.HEIGHT,
218
- ColumnName.WIDTH,
219
- ColumnName.SPLIT,
218
+ SampleField.FILE_PATH,
219
+ SampleField.HEIGHT,
220
+ SampleField.WIDTH,
221
+ SampleField.SPLIT,
220
222
  ]
221
223
  for field_name in metadata_field_names:
222
224
  if hasattr(sample, field_name) and getattr(sample, field_name) is not None:
@@ -224,7 +226,7 @@ class DatasetImageMetadata(BaseModel):
224
226
 
225
227
  obj = DatasetImageMetadata(
226
228
  annotations=Annotations(
227
- objects=sample.objects,
229
+ bboxes=sample.bboxes,
228
230
  classifications=sample.classifications,
229
231
  polygons=sample.polygons,
230
232
  bitmasks=sample.bitmasks,
@@ -343,13 +345,13 @@ def calculate_distribution_values(
343
345
  classifications = dataset_split.select(pl.col(classification_column).explode())
344
346
  classifications = classifications.filter(pl.col(classification_column).is_not_null()).unnest(classification_column)
345
347
  classifications = classifications.filter(
346
- pl.col(FieldName.TASK_NAME).is_in([task.name for task in distribution_tasks])
348
+ pl.col(PrimitiveField.TASK_NAME).is_in([task.name for task in distribution_tasks])
347
349
  )
348
350
  dist_values = []
349
- for (task_name,), task_group in classifications.group_by(FieldName.TASK_NAME):
351
+ for (task_name,), task_group in classifications.group_by(PrimitiveField.TASK_NAME):
350
352
  distribution_type = DbDistributionType(name=task_name)
351
353
  n_annotated_total = len(task_group)
352
- for (class_name,), class_group in task_group.group_by(FieldName.CLASS_NAME):
354
+ for (class_name,), class_group in task_group.group_by(PrimitiveField.CLASS_NAME):
353
355
  class_count = len(class_group)
354
356
 
355
357
  dist_values.append(
@@ -383,6 +385,7 @@ def dataset_info_from_dataset(
383
385
  path_hidden: Optional[Path],
384
386
  path_gallery_images: Optional[Path] = None,
385
387
  gallery_image_names: Optional[List[str]] = None,
388
+ distribution_task_names: Optional[List[TaskInfo]] = None,
386
389
  ) -> DbDataset:
387
390
  dataset_variants = []
388
391
  dataset_reports = []
@@ -427,13 +430,15 @@ def dataset_info_from_dataset(
427
430
  )
428
431
  )
429
432
 
433
+ distribution_task_names = distribution_task_names or []
434
+ distribution_tasks = [t for t in dataset.info.tasks if t.name in distribution_task_names]
430
435
  for split_name in SplitChoices:
431
436
  split_names = SPLIT_CHOICE_MAPPING[split_name]
432
- dataset_split = dataset_variant.samples.filter(pl.col(ColumnName.SPLIT).is_in(split_names))
437
+ dataset_split = dataset_variant.samples.filter(pl.col(SampleField.SPLIT).is_in(split_names))
433
438
 
434
439
  distribution_values = calculate_distribution_values(
435
440
  dataset_split=dataset_split,
436
- distribution_tasks=dataset.info.distributions,
441
+ distribution_tasks=distribution_tasks,
437
442
  )
438
443
  report = DbSplitAnnotationsReport(
439
444
  variant_type=VARIANT_TYPE_MAPPING[variant_type], # type: ignore[index]
@@ -461,7 +466,7 @@ def dataset_info_from_dataset(
461
466
 
462
467
  annotation_type = DbAnnotationType(name=AnnotationType.ObjectDetection.value)
463
468
  for (class_name, task_name), class_group in df_per_instance.group_by(
464
- FieldName.CLASS_NAME, FieldName.TASK_NAME
469
+ PrimitiveField.CLASS_NAME, PrimitiveField.TASK_NAME
465
470
  ):
466
471
  if class_name is None:
467
472
  continue
@@ -473,10 +478,10 @@ def dataset_info_from_dataset(
473
478
  annotation_type=annotation_type,
474
479
  task_name=task_name,
475
480
  ),
476
- unique_obj_ids=class_group[FieldName.OBJECT_ID].n_unique(),
481
+ unique_obj_ids=class_group[PrimitiveField.OBJECT_ID].n_unique(),
477
482
  obj_instances=len(class_group),
478
483
  annotation_type=[annotation_type],
479
- images_with_obj=class_group[ColumnName.SAMPLE_INDEX].n_unique(),
484
+ images_with_obj=class_group[SampleField.SAMPLE_INDEX].n_unique(),
480
485
  area_avg_ratio=class_group["area"].mean(),
481
486
  area_min_ratio=class_group["area"].min(),
482
487
  area_max_ratio=class_group["area"].max(),
@@ -495,7 +500,7 @@ def dataset_info_from_dataset(
495
500
  width_avg_px=class_group["width_px"].mean(),
496
501
  width_min_px=int(class_group["width_px"].min()),
497
502
  width_max_px=int(class_group["width_px"].max()),
498
- average_count_per_image=len(class_group) / class_group[ColumnName.SAMPLE_INDEX].n_unique(),
503
+ average_count_per_image=len(class_group) / class_group[SampleField.SAMPLE_INDEX].n_unique(),
499
504
  )
500
505
  )
501
506
 
@@ -509,13 +514,13 @@ def dataset_info_from_dataset(
509
514
 
510
515
  # Include only classification tasks that are defined in the dataset info
511
516
  classification_df = classification_df.filter(
512
- pl.col(FieldName.TASK_NAME).is_in(classification_tasks)
517
+ pl.col(PrimitiveField.TASK_NAME).is_in(classification_tasks)
513
518
  )
514
519
 
515
520
  for (
516
521
  task_name,
517
522
  class_name,
518
- ), class_group in classification_df.group_by(FieldName.TASK_NAME, FieldName.CLASS_NAME):
523
+ ), class_group in classification_df.group_by(PrimitiveField.TASK_NAME, PrimitiveField.CLASS_NAME):
519
524
  if class_name is None:
520
525
  continue
521
526
  if task_name == Classification.default_task_name():
@@ -544,7 +549,7 @@ def dataset_info_from_dataset(
544
549
  if has_primitive(dataset_split, PrimitiveType=Bitmask):
545
550
  col_name = Bitmask.column_name()
546
551
  drop_columns = [col for col in primitive_columns if col != col_name]
547
- drop_columns.append(FieldName.META)
552
+ drop_columns.append(PrimitiveField.META)
548
553
 
549
554
  df_per_instance = table_transformations.create_primitive_table(
550
555
  dataset_split, PrimitiveType=Bitmask, keep_sample_data=True
@@ -562,7 +567,7 @@ def dataset_info_from_dataset(
562
567
 
563
568
  annotation_type = DbAnnotationType(name=AnnotationType.InstanceSegmentation)
564
569
  for (class_name, task_name), class_group in df_per_instance.group_by(
565
- FieldName.CLASS_NAME, FieldName.TASK_NAME
570
+ PrimitiveField.CLASS_NAME, PrimitiveField.TASK_NAME
566
571
  ):
567
572
  if class_name is None:
568
573
  continue
@@ -574,11 +579,11 @@ def dataset_info_from_dataset(
574
579
  annotation_type=annotation_type,
575
580
  task_name=task_name,
576
581
  ),
577
- unique_obj_ids=class_group[FieldName.OBJECT_ID].n_unique(),
582
+ unique_obj_ids=class_group[PrimitiveField.OBJECT_ID].n_unique(),
578
583
  obj_instances=len(class_group),
579
584
  annotation_type=[annotation_type],
580
- average_count_per_image=len(class_group) / class_group[ColumnName.SAMPLE_INDEX].n_unique(),
581
- images_with_obj=class_group[ColumnName.SAMPLE_INDEX].n_unique(),
585
+ average_count_per_image=len(class_group) / class_group[SampleField.SAMPLE_INDEX].n_unique(),
586
+ images_with_obj=class_group[SampleField.SAMPLE_INDEX].n_unique(),
582
587
  area_avg_ratio=class_group["area"].mean(),
583
588
  area_min_ratio=class_group["area"].min(),
584
589
  area_max_ratio=class_group["area"].max(),
@@ -646,7 +651,7 @@ def create_gallery_images(
646
651
  path_gallery_images.mkdir(parents=True, exist_ok=True)
647
652
  COL_IMAGE_NAME = "image_name"
648
653
  samples = dataset.samples.with_columns(
649
- dataset.samples[ColumnName.FILE_PATH].str.split("/").list.last().alias(COL_IMAGE_NAME)
654
+ dataset.samples[SampleField.FILE_PATH].str.split("/").list.last().alias(COL_IMAGE_NAME)
650
655
  )
651
656
  gallery_samples = samples.filter(pl.col(COL_IMAGE_NAME).is_in(gallery_image_names))
652
657
 
@@ -1,23 +1,27 @@
1
1
  import shutil
2
2
  from pathlib import Path
3
- from typing import List, Optional
3
+ from typing import TYPE_CHECKING, List, Optional
4
4
 
5
5
  import more_itertools
6
6
  import polars as pl
7
7
  from PIL import Image
8
8
  from rich.progress import track
9
9
 
10
- from hafnia.dataset.dataset_names import ColumnName, FieldName
11
- from hafnia.dataset.hafnia_dataset import DatasetInfo, HafniaDataset, Sample, TaskInfo
10
+ from hafnia.dataset.dataset_names import PrimitiveField, SampleField
12
11
  from hafnia.dataset.primitives import Classification
13
12
  from hafnia.utils import is_image_file
14
13
 
14
+ if TYPE_CHECKING:
15
+ from hafnia.dataset.hafnia_dataset import HafniaDataset
15
16
 
16
- def import_image_classification_directory_tree(
17
+
18
+ def from_image_classification_folder(
17
19
  path_folder: Path,
18
20
  split: str,
19
21
  n_samples: Optional[int] = None,
20
- ) -> HafniaDataset:
22
+ ) -> "HafniaDataset":
23
+ from hafnia.dataset.hafnia_dataset import DatasetInfo, HafniaDataset, Sample, TaskInfo
24
+
21
25
  class_folder_paths = [path for path in path_folder.iterdir() if path.is_dir()]
22
26
  class_names = sorted([folder.name for folder in class_folder_paths]) # Sort for determinism
23
27
 
@@ -62,8 +66,8 @@ def import_image_classification_directory_tree(
62
66
  return hafnia_dataset
63
67
 
64
68
 
65
- def export_image_classification_directory_tree(
66
- dataset: HafniaDataset,
69
+ def to_image_classification_folder(
70
+ dataset: "HafniaDataset",
67
71
  path_output: Path,
68
72
  task_name: Optional[str] = None,
69
73
  clean_folder: bool = False,
@@ -72,7 +76,7 @@ def export_image_classification_directory_tree(
72
76
 
73
77
  samples = dataset.samples.with_columns(
74
78
  pl.col(task.primitive.column_name())
75
- .list.filter(pl.element().struct.field(FieldName.TASK_NAME) == task.name)
79
+ .list.filter(pl.element().struct.field(PrimitiveField.TASK_NAME) == task.name)
76
80
  .alias(task.primitive.column_name())
77
81
  )
78
82
 
@@ -95,11 +99,11 @@ def export_image_classification_directory_tree(
95
99
  if len(classifications) != 1:
96
100
  raise ValueError("Each sample should have exactly one classification.")
97
101
  classification = classifications[0]
98
- class_name = classification[FieldName.CLASS_NAME].replace("/", "_") # Avoid issues with subfolders
102
+ class_name = classification[PrimitiveField.CLASS_NAME].replace("/", "_") # Avoid issues with subfolders
99
103
  path_class_folder = path_output / class_name
100
104
  path_class_folder.mkdir(parents=True, exist_ok=True)
101
105
 
102
- path_image_org = Path(sample_dict[ColumnName.FILE_PATH])
106
+ path_image_org = Path(sample_dict[SampleField.FILE_PATH])
103
107
  path_image_new = path_class_folder / path_image_org.name
104
108
  shutil.copy2(path_image_org, path_image_new)
105
109
 
@@ -0,0 +1,164 @@
1
+ import shutil
2
+ from pathlib import Path
3
+ from typing import TYPE_CHECKING, List, Optional
4
+
5
+ from PIL import Image
6
+ from rich.progress import track
7
+
8
+ from hafnia.dataset import primitives
9
+ from hafnia.dataset.dataset_names import SplitName
10
+
11
+ if TYPE_CHECKING:
12
+ from hafnia.dataset.hafnia_dataset import HafniaDataset
13
+
14
+ FILENAME_YOLO_CLASS_NAMES = "obj.names"
15
+ FILENAME_YOLO_IMAGES_TXT = "images.txt"
16
+
17
+
18
+ def get_image_size(path: Path) -> tuple[int, int]:
19
+ with Image.open(path) as img:
20
+ return img.size # (width, height)
21
+
22
+
23
+ def from_yolo_format(
24
+ path_yolo_dataset: Path,
25
+ split_name: str = SplitName.UNDEFINED,
26
+ dataset_name: str = "yolo-dataset",
27
+ filename_class_names: str = FILENAME_YOLO_CLASS_NAMES,
28
+ filename_images_txt: str = FILENAME_YOLO_IMAGES_TXT,
29
+ ) -> "HafniaDataset":
30
+ """
31
+ Imports a YOLO (Darknet) formatted dataset as a HafniaDataset.
32
+ """
33
+ from hafnia.dataset.hafnia_dataset import DatasetInfo, HafniaDataset, Sample, TaskInfo
34
+
35
+ path_class_names = path_yolo_dataset / filename_class_names
36
+
37
+ if split_name not in SplitName.all_split_names():
38
+ raise ValueError(f"Invalid split name: {split_name}. Must be one of {SplitName.all_split_names()}")
39
+
40
+ if not path_class_names.exists():
41
+ raise FileNotFoundError(f"File with class names not found at '{path_class_names.resolve()}'.")
42
+
43
+ class_names_text = path_class_names.read_text()
44
+ if class_names_text.strip() == "":
45
+ raise ValueError(f"File with class names not found at '{path_class_names.resolve()}' is empty")
46
+
47
+ class_names = [class_name for class_name in class_names_text.splitlines() if class_name.strip() != ""]
48
+
49
+ if len(class_names) == 0:
50
+ raise ValueError(f"File with class names not found at '{path_class_names.resolve()}' has no class names")
51
+
52
+ path_images_txt = path_yolo_dataset / filename_images_txt
53
+
54
+ if not path_images_txt.exists():
55
+ raise FileNotFoundError(f"File with images not found at '{path_images_txt.resolve()}'")
56
+
57
+ images_txt_text = path_images_txt.read_text()
58
+ if len(images_txt_text.strip()) == 0:
59
+ raise ValueError(f"File is empty at '{path_images_txt.resolve()}'")
60
+
61
+ image_paths_raw = [line.strip() for line in images_txt_text.splitlines()]
62
+
63
+ samples: List[Sample] = []
64
+ for image_path_raw in track(image_paths_raw):
65
+ path_image = path_yolo_dataset / image_path_raw
66
+ if not path_image.exists():
67
+ raise FileNotFoundError(f"File with image not found at '{path_image.resolve()}'")
68
+ width, height = get_image_size(path_image)
69
+
70
+ path_label = path_image.with_suffix(".txt")
71
+ if not path_label.exists():
72
+ raise FileNotFoundError(f"File with labels not found at '{path_label.resolve()}'")
73
+
74
+ boxes: List[primitives.Bbox] = []
75
+ bbox_strings = path_label.read_text().splitlines()
76
+ for bbox_string in bbox_strings:
77
+ parts = bbox_string.strip().split()
78
+ if len(parts) != 5:
79
+ raise ValueError(f"Invalid bbox format in file {path_label.resolve()}: {bbox_string}")
80
+
81
+ class_idx = int(parts[0])
82
+ x_center, y_center, bbox_width, bbox_height = (float(value) for value in parts[1:5])
83
+
84
+ top_left_x = x_center - bbox_width / 2
85
+ top_left_y = y_center - bbox_height / 2
86
+
87
+ bbox = primitives.Bbox(
88
+ top_left_x=top_left_x,
89
+ top_left_y=top_left_y,
90
+ width=bbox_width,
91
+ height=bbox_height,
92
+ class_idx=class_idx,
93
+ class_name=class_names[class_idx] if 0 <= class_idx < len(class_names) else None,
94
+ )
95
+ boxes.append(bbox)
96
+
97
+ sample = Sample(
98
+ file_path=path_image.absolute().as_posix(),
99
+ height=height,
100
+ width=width,
101
+ split=split_name,
102
+ bboxes=boxes,
103
+ )
104
+ samples.append(sample)
105
+
106
+ tasks = [TaskInfo(primitive=primitives.Bbox, class_names=class_names)]
107
+ info = DatasetInfo(dataset_name=dataset_name, tasks=tasks)
108
+ hafnia_dataset = HafniaDataset.from_samples_list(samples, info=info)
109
+ return hafnia_dataset
110
+
111
+
112
+ def to_yolo_format(
113
+ dataset: "HafniaDataset",
114
+ path_export_yolo_dataset: Path,
115
+ task_name: Optional[str] = None,
116
+ ):
117
+ """Exports a HafniaDataset as YOLO (Darknet) format."""
118
+ from hafnia.dataset.hafnia_dataset import Sample
119
+
120
+ bbox_task = dataset.info.get_task_by_task_name_and_primitive(task_name=task_name, primitive=primitives.Bbox)
121
+
122
+ class_names = bbox_task.class_names or []
123
+ if len(class_names) == 0:
124
+ raise ValueError(
125
+ f"Hafnia dataset task '{bbox_task.name}' has no class names defined. This is required for YOLO export."
126
+ )
127
+ path_export_yolo_dataset.mkdir(parents=True, exist_ok=True)
128
+ path_class_names = path_export_yolo_dataset / FILENAME_YOLO_CLASS_NAMES
129
+ path_class_names.write_text("\n".join(class_names))
130
+
131
+ path_data_folder = path_export_yolo_dataset / "data"
132
+ path_data_folder.mkdir(parents=True, exist_ok=True)
133
+ image_paths: List[str] = []
134
+ for sample_dict in dataset:
135
+ sample = Sample(**sample_dict)
136
+ if sample.file_path is None:
137
+ raise ValueError("Sample has no file_path defined.")
138
+ path_image_src = Path(sample.file_path)
139
+ path_image_dst = path_data_folder / path_image_src.name
140
+ shutil.copy2(path_image_src, path_image_dst)
141
+ image_paths.append(path_image_dst.relative_to(path_export_yolo_dataset).as_posix())
142
+ path_label = path_image_dst.with_suffix(".txt")
143
+ bboxes = sample.bboxes or []
144
+ bbox_strings = [bbox_to_yolo_format(bbox) for bbox in bboxes]
145
+ path_label.write_text("\n".join(bbox_strings))
146
+
147
+ path_images_txt = path_export_yolo_dataset / FILENAME_YOLO_IMAGES_TXT
148
+ path_images_txt.write_text("\n".join(image_paths))
149
+
150
+
151
+ def bbox_to_yolo_format(bbox: primitives.Bbox) -> str:
152
+ """
153
+ From hafnia bbox to yolo bbox string conversion
154
+ Both yolo and hafnia use normalized coordinates [0, 1]
155
+ Hafnia: top_left_x, top_left_y, width, height
156
+ Yolo (darknet): "<object-class> <x_center> <y_center> <width> <height>"
157
+ Example (3 bounding boxes):
158
+ 1 0.716797 0.395833 0.216406 0.147222
159
+ 0 0.687109 0.379167 0.255469 0.158333
160
+ 1 0.420312 0.395833 0.140625 0.166667
161
+ """
162
+ x_center = bbox.top_left_x + bbox.width / 2
163
+ y_center = bbox.top_left_y + bbox.height / 2
164
+ return f"{bbox.class_idx} {x_center} {y_center} {bbox.width} {bbox.height}"