hafnia 0.4.2__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hafnia/dataset/{dataset_upload_helper.py → dataset_details_uploader.py} +114 -191
- hafnia/dataset/dataset_names.py +26 -0
- hafnia/dataset/format_conversions/format_coco.py +490 -0
- hafnia/dataset/format_conversions/format_helpers.py +33 -0
- hafnia/dataset/format_conversions/format_image_classification_folder.py +95 -14
- hafnia/dataset/format_conversions/format_yolo.py +115 -25
- hafnia/dataset/format_conversions/torchvision_datasets.py +10 -8
- hafnia/dataset/hafnia_dataset.py +20 -466
- hafnia/dataset/hafnia_dataset_types.py +477 -0
- hafnia/dataset/license_types.py +4 -4
- hafnia/dataset/operations/dataset_stats.py +3 -3
- hafnia/dataset/operations/dataset_transformations.py +14 -17
- hafnia/dataset/operations/table_transformations.py +20 -13
- hafnia/dataset/primitives/bbox.py +6 -2
- hafnia/dataset/primitives/bitmask.py +21 -46
- hafnia/dataset/primitives/classification.py +1 -1
- hafnia/dataset/primitives/polygon.py +43 -2
- hafnia/dataset/primitives/primitive.py +1 -1
- hafnia/dataset/primitives/segmentation.py +1 -1
- hafnia/experiment/hafnia_logger.py +13 -4
- hafnia/platform/datasets.py +2 -3
- hafnia/torch_helpers.py +48 -4
- hafnia/utils.py +34 -0
- hafnia/visualizations/image_visualizations.py +3 -1
- {hafnia-0.4.2.dist-info → hafnia-0.4.3.dist-info}/METADATA +2 -2
- {hafnia-0.4.2.dist-info → hafnia-0.4.3.dist-info}/RECORD +29 -26
- {hafnia-0.4.2.dist-info → hafnia-0.4.3.dist-info}/WHEEL +0 -0
- {hafnia-0.4.2.dist-info → hafnia-0.4.3.dist-info}/entry_points.txt +0 -0
- {hafnia-0.4.2.dist-info → hafnia-0.4.3.dist-info}/licenses/LICENSE +0 -0
hafnia/dataset/hafnia_dataset.py
CHANGED
|
@@ -1,25 +1,15 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import collections
|
|
4
3
|
import copy
|
|
5
|
-
import json
|
|
6
4
|
import shutil
|
|
7
5
|
from dataclasses import dataclass
|
|
8
|
-
from datetime import datetime
|
|
9
6
|
from pathlib import Path
|
|
10
7
|
from random import Random
|
|
11
8
|
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
|
12
9
|
|
|
13
|
-
import cv2
|
|
14
|
-
import more_itertools
|
|
15
|
-
import numpy as np
|
|
16
10
|
import polars as pl
|
|
17
11
|
from packaging.version import Version
|
|
18
|
-
from PIL import Image
|
|
19
|
-
from pydantic import BaseModel, Field, field_serializer, field_validator
|
|
20
|
-
from rich.progress import track
|
|
21
12
|
|
|
22
|
-
import hafnia
|
|
23
13
|
from hafnia.dataset import dataset_helpers
|
|
24
14
|
from hafnia.dataset.dataset_names import (
|
|
25
15
|
DATASET_FILENAMES_REQUIRED,
|
|
@@ -35,470 +25,19 @@ from hafnia.dataset.dataset_names import (
|
|
|
35
25
|
StorageFormat,
|
|
36
26
|
)
|
|
37
27
|
from hafnia.dataset.format_conversions import (
|
|
28
|
+
format_coco,
|
|
38
29
|
format_image_classification_folder,
|
|
39
30
|
format_yolo,
|
|
40
31
|
)
|
|
32
|
+
from hafnia.dataset.hafnia_dataset_types import DatasetInfo, Sample
|
|
41
33
|
from hafnia.dataset.operations import (
|
|
42
34
|
dataset_stats,
|
|
43
35
|
dataset_transformations,
|
|
44
36
|
table_transformations,
|
|
45
37
|
)
|
|
46
|
-
from hafnia.dataset.primitives import PRIMITIVE_TYPES, get_primitive_type_from_string
|
|
47
|
-
from hafnia.dataset.primitives.bbox import Bbox
|
|
48
|
-
from hafnia.dataset.primitives.bitmask import Bitmask
|
|
49
|
-
from hafnia.dataset.primitives.classification import Classification
|
|
50
|
-
from hafnia.dataset.primitives.polygon import Polygon
|
|
51
38
|
from hafnia.dataset.primitives.primitive import Primitive
|
|
52
39
|
from hafnia.log import user_logger
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
class TaskInfo(BaseModel):
|
|
56
|
-
primitive: Type[Primitive] = Field(
|
|
57
|
-
description="Primitive class or string name of the primitive, e.g. 'Bbox' or 'bitmask'"
|
|
58
|
-
)
|
|
59
|
-
class_names: Optional[List[str]] = Field(default=None, description="Optional list of class names for the primitive")
|
|
60
|
-
name: Optional[str] = Field(
|
|
61
|
-
default=None,
|
|
62
|
-
description=(
|
|
63
|
-
"Optional name for the task. 'None' will use default name of the provided primitive. "
|
|
64
|
-
"e.g. Bbox ->'bboxes', Bitmask -> 'bitmasks' etc."
|
|
65
|
-
),
|
|
66
|
-
)
|
|
67
|
-
|
|
68
|
-
def model_post_init(self, __context: Any) -> None:
|
|
69
|
-
if self.name is None:
|
|
70
|
-
self.name = self.primitive.default_task_name()
|
|
71
|
-
|
|
72
|
-
def get_class_index(self, class_name: str) -> int:
|
|
73
|
-
"""Get class index for a given class name"""
|
|
74
|
-
if self.class_names is None:
|
|
75
|
-
raise ValueError(f"Task '{self.name}' has no class names defined.")
|
|
76
|
-
if class_name not in self.class_names:
|
|
77
|
-
raise ValueError(f"Class name '{class_name}' not found in task '{self.name}'.")
|
|
78
|
-
return self.class_names.index(class_name)
|
|
79
|
-
|
|
80
|
-
# The 'primitive'-field of type 'Type[Primitive]' is not supported by pydantic out-of-the-box as
|
|
81
|
-
# the 'Primitive' class is an abstract base class and for the actual primtives such as Bbox, Bitmask, Classification.
|
|
82
|
-
# Below magic functions ('ensure_primitive' and 'serialize_primitive') ensures that the 'primitive' field can
|
|
83
|
-
# correctly validate and serialize sub-classes (Bbox, Classification, ...).
|
|
84
|
-
@field_validator("primitive", mode="plain")
|
|
85
|
-
@classmethod
|
|
86
|
-
def ensure_primitive(cls, primitive: Any) -> Any:
|
|
87
|
-
if isinstance(primitive, str):
|
|
88
|
-
return get_primitive_type_from_string(primitive)
|
|
89
|
-
|
|
90
|
-
if issubclass(primitive, Primitive):
|
|
91
|
-
return primitive
|
|
92
|
-
|
|
93
|
-
raise ValueError(f"Primitive must be a string or a Primitive subclass, got {type(primitive)} instead.")
|
|
94
|
-
|
|
95
|
-
@field_serializer("primitive")
|
|
96
|
-
@classmethod
|
|
97
|
-
def serialize_primitive(cls, primitive: Type[Primitive]) -> str:
|
|
98
|
-
if not issubclass(primitive, Primitive):
|
|
99
|
-
raise ValueError(f"Primitive must be a subclass of Primitive, got {type(primitive)} instead.")
|
|
100
|
-
return primitive.__name__
|
|
101
|
-
|
|
102
|
-
@field_validator("class_names", mode="after")
|
|
103
|
-
@classmethod
|
|
104
|
-
def validate_unique_class_names(cls, class_names: Optional[List[str]]) -> Optional[List[str]]:
|
|
105
|
-
"""Validate that class names are unique"""
|
|
106
|
-
if class_names is None:
|
|
107
|
-
return None
|
|
108
|
-
duplicate_class_names = set([name for name in class_names if class_names.count(name) > 1])
|
|
109
|
-
if duplicate_class_names:
|
|
110
|
-
raise ValueError(
|
|
111
|
-
f"Class names must be unique. The following class names appear multiple times: {duplicate_class_names}."
|
|
112
|
-
)
|
|
113
|
-
return class_names
|
|
114
|
-
|
|
115
|
-
def full_name(self) -> str:
|
|
116
|
-
"""Get qualified name for the task: <primitive_name>:<task_name>"""
|
|
117
|
-
return f"{self.primitive.__name__}:{self.name}"
|
|
118
|
-
|
|
119
|
-
# To get unique hash value for TaskInfo objects
|
|
120
|
-
def __hash__(self) -> int:
|
|
121
|
-
class_names = self.class_names or []
|
|
122
|
-
return hash((self.name, self.primitive.__name__, tuple(class_names)))
|
|
123
|
-
|
|
124
|
-
def __eq__(self, other: Any) -> bool:
|
|
125
|
-
if not isinstance(other, TaskInfo):
|
|
126
|
-
return False
|
|
127
|
-
return self.name == other.name and self.primitive == other.primitive and self.class_names == other.class_names
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
class DatasetInfo(BaseModel):
|
|
131
|
-
dataset_name: str = Field(description="Name of the dataset, e.g. 'coco'")
|
|
132
|
-
version: Optional[str] = Field(default=None, description="Version of the dataset")
|
|
133
|
-
tasks: List[TaskInfo] = Field(default=None, description="List of tasks in the dataset")
|
|
134
|
-
reference_bibtex: Optional[str] = Field(
|
|
135
|
-
default=None,
|
|
136
|
-
description="Optional, BibTeX reference to dataset publication",
|
|
137
|
-
)
|
|
138
|
-
reference_paper_url: Optional[str] = Field(
|
|
139
|
-
default=None,
|
|
140
|
-
description="Optional, URL to dataset publication",
|
|
141
|
-
)
|
|
142
|
-
reference_dataset_page: Optional[str] = Field(
|
|
143
|
-
default=None,
|
|
144
|
-
description="Optional, URL to the dataset page",
|
|
145
|
-
)
|
|
146
|
-
meta: Optional[Dict[str, Any]] = Field(default=None, description="Optional metadata about the dataset")
|
|
147
|
-
format_version: str = Field(
|
|
148
|
-
default=hafnia.__dataset_format_version__,
|
|
149
|
-
description="Version of the Hafnia dataset format. You should not set this manually.",
|
|
150
|
-
)
|
|
151
|
-
updated_at: datetime = Field(
|
|
152
|
-
default_factory=datetime.now,
|
|
153
|
-
description="Timestamp of the last update to the dataset info. You should not set this manually.",
|
|
154
|
-
)
|
|
155
|
-
|
|
156
|
-
@field_validator("tasks", mode="after")
|
|
157
|
-
@classmethod
|
|
158
|
-
def _validate_check_for_duplicate_tasks(cls, tasks: Optional[List[TaskInfo]]) -> List[TaskInfo]:
|
|
159
|
-
if tasks is None:
|
|
160
|
-
return []
|
|
161
|
-
task_name_counts = collections.Counter(task.name for task in tasks)
|
|
162
|
-
duplicate_task_names = [name for name, count in task_name_counts.items() if count > 1]
|
|
163
|
-
if duplicate_task_names:
|
|
164
|
-
raise ValueError(
|
|
165
|
-
f"Tasks must be unique. The following tasks appear multiple times: {duplicate_task_names}."
|
|
166
|
-
)
|
|
167
|
-
return tasks
|
|
168
|
-
|
|
169
|
-
@field_validator("format_version")
|
|
170
|
-
@classmethod
|
|
171
|
-
def _validate_format_version(cls, format_version: str) -> str:
|
|
172
|
-
try:
|
|
173
|
-
Version(format_version)
|
|
174
|
-
except Exception as e:
|
|
175
|
-
raise ValueError(f"Invalid format_version '{format_version}'. Must be a valid version string.") from e
|
|
176
|
-
|
|
177
|
-
if Version(format_version) > Version(hafnia.__dataset_format_version__):
|
|
178
|
-
user_logger.warning(
|
|
179
|
-
f"The loaded dataset format version '{format_version}' is newer than the format version "
|
|
180
|
-
f"'{hafnia.__dataset_format_version__}' used in your version of Hafnia. Please consider "
|
|
181
|
-
f"updating Hafnia package."
|
|
182
|
-
)
|
|
183
|
-
return format_version
|
|
184
|
-
|
|
185
|
-
@field_validator("version")
|
|
186
|
-
@classmethod
|
|
187
|
-
def _validate_version(cls, dataset_version: Optional[str]) -> Optional[str]:
|
|
188
|
-
if dataset_version is None:
|
|
189
|
-
return None
|
|
190
|
-
|
|
191
|
-
try:
|
|
192
|
-
Version(dataset_version)
|
|
193
|
-
except Exception as e:
|
|
194
|
-
raise ValueError(f"Invalid dataset_version '{dataset_version}'. Must be a valid version string.") from e
|
|
195
|
-
|
|
196
|
-
return dataset_version
|
|
197
|
-
|
|
198
|
-
def check_for_duplicate_task_names(self) -> List[TaskInfo]:
|
|
199
|
-
return self._validate_check_for_duplicate_tasks(self.tasks)
|
|
200
|
-
|
|
201
|
-
def write_json(self, path: Path, indent: Optional[int] = 4) -> None:
|
|
202
|
-
json_str = self.model_dump_json(indent=indent)
|
|
203
|
-
path.write_text(json_str)
|
|
204
|
-
|
|
205
|
-
@staticmethod
|
|
206
|
-
def from_json_file(path: Path) -> DatasetInfo:
|
|
207
|
-
json_str = path.read_text()
|
|
208
|
-
|
|
209
|
-
# TODO: Deprecated support for old dataset info without format_version
|
|
210
|
-
# Below 4 lines can be replaced by 'dataset_info = DatasetInfo.model_validate_json(json_str)'
|
|
211
|
-
# when all datasets include a 'format_version' field
|
|
212
|
-
json_dict = json.loads(json_str)
|
|
213
|
-
if "format_version" not in json_dict:
|
|
214
|
-
json_dict["format_version"] = "0.0.0"
|
|
215
|
-
|
|
216
|
-
if "updated_at" not in json_dict:
|
|
217
|
-
json_dict["updated_at"] = datetime.min.isoformat()
|
|
218
|
-
dataset_info = DatasetInfo.model_validate(json_dict)
|
|
219
|
-
|
|
220
|
-
return dataset_info
|
|
221
|
-
|
|
222
|
-
@staticmethod
|
|
223
|
-
def merge(info0: DatasetInfo, info1: DatasetInfo) -> DatasetInfo:
|
|
224
|
-
"""
|
|
225
|
-
Merges two DatasetInfo objects into one and validates if they are compatible.
|
|
226
|
-
"""
|
|
227
|
-
for task_ds0 in info0.tasks:
|
|
228
|
-
for task_ds1 in info1.tasks:
|
|
229
|
-
same_name = task_ds0.name == task_ds1.name
|
|
230
|
-
same_primitive = task_ds0.primitive == task_ds1.primitive
|
|
231
|
-
same_name_different_primitive = same_name and not same_primitive
|
|
232
|
-
if same_name_different_primitive:
|
|
233
|
-
raise ValueError(
|
|
234
|
-
f"Cannot merge datasets with different primitives for the same task name: "
|
|
235
|
-
f"'{task_ds0.name}' has primitive '{task_ds0.primitive}' in dataset0 and "
|
|
236
|
-
f"'{task_ds1.primitive}' in dataset1."
|
|
237
|
-
)
|
|
238
|
-
|
|
239
|
-
is_same_name_and_primitive = same_name and same_primitive
|
|
240
|
-
if is_same_name_and_primitive:
|
|
241
|
-
task_ds0_class_names = task_ds0.class_names or []
|
|
242
|
-
task_ds1_class_names = task_ds1.class_names or []
|
|
243
|
-
if task_ds0_class_names != task_ds1_class_names:
|
|
244
|
-
raise ValueError(
|
|
245
|
-
f"Cannot merge datasets with different class names for the same task name and primitive: "
|
|
246
|
-
f"'{task_ds0.name}' with primitive '{task_ds0.primitive}' has class names "
|
|
247
|
-
f"{task_ds0_class_names} in dataset0 and {task_ds1_class_names} in dataset1."
|
|
248
|
-
)
|
|
249
|
-
|
|
250
|
-
if info1.format_version != info0.format_version:
|
|
251
|
-
user_logger.warning(
|
|
252
|
-
"Dataset format version of the two datasets do not match. "
|
|
253
|
-
f"'{info1.format_version}' vs '{info0.format_version}'."
|
|
254
|
-
)
|
|
255
|
-
dataset_format_version = info0.format_version
|
|
256
|
-
if hafnia.__dataset_format_version__ != dataset_format_version:
|
|
257
|
-
user_logger.warning(
|
|
258
|
-
f"Dataset format version '{dataset_format_version}' does not match the current "
|
|
259
|
-
f"Hafnia format version '{hafnia.__dataset_format_version__}'."
|
|
260
|
-
)
|
|
261
|
-
unique_tasks = set(info0.tasks + info1.tasks)
|
|
262
|
-
meta = (info0.meta or {}).copy()
|
|
263
|
-
meta.update(info1.meta or {})
|
|
264
|
-
return DatasetInfo(
|
|
265
|
-
dataset_name=info0.dataset_name + "+" + info1.dataset_name,
|
|
266
|
-
version=None,
|
|
267
|
-
tasks=list(unique_tasks),
|
|
268
|
-
meta=meta,
|
|
269
|
-
format_version=dataset_format_version,
|
|
270
|
-
)
|
|
271
|
-
|
|
272
|
-
def get_task_by_name(self, task_name: str) -> TaskInfo:
|
|
273
|
-
"""
|
|
274
|
-
Get task by its name. Raises an error if the task name is not found or if multiple tasks have the same name.
|
|
275
|
-
"""
|
|
276
|
-
tasks_with_name = [task for task in self.tasks if task.name == task_name]
|
|
277
|
-
if not tasks_with_name:
|
|
278
|
-
raise ValueError(f"Task with name '{task_name}' not found in dataset info.")
|
|
279
|
-
if len(tasks_with_name) > 1:
|
|
280
|
-
raise ValueError(f"Multiple tasks found with name '{task_name}'. This should not happen!")
|
|
281
|
-
return tasks_with_name[0]
|
|
282
|
-
|
|
283
|
-
def get_tasks_by_primitive(self, primitive: Union[Type[Primitive], str]) -> List[TaskInfo]:
|
|
284
|
-
"""
|
|
285
|
-
Get all tasks by their primitive type.
|
|
286
|
-
"""
|
|
287
|
-
if isinstance(primitive, str):
|
|
288
|
-
primitive = get_primitive_type_from_string(primitive)
|
|
289
|
-
|
|
290
|
-
tasks_with_primitive = [task for task in self.tasks if task.primitive == primitive]
|
|
291
|
-
return tasks_with_primitive
|
|
292
|
-
|
|
293
|
-
def get_task_by_primitive(self, primitive: Union[Type[Primitive], str]) -> TaskInfo:
|
|
294
|
-
"""
|
|
295
|
-
Get task by its primitive type. Raises an error if the primitive type is not found or if multiple tasks
|
|
296
|
-
have the same primitive type.
|
|
297
|
-
"""
|
|
298
|
-
|
|
299
|
-
tasks_with_primitive = self.get_tasks_by_primitive(primitive)
|
|
300
|
-
if len(tasks_with_primitive) == 0:
|
|
301
|
-
raise ValueError(f"Task with primitive {primitive} not found in dataset info.")
|
|
302
|
-
if len(tasks_with_primitive) > 1:
|
|
303
|
-
raise ValueError(
|
|
304
|
-
f"Multiple tasks found with primitive {primitive}. Use '{self.get_task_by_name.__name__}' instead."
|
|
305
|
-
)
|
|
306
|
-
return tasks_with_primitive[0]
|
|
307
|
-
|
|
308
|
-
def get_task_by_task_name_and_primitive(
|
|
309
|
-
self,
|
|
310
|
-
task_name: Optional[str],
|
|
311
|
-
primitive: Optional[Union[Type[Primitive], str]],
|
|
312
|
-
) -> TaskInfo:
|
|
313
|
-
"""
|
|
314
|
-
Logic to get a unique task based on the provided 'task_name' and/or 'primitive'.
|
|
315
|
-
If both 'task_name' and 'primitive' are None, the dataset must have only one task.
|
|
316
|
-
"""
|
|
317
|
-
task = dataset_transformations.get_task_info_from_task_name_and_primitive(
|
|
318
|
-
tasks=self.tasks,
|
|
319
|
-
primitive=primitive,
|
|
320
|
-
task_name=task_name,
|
|
321
|
-
)
|
|
322
|
-
return task
|
|
323
|
-
|
|
324
|
-
def replace_task(self, old_task: TaskInfo, new_task: Optional[TaskInfo]) -> DatasetInfo:
|
|
325
|
-
dataset_info = self.model_copy(deep=True)
|
|
326
|
-
has_task = any(t for t in dataset_info.tasks if t.name == old_task.name and t.primitive == old_task.primitive)
|
|
327
|
-
if not has_task:
|
|
328
|
-
raise ValueError(f"Task '{old_task.__repr__()}' not found in dataset info.")
|
|
329
|
-
|
|
330
|
-
new_tasks = []
|
|
331
|
-
for task in dataset_info.tasks:
|
|
332
|
-
if task.name == old_task.name and task.primitive == old_task.primitive:
|
|
333
|
-
if new_task is None:
|
|
334
|
-
continue # Remove the task
|
|
335
|
-
new_tasks.append(new_task)
|
|
336
|
-
else:
|
|
337
|
-
new_tasks.append(task)
|
|
338
|
-
|
|
339
|
-
dataset_info.tasks = new_tasks
|
|
340
|
-
return dataset_info
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
class Sample(BaseModel):
|
|
344
|
-
file_path: Optional[str] = Field(description="Path to the image/video file.")
|
|
345
|
-
height: int = Field(description="Height of the image")
|
|
346
|
-
width: int = Field(description="Width of the image")
|
|
347
|
-
split: str = Field(description="Split name, e.g., 'train', 'val', 'test'")
|
|
348
|
-
tags: List[str] = Field(
|
|
349
|
-
default_factory=list,
|
|
350
|
-
description="Tags for a given sample. Used for creating subsets of the dataset.",
|
|
351
|
-
)
|
|
352
|
-
storage_format: str = Field(
|
|
353
|
-
default=StorageFormat.IMAGE,
|
|
354
|
-
description="Storage format. Sample data is stored as image or inside a video or zip file.",
|
|
355
|
-
)
|
|
356
|
-
collection_index: Optional[int] = Field(default=None, description="Optional e.g. frame number for video datasets")
|
|
357
|
-
collection_id: Optional[str] = Field(default=None, description="Optional e.g. video name for video datasets")
|
|
358
|
-
remote_path: Optional[str] = Field(default=None, description="Optional remote path for the image, if applicable")
|
|
359
|
-
sample_index: Optional[int] = Field(
|
|
360
|
-
default=None,
|
|
361
|
-
description="Don't manually set this, it is used for indexing samples in the dataset.",
|
|
362
|
-
)
|
|
363
|
-
classifications: Optional[List[Classification]] = Field(
|
|
364
|
-
default=None, description="Optional list of classifications"
|
|
365
|
-
)
|
|
366
|
-
bboxes: Optional[List[Bbox]] = Field(default=None, description="Optional list of bounding boxes")
|
|
367
|
-
bitmasks: Optional[List[Bitmask]] = Field(default=None, description="Optional list of bitmasks")
|
|
368
|
-
polygons: Optional[List[Polygon]] = Field(default=None, description="Optional list of polygons")
|
|
369
|
-
|
|
370
|
-
attribution: Optional[Attribution] = Field(default=None, description="Attribution information for the image")
|
|
371
|
-
dataset_name: Optional[str] = Field(
|
|
372
|
-
default=None,
|
|
373
|
-
description=(
|
|
374
|
-
"Don't manually set this, it will be automatically defined during initialization. "
|
|
375
|
-
"Name of the dataset the sample belongs to. E.g. 'coco-2017' or 'midwest-vehicle-detection'."
|
|
376
|
-
),
|
|
377
|
-
)
|
|
378
|
-
meta: Optional[Dict] = Field(
|
|
379
|
-
default=None,
|
|
380
|
-
description="Additional metadata, e.g., camera settings, GPS data, etc.",
|
|
381
|
-
)
|
|
382
|
-
|
|
383
|
-
def get_annotations(self, primitive_types: Optional[List[Type[Primitive]]] = None) -> List[Primitive]:
|
|
384
|
-
"""
|
|
385
|
-
Returns a list of all annotations (classifications, objects, bitmasks, polygons) for the sample.
|
|
386
|
-
"""
|
|
387
|
-
primitive_types = primitive_types or PRIMITIVE_TYPES
|
|
388
|
-
annotations_primitives = [
|
|
389
|
-
getattr(self, primitive_type.column_name(), None) for primitive_type in primitive_types
|
|
390
|
-
]
|
|
391
|
-
annotations = more_itertools.flatten(
|
|
392
|
-
[primitives for primitives in annotations_primitives if primitives is not None]
|
|
393
|
-
)
|
|
394
|
-
|
|
395
|
-
return list(annotations)
|
|
396
|
-
|
|
397
|
-
def read_image_pillow(self) -> Image.Image:
|
|
398
|
-
"""
|
|
399
|
-
Reads the image from the file path and returns it as a PIL Image.
|
|
400
|
-
Raises FileNotFoundError if the image file does not exist.
|
|
401
|
-
"""
|
|
402
|
-
if self.file_path is None:
|
|
403
|
-
raise ValueError(f"Sample has no '{SampleField.FILE_PATH}' defined.")
|
|
404
|
-
path_image = Path(self.file_path)
|
|
405
|
-
if not path_image.exists():
|
|
406
|
-
raise FileNotFoundError(f"Image file {path_image} does not exist. Please check the file path.")
|
|
407
|
-
|
|
408
|
-
image = Image.open(str(path_image))
|
|
409
|
-
return image
|
|
410
|
-
|
|
411
|
-
def read_image(self) -> np.ndarray:
|
|
412
|
-
if self.storage_format == StorageFormat.VIDEO:
|
|
413
|
-
video = cv2.VideoCapture(str(self.file_path))
|
|
414
|
-
if self.collection_index is None:
|
|
415
|
-
raise ValueError("collection_index must be set for video storage format to read the correct frame.")
|
|
416
|
-
video.set(cv2.CAP_PROP_POS_FRAMES, self.collection_index)
|
|
417
|
-
success, image = video.read()
|
|
418
|
-
video.release()
|
|
419
|
-
if not success:
|
|
420
|
-
raise ValueError(f"Could not read frame {self.collection_index} from video file {self.file_path}.")
|
|
421
|
-
return image
|
|
422
|
-
|
|
423
|
-
elif self.storage_format == StorageFormat.IMAGE:
|
|
424
|
-
image_pil = self.read_image_pillow()
|
|
425
|
-
image = np.array(image_pil)
|
|
426
|
-
else:
|
|
427
|
-
raise ValueError(f"Unsupported storage format: {self.storage_format}")
|
|
428
|
-
return image
|
|
429
|
-
|
|
430
|
-
def draw_annotations(self, image: Optional[np.ndarray] = None) -> np.ndarray:
|
|
431
|
-
from hafnia.visualizations import image_visualizations
|
|
432
|
-
|
|
433
|
-
image = image or self.read_image()
|
|
434
|
-
annotations = self.get_annotations()
|
|
435
|
-
annotations_visualized = image_visualizations.draw_annotations(image=image, primitives=annotations)
|
|
436
|
-
return annotations_visualized
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
class License(BaseModel):
|
|
440
|
-
"""License information"""
|
|
441
|
-
|
|
442
|
-
name: Optional[str] = Field(
|
|
443
|
-
default=None,
|
|
444
|
-
description="License name. E.g. 'Creative Commons: Attribution 2.0 Generic'",
|
|
445
|
-
max_length=100,
|
|
446
|
-
)
|
|
447
|
-
name_short: Optional[str] = Field(
|
|
448
|
-
default=None,
|
|
449
|
-
description="License short name or abbreviation. E.g. 'CC BY 4.0'",
|
|
450
|
-
max_length=100,
|
|
451
|
-
)
|
|
452
|
-
url: Optional[str] = Field(
|
|
453
|
-
default=None,
|
|
454
|
-
description="License URL e.g. https://creativecommons.org/licenses/by/4.0/",
|
|
455
|
-
)
|
|
456
|
-
description: Optional[str] = Field(
|
|
457
|
-
default=None,
|
|
458
|
-
description=(
|
|
459
|
-
"License description e.g. 'You must give appropriate credit, provide a "
|
|
460
|
-
"link to the license, and indicate if changes were made.'"
|
|
461
|
-
),
|
|
462
|
-
)
|
|
463
|
-
|
|
464
|
-
valid_date: Optional[datetime] = Field(
|
|
465
|
-
default=None,
|
|
466
|
-
description="License valid date. E.g. '2023-01-01T00:00:00Z'",
|
|
467
|
-
)
|
|
468
|
-
|
|
469
|
-
permissions: Optional[List[str]] = Field(
|
|
470
|
-
default=None,
|
|
471
|
-
description="License permissions. Allowed to Access, Label, Distribute, Represent and Modify data.",
|
|
472
|
-
)
|
|
473
|
-
liability: Optional[str] = Field(
|
|
474
|
-
default=None,
|
|
475
|
-
description="License liability. Optional and not always applicable.",
|
|
476
|
-
)
|
|
477
|
-
location: Optional[str] = Field(
|
|
478
|
-
default=None,
|
|
479
|
-
description=(
|
|
480
|
-
"License Location. E.g. Iowa state. This is essential to understand the industry and "
|
|
481
|
-
"privacy location specific rules that applies to the data. Optional and not always applicable."
|
|
482
|
-
),
|
|
483
|
-
)
|
|
484
|
-
notes: Optional[str] = Field(
|
|
485
|
-
default=None,
|
|
486
|
-
description="Additional license notes. Optional and not always applicable.",
|
|
487
|
-
)
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
class Attribution(BaseModel):
|
|
491
|
-
"""Attribution information for the image: Giving source and credit to the original creator"""
|
|
492
|
-
|
|
493
|
-
title: Optional[str] = Field(default=None, description="Title of the image", max_length=255)
|
|
494
|
-
creator: Optional[str] = Field(default=None, description="Creator of the image", max_length=255)
|
|
495
|
-
creator_url: Optional[str] = Field(default=None, description="URL of the creator", max_length=255)
|
|
496
|
-
date_captured: Optional[datetime] = Field(default=None, description="Date when the image was captured")
|
|
497
|
-
copyright_notice: Optional[str] = Field(default=None, description="Copyright notice for the image", max_length=255)
|
|
498
|
-
licenses: Optional[List[License]] = Field(default=None, description="List of licenses for the image")
|
|
499
|
-
disclaimer: Optional[str] = Field(default=None, description="Disclaimer for the image", max_length=255)
|
|
500
|
-
changes: Optional[str] = Field(default=None, description="Changes made to the image", max_length=255)
|
|
501
|
-
source_url: Optional[str] = Field(default=None, description="Source URL for the image", max_length=255)
|
|
40
|
+
from hafnia.utils import progress_bar
|
|
502
41
|
|
|
503
42
|
|
|
504
43
|
@dataclass
|
|
@@ -527,8 +66,10 @@ class HafniaDataset:
|
|
|
527
66
|
convert_to_image_storage_format = dataset_transformations.convert_to_image_storage_format
|
|
528
67
|
|
|
529
68
|
# Import / export functions
|
|
530
|
-
from_yolo_format = format_yolo.from_yolo_format
|
|
531
69
|
to_yolo_format = format_yolo.to_yolo_format
|
|
70
|
+
from_yolo_format = format_yolo.from_yolo_format
|
|
71
|
+
to_coco_format = format_coco.to_coco_format
|
|
72
|
+
from_coco_format = format_coco.from_coco_format
|
|
532
73
|
to_image_classification_folder = format_image_classification_folder.to_image_classification_folder
|
|
533
74
|
from_image_classification_folder = format_image_classification_folder.from_image_classification_folder
|
|
534
75
|
|
|
@@ -978,6 +519,10 @@ class HafniaDataset:
|
|
|
978
519
|
dataset.check_dataset_tasks()
|
|
979
520
|
return dataset
|
|
980
521
|
|
|
522
|
+
def has_primitive(dataset: HafniaDataset, PrimitiveType: Type[Primitive]) -> bool:
|
|
523
|
+
table = dataset.samples if isinstance(dataset, HafniaDataset) else dataset
|
|
524
|
+
return table_transformations.has_primitive(table, PrimitiveType)
|
|
525
|
+
|
|
981
526
|
@staticmethod
|
|
982
527
|
def check_dataset_path(path_dataset: Path, raise_error: bool = True) -> bool:
|
|
983
528
|
"""
|
|
@@ -1026,7 +571,7 @@ class HafniaDataset:
|
|
|
1026
571
|
hafnia_dataset = self.copy() # To avoid inplace modifications
|
|
1027
572
|
new_paths = []
|
|
1028
573
|
org_paths = hafnia_dataset.samples[SampleField.FILE_PATH].to_list()
|
|
1029
|
-
for org_path in
|
|
574
|
+
for org_path in progress_bar(org_paths, description="- Copy images"):
|
|
1030
575
|
new_path = dataset_helpers.copy_and_rename_file_to_hash_value(
|
|
1031
576
|
path_source=Path(org_path),
|
|
1032
577
|
path_dataset_root=path_folder,
|
|
@@ -1145,4 +690,13 @@ def _dataset_corrections(samples: pl.DataFrame, dataset_info: DatasetInfo) -> Tu
|
|
|
1145
690
|
if SampleField.SAMPLE_INDEX in samples.columns and samples[SampleField.SAMPLE_INDEX].dtype != pl.UInt64:
|
|
1146
691
|
samples = samples.cast({SampleField.SAMPLE_INDEX: pl.UInt64})
|
|
1147
692
|
|
|
693
|
+
if format_version_of_dataset <= Version("0.2.0"):
|
|
694
|
+
if SampleField.BITMASKS in samples.columns and samples[SampleField.BITMASKS].dtype == pl.List(pl.Struct):
|
|
695
|
+
struct_schema = samples.schema[SampleField.BITMASKS].inner
|
|
696
|
+
struct_names = [f.name for f in struct_schema.fields]
|
|
697
|
+
if "rleString" in struct_names:
|
|
698
|
+
struct_names[struct_names.index("rleString")] = "rle_string"
|
|
699
|
+
samples = samples.with_columns(
|
|
700
|
+
pl.col(SampleField.BITMASKS).list.eval(pl.element().struct.rename_fields(struct_names))
|
|
701
|
+
)
|
|
1148
702
|
return samples, dataset_info
|