hafnia 0.4.2__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hafnia/dataset/{dataset_upload_helper.py → dataset_details_uploader.py} +114 -191
- hafnia/dataset/dataset_names.py +26 -0
- hafnia/dataset/format_conversions/format_coco.py +490 -0
- hafnia/dataset/format_conversions/format_helpers.py +33 -0
- hafnia/dataset/format_conversions/format_image_classification_folder.py +95 -14
- hafnia/dataset/format_conversions/format_yolo.py +115 -25
- hafnia/dataset/format_conversions/torchvision_datasets.py +10 -8
- hafnia/dataset/hafnia_dataset.py +20 -466
- hafnia/dataset/hafnia_dataset_types.py +477 -0
- hafnia/dataset/license_types.py +4 -4
- hafnia/dataset/operations/dataset_stats.py +3 -3
- hafnia/dataset/operations/dataset_transformations.py +14 -17
- hafnia/dataset/operations/table_transformations.py +20 -13
- hafnia/dataset/primitives/bbox.py +6 -2
- hafnia/dataset/primitives/bitmask.py +21 -46
- hafnia/dataset/primitives/classification.py +1 -1
- hafnia/dataset/primitives/polygon.py +43 -2
- hafnia/dataset/primitives/primitive.py +1 -1
- hafnia/dataset/primitives/segmentation.py +1 -1
- hafnia/experiment/hafnia_logger.py +13 -4
- hafnia/platform/datasets.py +2 -3
- hafnia/torch_helpers.py +48 -4
- hafnia/utils.py +34 -0
- hafnia/visualizations/image_visualizations.py +3 -1
- {hafnia-0.4.2.dist-info → hafnia-0.4.3.dist-info}/METADATA +2 -2
- {hafnia-0.4.2.dist-info → hafnia-0.4.3.dist-info}/RECORD +29 -26
- {hafnia-0.4.2.dist-info → hafnia-0.4.3.dist-info}/WHEEL +0 -0
- {hafnia-0.4.2.dist-info → hafnia-0.4.3.dist-info}/entry_points.txt +0 -0
- {hafnia-0.4.2.dist-info → hafnia-0.4.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,477 @@
|
|
|
1
|
+
import collections
|
|
2
|
+
import json
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, Dict, List, Optional, Type, Union
|
|
6
|
+
|
|
7
|
+
import cv2
|
|
8
|
+
import more_itertools
|
|
9
|
+
import numpy as np
|
|
10
|
+
from packaging.version import Version
|
|
11
|
+
from PIL import Image
|
|
12
|
+
from pydantic import BaseModel, Field, field_serializer, field_validator
|
|
13
|
+
|
|
14
|
+
import hafnia
|
|
15
|
+
from hafnia.dataset.dataset_names import SampleField, StorageFormat
|
|
16
|
+
from hafnia.dataset.primitives import (
|
|
17
|
+
PRIMITIVE_TYPES,
|
|
18
|
+
Bbox,
|
|
19
|
+
Bitmask,
|
|
20
|
+
Classification,
|
|
21
|
+
Polygon,
|
|
22
|
+
get_primitive_type_from_string,
|
|
23
|
+
)
|
|
24
|
+
from hafnia.dataset.primitives.primitive import Primitive
|
|
25
|
+
from hafnia.log import user_logger
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class TaskInfo(BaseModel):
|
|
29
|
+
primitive: Type[Primitive] = Field(
|
|
30
|
+
description="Primitive class or string name of the primitive, e.g. 'Bbox' or 'bitmask'"
|
|
31
|
+
)
|
|
32
|
+
class_names: Optional[List[str]] = Field(default=None, description="Optional list of class names for the primitive")
|
|
33
|
+
name: Optional[str] = Field(
|
|
34
|
+
default=None,
|
|
35
|
+
description=(
|
|
36
|
+
"Optional name for the task. 'None' will use default name of the provided primitive. "
|
|
37
|
+
"e.g. Bbox ->'bboxes', Bitmask -> 'bitmasks' etc."
|
|
38
|
+
),
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
def model_post_init(self, __context: Any) -> None:
|
|
42
|
+
if self.name is None:
|
|
43
|
+
self.name = self.primitive.default_task_name()
|
|
44
|
+
|
|
45
|
+
def get_class_index(self, class_name: str) -> int:
|
|
46
|
+
"""Get class index for a given class name"""
|
|
47
|
+
if self.class_names is None:
|
|
48
|
+
raise ValueError(f"Task '{self.name}' has no class names defined.")
|
|
49
|
+
if class_name not in self.class_names:
|
|
50
|
+
raise ValueError(f"Class name '{class_name}' not found in task '{self.name}'.")
|
|
51
|
+
return self.class_names.index(class_name)
|
|
52
|
+
|
|
53
|
+
# The 'primitive'-field of type 'Type[Primitive]' is not supported by pydantic out-of-the-box as
|
|
54
|
+
# the 'Primitive' class is an abstract base class and for the actual primtives such as Bbox, Bitmask, Classification.
|
|
55
|
+
# Below magic functions ('ensure_primitive' and 'serialize_primitive') ensures that the 'primitive' field can
|
|
56
|
+
# correctly validate and serialize sub-classes (Bbox, Classification, ...).
|
|
57
|
+
@field_validator("primitive", mode="plain")
|
|
58
|
+
@classmethod
|
|
59
|
+
def ensure_primitive(cls, primitive: Any) -> Any:
|
|
60
|
+
if isinstance(primitive, str):
|
|
61
|
+
return get_primitive_type_from_string(primitive)
|
|
62
|
+
|
|
63
|
+
if issubclass(primitive, Primitive):
|
|
64
|
+
return primitive
|
|
65
|
+
|
|
66
|
+
raise ValueError(f"Primitive must be a string or a Primitive subclass, got {type(primitive)} instead.")
|
|
67
|
+
|
|
68
|
+
@field_serializer("primitive")
|
|
69
|
+
@classmethod
|
|
70
|
+
def serialize_primitive(cls, primitive: Type[Primitive]) -> str:
|
|
71
|
+
if not issubclass(primitive, Primitive):
|
|
72
|
+
raise ValueError(f"Primitive must be a subclass of Primitive, got {type(primitive)} instead.")
|
|
73
|
+
return primitive.__name__
|
|
74
|
+
|
|
75
|
+
@field_validator("class_names", mode="after")
|
|
76
|
+
@classmethod
|
|
77
|
+
def validate_unique_class_names(cls, class_names: Optional[List[str]]) -> Optional[List[str]]:
|
|
78
|
+
"""Validate that class names are unique"""
|
|
79
|
+
if class_names is None:
|
|
80
|
+
return None
|
|
81
|
+
duplicate_class_names = set([name for name in class_names if class_names.count(name) > 1])
|
|
82
|
+
if duplicate_class_names:
|
|
83
|
+
raise ValueError(
|
|
84
|
+
f"Class names must be unique. The following class names appear multiple times: {duplicate_class_names}."
|
|
85
|
+
)
|
|
86
|
+
return class_names
|
|
87
|
+
|
|
88
|
+
def full_name(self) -> str:
|
|
89
|
+
"""Get qualified name for the task: <primitive_name>:<task_name>"""
|
|
90
|
+
return f"{self.primitive.__name__}:{self.name}"
|
|
91
|
+
|
|
92
|
+
# To get unique hash value for TaskInfo objects
|
|
93
|
+
def __hash__(self) -> int:
|
|
94
|
+
class_names = self.class_names or []
|
|
95
|
+
return hash((self.name, self.primitive.__name__, tuple(class_names)))
|
|
96
|
+
|
|
97
|
+
def __eq__(self, other: Any) -> bool:
|
|
98
|
+
if not isinstance(other, TaskInfo):
|
|
99
|
+
return False
|
|
100
|
+
return self.name == other.name and self.primitive == other.primitive and self.class_names == other.class_names
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class DatasetInfo(BaseModel):
|
|
104
|
+
dataset_name: str = Field(description="Name of the dataset, e.g. 'coco'")
|
|
105
|
+
version: Optional[str] = Field(default=None, description="Version of the dataset")
|
|
106
|
+
tasks: List[TaskInfo] = Field(default=None, description="List of tasks in the dataset")
|
|
107
|
+
reference_bibtex: Optional[str] = Field(
|
|
108
|
+
default=None,
|
|
109
|
+
description="Optional, BibTeX reference to dataset publication",
|
|
110
|
+
)
|
|
111
|
+
reference_paper_url: Optional[str] = Field(
|
|
112
|
+
default=None,
|
|
113
|
+
description="Optional, URL to dataset publication",
|
|
114
|
+
)
|
|
115
|
+
reference_dataset_page: Optional[str] = Field(
|
|
116
|
+
default=None,
|
|
117
|
+
description="Optional, URL to the dataset page",
|
|
118
|
+
)
|
|
119
|
+
meta: Optional[Dict[str, Any]] = Field(default=None, description="Optional metadata about the dataset")
|
|
120
|
+
format_version: str = Field(
|
|
121
|
+
default=hafnia.__dataset_format_version__,
|
|
122
|
+
description="Version of the Hafnia dataset format. You should not set this manually.",
|
|
123
|
+
)
|
|
124
|
+
updated_at: datetime = Field(
|
|
125
|
+
default_factory=datetime.now,
|
|
126
|
+
description="Timestamp of the last update to the dataset info. You should not set this manually.",
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
@field_validator("tasks", mode="after")
|
|
130
|
+
@classmethod
|
|
131
|
+
def _validate_check_for_duplicate_tasks(cls, tasks: Optional[List[TaskInfo]]) -> List[TaskInfo]:
|
|
132
|
+
if tasks is None:
|
|
133
|
+
return []
|
|
134
|
+
task_name_counts = collections.Counter(task.name for task in tasks)
|
|
135
|
+
duplicate_task_names = [name for name, count in task_name_counts.items() if count > 1]
|
|
136
|
+
if duplicate_task_names:
|
|
137
|
+
raise ValueError(
|
|
138
|
+
f"Tasks must be unique. The following tasks appear multiple times: {duplicate_task_names}."
|
|
139
|
+
)
|
|
140
|
+
return tasks
|
|
141
|
+
|
|
142
|
+
@field_validator("format_version")
|
|
143
|
+
@classmethod
|
|
144
|
+
def _validate_format_version(cls, format_version: str) -> str:
|
|
145
|
+
try:
|
|
146
|
+
Version(format_version)
|
|
147
|
+
except Exception as e:
|
|
148
|
+
raise ValueError(f"Invalid format_version '{format_version}'. Must be a valid version string.") from e
|
|
149
|
+
|
|
150
|
+
if Version(format_version) > Version(hafnia.__dataset_format_version__):
|
|
151
|
+
user_logger.warning(
|
|
152
|
+
f"The loaded dataset format version '{format_version}' is newer than the format version "
|
|
153
|
+
f"'{hafnia.__dataset_format_version__}' used in your version of Hafnia. Please consider "
|
|
154
|
+
f"updating Hafnia package."
|
|
155
|
+
)
|
|
156
|
+
return format_version
|
|
157
|
+
|
|
158
|
+
@field_validator("version")
|
|
159
|
+
@classmethod
|
|
160
|
+
def _validate_version(cls, dataset_version: Optional[str]) -> Optional[str]:
|
|
161
|
+
if dataset_version is None:
|
|
162
|
+
return None
|
|
163
|
+
|
|
164
|
+
try:
|
|
165
|
+
Version(dataset_version)
|
|
166
|
+
except Exception as e:
|
|
167
|
+
raise ValueError(f"Invalid dataset_version '{dataset_version}'. Must be a valid version string.") from e
|
|
168
|
+
|
|
169
|
+
return dataset_version
|
|
170
|
+
|
|
171
|
+
def check_for_duplicate_task_names(self) -> List[TaskInfo]:
|
|
172
|
+
return self._validate_check_for_duplicate_tasks(self.tasks)
|
|
173
|
+
|
|
174
|
+
def write_json(self, path: Path, indent: Optional[int] = 4) -> None:
|
|
175
|
+
json_str = self.model_dump_json(indent=indent)
|
|
176
|
+
path.write_text(json_str)
|
|
177
|
+
|
|
178
|
+
@staticmethod
|
|
179
|
+
def from_json_file(path: Path) -> "DatasetInfo":
|
|
180
|
+
json_str = path.read_text()
|
|
181
|
+
|
|
182
|
+
# TODO: Deprecated support for old dataset info without format_version
|
|
183
|
+
# Below 4 lines can be replaced by 'dataset_info = DatasetInfo.model_validate_json(json_str)'
|
|
184
|
+
# when all datasets include a 'format_version' field
|
|
185
|
+
json_dict = json.loads(json_str)
|
|
186
|
+
if "format_version" not in json_dict:
|
|
187
|
+
json_dict["format_version"] = "0.0.0"
|
|
188
|
+
|
|
189
|
+
if "updated_at" not in json_dict:
|
|
190
|
+
json_dict["updated_at"] = datetime.min.isoformat()
|
|
191
|
+
dataset_info = DatasetInfo.model_validate(json_dict)
|
|
192
|
+
|
|
193
|
+
return dataset_info
|
|
194
|
+
|
|
195
|
+
@staticmethod
|
|
196
|
+
def merge(info0: "DatasetInfo", info1: "DatasetInfo") -> "DatasetInfo":
|
|
197
|
+
"""
|
|
198
|
+
Merges two DatasetInfo objects into one and validates if they are compatible.
|
|
199
|
+
"""
|
|
200
|
+
for task_ds0 in info0.tasks:
|
|
201
|
+
for task_ds1 in info1.tasks:
|
|
202
|
+
same_name = task_ds0.name == task_ds1.name
|
|
203
|
+
same_primitive = task_ds0.primitive == task_ds1.primitive
|
|
204
|
+
same_name_different_primitive = same_name and not same_primitive
|
|
205
|
+
if same_name_different_primitive:
|
|
206
|
+
raise ValueError(
|
|
207
|
+
f"Cannot merge datasets with different primitives for the same task name: "
|
|
208
|
+
f"'{task_ds0.name}' has primitive '{task_ds0.primitive}' in dataset0 and "
|
|
209
|
+
f"'{task_ds1.primitive}' in dataset1."
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
is_same_name_and_primitive = same_name and same_primitive
|
|
213
|
+
if is_same_name_and_primitive:
|
|
214
|
+
task_ds0_class_names = task_ds0.class_names or []
|
|
215
|
+
task_ds1_class_names = task_ds1.class_names or []
|
|
216
|
+
if task_ds0_class_names != task_ds1_class_names:
|
|
217
|
+
raise ValueError(
|
|
218
|
+
f"Cannot merge datasets with different class names for the same task name and primitive: "
|
|
219
|
+
f"'{task_ds0.name}' with primitive '{task_ds0.primitive}' has class names "
|
|
220
|
+
f"{task_ds0_class_names} in dataset0 and {task_ds1_class_names} in dataset1."
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
if info1.format_version != info0.format_version:
|
|
224
|
+
user_logger.warning(
|
|
225
|
+
"Dataset format version of the two datasets do not match. "
|
|
226
|
+
f"'{info1.format_version}' vs '{info0.format_version}'."
|
|
227
|
+
)
|
|
228
|
+
dataset_format_version = info0.format_version
|
|
229
|
+
if hafnia.__dataset_format_version__ != dataset_format_version:
|
|
230
|
+
user_logger.warning(
|
|
231
|
+
f"Dataset format version '{dataset_format_version}' does not match the current "
|
|
232
|
+
f"Hafnia format version '{hafnia.__dataset_format_version__}'."
|
|
233
|
+
)
|
|
234
|
+
unique_tasks = set(info0.tasks + info1.tasks)
|
|
235
|
+
meta = (info0.meta or {}).copy()
|
|
236
|
+
meta.update(info1.meta or {})
|
|
237
|
+
return DatasetInfo(
|
|
238
|
+
dataset_name=info0.dataset_name + "+" + info1.dataset_name,
|
|
239
|
+
version=None,
|
|
240
|
+
tasks=list(unique_tasks),
|
|
241
|
+
meta=meta,
|
|
242
|
+
format_version=dataset_format_version,
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
def get_task_by_name(self, task_name: str) -> TaskInfo:
|
|
246
|
+
"""
|
|
247
|
+
Get task by its name. Raises an error if the task name is not found or if multiple tasks have the same name.
|
|
248
|
+
"""
|
|
249
|
+
tasks_with_name = [task for task in self.tasks if task.name == task_name]
|
|
250
|
+
if not tasks_with_name:
|
|
251
|
+
raise ValueError(f"Task with name '{task_name}' not found in dataset info.")
|
|
252
|
+
if len(tasks_with_name) > 1:
|
|
253
|
+
raise ValueError(f"Multiple tasks found with name '{task_name}'. This should not happen!")
|
|
254
|
+
return tasks_with_name[0]
|
|
255
|
+
|
|
256
|
+
def get_tasks_by_primitive(self, primitive: Union[Type[Primitive], str]) -> List[TaskInfo]:
|
|
257
|
+
"""
|
|
258
|
+
Get all tasks by their primitive type.
|
|
259
|
+
"""
|
|
260
|
+
if isinstance(primitive, str):
|
|
261
|
+
primitive = get_primitive_type_from_string(primitive)
|
|
262
|
+
|
|
263
|
+
tasks_with_primitive = [task for task in self.tasks if task.primitive == primitive]
|
|
264
|
+
return tasks_with_primitive
|
|
265
|
+
|
|
266
|
+
def get_task_by_primitive(self, primitive: Union[Type[Primitive], str]) -> TaskInfo:
|
|
267
|
+
"""
|
|
268
|
+
Get task by its primitive type. Raises an error if the primitive type is not found or if multiple tasks
|
|
269
|
+
have the same primitive type.
|
|
270
|
+
"""
|
|
271
|
+
|
|
272
|
+
tasks_with_primitive = self.get_tasks_by_primitive(primitive)
|
|
273
|
+
if len(tasks_with_primitive) == 0:
|
|
274
|
+
raise ValueError(f"Task with primitive {primitive} not found in dataset info.")
|
|
275
|
+
if len(tasks_with_primitive) > 1:
|
|
276
|
+
raise ValueError(
|
|
277
|
+
f"Multiple tasks found with primitive {primitive}. Use '{self.get_task_by_name.__name__}' instead."
|
|
278
|
+
)
|
|
279
|
+
return tasks_with_primitive[0]
|
|
280
|
+
|
|
281
|
+
def get_task_by_task_name_and_primitive(
|
|
282
|
+
self,
|
|
283
|
+
task_name: Optional[str],
|
|
284
|
+
primitive: Optional[Union[Type[Primitive], str]],
|
|
285
|
+
) -> TaskInfo:
|
|
286
|
+
"""
|
|
287
|
+
Logic to get a unique task based on the provided 'task_name' and/or 'primitive'.
|
|
288
|
+
If both 'task_name' and 'primitive' are None, the dataset must have only one task.
|
|
289
|
+
"""
|
|
290
|
+
from hafnia.dataset.operations import dataset_transformations
|
|
291
|
+
|
|
292
|
+
task = dataset_transformations.get_task_info_from_task_name_and_primitive(
|
|
293
|
+
tasks=self.tasks,
|
|
294
|
+
primitive=primitive,
|
|
295
|
+
task_name=task_name,
|
|
296
|
+
)
|
|
297
|
+
return task
|
|
298
|
+
|
|
299
|
+
def replace_task(self, old_task: TaskInfo, new_task: Optional[TaskInfo]) -> "DatasetInfo":
|
|
300
|
+
dataset_info = self.model_copy(deep=True)
|
|
301
|
+
has_task = any(t for t in dataset_info.tasks if t.name == old_task.name and t.primitive == old_task.primitive)
|
|
302
|
+
if not has_task:
|
|
303
|
+
raise ValueError(f"Task '{old_task.__repr__()}' not found in dataset info.")
|
|
304
|
+
|
|
305
|
+
new_tasks = []
|
|
306
|
+
for task in dataset_info.tasks:
|
|
307
|
+
if task.name == old_task.name and task.primitive == old_task.primitive:
|
|
308
|
+
if new_task is None:
|
|
309
|
+
continue # Remove the task
|
|
310
|
+
new_tasks.append(new_task)
|
|
311
|
+
else:
|
|
312
|
+
new_tasks.append(task)
|
|
313
|
+
|
|
314
|
+
dataset_info.tasks = new_tasks
|
|
315
|
+
return dataset_info
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
class License(BaseModel):
|
|
319
|
+
"""License information"""
|
|
320
|
+
|
|
321
|
+
name: Optional[str] = Field(
|
|
322
|
+
default=None,
|
|
323
|
+
description="License name. E.g. 'Creative Commons: Attribution 2.0 Generic'",
|
|
324
|
+
max_length=100,
|
|
325
|
+
)
|
|
326
|
+
name_short: Optional[str] = Field(
|
|
327
|
+
default=None,
|
|
328
|
+
description="License short name or abbreviation. E.g. 'CC BY 4.0'",
|
|
329
|
+
max_length=100,
|
|
330
|
+
)
|
|
331
|
+
url: Optional[str] = Field(
|
|
332
|
+
default=None,
|
|
333
|
+
description="License URL e.g. https://creativecommons.org/licenses/by/4.0/",
|
|
334
|
+
)
|
|
335
|
+
description: Optional[str] = Field(
|
|
336
|
+
default=None,
|
|
337
|
+
description=(
|
|
338
|
+
"License description e.g. 'You must give appropriate credit, provide a "
|
|
339
|
+
"link to the license, and indicate if changes were made.'"
|
|
340
|
+
),
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
valid_date: Optional[datetime] = Field(
|
|
344
|
+
default=None,
|
|
345
|
+
description="License valid date. E.g. '2023-01-01T00:00:00Z'",
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
permissions: Optional[List[str]] = Field(
|
|
349
|
+
default=None,
|
|
350
|
+
description="License permissions. Allowed to Access, Label, Distribute, Represent and Modify data.",
|
|
351
|
+
)
|
|
352
|
+
liability: Optional[str] = Field(
|
|
353
|
+
default=None,
|
|
354
|
+
description="License liability. Optional and not always applicable.",
|
|
355
|
+
)
|
|
356
|
+
location: Optional[str] = Field(
|
|
357
|
+
default=None,
|
|
358
|
+
description=(
|
|
359
|
+
"License Location. E.g. Iowa state. This is essential to understand the industry and "
|
|
360
|
+
"privacy location specific rules that applies to the data. Optional and not always applicable."
|
|
361
|
+
),
|
|
362
|
+
)
|
|
363
|
+
notes: Optional[str] = Field(
|
|
364
|
+
default=None,
|
|
365
|
+
description="Additional license notes. Optional and not always applicable.",
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
class Attribution(BaseModel):
|
|
370
|
+
"""Attribution information for the image: Giving source and credit to the original creator"""
|
|
371
|
+
|
|
372
|
+
title: Optional[str] = Field(default=None, description="Title of the image", max_length=255)
|
|
373
|
+
creator: Optional[str] = Field(default=None, description="Creator of the image", max_length=255)
|
|
374
|
+
creator_url: Optional[str] = Field(default=None, description="URL of the creator", max_length=255)
|
|
375
|
+
date_captured: Optional[datetime] = Field(default=None, description="Date when the image was captured")
|
|
376
|
+
copyright_notice: Optional[str] = Field(default=None, description="Copyright notice for the image", max_length=255)
|
|
377
|
+
licenses: Optional[List[License]] = Field(default=None, description="List of licenses for the image")
|
|
378
|
+
disclaimer: Optional[str] = Field(default=None, description="Disclaimer for the image", max_length=255)
|
|
379
|
+
changes: Optional[str] = Field(default=None, description="Changes made to the image", max_length=255)
|
|
380
|
+
source_url: Optional[str] = Field(default=None, description="Source URL for the image", max_length=255)
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
class Sample(BaseModel):
|
|
384
|
+
file_path: Optional[str] = Field(description="Path to the image/video file.")
|
|
385
|
+
height: int = Field(description="Height of the image")
|
|
386
|
+
width: int = Field(description="Width of the image")
|
|
387
|
+
split: str = Field(description="Split name, e.g., 'train', 'val', 'test'")
|
|
388
|
+
tags: List[str] = Field(
|
|
389
|
+
default_factory=list,
|
|
390
|
+
description="Tags for a given sample. Used for creating subsets of the dataset.",
|
|
391
|
+
)
|
|
392
|
+
storage_format: str = Field(
|
|
393
|
+
default=StorageFormat.IMAGE,
|
|
394
|
+
description="Storage format. Sample data is stored as image or inside a video or zip file.",
|
|
395
|
+
)
|
|
396
|
+
collection_index: Optional[int] = Field(default=None, description="Optional e.g. frame number for video datasets")
|
|
397
|
+
collection_id: Optional[str] = Field(default=None, description="Optional e.g. video name for video datasets")
|
|
398
|
+
remote_path: Optional[str] = Field(default=None, description="Optional remote path for the image, if applicable")
|
|
399
|
+
sample_index: Optional[int] = Field(
|
|
400
|
+
default=None,
|
|
401
|
+
description="Don't manually set this, it is used for indexing samples in the dataset.",
|
|
402
|
+
)
|
|
403
|
+
classifications: Optional[List[Classification]] = Field(
|
|
404
|
+
default=None, description="Optional list of classifications"
|
|
405
|
+
)
|
|
406
|
+
bboxes: Optional[List[Bbox]] = Field(default=None, description="Optional list of bounding boxes")
|
|
407
|
+
bitmasks: Optional[List[Bitmask]] = Field(default=None, description="Optional list of bitmasks")
|
|
408
|
+
polygons: Optional[List[Polygon]] = Field(default=None, description="Optional list of polygons")
|
|
409
|
+
|
|
410
|
+
attribution: Optional[Attribution] = Field(default=None, description="Attribution information for the image")
|
|
411
|
+
dataset_name: Optional[str] = Field(
|
|
412
|
+
default=None,
|
|
413
|
+
description=(
|
|
414
|
+
"Don't manually set this, it will be automatically defined during initialization. "
|
|
415
|
+
"Name of the dataset the sample belongs to. E.g. 'coco-2017' or 'midwest-vehicle-detection'."
|
|
416
|
+
),
|
|
417
|
+
)
|
|
418
|
+
meta: Optional[Dict] = Field(
|
|
419
|
+
default=None,
|
|
420
|
+
description="Additional metadata, e.g., camera settings, GPS data, etc.",
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
def get_annotations(self, primitive_types: Optional[List[Type[Primitive]]] = None) -> List[Primitive]:
|
|
424
|
+
"""
|
|
425
|
+
Returns a list of all annotations (classifications, objects, bitmasks, polygons) for the sample.
|
|
426
|
+
"""
|
|
427
|
+
primitive_types = primitive_types or PRIMITIVE_TYPES
|
|
428
|
+
annotations_primitives = [
|
|
429
|
+
getattr(self, primitive_type.column_name(), None) for primitive_type in primitive_types
|
|
430
|
+
]
|
|
431
|
+
annotations = more_itertools.flatten(
|
|
432
|
+
[primitives for primitives in annotations_primitives if primitives is not None]
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
return list(annotations)
|
|
436
|
+
|
|
437
|
+
def read_image_pillow(self) -> Image.Image:
|
|
438
|
+
"""
|
|
439
|
+
Reads the image from the file path and returns it as a PIL Image.
|
|
440
|
+
Raises FileNotFoundError if the image file does not exist.
|
|
441
|
+
"""
|
|
442
|
+
if self.file_path is None:
|
|
443
|
+
raise ValueError(f"Sample has no '{SampleField.FILE_PATH}' defined.")
|
|
444
|
+
path_image = Path(self.file_path)
|
|
445
|
+
if not path_image.exists():
|
|
446
|
+
raise FileNotFoundError(f"Image file {path_image} does not exist. Please check the file path.")
|
|
447
|
+
|
|
448
|
+
image = Image.open(str(path_image))
|
|
449
|
+
return image
|
|
450
|
+
|
|
451
|
+
def read_image(self) -> np.ndarray:
|
|
452
|
+
if self.storage_format == StorageFormat.VIDEO:
|
|
453
|
+
video = cv2.VideoCapture(str(self.file_path))
|
|
454
|
+
if self.collection_index is None:
|
|
455
|
+
raise ValueError("collection_index must be set for video storage format to read the correct frame.")
|
|
456
|
+
video.set(cv2.CAP_PROP_POS_FRAMES, self.collection_index)
|
|
457
|
+
success, image = video.read()
|
|
458
|
+
video.release()
|
|
459
|
+
if not success:
|
|
460
|
+
raise ValueError(f"Could not read frame {self.collection_index} from video file {self.file_path}.")
|
|
461
|
+
return image
|
|
462
|
+
|
|
463
|
+
elif self.storage_format == StorageFormat.IMAGE:
|
|
464
|
+
image_pil = self.read_image_pillow()
|
|
465
|
+
image = np.array(image_pil)
|
|
466
|
+
else:
|
|
467
|
+
raise ValueError(f"Unsupported storage format: {self.storage_format}")
|
|
468
|
+
return image
|
|
469
|
+
|
|
470
|
+
def draw_annotations(self, image: Optional[np.ndarray] = None) -> np.ndarray:
|
|
471
|
+
from hafnia.visualizations import image_visualizations
|
|
472
|
+
|
|
473
|
+
if image is None:
|
|
474
|
+
image = self.read_image()
|
|
475
|
+
annotations = self.get_annotations()
|
|
476
|
+
annotations_visualized = image_visualizations.draw_annotations(image=image, primitives=annotations)
|
|
477
|
+
return annotations_visualized
|
hafnia/dataset/license_types.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
from typing import List
|
|
1
|
+
from typing import List
|
|
2
2
|
|
|
3
|
-
from hafnia.dataset.
|
|
3
|
+
from hafnia.dataset.hafnia_dataset_types import License
|
|
4
4
|
|
|
5
5
|
LICENSE_TYPES: List[License] = [
|
|
6
6
|
License(
|
|
@@ -46,7 +46,7 @@ LICENSE_TYPES: List[License] = [
|
|
|
46
46
|
]
|
|
47
47
|
|
|
48
48
|
|
|
49
|
-
def get_license_by_url(url: str) ->
|
|
49
|
+
def get_license_by_url(url: str) -> License:
|
|
50
50
|
for license in LICENSE_TYPES:
|
|
51
51
|
# To handle http urls
|
|
52
52
|
license_url = (license.url or "").replace("http://", "https://")
|
|
@@ -56,7 +56,7 @@ def get_license_by_url(url: str) -> Optional[License]:
|
|
|
56
56
|
raise ValueError(f"License with URL '{url}' not found.")
|
|
57
57
|
|
|
58
58
|
|
|
59
|
-
def get_license_by_short_name(short_name: str) ->
|
|
59
|
+
def get_license_by_short_name(short_name: str) -> License:
|
|
60
60
|
for license in LICENSE_TYPES:
|
|
61
61
|
if license.name_short == short_name:
|
|
62
62
|
return license
|
|
@@ -5,13 +5,14 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
|
|
|
5
5
|
import polars as pl
|
|
6
6
|
import rich
|
|
7
7
|
from rich import print as rprint
|
|
8
|
-
from rich.progress import track
|
|
9
8
|
from rich.table import Table
|
|
10
9
|
|
|
11
10
|
from hafnia.dataset.dataset_names import PrimitiveField, SampleField, SplitName
|
|
11
|
+
from hafnia.dataset.hafnia_dataset_types import Sample
|
|
12
12
|
from hafnia.dataset.operations.table_transformations import create_primitive_table
|
|
13
13
|
from hafnia.dataset.primitives import PRIMITIVE_TYPES
|
|
14
14
|
from hafnia.log import user_logger
|
|
15
|
+
from hafnia.utils import progress_bar
|
|
15
16
|
|
|
16
17
|
if TYPE_CHECKING: # Using 'TYPE_CHECKING' to avoid circular imports during type checking
|
|
17
18
|
from hafnia.dataset.hafnia_dataset import HafniaDataset
|
|
@@ -188,7 +189,6 @@ def check_dataset(dataset: HafniaDataset, check_splits: bool = True):
|
|
|
188
189
|
Performs various checks on the dataset to ensure its integrity and consistency.
|
|
189
190
|
Raises errors if any issues are found.
|
|
190
191
|
"""
|
|
191
|
-
from hafnia.dataset.hafnia_dataset import Sample
|
|
192
192
|
|
|
193
193
|
user_logger.info("Checking Hafnia dataset...")
|
|
194
194
|
assert isinstance(dataset.info.dataset_name, str) and len(dataset.info.dataset_name) > 0
|
|
@@ -226,7 +226,7 @@ def check_dataset(dataset: HafniaDataset, check_splits: bool = True):
|
|
|
226
226
|
f"classes: {class_names}. "
|
|
227
227
|
)
|
|
228
228
|
|
|
229
|
-
for sample_dict in
|
|
229
|
+
for sample_dict in progress_bar(dataset, description="Checking samples in dataset"):
|
|
230
230
|
sample = Sample(**sample_dict) # noqa: F841
|
|
231
231
|
|
|
232
232
|
|
|
@@ -31,7 +31,6 @@ that the signatures match.
|
|
|
31
31
|
|
|
32
32
|
import json
|
|
33
33
|
import re
|
|
34
|
-
import shutil
|
|
35
34
|
import textwrap
|
|
36
35
|
from pathlib import Path
|
|
37
36
|
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Type, Union
|
|
@@ -40,7 +39,6 @@ import cv2
|
|
|
40
39
|
import more_itertools
|
|
41
40
|
import numpy as np
|
|
42
41
|
import polars as pl
|
|
43
|
-
from rich.progress import track
|
|
44
42
|
|
|
45
43
|
from hafnia.dataset import dataset_helpers
|
|
46
44
|
from hafnia.dataset.dataset_names import (
|
|
@@ -49,14 +47,15 @@ from hafnia.dataset.dataset_names import (
|
|
|
49
47
|
SampleField,
|
|
50
48
|
StorageFormat,
|
|
51
49
|
)
|
|
50
|
+
from hafnia.dataset.hafnia_dataset_types import Sample, TaskInfo
|
|
52
51
|
from hafnia.dataset.operations.table_transformations import update_class_indices
|
|
53
52
|
from hafnia.dataset.primitives import get_primitive_type_from_string
|
|
54
53
|
from hafnia.dataset.primitives.primitive import Primitive
|
|
55
54
|
from hafnia.log import user_logger
|
|
56
|
-
from hafnia.utils import remove_duplicates_preserve_order
|
|
55
|
+
from hafnia.utils import progress_bar, remove_duplicates_preserve_order
|
|
57
56
|
|
|
58
57
|
if TYPE_CHECKING: # Using 'TYPE_CHECKING' to avoid circular imports during type checking
|
|
59
|
-
from hafnia.dataset.hafnia_dataset import HafniaDataset
|
|
58
|
+
from hafnia.dataset.hafnia_dataset import HafniaDataset
|
|
60
59
|
|
|
61
60
|
|
|
62
61
|
### Image transformations ###
|
|
@@ -64,7 +63,7 @@ class AnonymizeByPixelation:
|
|
|
64
63
|
def __init__(self, resize_factor: float = 0.10):
|
|
65
64
|
self.resize_factor = resize_factor
|
|
66
65
|
|
|
67
|
-
def __call__(self, frame: np.ndarray, sample:
|
|
66
|
+
def __call__(self, frame: np.ndarray, sample: Sample) -> np.ndarray:
|
|
68
67
|
org_size = frame.shape[:2]
|
|
69
68
|
frame = cv2.resize(frame, (0, 0), fx=self.resize_factor, fy=self.resize_factor)
|
|
70
69
|
frame = cv2.resize(frame, org_size[::-1], interpolation=cv2.INTER_NEAREST)
|
|
@@ -73,17 +72,15 @@ class AnonymizeByPixelation:
|
|
|
73
72
|
|
|
74
73
|
def transform_images(
|
|
75
74
|
dataset: "HafniaDataset",
|
|
76
|
-
transform: Callable[[np.ndarray,
|
|
75
|
+
transform: Callable[[np.ndarray, Sample], np.ndarray],
|
|
77
76
|
path_output: Path,
|
|
78
77
|
description: str = "Transform images",
|
|
79
78
|
) -> "HafniaDataset":
|
|
80
|
-
from hafnia.dataset.hafnia_dataset import Sample
|
|
81
|
-
|
|
82
79
|
new_paths = []
|
|
83
80
|
path_image_folder = path_output / "data"
|
|
84
81
|
path_image_folder.mkdir(parents=True, exist_ok=True)
|
|
85
82
|
|
|
86
|
-
for sample_dict in
|
|
83
|
+
for sample_dict in progress_bar(dataset, description=description):
|
|
87
84
|
sample = Sample(**sample_dict)
|
|
88
85
|
image = sample.read_image()
|
|
89
86
|
image_transformed = transform(image, sample)
|
|
@@ -102,15 +99,15 @@ def convert_to_image_storage_format(
|
|
|
102
99
|
path_output_folder: Path,
|
|
103
100
|
reextract_frames: bool,
|
|
104
101
|
image_format: str = "png",
|
|
105
|
-
transform: Optional[Callable[[np.ndarray,
|
|
102
|
+
transform: Optional[Callable[[np.ndarray, Sample], np.ndarray]] = None,
|
|
106
103
|
) -> "HafniaDataset":
|
|
107
104
|
"""
|
|
108
105
|
Convert a video-based dataset ("storage_format" == "video", FieldName.STORAGE_FORMAT == StorageFormat.VIDEO)
|
|
109
106
|
to an image-based dataset by extracting frames.
|
|
110
107
|
"""
|
|
111
|
-
from hafnia.dataset.hafnia_dataset import HafniaDataset
|
|
108
|
+
from hafnia.dataset.hafnia_dataset import HafniaDataset
|
|
112
109
|
|
|
113
|
-
path_images = path_output_folder / "data"
|
|
110
|
+
path_images = (path_output_folder / "data").absolute()
|
|
114
111
|
path_images.mkdir(parents=True, exist_ok=True)
|
|
115
112
|
|
|
116
113
|
# Only video format dataset samples are processed
|
|
@@ -128,7 +125,7 @@ def convert_to_image_storage_format(
|
|
|
128
125
|
video = cv2.VideoCapture(str(path_video))
|
|
129
126
|
|
|
130
127
|
video_samples = video_samples.sort(SampleField.COLLECTION_INDEX)
|
|
131
|
-
for sample_dict in
|
|
128
|
+
for sample_dict in progress_bar(
|
|
132
129
|
video_samples.iter_rows(named=True),
|
|
133
130
|
total=video_samples.height,
|
|
134
131
|
description=f"Extracting frames from '{Path(path_video).name}'",
|
|
@@ -147,7 +144,7 @@ def convert_to_image_storage_format(
|
|
|
147
144
|
}
|
|
148
145
|
)
|
|
149
146
|
if reextract_frames:
|
|
150
|
-
|
|
147
|
+
path_image.unlink(missing_ok=True)
|
|
151
148
|
if path_image.exists():
|
|
152
149
|
continue
|
|
153
150
|
|
|
@@ -168,10 +165,10 @@ def convert_to_image_storage_format(
|
|
|
168
165
|
|
|
169
166
|
|
|
170
167
|
def get_task_info_from_task_name_and_primitive(
|
|
171
|
-
tasks: List[
|
|
168
|
+
tasks: List[TaskInfo],
|
|
172
169
|
task_name: Optional[str] = None,
|
|
173
170
|
primitive: Union[None, str, Type[Primitive]] = None,
|
|
174
|
-
) ->
|
|
171
|
+
) -> TaskInfo:
|
|
175
172
|
if len(tasks) == 0:
|
|
176
173
|
raise ValueError("Dataset has no tasks defined.")
|
|
177
174
|
|
|
@@ -423,7 +420,7 @@ def _validate_inputs_select_samples_by_class_name(
|
|
|
423
420
|
name: Union[List[str], str],
|
|
424
421
|
task_name: Optional[str] = None,
|
|
425
422
|
primitive: Optional[Type[Primitive]] = None,
|
|
426
|
-
) -> Tuple[
|
|
423
|
+
) -> Tuple[TaskInfo, List[str]]:
|
|
427
424
|
if isinstance(name, str):
|
|
428
425
|
name = [name]
|
|
429
426
|
names = list(name)
|