hafnia 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hafnia/__init__.py +1 -1
- hafnia/dataset/dataset_names.py +128 -15
- hafnia/dataset/dataset_upload_helper.py +30 -25
- hafnia/dataset/format_conversions/{image_classification_from_directory.py → format_image_classification_folder.py} +14 -10
- hafnia/dataset/format_conversions/format_yolo.py +164 -0
- hafnia/dataset/format_conversions/torchvision_datasets.py +10 -4
- hafnia/dataset/hafnia_dataset.py +246 -72
- hafnia/dataset/operations/dataset_stats.py +82 -70
- hafnia/dataset/operations/dataset_transformations.py +102 -37
- hafnia/dataset/operations/table_transformations.py +132 -15
- hafnia/dataset/primitives/bbox.py +3 -5
- hafnia/dataset/primitives/bitmask.py +2 -7
- hafnia/dataset/primitives/classification.py +3 -3
- hafnia/dataset/primitives/polygon.py +2 -4
- hafnia/dataset/primitives/primitive.py +1 -1
- hafnia/dataset/primitives/segmentation.py +2 -2
- hafnia/platform/datasets.py +3 -7
- hafnia/platform/download.py +1 -72
- hafnia/torch_helpers.py +12 -12
- hafnia/visualizations/image_visualizations.py +2 -0
- {hafnia-0.4.0.dist-info → hafnia-0.4.1.dist-info}/METADATA +4 -4
- {hafnia-0.4.0.dist-info → hafnia-0.4.1.dist-info}/RECORD +25 -24
- {hafnia-0.4.0.dist-info → hafnia-0.4.1.dist-info}/WHEEL +0 -0
- {hafnia-0.4.0.dist-info → hafnia-0.4.1.dist-info}/entry_points.txt +0 -0
- {hafnia-0.4.0.dist-info → hafnia-0.4.1.dist-info}/licenses/LICENSE +0 -0
hafnia/__init__.py
CHANGED
hafnia/dataset/dataset_names.py
CHANGED
|
@@ -1,5 +1,8 @@
|
|
|
1
1
|
from enum import Enum
|
|
2
|
-
from typing import List
|
|
2
|
+
from typing import Dict, List, Optional
|
|
3
|
+
|
|
4
|
+
import boto3
|
|
5
|
+
from pydantic import BaseModel, field_validator
|
|
3
6
|
|
|
4
7
|
FILENAME_RECIPE_JSON = "recipe.json"
|
|
5
8
|
FILENAME_DATASET_INFO = "dataset_info.json"
|
|
@@ -23,7 +26,7 @@ TAG_IS_SAMPLE = "sample"
|
|
|
23
26
|
OPS_REMOVE_CLASS = "__REMOVE__"
|
|
24
27
|
|
|
25
28
|
|
|
26
|
-
class
|
|
29
|
+
class PrimitiveField:
|
|
27
30
|
CLASS_NAME: str = "class_name" # Name of the class this primitive is associated with, e.g. "car" for Bbox
|
|
28
31
|
CLASS_IDX: str = "class_idx" # Index of the class this primitive is associated with, e.g. 0 for "car" if it is the first class # noqa: E501
|
|
29
32
|
OBJECT_ID: str = "object_id" # Unique identifier for the object, e.g. "12345123"
|
|
@@ -38,40 +41,150 @@ class FieldName:
|
|
|
38
41
|
Returns a list of expected field names for primitives.
|
|
39
42
|
"""
|
|
40
43
|
return [
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
44
|
+
PrimitiveField.CLASS_NAME,
|
|
45
|
+
PrimitiveField.CLASS_IDX,
|
|
46
|
+
PrimitiveField.OBJECT_ID,
|
|
47
|
+
PrimitiveField.CONFIDENCE,
|
|
48
|
+
PrimitiveField.META,
|
|
49
|
+
PrimitiveField.TASK_NAME,
|
|
47
50
|
]
|
|
48
51
|
|
|
49
52
|
|
|
50
|
-
class
|
|
51
|
-
SAMPLE_INDEX: str = "sample_index"
|
|
53
|
+
class SampleField:
|
|
52
54
|
FILE_PATH: str = "file_path"
|
|
53
55
|
HEIGHT: str = "height"
|
|
54
56
|
WIDTH: str = "width"
|
|
55
57
|
SPLIT: str = "split"
|
|
58
|
+
TAGS: str = "tags"
|
|
59
|
+
|
|
60
|
+
CLASSIFICATIONS: str = "classifications"
|
|
61
|
+
BBOXES: str = "bboxes"
|
|
62
|
+
BITMASKS: str = "bitmasks"
|
|
63
|
+
POLYGONS: str = "polygons"
|
|
64
|
+
|
|
65
|
+
STORAGE_FORMAT: str = "storage_format" # E.g. "image", "video", "zip"
|
|
66
|
+
COLLECTION_INDEX: str = "collection_index"
|
|
67
|
+
COLLECTION_ID: str = "collection_id"
|
|
56
68
|
REMOTE_PATH: str = "remote_path" # Path to the file in remote storage, e.g. S3
|
|
69
|
+
SAMPLE_INDEX: str = "sample_index"
|
|
70
|
+
|
|
57
71
|
ATTRIBUTION: str = "attribution" # Attribution for the sample (image/video), e.g. creator, license, source, etc.
|
|
58
|
-
TAGS: str = "tags"
|
|
59
72
|
META: str = "meta"
|
|
60
73
|
DATASET_NAME: str = "dataset_name"
|
|
61
74
|
|
|
62
75
|
|
|
76
|
+
class StorageFormat:
|
|
77
|
+
IMAGE: str = "image"
|
|
78
|
+
VIDEO: str = "video"
|
|
79
|
+
ZIP: str = "zip"
|
|
80
|
+
|
|
81
|
+
|
|
63
82
|
class SplitName:
|
|
64
|
-
TRAIN = "train"
|
|
65
|
-
VAL = "validation"
|
|
66
|
-
TEST = "test"
|
|
67
|
-
UNDEFINED = "UNDEFINED"
|
|
83
|
+
TRAIN: str = "train"
|
|
84
|
+
VAL: str = "validation"
|
|
85
|
+
TEST: str = "test"
|
|
86
|
+
UNDEFINED: str = "UNDEFINED"
|
|
68
87
|
|
|
69
88
|
@staticmethod
|
|
70
89
|
def valid_splits() -> List[str]:
|
|
71
90
|
return [SplitName.TRAIN, SplitName.VAL, SplitName.TEST]
|
|
72
91
|
|
|
92
|
+
@staticmethod
|
|
93
|
+
def all_split_names() -> List[str]:
|
|
94
|
+
return [*SplitName.valid_splits(), SplitName.UNDEFINED]
|
|
95
|
+
|
|
73
96
|
|
|
74
97
|
class DatasetVariant(Enum):
|
|
75
98
|
DUMP = "dump"
|
|
76
99
|
SAMPLE = "sample"
|
|
77
100
|
HIDDEN = "hidden"
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class AwsCredentials(BaseModel):
|
|
104
|
+
access_key: str
|
|
105
|
+
secret_key: str
|
|
106
|
+
session_token: str
|
|
107
|
+
region: Optional[str]
|
|
108
|
+
|
|
109
|
+
def aws_credentials(self) -> Dict[str, str]:
|
|
110
|
+
"""
|
|
111
|
+
Returns the AWS credentials as a dictionary.
|
|
112
|
+
"""
|
|
113
|
+
environment_vars = {
|
|
114
|
+
"AWS_ACCESS_KEY_ID": self.access_key,
|
|
115
|
+
"AWS_SECRET_ACCESS_KEY": self.secret_key,
|
|
116
|
+
"AWS_SESSION_TOKEN": self.session_token,
|
|
117
|
+
}
|
|
118
|
+
if self.region:
|
|
119
|
+
environment_vars["AWS_REGION"] = self.region
|
|
120
|
+
|
|
121
|
+
return environment_vars
|
|
122
|
+
|
|
123
|
+
@staticmethod
|
|
124
|
+
def from_session(session: boto3.Session) -> "AwsCredentials":
|
|
125
|
+
"""
|
|
126
|
+
Creates AwsCredentials from a Boto3 session.
|
|
127
|
+
"""
|
|
128
|
+
frozen_credentials = session.get_credentials().get_frozen_credentials()
|
|
129
|
+
return AwsCredentials(
|
|
130
|
+
access_key=frozen_credentials.access_key,
|
|
131
|
+
secret_key=frozen_credentials.secret_key,
|
|
132
|
+
session_token=frozen_credentials.token,
|
|
133
|
+
region=session.region_name,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
ARN_PREFIX = "arn:aws:s3:::"
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class ResourceCredentials(AwsCredentials):
|
|
141
|
+
s3_arn: str
|
|
142
|
+
|
|
143
|
+
@staticmethod
|
|
144
|
+
def fix_naming(payload: Dict[str, str]) -> "ResourceCredentials":
|
|
145
|
+
"""
|
|
146
|
+
The endpoint returns a payload with a key called 's3_path', but it
|
|
147
|
+
is actually an ARN path (starts with arn:aws:s3::). This method renames it to 's3_arn' for consistency.
|
|
148
|
+
"""
|
|
149
|
+
if "s3_path" in payload and payload["s3_path"].startswith(ARN_PREFIX):
|
|
150
|
+
payload["s3_arn"] = payload.pop("s3_path")
|
|
151
|
+
|
|
152
|
+
if "region" not in payload:
|
|
153
|
+
payload["region"] = "eu-west-1"
|
|
154
|
+
return ResourceCredentials(**payload)
|
|
155
|
+
|
|
156
|
+
@field_validator("s3_arn")
|
|
157
|
+
@classmethod
|
|
158
|
+
def validate_s3_arn(cls, value: str) -> str:
|
|
159
|
+
"""Validate s3_arn to ensure it starts with 'arn:aws:s3:::'"""
|
|
160
|
+
if not value.startswith("arn:aws:s3:::"):
|
|
161
|
+
raise ValueError(f"Invalid S3 ARN: {value}. It should start with 'arn:aws:s3:::'")
|
|
162
|
+
return value
|
|
163
|
+
|
|
164
|
+
def s3_path(self) -> str:
|
|
165
|
+
"""
|
|
166
|
+
Extracts the S3 path from the ARN.
|
|
167
|
+
Example: arn:aws:s3:::my-bucket/my-prefix -> my-bucket/my-prefix
|
|
168
|
+
"""
|
|
169
|
+
return self.s3_arn[len(ARN_PREFIX) :]
|
|
170
|
+
|
|
171
|
+
def s3_uri(self) -> str:
|
|
172
|
+
"""
|
|
173
|
+
Converts the S3 ARN to a URI format.
|
|
174
|
+
Example: arn:aws:s3:::my-bucket/my-prefix -> s3://my-bucket/my-prefix
|
|
175
|
+
"""
|
|
176
|
+
return f"s3://{self.s3_path()}"
|
|
177
|
+
|
|
178
|
+
def bucket_name(self) -> str:
|
|
179
|
+
"""
|
|
180
|
+
Extracts the bucket name from the S3 ARN.
|
|
181
|
+
Example: arn:aws:s3:::my-bucket/my-prefix -> my-bucket
|
|
182
|
+
"""
|
|
183
|
+
return self.s3_path().split("/")[0]
|
|
184
|
+
|
|
185
|
+
def object_key(self) -> str:
|
|
186
|
+
"""
|
|
187
|
+
Extracts the object key from the S3 ARN.
|
|
188
|
+
Example: arn:aws:s3:::my-bucket/my-prefix -> my-prefix
|
|
189
|
+
"""
|
|
190
|
+
return "/".join(self.s3_path().split("/")[1:])
|
|
@@ -14,10 +14,10 @@ from pydantic import BaseModel, ConfigDict, field_validator
|
|
|
14
14
|
from cli.config import Config
|
|
15
15
|
from hafnia.dataset import primitives
|
|
16
16
|
from hafnia.dataset.dataset_names import (
|
|
17
|
-
ColumnName,
|
|
18
17
|
DatasetVariant,
|
|
19
18
|
DeploymentStage,
|
|
20
|
-
|
|
19
|
+
PrimitiveField,
|
|
20
|
+
SampleField,
|
|
21
21
|
SplitName,
|
|
22
22
|
)
|
|
23
23
|
from hafnia.dataset.hafnia_dataset import Attribution, HafniaDataset, Sample, TaskInfo
|
|
@@ -193,7 +193,7 @@ class Annotations(BaseModel):
|
|
|
193
193
|
in gallery images on the dataset detail page.
|
|
194
194
|
"""
|
|
195
195
|
|
|
196
|
-
|
|
196
|
+
bboxes: Optional[List[Bbox]] = None
|
|
197
197
|
classifications: Optional[List[Classification]] = None
|
|
198
198
|
polygons: Optional[List[Polygon]] = None
|
|
199
199
|
bitmasks: Optional[List[Bitmask]] = None
|
|
@@ -210,13 +210,15 @@ class DatasetImageMetadata(BaseModel):
|
|
|
210
210
|
@classmethod
|
|
211
211
|
def from_sample(cls, sample: Sample) -> "DatasetImageMetadata":
|
|
212
212
|
sample = sample.model_copy(deep=True)
|
|
213
|
+
if sample.file_path is None:
|
|
214
|
+
raise ValueError("Sample has no file_path defined.")
|
|
213
215
|
sample.file_path = "/".join(Path(sample.file_path).parts[-3:])
|
|
214
216
|
metadata = {}
|
|
215
217
|
metadata_field_names = [
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
218
|
+
SampleField.FILE_PATH,
|
|
219
|
+
SampleField.HEIGHT,
|
|
220
|
+
SampleField.WIDTH,
|
|
221
|
+
SampleField.SPLIT,
|
|
220
222
|
]
|
|
221
223
|
for field_name in metadata_field_names:
|
|
222
224
|
if hasattr(sample, field_name) and getattr(sample, field_name) is not None:
|
|
@@ -224,7 +226,7 @@ class DatasetImageMetadata(BaseModel):
|
|
|
224
226
|
|
|
225
227
|
obj = DatasetImageMetadata(
|
|
226
228
|
annotations=Annotations(
|
|
227
|
-
|
|
229
|
+
bboxes=sample.bboxes,
|
|
228
230
|
classifications=sample.classifications,
|
|
229
231
|
polygons=sample.polygons,
|
|
230
232
|
bitmasks=sample.bitmasks,
|
|
@@ -343,13 +345,13 @@ def calculate_distribution_values(
|
|
|
343
345
|
classifications = dataset_split.select(pl.col(classification_column).explode())
|
|
344
346
|
classifications = classifications.filter(pl.col(classification_column).is_not_null()).unnest(classification_column)
|
|
345
347
|
classifications = classifications.filter(
|
|
346
|
-
pl.col(
|
|
348
|
+
pl.col(PrimitiveField.TASK_NAME).is_in([task.name for task in distribution_tasks])
|
|
347
349
|
)
|
|
348
350
|
dist_values = []
|
|
349
|
-
for (task_name,), task_group in classifications.group_by(
|
|
351
|
+
for (task_name,), task_group in classifications.group_by(PrimitiveField.TASK_NAME):
|
|
350
352
|
distribution_type = DbDistributionType(name=task_name)
|
|
351
353
|
n_annotated_total = len(task_group)
|
|
352
|
-
for (class_name,), class_group in task_group.group_by(
|
|
354
|
+
for (class_name,), class_group in task_group.group_by(PrimitiveField.CLASS_NAME):
|
|
353
355
|
class_count = len(class_group)
|
|
354
356
|
|
|
355
357
|
dist_values.append(
|
|
@@ -383,6 +385,7 @@ def dataset_info_from_dataset(
|
|
|
383
385
|
path_hidden: Optional[Path],
|
|
384
386
|
path_gallery_images: Optional[Path] = None,
|
|
385
387
|
gallery_image_names: Optional[List[str]] = None,
|
|
388
|
+
distribution_task_names: Optional[List[TaskInfo]] = None,
|
|
386
389
|
) -> DbDataset:
|
|
387
390
|
dataset_variants = []
|
|
388
391
|
dataset_reports = []
|
|
@@ -427,13 +430,15 @@ def dataset_info_from_dataset(
|
|
|
427
430
|
)
|
|
428
431
|
)
|
|
429
432
|
|
|
433
|
+
distribution_task_names = distribution_task_names or []
|
|
434
|
+
distribution_tasks = [t for t in dataset.info.tasks if t.name in distribution_task_names]
|
|
430
435
|
for split_name in SplitChoices:
|
|
431
436
|
split_names = SPLIT_CHOICE_MAPPING[split_name]
|
|
432
|
-
dataset_split = dataset_variant.samples.filter(pl.col(
|
|
437
|
+
dataset_split = dataset_variant.samples.filter(pl.col(SampleField.SPLIT).is_in(split_names))
|
|
433
438
|
|
|
434
439
|
distribution_values = calculate_distribution_values(
|
|
435
440
|
dataset_split=dataset_split,
|
|
436
|
-
distribution_tasks=
|
|
441
|
+
distribution_tasks=distribution_tasks,
|
|
437
442
|
)
|
|
438
443
|
report = DbSplitAnnotationsReport(
|
|
439
444
|
variant_type=VARIANT_TYPE_MAPPING[variant_type], # type: ignore[index]
|
|
@@ -461,7 +466,7 @@ def dataset_info_from_dataset(
|
|
|
461
466
|
|
|
462
467
|
annotation_type = DbAnnotationType(name=AnnotationType.ObjectDetection.value)
|
|
463
468
|
for (class_name, task_name), class_group in df_per_instance.group_by(
|
|
464
|
-
|
|
469
|
+
PrimitiveField.CLASS_NAME, PrimitiveField.TASK_NAME
|
|
465
470
|
):
|
|
466
471
|
if class_name is None:
|
|
467
472
|
continue
|
|
@@ -473,10 +478,10 @@ def dataset_info_from_dataset(
|
|
|
473
478
|
annotation_type=annotation_type,
|
|
474
479
|
task_name=task_name,
|
|
475
480
|
),
|
|
476
|
-
unique_obj_ids=class_group[
|
|
481
|
+
unique_obj_ids=class_group[PrimitiveField.OBJECT_ID].n_unique(),
|
|
477
482
|
obj_instances=len(class_group),
|
|
478
483
|
annotation_type=[annotation_type],
|
|
479
|
-
images_with_obj=class_group[
|
|
484
|
+
images_with_obj=class_group[SampleField.SAMPLE_INDEX].n_unique(),
|
|
480
485
|
area_avg_ratio=class_group["area"].mean(),
|
|
481
486
|
area_min_ratio=class_group["area"].min(),
|
|
482
487
|
area_max_ratio=class_group["area"].max(),
|
|
@@ -495,7 +500,7 @@ def dataset_info_from_dataset(
|
|
|
495
500
|
width_avg_px=class_group["width_px"].mean(),
|
|
496
501
|
width_min_px=int(class_group["width_px"].min()),
|
|
497
502
|
width_max_px=int(class_group["width_px"].max()),
|
|
498
|
-
average_count_per_image=len(class_group) / class_group[
|
|
503
|
+
average_count_per_image=len(class_group) / class_group[SampleField.SAMPLE_INDEX].n_unique(),
|
|
499
504
|
)
|
|
500
505
|
)
|
|
501
506
|
|
|
@@ -509,13 +514,13 @@ def dataset_info_from_dataset(
|
|
|
509
514
|
|
|
510
515
|
# Include only classification tasks that are defined in the dataset info
|
|
511
516
|
classification_df = classification_df.filter(
|
|
512
|
-
pl.col(
|
|
517
|
+
pl.col(PrimitiveField.TASK_NAME).is_in(classification_tasks)
|
|
513
518
|
)
|
|
514
519
|
|
|
515
520
|
for (
|
|
516
521
|
task_name,
|
|
517
522
|
class_name,
|
|
518
|
-
), class_group in classification_df.group_by(
|
|
523
|
+
), class_group in classification_df.group_by(PrimitiveField.TASK_NAME, PrimitiveField.CLASS_NAME):
|
|
519
524
|
if class_name is None:
|
|
520
525
|
continue
|
|
521
526
|
if task_name == Classification.default_task_name():
|
|
@@ -544,7 +549,7 @@ def dataset_info_from_dataset(
|
|
|
544
549
|
if has_primitive(dataset_split, PrimitiveType=Bitmask):
|
|
545
550
|
col_name = Bitmask.column_name()
|
|
546
551
|
drop_columns = [col for col in primitive_columns if col != col_name]
|
|
547
|
-
drop_columns.append(
|
|
552
|
+
drop_columns.append(PrimitiveField.META)
|
|
548
553
|
|
|
549
554
|
df_per_instance = table_transformations.create_primitive_table(
|
|
550
555
|
dataset_split, PrimitiveType=Bitmask, keep_sample_data=True
|
|
@@ -562,7 +567,7 @@ def dataset_info_from_dataset(
|
|
|
562
567
|
|
|
563
568
|
annotation_type = DbAnnotationType(name=AnnotationType.InstanceSegmentation)
|
|
564
569
|
for (class_name, task_name), class_group in df_per_instance.group_by(
|
|
565
|
-
|
|
570
|
+
PrimitiveField.CLASS_NAME, PrimitiveField.TASK_NAME
|
|
566
571
|
):
|
|
567
572
|
if class_name is None:
|
|
568
573
|
continue
|
|
@@ -574,11 +579,11 @@ def dataset_info_from_dataset(
|
|
|
574
579
|
annotation_type=annotation_type,
|
|
575
580
|
task_name=task_name,
|
|
576
581
|
),
|
|
577
|
-
unique_obj_ids=class_group[
|
|
582
|
+
unique_obj_ids=class_group[PrimitiveField.OBJECT_ID].n_unique(),
|
|
578
583
|
obj_instances=len(class_group),
|
|
579
584
|
annotation_type=[annotation_type],
|
|
580
|
-
average_count_per_image=len(class_group) / class_group[
|
|
581
|
-
images_with_obj=class_group[
|
|
585
|
+
average_count_per_image=len(class_group) / class_group[SampleField.SAMPLE_INDEX].n_unique(),
|
|
586
|
+
images_with_obj=class_group[SampleField.SAMPLE_INDEX].n_unique(),
|
|
582
587
|
area_avg_ratio=class_group["area"].mean(),
|
|
583
588
|
area_min_ratio=class_group["area"].min(),
|
|
584
589
|
area_max_ratio=class_group["area"].max(),
|
|
@@ -646,7 +651,7 @@ def create_gallery_images(
|
|
|
646
651
|
path_gallery_images.mkdir(parents=True, exist_ok=True)
|
|
647
652
|
COL_IMAGE_NAME = "image_name"
|
|
648
653
|
samples = dataset.samples.with_columns(
|
|
649
|
-
dataset.samples[
|
|
654
|
+
dataset.samples[SampleField.FILE_PATH].str.split("/").list.last().alias(COL_IMAGE_NAME)
|
|
650
655
|
)
|
|
651
656
|
gallery_samples = samples.filter(pl.col(COL_IMAGE_NAME).is_in(gallery_image_names))
|
|
652
657
|
|
|
@@ -1,23 +1,27 @@
|
|
|
1
1
|
import shutil
|
|
2
2
|
from pathlib import Path
|
|
3
|
-
from typing import List, Optional
|
|
3
|
+
from typing import TYPE_CHECKING, List, Optional
|
|
4
4
|
|
|
5
5
|
import more_itertools
|
|
6
6
|
import polars as pl
|
|
7
7
|
from PIL import Image
|
|
8
8
|
from rich.progress import track
|
|
9
9
|
|
|
10
|
-
from hafnia.dataset.dataset_names import
|
|
11
|
-
from hafnia.dataset.hafnia_dataset import DatasetInfo, HafniaDataset, Sample, TaskInfo
|
|
10
|
+
from hafnia.dataset.dataset_names import PrimitiveField, SampleField
|
|
12
11
|
from hafnia.dataset.primitives import Classification
|
|
13
12
|
from hafnia.utils import is_image_file
|
|
14
13
|
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from hafnia.dataset.hafnia_dataset import HafniaDataset
|
|
15
16
|
|
|
16
|
-
|
|
17
|
+
|
|
18
|
+
def from_image_classification_folder(
|
|
17
19
|
path_folder: Path,
|
|
18
20
|
split: str,
|
|
19
21
|
n_samples: Optional[int] = None,
|
|
20
|
-
) -> HafniaDataset:
|
|
22
|
+
) -> "HafniaDataset":
|
|
23
|
+
from hafnia.dataset.hafnia_dataset import DatasetInfo, HafniaDataset, Sample, TaskInfo
|
|
24
|
+
|
|
21
25
|
class_folder_paths = [path for path in path_folder.iterdir() if path.is_dir()]
|
|
22
26
|
class_names = sorted([folder.name for folder in class_folder_paths]) # Sort for determinism
|
|
23
27
|
|
|
@@ -62,8 +66,8 @@ def import_image_classification_directory_tree(
|
|
|
62
66
|
return hafnia_dataset
|
|
63
67
|
|
|
64
68
|
|
|
65
|
-
def
|
|
66
|
-
dataset: HafniaDataset,
|
|
69
|
+
def to_image_classification_folder(
|
|
70
|
+
dataset: "HafniaDataset",
|
|
67
71
|
path_output: Path,
|
|
68
72
|
task_name: Optional[str] = None,
|
|
69
73
|
clean_folder: bool = False,
|
|
@@ -72,7 +76,7 @@ def export_image_classification_directory_tree(
|
|
|
72
76
|
|
|
73
77
|
samples = dataset.samples.with_columns(
|
|
74
78
|
pl.col(task.primitive.column_name())
|
|
75
|
-
.list.filter(pl.element().struct.field(
|
|
79
|
+
.list.filter(pl.element().struct.field(PrimitiveField.TASK_NAME) == task.name)
|
|
76
80
|
.alias(task.primitive.column_name())
|
|
77
81
|
)
|
|
78
82
|
|
|
@@ -95,11 +99,11 @@ def export_image_classification_directory_tree(
|
|
|
95
99
|
if len(classifications) != 1:
|
|
96
100
|
raise ValueError("Each sample should have exactly one classification.")
|
|
97
101
|
classification = classifications[0]
|
|
98
|
-
class_name = classification[
|
|
102
|
+
class_name = classification[PrimitiveField.CLASS_NAME].replace("/", "_") # Avoid issues with subfolders
|
|
99
103
|
path_class_folder = path_output / class_name
|
|
100
104
|
path_class_folder.mkdir(parents=True, exist_ok=True)
|
|
101
105
|
|
|
102
|
-
path_image_org = Path(sample_dict[
|
|
106
|
+
path_image_org = Path(sample_dict[SampleField.FILE_PATH])
|
|
103
107
|
path_image_new = path_class_folder / path_image_org.name
|
|
104
108
|
shutil.copy2(path_image_org, path_image_new)
|
|
105
109
|
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
import shutil
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import TYPE_CHECKING, List, Optional
|
|
4
|
+
|
|
5
|
+
from PIL import Image
|
|
6
|
+
from rich.progress import track
|
|
7
|
+
|
|
8
|
+
from hafnia.dataset import primitives
|
|
9
|
+
from hafnia.dataset.dataset_names import SplitName
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from hafnia.dataset.hafnia_dataset import HafniaDataset
|
|
13
|
+
|
|
14
|
+
FILENAME_YOLO_CLASS_NAMES = "obj.names"
|
|
15
|
+
FILENAME_YOLO_IMAGES_TXT = "images.txt"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def get_image_size(path: Path) -> tuple[int, int]:
|
|
19
|
+
with Image.open(path) as img:
|
|
20
|
+
return img.size # (width, height)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def from_yolo_format(
|
|
24
|
+
path_yolo_dataset: Path,
|
|
25
|
+
split_name: str = SplitName.UNDEFINED,
|
|
26
|
+
dataset_name: str = "yolo-dataset",
|
|
27
|
+
filename_class_names: str = FILENAME_YOLO_CLASS_NAMES,
|
|
28
|
+
filename_images_txt: str = FILENAME_YOLO_IMAGES_TXT,
|
|
29
|
+
) -> "HafniaDataset":
|
|
30
|
+
"""
|
|
31
|
+
Imports a YOLO (Darknet) formatted dataset as a HafniaDataset.
|
|
32
|
+
"""
|
|
33
|
+
from hafnia.dataset.hafnia_dataset import DatasetInfo, HafniaDataset, Sample, TaskInfo
|
|
34
|
+
|
|
35
|
+
path_class_names = path_yolo_dataset / filename_class_names
|
|
36
|
+
|
|
37
|
+
if split_name not in SplitName.all_split_names():
|
|
38
|
+
raise ValueError(f"Invalid split name: {split_name}. Must be one of {SplitName.all_split_names()}")
|
|
39
|
+
|
|
40
|
+
if not path_class_names.exists():
|
|
41
|
+
raise FileNotFoundError(f"File with class names not found at '{path_class_names.resolve()}'.")
|
|
42
|
+
|
|
43
|
+
class_names_text = path_class_names.read_text()
|
|
44
|
+
if class_names_text.strip() == "":
|
|
45
|
+
raise ValueError(f"File with class names not found at '{path_class_names.resolve()}' is empty")
|
|
46
|
+
|
|
47
|
+
class_names = [class_name for class_name in class_names_text.splitlines() if class_name.strip() != ""]
|
|
48
|
+
|
|
49
|
+
if len(class_names) == 0:
|
|
50
|
+
raise ValueError(f"File with class names not found at '{path_class_names.resolve()}' has no class names")
|
|
51
|
+
|
|
52
|
+
path_images_txt = path_yolo_dataset / filename_images_txt
|
|
53
|
+
|
|
54
|
+
if not path_images_txt.exists():
|
|
55
|
+
raise FileNotFoundError(f"File with images not found at '{path_images_txt.resolve()}'")
|
|
56
|
+
|
|
57
|
+
images_txt_text = path_images_txt.read_text()
|
|
58
|
+
if len(images_txt_text.strip()) == 0:
|
|
59
|
+
raise ValueError(f"File is empty at '{path_images_txt.resolve()}'")
|
|
60
|
+
|
|
61
|
+
image_paths_raw = [line.strip() for line in images_txt_text.splitlines()]
|
|
62
|
+
|
|
63
|
+
samples: List[Sample] = []
|
|
64
|
+
for image_path_raw in track(image_paths_raw):
|
|
65
|
+
path_image = path_yolo_dataset / image_path_raw
|
|
66
|
+
if not path_image.exists():
|
|
67
|
+
raise FileNotFoundError(f"File with image not found at '{path_image.resolve()}'")
|
|
68
|
+
width, height = get_image_size(path_image)
|
|
69
|
+
|
|
70
|
+
path_label = path_image.with_suffix(".txt")
|
|
71
|
+
if not path_label.exists():
|
|
72
|
+
raise FileNotFoundError(f"File with labels not found at '{path_label.resolve()}'")
|
|
73
|
+
|
|
74
|
+
boxes: List[primitives.Bbox] = []
|
|
75
|
+
bbox_strings = path_label.read_text().splitlines()
|
|
76
|
+
for bbox_string in bbox_strings:
|
|
77
|
+
parts = bbox_string.strip().split()
|
|
78
|
+
if len(parts) != 5:
|
|
79
|
+
raise ValueError(f"Invalid bbox format in file {path_label.resolve()}: {bbox_string}")
|
|
80
|
+
|
|
81
|
+
class_idx = int(parts[0])
|
|
82
|
+
x_center, y_center, bbox_width, bbox_height = (float(value) for value in parts[1:5])
|
|
83
|
+
|
|
84
|
+
top_left_x = x_center - bbox_width / 2
|
|
85
|
+
top_left_y = y_center - bbox_height / 2
|
|
86
|
+
|
|
87
|
+
bbox = primitives.Bbox(
|
|
88
|
+
top_left_x=top_left_x,
|
|
89
|
+
top_left_y=top_left_y,
|
|
90
|
+
width=bbox_width,
|
|
91
|
+
height=bbox_height,
|
|
92
|
+
class_idx=class_idx,
|
|
93
|
+
class_name=class_names[class_idx] if 0 <= class_idx < len(class_names) else None,
|
|
94
|
+
)
|
|
95
|
+
boxes.append(bbox)
|
|
96
|
+
|
|
97
|
+
sample = Sample(
|
|
98
|
+
file_path=path_image.absolute().as_posix(),
|
|
99
|
+
height=height,
|
|
100
|
+
width=width,
|
|
101
|
+
split=split_name,
|
|
102
|
+
bboxes=boxes,
|
|
103
|
+
)
|
|
104
|
+
samples.append(sample)
|
|
105
|
+
|
|
106
|
+
tasks = [TaskInfo(primitive=primitives.Bbox, class_names=class_names)]
|
|
107
|
+
info = DatasetInfo(dataset_name=dataset_name, tasks=tasks)
|
|
108
|
+
hafnia_dataset = HafniaDataset.from_samples_list(samples, info=info)
|
|
109
|
+
return hafnia_dataset
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def to_yolo_format(
|
|
113
|
+
dataset: "HafniaDataset",
|
|
114
|
+
path_export_yolo_dataset: Path,
|
|
115
|
+
task_name: Optional[str] = None,
|
|
116
|
+
):
|
|
117
|
+
"""Exports a HafniaDataset as YOLO (Darknet) format."""
|
|
118
|
+
from hafnia.dataset.hafnia_dataset import Sample
|
|
119
|
+
|
|
120
|
+
bbox_task = dataset.info.get_task_by_task_name_and_primitive(task_name=task_name, primitive=primitives.Bbox)
|
|
121
|
+
|
|
122
|
+
class_names = bbox_task.class_names or []
|
|
123
|
+
if len(class_names) == 0:
|
|
124
|
+
raise ValueError(
|
|
125
|
+
f"Hafnia dataset task '{bbox_task.name}' has no class names defined. This is required for YOLO export."
|
|
126
|
+
)
|
|
127
|
+
path_export_yolo_dataset.mkdir(parents=True, exist_ok=True)
|
|
128
|
+
path_class_names = path_export_yolo_dataset / FILENAME_YOLO_CLASS_NAMES
|
|
129
|
+
path_class_names.write_text("\n".join(class_names))
|
|
130
|
+
|
|
131
|
+
path_data_folder = path_export_yolo_dataset / "data"
|
|
132
|
+
path_data_folder.mkdir(parents=True, exist_ok=True)
|
|
133
|
+
image_paths: List[str] = []
|
|
134
|
+
for sample_dict in dataset:
|
|
135
|
+
sample = Sample(**sample_dict)
|
|
136
|
+
if sample.file_path is None:
|
|
137
|
+
raise ValueError("Sample has no file_path defined.")
|
|
138
|
+
path_image_src = Path(sample.file_path)
|
|
139
|
+
path_image_dst = path_data_folder / path_image_src.name
|
|
140
|
+
shutil.copy2(path_image_src, path_image_dst)
|
|
141
|
+
image_paths.append(path_image_dst.relative_to(path_export_yolo_dataset).as_posix())
|
|
142
|
+
path_label = path_image_dst.with_suffix(".txt")
|
|
143
|
+
bboxes = sample.bboxes or []
|
|
144
|
+
bbox_strings = [bbox_to_yolo_format(bbox) for bbox in bboxes]
|
|
145
|
+
path_label.write_text("\n".join(bbox_strings))
|
|
146
|
+
|
|
147
|
+
path_images_txt = path_export_yolo_dataset / FILENAME_YOLO_IMAGES_TXT
|
|
148
|
+
path_images_txt.write_text("\n".join(image_paths))
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def bbox_to_yolo_format(bbox: primitives.Bbox) -> str:
|
|
152
|
+
"""
|
|
153
|
+
From hafnia bbox to yolo bbox string conversion
|
|
154
|
+
Both yolo and hafnia use normalized coordinates [0, 1]
|
|
155
|
+
Hafnia: top_left_x, top_left_y, width, height
|
|
156
|
+
Yolo (darknet): "<object-class> <x_center> <y_center> <width> <height>"
|
|
157
|
+
Example (3 bounding boxes):
|
|
158
|
+
1 0.716797 0.395833 0.216406 0.147222
|
|
159
|
+
0 0.687109 0.379167 0.255469 0.158333
|
|
160
|
+
1 0.420312 0.395833 0.140625 0.166667
|
|
161
|
+
"""
|
|
162
|
+
x_center = bbox.top_left_x + bbox.width / 2
|
|
163
|
+
y_center = bbox.top_left_y + bbox.height / 2
|
|
164
|
+
return f"{bbox.class_idx} {x_center} {y_center} {bbox.width} {bbox.height}"
|
|
@@ -14,8 +14,8 @@ from torchvision.datasets.utils import download_and_extract_archive, extract_arc
|
|
|
14
14
|
from hafnia import utils
|
|
15
15
|
from hafnia.dataset.dataset_helpers import save_pil_image_with_hash_name
|
|
16
16
|
from hafnia.dataset.dataset_names import SplitName
|
|
17
|
-
from hafnia.dataset.format_conversions.
|
|
18
|
-
|
|
17
|
+
from hafnia.dataset.format_conversions.format_image_classification_folder import (
|
|
18
|
+
from_image_classification_folder,
|
|
19
19
|
)
|
|
20
20
|
from hafnia.dataset.hafnia_dataset import DatasetInfo, HafniaDataset, Sample, TaskInfo
|
|
21
21
|
from hafnia.dataset.primitives import Classification
|
|
@@ -72,7 +72,7 @@ def caltech_101_as_hafnia_dataset(
|
|
|
72
72
|
path_image_classification_folder = _download_and_extract_caltech_dataset(
|
|
73
73
|
dataset_name, force_redownload=force_redownload
|
|
74
74
|
)
|
|
75
|
-
hafnia_dataset =
|
|
75
|
+
hafnia_dataset = from_image_classification_folder(
|
|
76
76
|
path_image_classification_folder,
|
|
77
77
|
split=SplitName.TRAIN,
|
|
78
78
|
n_samples=n_samples,
|
|
@@ -102,7 +102,7 @@ def caltech_256_as_hafnia_dataset(
|
|
|
102
102
|
path_image_classification_folder = _download_and_extract_caltech_dataset(
|
|
103
103
|
dataset_name, force_redownload=force_redownload
|
|
104
104
|
)
|
|
105
|
-
hafnia_dataset =
|
|
105
|
+
hafnia_dataset = from_image_classification_folder(
|
|
106
106
|
path_image_classification_folder,
|
|
107
107
|
split=SplitName.TRAIN,
|
|
108
108
|
n_samples=n_samples,
|
|
@@ -122,6 +122,12 @@ def caltech_256_as_hafnia_dataset(
|
|
|
122
122
|
}""")
|
|
123
123
|
hafnia_dataset.info.reference_dataset_page = "https://data.caltech.edu/records/nyy15-4j048"
|
|
124
124
|
|
|
125
|
+
task = hafnia_dataset.info.get_task_by_primitive(Classification)
|
|
126
|
+
|
|
127
|
+
# Class Mapping: To remove numeric prefixes from class names
|
|
128
|
+
# E.g. "001.ak47 --> ak47", "002.american-flag --> american-flag", ...
|
|
129
|
+
class_mapping = {name: name.split(".")[-1] for name in task.class_names or []}
|
|
130
|
+
hafnia_dataset = hafnia_dataset.class_mapper(class_mapping=class_mapping, task_name=task.name)
|
|
125
131
|
return hafnia_dataset
|
|
126
132
|
|
|
127
133
|
|