hafnia 0.4.2__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. hafnia/dataset/{dataset_upload_helper.py → dataset_details_uploader.py} +114 -191
  2. hafnia/dataset/dataset_names.py +26 -0
  3. hafnia/dataset/format_conversions/format_coco.py +490 -0
  4. hafnia/dataset/format_conversions/format_helpers.py +33 -0
  5. hafnia/dataset/format_conversions/format_image_classification_folder.py +95 -14
  6. hafnia/dataset/format_conversions/format_yolo.py +115 -25
  7. hafnia/dataset/format_conversions/torchvision_datasets.py +10 -8
  8. hafnia/dataset/hafnia_dataset.py +20 -466
  9. hafnia/dataset/hafnia_dataset_types.py +477 -0
  10. hafnia/dataset/license_types.py +4 -4
  11. hafnia/dataset/operations/dataset_stats.py +3 -3
  12. hafnia/dataset/operations/dataset_transformations.py +14 -17
  13. hafnia/dataset/operations/table_transformations.py +20 -13
  14. hafnia/dataset/primitives/bbox.py +6 -2
  15. hafnia/dataset/primitives/bitmask.py +21 -46
  16. hafnia/dataset/primitives/classification.py +1 -1
  17. hafnia/dataset/primitives/polygon.py +43 -2
  18. hafnia/dataset/primitives/primitive.py +1 -1
  19. hafnia/dataset/primitives/segmentation.py +1 -1
  20. hafnia/experiment/hafnia_logger.py +13 -4
  21. hafnia/platform/datasets.py +2 -3
  22. hafnia/torch_helpers.py +48 -4
  23. hafnia/utils.py +34 -0
  24. hafnia/visualizations/image_visualizations.py +3 -1
  25. {hafnia-0.4.2.dist-info → hafnia-0.4.3.dist-info}/METADATA +2 -2
  26. {hafnia-0.4.2.dist-info → hafnia-0.4.3.dist-info}/RECORD +29 -26
  27. {hafnia-0.4.2.dist-info → hafnia-0.4.3.dist-info}/WHEEL +0 -0
  28. {hafnia-0.4.2.dist-info → hafnia-0.4.3.dist-info}/entry_points.txt +0 -0
  29. {hafnia-0.4.2.dist-info → hafnia-0.4.3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,477 @@
1
+ import collections
2
+ import json
3
+ from datetime import datetime
4
+ from pathlib import Path
5
+ from typing import Any, Dict, List, Optional, Type, Union
6
+
7
+ import cv2
8
+ import more_itertools
9
+ import numpy as np
10
+ from packaging.version import Version
11
+ from PIL import Image
12
+ from pydantic import BaseModel, Field, field_serializer, field_validator
13
+
14
+ import hafnia
15
+ from hafnia.dataset.dataset_names import SampleField, StorageFormat
16
+ from hafnia.dataset.primitives import (
17
+ PRIMITIVE_TYPES,
18
+ Bbox,
19
+ Bitmask,
20
+ Classification,
21
+ Polygon,
22
+ get_primitive_type_from_string,
23
+ )
24
+ from hafnia.dataset.primitives.primitive import Primitive
25
+ from hafnia.log import user_logger
26
+
27
+
28
+ class TaskInfo(BaseModel):
29
+ primitive: Type[Primitive] = Field(
30
+ description="Primitive class or string name of the primitive, e.g. 'Bbox' or 'bitmask'"
31
+ )
32
+ class_names: Optional[List[str]] = Field(default=None, description="Optional list of class names for the primitive")
33
+ name: Optional[str] = Field(
34
+ default=None,
35
+ description=(
36
+ "Optional name for the task. 'None' will use default name of the provided primitive. "
37
+ "e.g. Bbox ->'bboxes', Bitmask -> 'bitmasks' etc."
38
+ ),
39
+ )
40
+
41
+ def model_post_init(self, __context: Any) -> None:
42
+ if self.name is None:
43
+ self.name = self.primitive.default_task_name()
44
+
45
+ def get_class_index(self, class_name: str) -> int:
46
+ """Get class index for a given class name"""
47
+ if self.class_names is None:
48
+ raise ValueError(f"Task '{self.name}' has no class names defined.")
49
+ if class_name not in self.class_names:
50
+ raise ValueError(f"Class name '{class_name}' not found in task '{self.name}'.")
51
+ return self.class_names.index(class_name)
52
+
53
+ # The 'primitive'-field of type 'Type[Primitive]' is not supported by pydantic out-of-the-box as
54
+ # the 'Primitive' class is an abstract base class and for the actual primtives such as Bbox, Bitmask, Classification.
55
+ # Below magic functions ('ensure_primitive' and 'serialize_primitive') ensures that the 'primitive' field can
56
+ # correctly validate and serialize sub-classes (Bbox, Classification, ...).
57
+ @field_validator("primitive", mode="plain")
58
+ @classmethod
59
+ def ensure_primitive(cls, primitive: Any) -> Any:
60
+ if isinstance(primitive, str):
61
+ return get_primitive_type_from_string(primitive)
62
+
63
+ if issubclass(primitive, Primitive):
64
+ return primitive
65
+
66
+ raise ValueError(f"Primitive must be a string or a Primitive subclass, got {type(primitive)} instead.")
67
+
68
+ @field_serializer("primitive")
69
+ @classmethod
70
+ def serialize_primitive(cls, primitive: Type[Primitive]) -> str:
71
+ if not issubclass(primitive, Primitive):
72
+ raise ValueError(f"Primitive must be a subclass of Primitive, got {type(primitive)} instead.")
73
+ return primitive.__name__
74
+
75
+ @field_validator("class_names", mode="after")
76
+ @classmethod
77
+ def validate_unique_class_names(cls, class_names: Optional[List[str]]) -> Optional[List[str]]:
78
+ """Validate that class names are unique"""
79
+ if class_names is None:
80
+ return None
81
+ duplicate_class_names = set([name for name in class_names if class_names.count(name) > 1])
82
+ if duplicate_class_names:
83
+ raise ValueError(
84
+ f"Class names must be unique. The following class names appear multiple times: {duplicate_class_names}."
85
+ )
86
+ return class_names
87
+
88
+ def full_name(self) -> str:
89
+ """Get qualified name for the task: <primitive_name>:<task_name>"""
90
+ return f"{self.primitive.__name__}:{self.name}"
91
+
92
+ # To get unique hash value for TaskInfo objects
93
+ def __hash__(self) -> int:
94
+ class_names = self.class_names or []
95
+ return hash((self.name, self.primitive.__name__, tuple(class_names)))
96
+
97
+ def __eq__(self, other: Any) -> bool:
98
+ if not isinstance(other, TaskInfo):
99
+ return False
100
+ return self.name == other.name and self.primitive == other.primitive and self.class_names == other.class_names
101
+
102
+
103
+ class DatasetInfo(BaseModel):
104
+ dataset_name: str = Field(description="Name of the dataset, e.g. 'coco'")
105
+ version: Optional[str] = Field(default=None, description="Version of the dataset")
106
+ tasks: List[TaskInfo] = Field(default=None, description="List of tasks in the dataset")
107
+ reference_bibtex: Optional[str] = Field(
108
+ default=None,
109
+ description="Optional, BibTeX reference to dataset publication",
110
+ )
111
+ reference_paper_url: Optional[str] = Field(
112
+ default=None,
113
+ description="Optional, URL to dataset publication",
114
+ )
115
+ reference_dataset_page: Optional[str] = Field(
116
+ default=None,
117
+ description="Optional, URL to the dataset page",
118
+ )
119
+ meta: Optional[Dict[str, Any]] = Field(default=None, description="Optional metadata about the dataset")
120
+ format_version: str = Field(
121
+ default=hafnia.__dataset_format_version__,
122
+ description="Version of the Hafnia dataset format. You should not set this manually.",
123
+ )
124
+ updated_at: datetime = Field(
125
+ default_factory=datetime.now,
126
+ description="Timestamp of the last update to the dataset info. You should not set this manually.",
127
+ )
128
+
129
+ @field_validator("tasks", mode="after")
130
+ @classmethod
131
+ def _validate_check_for_duplicate_tasks(cls, tasks: Optional[List[TaskInfo]]) -> List[TaskInfo]:
132
+ if tasks is None:
133
+ return []
134
+ task_name_counts = collections.Counter(task.name for task in tasks)
135
+ duplicate_task_names = [name for name, count in task_name_counts.items() if count > 1]
136
+ if duplicate_task_names:
137
+ raise ValueError(
138
+ f"Tasks must be unique. The following tasks appear multiple times: {duplicate_task_names}."
139
+ )
140
+ return tasks
141
+
142
+ @field_validator("format_version")
143
+ @classmethod
144
+ def _validate_format_version(cls, format_version: str) -> str:
145
+ try:
146
+ Version(format_version)
147
+ except Exception as e:
148
+ raise ValueError(f"Invalid format_version '{format_version}'. Must be a valid version string.") from e
149
+
150
+ if Version(format_version) > Version(hafnia.__dataset_format_version__):
151
+ user_logger.warning(
152
+ f"The loaded dataset format version '{format_version}' is newer than the format version "
153
+ f"'{hafnia.__dataset_format_version__}' used in your version of Hafnia. Please consider "
154
+ f"updating Hafnia package."
155
+ )
156
+ return format_version
157
+
158
+ @field_validator("version")
159
+ @classmethod
160
+ def _validate_version(cls, dataset_version: Optional[str]) -> Optional[str]:
161
+ if dataset_version is None:
162
+ return None
163
+
164
+ try:
165
+ Version(dataset_version)
166
+ except Exception as e:
167
+ raise ValueError(f"Invalid dataset_version '{dataset_version}'. Must be a valid version string.") from e
168
+
169
+ return dataset_version
170
+
171
+ def check_for_duplicate_task_names(self) -> List[TaskInfo]:
172
+ return self._validate_check_for_duplicate_tasks(self.tasks)
173
+
174
+ def write_json(self, path: Path, indent: Optional[int] = 4) -> None:
175
+ json_str = self.model_dump_json(indent=indent)
176
+ path.write_text(json_str)
177
+
178
+ @staticmethod
179
+ def from_json_file(path: Path) -> "DatasetInfo":
180
+ json_str = path.read_text()
181
+
182
+ # TODO: Deprecated support for old dataset info without format_version
183
+ # Below 4 lines can be replaced by 'dataset_info = DatasetInfo.model_validate_json(json_str)'
184
+ # when all datasets include a 'format_version' field
185
+ json_dict = json.loads(json_str)
186
+ if "format_version" not in json_dict:
187
+ json_dict["format_version"] = "0.0.0"
188
+
189
+ if "updated_at" not in json_dict:
190
+ json_dict["updated_at"] = datetime.min.isoformat()
191
+ dataset_info = DatasetInfo.model_validate(json_dict)
192
+
193
+ return dataset_info
194
+
195
+ @staticmethod
196
+ def merge(info0: "DatasetInfo", info1: "DatasetInfo") -> "DatasetInfo":
197
+ """
198
+ Merges two DatasetInfo objects into one and validates if they are compatible.
199
+ """
200
+ for task_ds0 in info0.tasks:
201
+ for task_ds1 in info1.tasks:
202
+ same_name = task_ds0.name == task_ds1.name
203
+ same_primitive = task_ds0.primitive == task_ds1.primitive
204
+ same_name_different_primitive = same_name and not same_primitive
205
+ if same_name_different_primitive:
206
+ raise ValueError(
207
+ f"Cannot merge datasets with different primitives for the same task name: "
208
+ f"'{task_ds0.name}' has primitive '{task_ds0.primitive}' in dataset0 and "
209
+ f"'{task_ds1.primitive}' in dataset1."
210
+ )
211
+
212
+ is_same_name_and_primitive = same_name and same_primitive
213
+ if is_same_name_and_primitive:
214
+ task_ds0_class_names = task_ds0.class_names or []
215
+ task_ds1_class_names = task_ds1.class_names or []
216
+ if task_ds0_class_names != task_ds1_class_names:
217
+ raise ValueError(
218
+ f"Cannot merge datasets with different class names for the same task name and primitive: "
219
+ f"'{task_ds0.name}' with primitive '{task_ds0.primitive}' has class names "
220
+ f"{task_ds0_class_names} in dataset0 and {task_ds1_class_names} in dataset1."
221
+ )
222
+
223
+ if info1.format_version != info0.format_version:
224
+ user_logger.warning(
225
+ "Dataset format version of the two datasets do not match. "
226
+ f"'{info1.format_version}' vs '{info0.format_version}'."
227
+ )
228
+ dataset_format_version = info0.format_version
229
+ if hafnia.__dataset_format_version__ != dataset_format_version:
230
+ user_logger.warning(
231
+ f"Dataset format version '{dataset_format_version}' does not match the current "
232
+ f"Hafnia format version '{hafnia.__dataset_format_version__}'."
233
+ )
234
+ unique_tasks = set(info0.tasks + info1.tasks)
235
+ meta = (info0.meta or {}).copy()
236
+ meta.update(info1.meta or {})
237
+ return DatasetInfo(
238
+ dataset_name=info0.dataset_name + "+" + info1.dataset_name,
239
+ version=None,
240
+ tasks=list(unique_tasks),
241
+ meta=meta,
242
+ format_version=dataset_format_version,
243
+ )
244
+
245
+ def get_task_by_name(self, task_name: str) -> TaskInfo:
246
+ """
247
+ Get task by its name. Raises an error if the task name is not found or if multiple tasks have the same name.
248
+ """
249
+ tasks_with_name = [task for task in self.tasks if task.name == task_name]
250
+ if not tasks_with_name:
251
+ raise ValueError(f"Task with name '{task_name}' not found in dataset info.")
252
+ if len(tasks_with_name) > 1:
253
+ raise ValueError(f"Multiple tasks found with name '{task_name}'. This should not happen!")
254
+ return tasks_with_name[0]
255
+
256
+ def get_tasks_by_primitive(self, primitive: Union[Type[Primitive], str]) -> List[TaskInfo]:
257
+ """
258
+ Get all tasks by their primitive type.
259
+ """
260
+ if isinstance(primitive, str):
261
+ primitive = get_primitive_type_from_string(primitive)
262
+
263
+ tasks_with_primitive = [task for task in self.tasks if task.primitive == primitive]
264
+ return tasks_with_primitive
265
+
266
+ def get_task_by_primitive(self, primitive: Union[Type[Primitive], str]) -> TaskInfo:
267
+ """
268
+ Get task by its primitive type. Raises an error if the primitive type is not found or if multiple tasks
269
+ have the same primitive type.
270
+ """
271
+
272
+ tasks_with_primitive = self.get_tasks_by_primitive(primitive)
273
+ if len(tasks_with_primitive) == 0:
274
+ raise ValueError(f"Task with primitive {primitive} not found in dataset info.")
275
+ if len(tasks_with_primitive) > 1:
276
+ raise ValueError(
277
+ f"Multiple tasks found with primitive {primitive}. Use '{self.get_task_by_name.__name__}' instead."
278
+ )
279
+ return tasks_with_primitive[0]
280
+
281
+ def get_task_by_task_name_and_primitive(
282
+ self,
283
+ task_name: Optional[str],
284
+ primitive: Optional[Union[Type[Primitive], str]],
285
+ ) -> TaskInfo:
286
+ """
287
+ Logic to get a unique task based on the provided 'task_name' and/or 'primitive'.
288
+ If both 'task_name' and 'primitive' are None, the dataset must have only one task.
289
+ """
290
+ from hafnia.dataset.operations import dataset_transformations
291
+
292
+ task = dataset_transformations.get_task_info_from_task_name_and_primitive(
293
+ tasks=self.tasks,
294
+ primitive=primitive,
295
+ task_name=task_name,
296
+ )
297
+ return task
298
+
299
+ def replace_task(self, old_task: TaskInfo, new_task: Optional[TaskInfo]) -> "DatasetInfo":
300
+ dataset_info = self.model_copy(deep=True)
301
+ has_task = any(t for t in dataset_info.tasks if t.name == old_task.name and t.primitive == old_task.primitive)
302
+ if not has_task:
303
+ raise ValueError(f"Task '{old_task.__repr__()}' not found in dataset info.")
304
+
305
+ new_tasks = []
306
+ for task in dataset_info.tasks:
307
+ if task.name == old_task.name and task.primitive == old_task.primitive:
308
+ if new_task is None:
309
+ continue # Remove the task
310
+ new_tasks.append(new_task)
311
+ else:
312
+ new_tasks.append(task)
313
+
314
+ dataset_info.tasks = new_tasks
315
+ return dataset_info
316
+
317
+
318
+ class License(BaseModel):
319
+ """License information"""
320
+
321
+ name: Optional[str] = Field(
322
+ default=None,
323
+ description="License name. E.g. 'Creative Commons: Attribution 2.0 Generic'",
324
+ max_length=100,
325
+ )
326
+ name_short: Optional[str] = Field(
327
+ default=None,
328
+ description="License short name or abbreviation. E.g. 'CC BY 4.0'",
329
+ max_length=100,
330
+ )
331
+ url: Optional[str] = Field(
332
+ default=None,
333
+ description="License URL e.g. https://creativecommons.org/licenses/by/4.0/",
334
+ )
335
+ description: Optional[str] = Field(
336
+ default=None,
337
+ description=(
338
+ "License description e.g. 'You must give appropriate credit, provide a "
339
+ "link to the license, and indicate if changes were made.'"
340
+ ),
341
+ )
342
+
343
+ valid_date: Optional[datetime] = Field(
344
+ default=None,
345
+ description="License valid date. E.g. '2023-01-01T00:00:00Z'",
346
+ )
347
+
348
+ permissions: Optional[List[str]] = Field(
349
+ default=None,
350
+ description="License permissions. Allowed to Access, Label, Distribute, Represent and Modify data.",
351
+ )
352
+ liability: Optional[str] = Field(
353
+ default=None,
354
+ description="License liability. Optional and not always applicable.",
355
+ )
356
+ location: Optional[str] = Field(
357
+ default=None,
358
+ description=(
359
+ "License Location. E.g. Iowa state. This is essential to understand the industry and "
360
+ "privacy location specific rules that applies to the data. Optional and not always applicable."
361
+ ),
362
+ )
363
+ notes: Optional[str] = Field(
364
+ default=None,
365
+ description="Additional license notes. Optional and not always applicable.",
366
+ )
367
+
368
+
369
+ class Attribution(BaseModel):
370
+ """Attribution information for the image: Giving source and credit to the original creator"""
371
+
372
+ title: Optional[str] = Field(default=None, description="Title of the image", max_length=255)
373
+ creator: Optional[str] = Field(default=None, description="Creator of the image", max_length=255)
374
+ creator_url: Optional[str] = Field(default=None, description="URL of the creator", max_length=255)
375
+ date_captured: Optional[datetime] = Field(default=None, description="Date when the image was captured")
376
+ copyright_notice: Optional[str] = Field(default=None, description="Copyright notice for the image", max_length=255)
377
+ licenses: Optional[List[License]] = Field(default=None, description="List of licenses for the image")
378
+ disclaimer: Optional[str] = Field(default=None, description="Disclaimer for the image", max_length=255)
379
+ changes: Optional[str] = Field(default=None, description="Changes made to the image", max_length=255)
380
+ source_url: Optional[str] = Field(default=None, description="Source URL for the image", max_length=255)
381
+
382
+
383
+ class Sample(BaseModel):
384
+ file_path: Optional[str] = Field(description="Path to the image/video file.")
385
+ height: int = Field(description="Height of the image")
386
+ width: int = Field(description="Width of the image")
387
+ split: str = Field(description="Split name, e.g., 'train', 'val', 'test'")
388
+ tags: List[str] = Field(
389
+ default_factory=list,
390
+ description="Tags for a given sample. Used for creating subsets of the dataset.",
391
+ )
392
+ storage_format: str = Field(
393
+ default=StorageFormat.IMAGE,
394
+ description="Storage format. Sample data is stored as image or inside a video or zip file.",
395
+ )
396
+ collection_index: Optional[int] = Field(default=None, description="Optional e.g. frame number for video datasets")
397
+ collection_id: Optional[str] = Field(default=None, description="Optional e.g. video name for video datasets")
398
+ remote_path: Optional[str] = Field(default=None, description="Optional remote path for the image, if applicable")
399
+ sample_index: Optional[int] = Field(
400
+ default=None,
401
+ description="Don't manually set this, it is used for indexing samples in the dataset.",
402
+ )
403
+ classifications: Optional[List[Classification]] = Field(
404
+ default=None, description="Optional list of classifications"
405
+ )
406
+ bboxes: Optional[List[Bbox]] = Field(default=None, description="Optional list of bounding boxes")
407
+ bitmasks: Optional[List[Bitmask]] = Field(default=None, description="Optional list of bitmasks")
408
+ polygons: Optional[List[Polygon]] = Field(default=None, description="Optional list of polygons")
409
+
410
+ attribution: Optional[Attribution] = Field(default=None, description="Attribution information for the image")
411
+ dataset_name: Optional[str] = Field(
412
+ default=None,
413
+ description=(
414
+ "Don't manually set this, it will be automatically defined during initialization. "
415
+ "Name of the dataset the sample belongs to. E.g. 'coco-2017' or 'midwest-vehicle-detection'."
416
+ ),
417
+ )
418
+ meta: Optional[Dict] = Field(
419
+ default=None,
420
+ description="Additional metadata, e.g., camera settings, GPS data, etc.",
421
+ )
422
+
423
+ def get_annotations(self, primitive_types: Optional[List[Type[Primitive]]] = None) -> List[Primitive]:
424
+ """
425
+ Returns a list of all annotations (classifications, objects, bitmasks, polygons) for the sample.
426
+ """
427
+ primitive_types = primitive_types or PRIMITIVE_TYPES
428
+ annotations_primitives = [
429
+ getattr(self, primitive_type.column_name(), None) for primitive_type in primitive_types
430
+ ]
431
+ annotations = more_itertools.flatten(
432
+ [primitives for primitives in annotations_primitives if primitives is not None]
433
+ )
434
+
435
+ return list(annotations)
436
+
437
+ def read_image_pillow(self) -> Image.Image:
438
+ """
439
+ Reads the image from the file path and returns it as a PIL Image.
440
+ Raises FileNotFoundError if the image file does not exist.
441
+ """
442
+ if self.file_path is None:
443
+ raise ValueError(f"Sample has no '{SampleField.FILE_PATH}' defined.")
444
+ path_image = Path(self.file_path)
445
+ if not path_image.exists():
446
+ raise FileNotFoundError(f"Image file {path_image} does not exist. Please check the file path.")
447
+
448
+ image = Image.open(str(path_image))
449
+ return image
450
+
451
+ def read_image(self) -> np.ndarray:
452
+ if self.storage_format == StorageFormat.VIDEO:
453
+ video = cv2.VideoCapture(str(self.file_path))
454
+ if self.collection_index is None:
455
+ raise ValueError("collection_index must be set for video storage format to read the correct frame.")
456
+ video.set(cv2.CAP_PROP_POS_FRAMES, self.collection_index)
457
+ success, image = video.read()
458
+ video.release()
459
+ if not success:
460
+ raise ValueError(f"Could not read frame {self.collection_index} from video file {self.file_path}.")
461
+ return image
462
+
463
+ elif self.storage_format == StorageFormat.IMAGE:
464
+ image_pil = self.read_image_pillow()
465
+ image = np.array(image_pil)
466
+ else:
467
+ raise ValueError(f"Unsupported storage format: {self.storage_format}")
468
+ return image
469
+
470
+ def draw_annotations(self, image: Optional[np.ndarray] = None) -> np.ndarray:
471
+ from hafnia.visualizations import image_visualizations
472
+
473
+ if image is None:
474
+ image = self.read_image()
475
+ annotations = self.get_annotations()
476
+ annotations_visualized = image_visualizations.draw_annotations(image=image, primitives=annotations)
477
+ return annotations_visualized
@@ -1,6 +1,6 @@
1
- from typing import List, Optional
1
+ from typing import List
2
2
 
3
- from hafnia.dataset.hafnia_dataset import License
3
+ from hafnia.dataset.hafnia_dataset_types import License
4
4
 
5
5
  LICENSE_TYPES: List[License] = [
6
6
  License(
@@ -46,7 +46,7 @@ LICENSE_TYPES: List[License] = [
46
46
  ]
47
47
 
48
48
 
49
- def get_license_by_url(url: str) -> Optional[License]:
49
+ def get_license_by_url(url: str) -> License:
50
50
  for license in LICENSE_TYPES:
51
51
  # To handle http urls
52
52
  license_url = (license.url or "").replace("http://", "https://")
@@ -56,7 +56,7 @@ def get_license_by_url(url: str) -> Optional[License]:
56
56
  raise ValueError(f"License with URL '{url}' not found.")
57
57
 
58
58
 
59
- def get_license_by_short_name(short_name: str) -> Optional[License]:
59
+ def get_license_by_short_name(short_name: str) -> License:
60
60
  for license in LICENSE_TYPES:
61
61
  if license.name_short == short_name:
62
62
  return license
@@ -5,13 +5,14 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
5
5
  import polars as pl
6
6
  import rich
7
7
  from rich import print as rprint
8
- from rich.progress import track
9
8
  from rich.table import Table
10
9
 
11
10
  from hafnia.dataset.dataset_names import PrimitiveField, SampleField, SplitName
11
+ from hafnia.dataset.hafnia_dataset_types import Sample
12
12
  from hafnia.dataset.operations.table_transformations import create_primitive_table
13
13
  from hafnia.dataset.primitives import PRIMITIVE_TYPES
14
14
  from hafnia.log import user_logger
15
+ from hafnia.utils import progress_bar
15
16
 
16
17
  if TYPE_CHECKING: # Using 'TYPE_CHECKING' to avoid circular imports during type checking
17
18
  from hafnia.dataset.hafnia_dataset import HafniaDataset
@@ -188,7 +189,6 @@ def check_dataset(dataset: HafniaDataset, check_splits: bool = True):
188
189
  Performs various checks on the dataset to ensure its integrity and consistency.
189
190
  Raises errors if any issues are found.
190
191
  """
191
- from hafnia.dataset.hafnia_dataset import Sample
192
192
 
193
193
  user_logger.info("Checking Hafnia dataset...")
194
194
  assert isinstance(dataset.info.dataset_name, str) and len(dataset.info.dataset_name) > 0
@@ -226,7 +226,7 @@ def check_dataset(dataset: HafniaDataset, check_splits: bool = True):
226
226
  f"classes: {class_names}. "
227
227
  )
228
228
 
229
- for sample_dict in track(dataset, description="Checking samples in dataset"):
229
+ for sample_dict in progress_bar(dataset, description="Checking samples in dataset"):
230
230
  sample = Sample(**sample_dict) # noqa: F841
231
231
 
232
232
 
@@ -31,7 +31,6 @@ that the signatures match.
31
31
 
32
32
  import json
33
33
  import re
34
- import shutil
35
34
  import textwrap
36
35
  from pathlib import Path
37
36
  from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Type, Union
@@ -40,7 +39,6 @@ import cv2
40
39
  import more_itertools
41
40
  import numpy as np
42
41
  import polars as pl
43
- from rich.progress import track
44
42
 
45
43
  from hafnia.dataset import dataset_helpers
46
44
  from hafnia.dataset.dataset_names import (
@@ -49,14 +47,15 @@ from hafnia.dataset.dataset_names import (
49
47
  SampleField,
50
48
  StorageFormat,
51
49
  )
50
+ from hafnia.dataset.hafnia_dataset_types import Sample, TaskInfo
52
51
  from hafnia.dataset.operations.table_transformations import update_class_indices
53
52
  from hafnia.dataset.primitives import get_primitive_type_from_string
54
53
  from hafnia.dataset.primitives.primitive import Primitive
55
54
  from hafnia.log import user_logger
56
- from hafnia.utils import remove_duplicates_preserve_order
55
+ from hafnia.utils import progress_bar, remove_duplicates_preserve_order
57
56
 
58
57
  if TYPE_CHECKING: # Using 'TYPE_CHECKING' to avoid circular imports during type checking
59
- from hafnia.dataset.hafnia_dataset import HafniaDataset, Sample, TaskInfo
58
+ from hafnia.dataset.hafnia_dataset import HafniaDataset
60
59
 
61
60
 
62
61
  ### Image transformations ###
@@ -64,7 +63,7 @@ class AnonymizeByPixelation:
64
63
  def __init__(self, resize_factor: float = 0.10):
65
64
  self.resize_factor = resize_factor
66
65
 
67
- def __call__(self, frame: np.ndarray, sample: "Sample") -> np.ndarray:
66
+ def __call__(self, frame: np.ndarray, sample: Sample) -> np.ndarray:
68
67
  org_size = frame.shape[:2]
69
68
  frame = cv2.resize(frame, (0, 0), fx=self.resize_factor, fy=self.resize_factor)
70
69
  frame = cv2.resize(frame, org_size[::-1], interpolation=cv2.INTER_NEAREST)
@@ -73,17 +72,15 @@ class AnonymizeByPixelation:
73
72
 
74
73
  def transform_images(
75
74
  dataset: "HafniaDataset",
76
- transform: Callable[[np.ndarray, "Sample"], np.ndarray],
75
+ transform: Callable[[np.ndarray, Sample], np.ndarray],
77
76
  path_output: Path,
78
77
  description: str = "Transform images",
79
78
  ) -> "HafniaDataset":
80
- from hafnia.dataset.hafnia_dataset import Sample
81
-
82
79
  new_paths = []
83
80
  path_image_folder = path_output / "data"
84
81
  path_image_folder.mkdir(parents=True, exist_ok=True)
85
82
 
86
- for sample_dict in track(dataset, description=description):
83
+ for sample_dict in progress_bar(dataset, description=description):
87
84
  sample = Sample(**sample_dict)
88
85
  image = sample.read_image()
89
86
  image_transformed = transform(image, sample)
@@ -102,15 +99,15 @@ def convert_to_image_storage_format(
102
99
  path_output_folder: Path,
103
100
  reextract_frames: bool,
104
101
  image_format: str = "png",
105
- transform: Optional[Callable[[np.ndarray, "Sample"], np.ndarray]] = None,
102
+ transform: Optional[Callable[[np.ndarray, Sample], np.ndarray]] = None,
106
103
  ) -> "HafniaDataset":
107
104
  """
108
105
  Convert a video-based dataset ("storage_format" == "video", FieldName.STORAGE_FORMAT == StorageFormat.VIDEO)
109
106
  to an image-based dataset by extracting frames.
110
107
  """
111
- from hafnia.dataset.hafnia_dataset import HafniaDataset, Sample
108
+ from hafnia.dataset.hafnia_dataset import HafniaDataset
112
109
 
113
- path_images = path_output_folder / "data"
110
+ path_images = (path_output_folder / "data").absolute()
114
111
  path_images.mkdir(parents=True, exist_ok=True)
115
112
 
116
113
  # Only video format dataset samples are processed
@@ -128,7 +125,7 @@ def convert_to_image_storage_format(
128
125
  video = cv2.VideoCapture(str(path_video))
129
126
 
130
127
  video_samples = video_samples.sort(SampleField.COLLECTION_INDEX)
131
- for sample_dict in track(
128
+ for sample_dict in progress_bar(
132
129
  video_samples.iter_rows(named=True),
133
130
  total=video_samples.height,
134
131
  description=f"Extracting frames from '{Path(path_video).name}'",
@@ -147,7 +144,7 @@ def convert_to_image_storage_format(
147
144
  }
148
145
  )
149
146
  if reextract_frames:
150
- shutil.rmtree(path_image, ignore_errors=True)
147
+ path_image.unlink(missing_ok=True)
151
148
  if path_image.exists():
152
149
  continue
153
150
 
@@ -168,10 +165,10 @@ def convert_to_image_storage_format(
168
165
 
169
166
 
170
167
  def get_task_info_from_task_name_and_primitive(
171
- tasks: List["TaskInfo"],
168
+ tasks: List[TaskInfo],
172
169
  task_name: Optional[str] = None,
173
170
  primitive: Union[None, str, Type[Primitive]] = None,
174
- ) -> "TaskInfo":
171
+ ) -> TaskInfo:
175
172
  if len(tasks) == 0:
176
173
  raise ValueError("Dataset has no tasks defined.")
177
174
 
@@ -423,7 +420,7 @@ def _validate_inputs_select_samples_by_class_name(
423
420
  name: Union[List[str], str],
424
421
  task_name: Optional[str] = None,
425
422
  primitive: Optional[Type[Primitive]] = None,
426
- ) -> Tuple["TaskInfo", List[str]]:
423
+ ) -> Tuple[TaskInfo, List[str]]:
427
424
  if isinstance(name, str):
428
425
  name = [name]
429
426
  names = list(name)