hafnia 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. hafnia/dataset/{dataset_upload_helper.py → dataset_details_uploader.py} +115 -192
  2. hafnia/dataset/dataset_names.py +26 -0
  3. hafnia/dataset/dataset_recipe/dataset_recipe.py +3 -3
  4. hafnia/dataset/format_conversions/format_coco.py +490 -0
  5. hafnia/dataset/format_conversions/format_helpers.py +33 -0
  6. hafnia/dataset/format_conversions/format_image_classification_folder.py +95 -14
  7. hafnia/dataset/format_conversions/format_yolo.py +115 -25
  8. hafnia/dataset/format_conversions/torchvision_datasets.py +10 -8
  9. hafnia/dataset/hafnia_dataset.py +20 -466
  10. hafnia/dataset/hafnia_dataset_types.py +477 -0
  11. hafnia/dataset/license_types.py +4 -4
  12. hafnia/dataset/operations/dataset_stats.py +3 -3
  13. hafnia/dataset/operations/dataset_transformations.py +14 -17
  14. hafnia/dataset/operations/table_transformations.py +20 -13
  15. hafnia/dataset/primitives/bbox.py +6 -2
  16. hafnia/dataset/primitives/bitmask.py +21 -46
  17. hafnia/dataset/primitives/classification.py +1 -1
  18. hafnia/dataset/primitives/polygon.py +43 -2
  19. hafnia/dataset/primitives/primitive.py +1 -1
  20. hafnia/dataset/primitives/segmentation.py +1 -1
  21. hafnia/experiment/hafnia_logger.py +13 -4
  22. hafnia/platform/datasets.py +3 -4
  23. hafnia/torch_helpers.py +48 -4
  24. hafnia/utils.py +35 -1
  25. hafnia/visualizations/image_visualizations.py +3 -1
  26. {hafnia-0.4.1.dist-info → hafnia-0.4.3.dist-info}/METADATA +2 -2
  27. hafnia-0.4.3.dist-info/RECORD +60 -0
  28. hafnia-0.4.3.dist-info/entry_points.txt +2 -0
  29. {cli → hafnia_cli}/__main__.py +2 -2
  30. {cli → hafnia_cli}/config.py +2 -2
  31. {cli → hafnia_cli}/dataset_cmds.py +2 -2
  32. {cli → hafnia_cli}/dataset_recipe_cmds.py +1 -1
  33. {cli → hafnia_cli}/experiment_cmds.py +1 -1
  34. {cli → hafnia_cli}/profile_cmds.py +2 -2
  35. {cli → hafnia_cli}/runc_cmds.py +1 -1
  36. {cli → hafnia_cli}/trainer_package_cmds.py +2 -2
  37. hafnia-0.4.1.dist-info/RECORD +0 -57
  38. hafnia-0.4.1.dist-info/entry_points.txt +0 -2
  39. {hafnia-0.4.1.dist-info → hafnia-0.4.3.dist-info}/WHEEL +0 -0
  40. {hafnia-0.4.1.dist-info → hafnia-0.4.3.dist-info}/licenses/LICENSE +0 -0
  41. {cli → hafnia_cli}/__init__.py +0 -0
  42. {cli → hafnia_cli}/consts.py +0 -0
  43. {cli → hafnia_cli}/keychain.py +0 -0
@@ -1,25 +1,15 @@
1
1
  from __future__ import annotations
2
2
 
3
- import collections
4
3
  import copy
5
- import json
6
4
  import shutil
7
5
  from dataclasses import dataclass
8
- from datetime import datetime
9
6
  from pathlib import Path
10
7
  from random import Random
11
8
  from typing import Any, Dict, List, Optional, Tuple, Type, Union
12
9
 
13
- import cv2
14
- import more_itertools
15
- import numpy as np
16
10
  import polars as pl
17
11
  from packaging.version import Version
18
- from PIL import Image
19
- from pydantic import BaseModel, Field, field_serializer, field_validator
20
- from rich.progress import track
21
12
 
22
- import hafnia
23
13
  from hafnia.dataset import dataset_helpers
24
14
  from hafnia.dataset.dataset_names import (
25
15
  DATASET_FILENAMES_REQUIRED,
@@ -35,470 +25,19 @@ from hafnia.dataset.dataset_names import (
35
25
  StorageFormat,
36
26
  )
37
27
  from hafnia.dataset.format_conversions import (
28
+ format_coco,
38
29
  format_image_classification_folder,
39
30
  format_yolo,
40
31
  )
32
+ from hafnia.dataset.hafnia_dataset_types import DatasetInfo, Sample
41
33
  from hafnia.dataset.operations import (
42
34
  dataset_stats,
43
35
  dataset_transformations,
44
36
  table_transformations,
45
37
  )
46
- from hafnia.dataset.primitives import PRIMITIVE_TYPES, get_primitive_type_from_string
47
- from hafnia.dataset.primitives.bbox import Bbox
48
- from hafnia.dataset.primitives.bitmask import Bitmask
49
- from hafnia.dataset.primitives.classification import Classification
50
- from hafnia.dataset.primitives.polygon import Polygon
51
38
  from hafnia.dataset.primitives.primitive import Primitive
52
39
  from hafnia.log import user_logger
53
-
54
-
55
- class TaskInfo(BaseModel):
56
- primitive: Type[Primitive] = Field(
57
- description="Primitive class or string name of the primitive, e.g. 'Bbox' or 'bitmask'"
58
- )
59
- class_names: Optional[List[str]] = Field(default=None, description="Optional list of class names for the primitive")
60
- name: Optional[str] = Field(
61
- default=None,
62
- description=(
63
- "Optional name for the task. 'None' will use default name of the provided primitive. "
64
- "e.g. Bbox ->'bboxes', Bitmask -> 'bitmasks' etc."
65
- ),
66
- )
67
-
68
- def model_post_init(self, __context: Any) -> None:
69
- if self.name is None:
70
- self.name = self.primitive.default_task_name()
71
-
72
- def get_class_index(self, class_name: str) -> int:
73
- """Get class index for a given class name"""
74
- if self.class_names is None:
75
- raise ValueError(f"Task '{self.name}' has no class names defined.")
76
- if class_name not in self.class_names:
77
- raise ValueError(f"Class name '{class_name}' not found in task '{self.name}'.")
78
- return self.class_names.index(class_name)
79
-
80
- # The 'primitive'-field of type 'Type[Primitive]' is not supported by pydantic out-of-the-box as
81
- # the 'Primitive' class is an abstract base class and for the actual primtives such as Bbox, Bitmask, Classification.
82
- # Below magic functions ('ensure_primitive' and 'serialize_primitive') ensures that the 'primitive' field can
83
- # correctly validate and serialize sub-classes (Bbox, Classification, ...).
84
- @field_validator("primitive", mode="plain")
85
- @classmethod
86
- def ensure_primitive(cls, primitive: Any) -> Any:
87
- if isinstance(primitive, str):
88
- return get_primitive_type_from_string(primitive)
89
-
90
- if issubclass(primitive, Primitive):
91
- return primitive
92
-
93
- raise ValueError(f"Primitive must be a string or a Primitive subclass, got {type(primitive)} instead.")
94
-
95
- @field_serializer("primitive")
96
- @classmethod
97
- def serialize_primitive(cls, primitive: Type[Primitive]) -> str:
98
- if not issubclass(primitive, Primitive):
99
- raise ValueError(f"Primitive must be a subclass of Primitive, got {type(primitive)} instead.")
100
- return primitive.__name__
101
-
102
- @field_validator("class_names", mode="after")
103
- @classmethod
104
- def validate_unique_class_names(cls, class_names: Optional[List[str]]) -> Optional[List[str]]:
105
- """Validate that class names are unique"""
106
- if class_names is None:
107
- return None
108
- duplicate_class_names = set([name for name in class_names if class_names.count(name) > 1])
109
- if duplicate_class_names:
110
- raise ValueError(
111
- f"Class names must be unique. The following class names appear multiple times: {duplicate_class_names}."
112
- )
113
- return class_names
114
-
115
- def full_name(self) -> str:
116
- """Get qualified name for the task: <primitive_name>:<task_name>"""
117
- return f"{self.primitive.__name__}:{self.name}"
118
-
119
- # To get unique hash value for TaskInfo objects
120
- def __hash__(self) -> int:
121
- class_names = self.class_names or []
122
- return hash((self.name, self.primitive.__name__, tuple(class_names)))
123
-
124
- def __eq__(self, other: Any) -> bool:
125
- if not isinstance(other, TaskInfo):
126
- return False
127
- return self.name == other.name and self.primitive == other.primitive and self.class_names == other.class_names
128
-
129
-
130
- class DatasetInfo(BaseModel):
131
- dataset_name: str = Field(description="Name of the dataset, e.g. 'coco'")
132
- version: Optional[str] = Field(default=None, description="Version of the dataset")
133
- tasks: List[TaskInfo] = Field(default=None, description="List of tasks in the dataset")
134
- reference_bibtex: Optional[str] = Field(
135
- default=None,
136
- description="Optional, BibTeX reference to dataset publication",
137
- )
138
- reference_paper_url: Optional[str] = Field(
139
- default=None,
140
- description="Optional, URL to dataset publication",
141
- )
142
- reference_dataset_page: Optional[str] = Field(
143
- default=None,
144
- description="Optional, URL to the dataset page",
145
- )
146
- meta: Optional[Dict[str, Any]] = Field(default=None, description="Optional metadata about the dataset")
147
- format_version: str = Field(
148
- default=hafnia.__dataset_format_version__,
149
- description="Version of the Hafnia dataset format. You should not set this manually.",
150
- )
151
- updated_at: datetime = Field(
152
- default_factory=datetime.now,
153
- description="Timestamp of the last update to the dataset info. You should not set this manually.",
154
- )
155
-
156
- @field_validator("tasks", mode="after")
157
- @classmethod
158
- def _validate_check_for_duplicate_tasks(cls, tasks: Optional[List[TaskInfo]]) -> List[TaskInfo]:
159
- if tasks is None:
160
- return []
161
- task_name_counts = collections.Counter(task.name for task in tasks)
162
- duplicate_task_names = [name for name, count in task_name_counts.items() if count > 1]
163
- if duplicate_task_names:
164
- raise ValueError(
165
- f"Tasks must be unique. The following tasks appear multiple times: {duplicate_task_names}."
166
- )
167
- return tasks
168
-
169
- @field_validator("format_version")
170
- @classmethod
171
- def _validate_format_version(cls, format_version: str) -> str:
172
- try:
173
- Version(format_version)
174
- except Exception as e:
175
- raise ValueError(f"Invalid format_version '{format_version}'. Must be a valid version string.") from e
176
-
177
- if Version(format_version) > Version(hafnia.__dataset_format_version__):
178
- user_logger.warning(
179
- f"The loaded dataset format version '{format_version}' is newer than the format version "
180
- f"'{hafnia.__dataset_format_version__}' used in your version of Hafnia. Please consider "
181
- f"updating Hafnia package."
182
- )
183
- return format_version
184
-
185
- @field_validator("version")
186
- @classmethod
187
- def _validate_version(cls, dataset_version: Optional[str]) -> Optional[str]:
188
- if dataset_version is None:
189
- return None
190
-
191
- try:
192
- Version(dataset_version)
193
- except Exception as e:
194
- raise ValueError(f"Invalid dataset_version '{dataset_version}'. Must be a valid version string.") from e
195
-
196
- return dataset_version
197
-
198
- def check_for_duplicate_task_names(self) -> List[TaskInfo]:
199
- return self._validate_check_for_duplicate_tasks(self.tasks)
200
-
201
- def write_json(self, path: Path, indent: Optional[int] = 4) -> None:
202
- json_str = self.model_dump_json(indent=indent)
203
- path.write_text(json_str)
204
-
205
- @staticmethod
206
- def from_json_file(path: Path) -> DatasetInfo:
207
- json_str = path.read_text()
208
-
209
- # TODO: Deprecated support for old dataset info without format_version
210
- # Below 4 lines can be replaced by 'dataset_info = DatasetInfo.model_validate_json(json_str)'
211
- # when all datasets include a 'format_version' field
212
- json_dict = json.loads(json_str)
213
- if "format_version" not in json_dict:
214
- json_dict["format_version"] = "0.0.0"
215
-
216
- if "updated_at" not in json_dict:
217
- json_dict["updated_at"] = datetime.min.isoformat()
218
- dataset_info = DatasetInfo.model_validate(json_dict)
219
-
220
- return dataset_info
221
-
222
- @staticmethod
223
- def merge(info0: DatasetInfo, info1: DatasetInfo) -> DatasetInfo:
224
- """
225
- Merges two DatasetInfo objects into one and validates if they are compatible.
226
- """
227
- for task_ds0 in info0.tasks:
228
- for task_ds1 in info1.tasks:
229
- same_name = task_ds0.name == task_ds1.name
230
- same_primitive = task_ds0.primitive == task_ds1.primitive
231
- same_name_different_primitive = same_name and not same_primitive
232
- if same_name_different_primitive:
233
- raise ValueError(
234
- f"Cannot merge datasets with different primitives for the same task name: "
235
- f"'{task_ds0.name}' has primitive '{task_ds0.primitive}' in dataset0 and "
236
- f"'{task_ds1.primitive}' in dataset1."
237
- )
238
-
239
- is_same_name_and_primitive = same_name and same_primitive
240
- if is_same_name_and_primitive:
241
- task_ds0_class_names = task_ds0.class_names or []
242
- task_ds1_class_names = task_ds1.class_names or []
243
- if task_ds0_class_names != task_ds1_class_names:
244
- raise ValueError(
245
- f"Cannot merge datasets with different class names for the same task name and primitive: "
246
- f"'{task_ds0.name}' with primitive '{task_ds0.primitive}' has class names "
247
- f"{task_ds0_class_names} in dataset0 and {task_ds1_class_names} in dataset1."
248
- )
249
-
250
- if info1.format_version != info0.format_version:
251
- user_logger.warning(
252
- "Dataset format version of the two datasets do not match. "
253
- f"'{info1.format_version}' vs '{info0.format_version}'."
254
- )
255
- dataset_format_version = info0.format_version
256
- if hafnia.__dataset_format_version__ != dataset_format_version:
257
- user_logger.warning(
258
- f"Dataset format version '{dataset_format_version}' does not match the current "
259
- f"Hafnia format version '{hafnia.__dataset_format_version__}'."
260
- )
261
- unique_tasks = set(info0.tasks + info1.tasks)
262
- meta = (info0.meta or {}).copy()
263
- meta.update(info1.meta or {})
264
- return DatasetInfo(
265
- dataset_name=info0.dataset_name + "+" + info1.dataset_name,
266
- version=None,
267
- tasks=list(unique_tasks),
268
- meta=meta,
269
- format_version=dataset_format_version,
270
- )
271
-
272
- def get_task_by_name(self, task_name: str) -> TaskInfo:
273
- """
274
- Get task by its name. Raises an error if the task name is not found or if multiple tasks have the same name.
275
- """
276
- tasks_with_name = [task for task in self.tasks if task.name == task_name]
277
- if not tasks_with_name:
278
- raise ValueError(f"Task with name '{task_name}' not found in dataset info.")
279
- if len(tasks_with_name) > 1:
280
- raise ValueError(f"Multiple tasks found with name '{task_name}'. This should not happen!")
281
- return tasks_with_name[0]
282
-
283
- def get_tasks_by_primitive(self, primitive: Union[Type[Primitive], str]) -> List[TaskInfo]:
284
- """
285
- Get all tasks by their primitive type.
286
- """
287
- if isinstance(primitive, str):
288
- primitive = get_primitive_type_from_string(primitive)
289
-
290
- tasks_with_primitive = [task for task in self.tasks if task.primitive == primitive]
291
- return tasks_with_primitive
292
-
293
- def get_task_by_primitive(self, primitive: Union[Type[Primitive], str]) -> TaskInfo:
294
- """
295
- Get task by its primitive type. Raises an error if the primitive type is not found or if multiple tasks
296
- have the same primitive type.
297
- """
298
-
299
- tasks_with_primitive = self.get_tasks_by_primitive(primitive)
300
- if len(tasks_with_primitive) == 0:
301
- raise ValueError(f"Task with primitive {primitive} not found in dataset info.")
302
- if len(tasks_with_primitive) > 1:
303
- raise ValueError(
304
- f"Multiple tasks found with primitive {primitive}. Use '{self.get_task_by_name.__name__}' instead."
305
- )
306
- return tasks_with_primitive[0]
307
-
308
- def get_task_by_task_name_and_primitive(
309
- self,
310
- task_name: Optional[str],
311
- primitive: Optional[Union[Type[Primitive], str]],
312
- ) -> TaskInfo:
313
- """
314
- Logic to get a unique task based on the provided 'task_name' and/or 'primitive'.
315
- If both 'task_name' and 'primitive' are None, the dataset must have only one task.
316
- """
317
- task = dataset_transformations.get_task_info_from_task_name_and_primitive(
318
- tasks=self.tasks,
319
- primitive=primitive,
320
- task_name=task_name,
321
- )
322
- return task
323
-
324
- def replace_task(self, old_task: TaskInfo, new_task: Optional[TaskInfo]) -> DatasetInfo:
325
- dataset_info = self.model_copy(deep=True)
326
- has_task = any(t for t in dataset_info.tasks if t.name == old_task.name and t.primitive == old_task.primitive)
327
- if not has_task:
328
- raise ValueError(f"Task '{old_task.__repr__()}' not found in dataset info.")
329
-
330
- new_tasks = []
331
- for task in dataset_info.tasks:
332
- if task.name == old_task.name and task.primitive == old_task.primitive:
333
- if new_task is None:
334
- continue # Remove the task
335
- new_tasks.append(new_task)
336
- else:
337
- new_tasks.append(task)
338
-
339
- dataset_info.tasks = new_tasks
340
- return dataset_info
341
-
342
-
343
- class Sample(BaseModel):
344
- file_path: Optional[str] = Field(description="Path to the image/video file.")
345
- height: int = Field(description="Height of the image")
346
- width: int = Field(description="Width of the image")
347
- split: str = Field(description="Split name, e.g., 'train', 'val', 'test'")
348
- tags: List[str] = Field(
349
- default_factory=list,
350
- description="Tags for a given sample. Used for creating subsets of the dataset.",
351
- )
352
- storage_format: str = Field(
353
- default=StorageFormat.IMAGE,
354
- description="Storage format. Sample data is stored as image or inside a video or zip file.",
355
- )
356
- collection_index: Optional[int] = Field(default=None, description="Optional e.g. frame number for video datasets")
357
- collection_id: Optional[str] = Field(default=None, description="Optional e.g. video name for video datasets")
358
- remote_path: Optional[str] = Field(default=None, description="Optional remote path for the image, if applicable")
359
- sample_index: Optional[int] = Field(
360
- default=None,
361
- description="Don't manually set this, it is used for indexing samples in the dataset.",
362
- )
363
- classifications: Optional[List[Classification]] = Field(
364
- default=None, description="Optional list of classifications"
365
- )
366
- bboxes: Optional[List[Bbox]] = Field(default=None, description="Optional list of bounding boxes")
367
- bitmasks: Optional[List[Bitmask]] = Field(default=None, description="Optional list of bitmasks")
368
- polygons: Optional[List[Polygon]] = Field(default=None, description="Optional list of polygons")
369
-
370
- attribution: Optional[Attribution] = Field(default=None, description="Attribution information for the image")
371
- dataset_name: Optional[str] = Field(
372
- default=None,
373
- description=(
374
- "Don't manually set this, it will be automatically defined during initialization. "
375
- "Name of the dataset the sample belongs to. E.g. 'coco-2017' or 'midwest-vehicle-detection'."
376
- ),
377
- )
378
- meta: Optional[Dict] = Field(
379
- default=None,
380
- description="Additional metadata, e.g., camera settings, GPS data, etc.",
381
- )
382
-
383
- def get_annotations(self, primitive_types: Optional[List[Type[Primitive]]] = None) -> List[Primitive]:
384
- """
385
- Returns a list of all annotations (classifications, objects, bitmasks, polygons) for the sample.
386
- """
387
- primitive_types = primitive_types or PRIMITIVE_TYPES
388
- annotations_primitives = [
389
- getattr(self, primitive_type.column_name(), None) for primitive_type in primitive_types
390
- ]
391
- annotations = more_itertools.flatten(
392
- [primitives for primitives in annotations_primitives if primitives is not None]
393
- )
394
-
395
- return list(annotations)
396
-
397
- def read_image_pillow(self) -> Image.Image:
398
- """
399
- Reads the image from the file path and returns it as a PIL Image.
400
- Raises FileNotFoundError if the image file does not exist.
401
- """
402
- if self.file_path is None:
403
- raise ValueError(f"Sample has no '{SampleField.FILE_PATH}' defined.")
404
- path_image = Path(self.file_path)
405
- if not path_image.exists():
406
- raise FileNotFoundError(f"Image file {path_image} does not exist. Please check the file path.")
407
-
408
- image = Image.open(str(path_image))
409
- return image
410
-
411
- def read_image(self) -> np.ndarray:
412
- if self.storage_format == StorageFormat.VIDEO:
413
- video = cv2.VideoCapture(str(self.file_path))
414
- if self.collection_index is None:
415
- raise ValueError("collection_index must be set for video storage format to read the correct frame.")
416
- video.set(cv2.CAP_PROP_POS_FRAMES, self.collection_index)
417
- success, image = video.read()
418
- video.release()
419
- if not success:
420
- raise ValueError(f"Could not read frame {self.collection_index} from video file {self.file_path}.")
421
- return image
422
-
423
- elif self.storage_format == StorageFormat.IMAGE:
424
- image_pil = self.read_image_pillow()
425
- image = np.array(image_pil)
426
- else:
427
- raise ValueError(f"Unsupported storage format: {self.storage_format}")
428
- return image
429
-
430
- def draw_annotations(self, image: Optional[np.ndarray] = None) -> np.ndarray:
431
- from hafnia.visualizations import image_visualizations
432
-
433
- image = image or self.read_image()
434
- annotations = self.get_annotations()
435
- annotations_visualized = image_visualizations.draw_annotations(image=image, primitives=annotations)
436
- return annotations_visualized
437
-
438
-
439
- class License(BaseModel):
440
- """License information"""
441
-
442
- name: Optional[str] = Field(
443
- default=None,
444
- description="License name. E.g. 'Creative Commons: Attribution 2.0 Generic'",
445
- max_length=100,
446
- )
447
- name_short: Optional[str] = Field(
448
- default=None,
449
- description="License short name or abbreviation. E.g. 'CC BY 4.0'",
450
- max_length=100,
451
- )
452
- url: Optional[str] = Field(
453
- default=None,
454
- description="License URL e.g. https://creativecommons.org/licenses/by/4.0/",
455
- )
456
- description: Optional[str] = Field(
457
- default=None,
458
- description=(
459
- "License description e.g. 'You must give appropriate credit, provide a "
460
- "link to the license, and indicate if changes were made.'"
461
- ),
462
- )
463
-
464
- valid_date: Optional[datetime] = Field(
465
- default=None,
466
- description="License valid date. E.g. '2023-01-01T00:00:00Z'",
467
- )
468
-
469
- permissions: Optional[List[str]] = Field(
470
- default=None,
471
- description="License permissions. Allowed to Access, Label, Distribute, Represent and Modify data.",
472
- )
473
- liability: Optional[str] = Field(
474
- default=None,
475
- description="License liability. Optional and not always applicable.",
476
- )
477
- location: Optional[str] = Field(
478
- default=None,
479
- description=(
480
- "License Location. E.g. Iowa state. This is essential to understand the industry and "
481
- "privacy location specific rules that applies to the data. Optional and not always applicable."
482
- ),
483
- )
484
- notes: Optional[str] = Field(
485
- default=None,
486
- description="Additional license notes. Optional and not always applicable.",
487
- )
488
-
489
-
490
- class Attribution(BaseModel):
491
- """Attribution information for the image: Giving source and credit to the original creator"""
492
-
493
- title: Optional[str] = Field(default=None, description="Title of the image", max_length=255)
494
- creator: Optional[str] = Field(default=None, description="Creator of the image", max_length=255)
495
- creator_url: Optional[str] = Field(default=None, description="URL of the creator", max_length=255)
496
- date_captured: Optional[datetime] = Field(default=None, description="Date when the image was captured")
497
- copyright_notice: Optional[str] = Field(default=None, description="Copyright notice for the image", max_length=255)
498
- licenses: Optional[List[License]] = Field(default=None, description="List of licenses for the image")
499
- disclaimer: Optional[str] = Field(default=None, description="Disclaimer for the image", max_length=255)
500
- changes: Optional[str] = Field(default=None, description="Changes made to the image", max_length=255)
501
- source_url: Optional[str] = Field(default=None, description="Source URL for the image", max_length=255)
40
+ from hafnia.utils import progress_bar
502
41
 
503
42
 
504
43
  @dataclass
@@ -527,8 +66,10 @@ class HafniaDataset:
527
66
  convert_to_image_storage_format = dataset_transformations.convert_to_image_storage_format
528
67
 
529
68
  # Import / export functions
530
- from_yolo_format = format_yolo.from_yolo_format
531
69
  to_yolo_format = format_yolo.to_yolo_format
70
+ from_yolo_format = format_yolo.from_yolo_format
71
+ to_coco_format = format_coco.to_coco_format
72
+ from_coco_format = format_coco.from_coco_format
532
73
  to_image_classification_folder = format_image_classification_folder.to_image_classification_folder
533
74
  from_image_classification_folder = format_image_classification_folder.from_image_classification_folder
534
75
 
@@ -978,6 +519,10 @@ class HafniaDataset:
978
519
  dataset.check_dataset_tasks()
979
520
  return dataset
980
521
 
522
+ def has_primitive(dataset: HafniaDataset, PrimitiveType: Type[Primitive]) -> bool:
523
+ table = dataset.samples if isinstance(dataset, HafniaDataset) else dataset
524
+ return table_transformations.has_primitive(table, PrimitiveType)
525
+
981
526
  @staticmethod
982
527
  def check_dataset_path(path_dataset: Path, raise_error: bool = True) -> bool:
983
528
  """
@@ -1026,7 +571,7 @@ class HafniaDataset:
1026
571
  hafnia_dataset = self.copy() # To avoid inplace modifications
1027
572
  new_paths = []
1028
573
  org_paths = hafnia_dataset.samples[SampleField.FILE_PATH].to_list()
1029
- for org_path in track(org_paths, description="- Copy images"):
574
+ for org_path in progress_bar(org_paths, description="- Copy images"):
1030
575
  new_path = dataset_helpers.copy_and_rename_file_to_hash_value(
1031
576
  path_source=Path(org_path),
1032
577
  path_dataset_root=path_folder,
@@ -1145,4 +690,13 @@ def _dataset_corrections(samples: pl.DataFrame, dataset_info: DatasetInfo) -> Tu
1145
690
  if SampleField.SAMPLE_INDEX in samples.columns and samples[SampleField.SAMPLE_INDEX].dtype != pl.UInt64:
1146
691
  samples = samples.cast({SampleField.SAMPLE_INDEX: pl.UInt64})
1147
692
 
693
+ if format_version_of_dataset <= Version("0.2.0"):
694
+ if SampleField.BITMASKS in samples.columns and samples[SampleField.BITMASKS].dtype == pl.List(pl.Struct):
695
+ struct_schema = samples.schema[SampleField.BITMASKS].inner
696
+ struct_names = [f.name for f in struct_schema.fields]
697
+ if "rleString" in struct_names:
698
+ struct_names[struct_names.index("rleString")] = "rle_string"
699
+ samples = samples.with_columns(
700
+ pl.col(SampleField.BITMASKS).list.eval(pl.element().struct.rename_fields(struct_names))
701
+ )
1148
702
  return samples, dataset_info