hafnia 0.2.4__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. cli/__main__.py +16 -3
  2. cli/config.py +45 -4
  3. cli/consts.py +1 -1
  4. cli/dataset_cmds.py +6 -14
  5. cli/dataset_recipe_cmds.py +78 -0
  6. cli/experiment_cmds.py +226 -43
  7. cli/keychain.py +88 -0
  8. cli/profile_cmds.py +10 -6
  9. cli/runc_cmds.py +5 -5
  10. cli/trainer_package_cmds.py +65 -0
  11. hafnia/__init__.py +2 -0
  12. hafnia/data/factory.py +1 -2
  13. hafnia/dataset/dataset_helpers.py +9 -14
  14. hafnia/dataset/dataset_names.py +10 -5
  15. hafnia/dataset/dataset_recipe/dataset_recipe.py +165 -67
  16. hafnia/dataset/dataset_recipe/recipe_transforms.py +48 -4
  17. hafnia/dataset/dataset_recipe/recipe_types.py +1 -1
  18. hafnia/dataset/dataset_upload_helper.py +265 -56
  19. hafnia/dataset/format_conversions/image_classification_from_directory.py +106 -0
  20. hafnia/dataset/format_conversions/torchvision_datasets.py +281 -0
  21. hafnia/dataset/hafnia_dataset.py +577 -213
  22. hafnia/dataset/license_types.py +63 -0
  23. hafnia/dataset/operations/dataset_stats.py +259 -3
  24. hafnia/dataset/operations/dataset_transformations.py +332 -7
  25. hafnia/dataset/operations/table_transformations.py +43 -5
  26. hafnia/dataset/primitives/__init__.py +8 -0
  27. hafnia/dataset/primitives/bbox.py +25 -12
  28. hafnia/dataset/primitives/bitmask.py +26 -14
  29. hafnia/dataset/primitives/classification.py +16 -8
  30. hafnia/dataset/primitives/point.py +7 -3
  31. hafnia/dataset/primitives/polygon.py +16 -9
  32. hafnia/dataset/primitives/segmentation.py +10 -7
  33. hafnia/experiment/hafnia_logger.py +111 -8
  34. hafnia/http.py +16 -2
  35. hafnia/platform/__init__.py +9 -3
  36. hafnia/platform/builder.py +12 -10
  37. hafnia/platform/dataset_recipe.py +104 -0
  38. hafnia/platform/datasets.py +47 -9
  39. hafnia/platform/download.py +25 -19
  40. hafnia/platform/experiment.py +51 -56
  41. hafnia/platform/trainer_package.py +57 -0
  42. hafnia/utils.py +81 -13
  43. hafnia/visualizations/image_visualizations.py +4 -4
  44. {hafnia-0.2.4.dist-info → hafnia-0.4.0.dist-info}/METADATA +40 -34
  45. hafnia-0.4.0.dist-info/RECORD +56 -0
  46. cli/recipe_cmds.py +0 -45
  47. hafnia-0.2.4.dist-info/RECORD +0 -49
  48. {hafnia-0.2.4.dist-info → hafnia-0.4.0.dist-info}/WHEEL +0 -0
  49. {hafnia-0.2.4.dist-info → hafnia-0.4.0.dist-info}/entry_points.txt +0 -0
  50. {hafnia-0.2.4.dist-info → hafnia-0.4.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,21 +1,24 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import collections
4
+ import copy
5
+ import json
3
6
  import shutil
4
7
  from dataclasses import dataclass
8
+ from datetime import datetime
5
9
  from pathlib import Path
6
10
  from random import Random
7
- from typing import Any, Dict, List, Optional, Type, Union
11
+ from typing import Any, Dict, List, Optional, Tuple, Type, Union
8
12
 
9
13
  import more_itertools
10
14
  import numpy as np
11
15
  import polars as pl
12
- import rich
16
+ from packaging.version import Version
13
17
  from PIL import Image
14
- from pydantic import BaseModel, field_serializer, field_validator
15
- from rich import print as rprint
16
- from rich.table import Table
17
- from tqdm import tqdm
18
+ from pydantic import BaseModel, Field, field_serializer, field_validator
19
+ from rich.progress import track
18
20
 
21
+ import hafnia
19
22
  from hafnia.dataset import dataset_helpers
20
23
  from hafnia.dataset.dataset_names import (
21
24
  DATASET_FILENAMES_REQUIRED,
@@ -23,20 +26,20 @@ from hafnia.dataset.dataset_names import (
23
26
  FILENAME_ANNOTATIONS_PARQUET,
24
27
  FILENAME_DATASET_INFO,
25
28
  FILENAME_RECIPE_JSON,
29
+ TAG_IS_SAMPLE,
26
30
  ColumnName,
27
- FieldName,
28
31
  SplitName,
29
32
  )
30
- from hafnia.dataset.operations import dataset_stats, dataset_transformations
33
+ from hafnia.dataset.operations import (
34
+ dataset_stats,
35
+ dataset_transformations,
36
+ table_transformations,
37
+ )
31
38
  from hafnia.dataset.operations.table_transformations import (
32
39
  check_image_paths,
33
- create_primitive_table,
34
- read_table_from_path,
35
- )
36
- from hafnia.dataset.primitives import (
37
- PRIMITIVE_NAME_TO_TYPE,
38
- PRIMITIVE_TYPES,
40
+ read_samples_from_path,
39
41
  )
42
+ from hafnia.dataset.primitives import PRIMITIVE_TYPES, get_primitive_type_from_string
40
43
  from hafnia.dataset.primitives.bbox import Bbox
41
44
  from hafnia.dataset.primitives.bitmask import Bitmask
42
45
  from hafnia.dataset.primitives.classification import Classification
@@ -46,10 +49,16 @@ from hafnia.log import user_logger
46
49
 
47
50
 
48
51
  class TaskInfo(BaseModel):
49
- primitive: Type[Primitive] # Primitive class or string name of the primitive, e.g. "Bbox" or "bitmask"
50
- class_names: Optional[List[str]] # Class names for the tasks. To get consistent class indices specify class_names.
51
- name: Optional[str] = (
52
- None # None to use the default primitive task name Bbox ->"bboxes", Bitmask -> "bitmasks" etc.
52
+ primitive: Type[Primitive] = Field(
53
+ description="Primitive class or string name of the primitive, e.g. 'Bbox' or 'bitmask'"
54
+ )
55
+ class_names: Optional[List[str]] = Field(default=None, description="Optional list of class names for the primitive")
56
+ name: Optional[str] = Field(
57
+ default=None,
58
+ description=(
59
+ "Optional name for the task. 'None' will use default name of the provided primitive. "
60
+ "e.g. Bbox ->'bboxes', Bitmask -> 'bitmasks' etc."
61
+ ),
53
62
  )
54
63
 
55
64
  def model_post_init(self, __context: Any) -> None:
@@ -64,12 +73,7 @@ class TaskInfo(BaseModel):
64
73
  @classmethod
65
74
  def ensure_primitive(cls, primitive: Any) -> Any:
66
75
  if isinstance(primitive, str):
67
- if primitive not in PRIMITIVE_NAME_TO_TYPE:
68
- raise ValueError(
69
- f"Primitive '{primitive}' is not recognized. Available primitives: {list(PRIMITIVE_NAME_TO_TYPE.keys())}"
70
- )
71
-
72
- return PRIMITIVE_NAME_TO_TYPE[primitive]
76
+ return get_primitive_type_from_string(primitive)
73
77
 
74
78
  if issubclass(primitive, Primitive):
75
79
  return primitive
@@ -83,40 +87,273 @@ class TaskInfo(BaseModel):
83
87
  raise ValueError(f"Primitive must be a subclass of Primitive, got {type(primitive)} instead.")
84
88
  return primitive.__name__
85
89
 
90
+ @field_validator("class_names", mode="after")
91
+ @classmethod
92
+ def validate_unique_class_names(cls, class_names: Optional[List[str]]) -> Optional[List[str]]:
93
+ """Validate that class names are unique"""
94
+ if class_names is None:
95
+ return None
96
+ duplicate_class_names = set([name for name in class_names if class_names.count(name) > 1])
97
+ if duplicate_class_names:
98
+ raise ValueError(
99
+ f"Class names must be unique. The following class names appear multiple times: {duplicate_class_names}."
100
+ )
101
+ return class_names
102
+
103
+ # To get unique hash value for TaskInfo objects
104
+ def __hash__(self) -> int:
105
+ class_names = self.class_names or []
106
+ return hash((self.name, self.primitive.__name__, tuple(class_names)))
107
+
108
+ def __eq__(self, other: Any) -> bool:
109
+ if not isinstance(other, TaskInfo):
110
+ return False
111
+ return self.name == other.name and self.primitive == other.primitive and self.class_names == other.class_names
112
+
86
113
 
87
114
  class DatasetInfo(BaseModel):
88
- dataset_name: str
89
- version: str
90
- tasks: list[TaskInfo]
91
- distributions: Optional[List[TaskInfo]] = None # Distributions. TODO: FIX/REMOVE/CHANGE this
92
- meta: Optional[Dict[str, Any]] = None # Metadata about the dataset, e.g. description, etc.
115
+ dataset_name: str = Field(description="Name of the dataset, e.g. 'coco'")
116
+ version: Optional[str] = Field(default=None, description="Version of the dataset")
117
+ tasks: List[TaskInfo] = Field(default=None, description="List of tasks in the dataset")
118
+ distributions: Optional[List[TaskInfo]] = Field(default=None, description="Optional list of task distributions")
119
+ reference_bibtex: Optional[str] = Field(
120
+ default=None,
121
+ description="Optional, BibTeX reference to dataset publication",
122
+ )
123
+ reference_paper_url: Optional[str] = Field(
124
+ default=None,
125
+ description="Optional, URL to dataset publication",
126
+ )
127
+ reference_dataset_page: Optional[str] = Field(
128
+ default=None,
129
+ description="Optional, URL to the dataset page",
130
+ )
131
+ meta: Optional[Dict[str, Any]] = Field(default=None, description="Optional metadata about the dataset")
132
+ format_version: str = Field(
133
+ default=hafnia.__dataset_format_version__,
134
+ description="Version of the Hafnia dataset format. You should not set this manually.",
135
+ )
136
+ updated_at: datetime = Field(
137
+ default_factory=datetime.now,
138
+ description="Timestamp of the last update to the dataset info. You should not set this manually.",
139
+ )
140
+
141
+ @field_validator("tasks", mode="after")
142
+ @classmethod
143
+ def _validate_check_for_duplicate_tasks(cls, tasks: Optional[List[TaskInfo]]) -> List[TaskInfo]:
144
+ if tasks is None:
145
+ return []
146
+ task_name_counts = collections.Counter(task.name for task in tasks)
147
+ duplicate_task_names = [name for name, count in task_name_counts.items() if count > 1]
148
+ if duplicate_task_names:
149
+ raise ValueError(
150
+ f"Tasks must be unique. The following tasks appear multiple times: {duplicate_task_names}."
151
+ )
152
+ return tasks
153
+
154
+ @field_validator("format_version")
155
+ @classmethod
156
+ def _validate_format_version(cls, format_version: str) -> str:
157
+ try:
158
+ Version(format_version)
159
+ except Exception as e:
160
+ raise ValueError(f"Invalid format_version '{format_version}'. Must be a valid version string.") from e
161
+
162
+ if Version(format_version) > Version(hafnia.__dataset_format_version__):
163
+ user_logger.warning(
164
+ f"The loaded dataset format version '{format_version}' is newer than the format version "
165
+ f"'{hafnia.__dataset_format_version__}' used in your version of Hafnia. Please consider "
166
+ f"updating Hafnia package."
167
+ )
168
+ return format_version
169
+
170
+ @field_validator("version")
171
+ @classmethod
172
+ def _validate_version(cls, dataset_version: Optional[str]) -> Optional[str]:
173
+ if dataset_version is None:
174
+ return None
175
+
176
+ try:
177
+ Version(dataset_version)
178
+ except Exception as e:
179
+ raise ValueError(f"Invalid dataset_version '{dataset_version}'. Must be a valid version string.") from e
180
+
181
+ return dataset_version
182
+
183
+ def check_for_duplicate_task_names(self) -> List[TaskInfo]:
184
+ return self._validate_check_for_duplicate_tasks(self.tasks)
93
185
 
94
186
  def write_json(self, path: Path, indent: Optional[int] = 4) -> None:
95
187
  json_str = self.model_dump_json(indent=indent)
96
188
  path.write_text(json_str)
97
189
 
98
190
  @staticmethod
99
- def from_json_file(path: Path) -> "DatasetInfo":
191
+ def from_json_file(path: Path) -> DatasetInfo:
100
192
  json_str = path.read_text()
101
- return DatasetInfo.model_validate_json(json_str)
193
+
194
+ # TODO: Deprecated support for old dataset info without format_version
195
+ # Below 4 lines can be replaced by 'dataset_info = DatasetInfo.model_validate_json(json_str)'
196
+ # when all datasets include a 'format_version' field
197
+ json_dict = json.loads(json_str)
198
+ if "format_version" not in json_dict:
199
+ json_dict["format_version"] = "0.0.0"
200
+
201
+ if "updated_at" not in json_dict:
202
+ json_dict["updated_at"] = datetime.min.isoformat()
203
+ dataset_info = DatasetInfo.model_validate(json_dict)
204
+
205
+ return dataset_info
206
+
207
+ @staticmethod
208
+ def merge(info0: DatasetInfo, info1: DatasetInfo) -> DatasetInfo:
209
+ """
210
+ Merges two DatasetInfo objects into one and validates if they are compatible.
211
+ """
212
+ for task_ds0 in info0.tasks:
213
+ for task_ds1 in info1.tasks:
214
+ same_name = task_ds0.name == task_ds1.name
215
+ same_primitive = task_ds0.primitive == task_ds1.primitive
216
+ same_name_different_primitive = same_name and not same_primitive
217
+ if same_name_different_primitive:
218
+ raise ValueError(
219
+ f"Cannot merge datasets with different primitives for the same task name: "
220
+ f"'{task_ds0.name}' has primitive '{task_ds0.primitive}' in dataset0 and "
221
+ f"'{task_ds1.primitive}' in dataset1."
222
+ )
223
+
224
+ is_same_name_and_primitive = same_name and same_primitive
225
+ if is_same_name_and_primitive:
226
+ task_ds0_class_names = task_ds0.class_names or []
227
+ task_ds1_class_names = task_ds1.class_names or []
228
+ if task_ds0_class_names != task_ds1_class_names:
229
+ raise ValueError(
230
+ f"Cannot merge datasets with different class names for the same task name and primitive: "
231
+ f"'{task_ds0.name}' with primitive '{task_ds0.primitive}' has class names "
232
+ f"{task_ds0_class_names} in dataset0 and {task_ds1_class_names} in dataset1."
233
+ )
234
+
235
+ if info1.format_version != info0.format_version:
236
+ user_logger.warning(
237
+ "Dataset format version of the two datasets do not match. "
238
+ f"'{info1.format_version}' vs '{info0.format_version}'."
239
+ )
240
+ dataset_format_version = info0.format_version
241
+ if hafnia.__dataset_format_version__ != dataset_format_version:
242
+ user_logger.warning(
243
+ f"Dataset format version '{dataset_format_version}' does not match the current "
244
+ f"Hafnia format version '{hafnia.__dataset_format_version__}'."
245
+ )
246
+ unique_tasks = set(info0.tasks + info1.tasks)
247
+ distributions = set((info0.distributions or []) + (info1.distributions or []))
248
+ meta = (info0.meta or {}).copy()
249
+ meta.update(info1.meta or {})
250
+ return DatasetInfo(
251
+ dataset_name=info0.dataset_name + "+" + info1.dataset_name,
252
+ version=None,
253
+ tasks=list(unique_tasks),
254
+ distributions=list(distributions),
255
+ meta=meta,
256
+ format_version=dataset_format_version,
257
+ )
258
+
259
+ def get_task_by_name(self, task_name: str) -> TaskInfo:
260
+ """
261
+ Get task by its name. Raises an error if the task name is not found or if multiple tasks have the same name.
262
+ """
263
+ tasks_with_name = [task for task in self.tasks if task.name == task_name]
264
+ if not tasks_with_name:
265
+ raise ValueError(f"Task with name '{task_name}' not found in dataset info.")
266
+ if len(tasks_with_name) > 1:
267
+ raise ValueError(f"Multiple tasks found with name '{task_name}'. This should not happen!")
268
+ return tasks_with_name[0]
269
+
270
+ def get_task_by_primitive(self, primitive: Union[Type[Primitive], str]) -> TaskInfo:
271
+ """
272
+ Get task by its primitive type. Raises an error if the primitive type is not found or if multiple tasks
273
+ have the same primitive type.
274
+ """
275
+ if isinstance(primitive, str):
276
+ primitive = get_primitive_type_from_string(primitive)
277
+
278
+ tasks_with_primitive = [task for task in self.tasks if task.primitive == primitive]
279
+ if not tasks_with_primitive:
280
+ raise ValueError(f"Task with primitive {primitive} not found in dataset info.")
281
+ if len(tasks_with_primitive) > 1:
282
+ raise ValueError(
283
+ f"Multiple tasks found with primitive {primitive}. Use '{self.get_task_by_name.__name__}' instead."
284
+ )
285
+ return tasks_with_primitive[0]
286
+
287
+ def get_task_by_task_name_and_primitive(
288
+ self,
289
+ task_name: Optional[str],
290
+ primitive: Optional[Union[Type[Primitive], str]],
291
+ ) -> TaskInfo:
292
+ """
293
+ Logic to get a unique task based on the provided 'task_name' and/or 'primitive'.
294
+ If both 'task_name' and 'primitive' are None, the dataset must have only one task.
295
+ """
296
+ task = dataset_transformations.get_task_info_from_task_name_and_primitive(
297
+ tasks=self.tasks,
298
+ primitive=primitive,
299
+ task_name=task_name,
300
+ )
301
+ return task
302
+
303
+ def replace_task(self, old_task: TaskInfo, new_task: Optional[TaskInfo]) -> DatasetInfo:
304
+ dataset_info = self.model_copy(deep=True)
305
+ has_task = any(t for t in dataset_info.tasks if t.name == old_task.name and t.primitive == old_task.primitive)
306
+ if not has_task:
307
+ raise ValueError(f"Task '{old_task.__repr__()}' not found in dataset info.")
308
+
309
+ new_tasks = []
310
+ for task in dataset_info.tasks:
311
+ if task.name == old_task.name and task.primitive == old_task.primitive:
312
+ if new_task is None:
313
+ continue # Remove the task
314
+ new_tasks.append(new_task)
315
+ else:
316
+ new_tasks.append(task)
317
+
318
+ dataset_info.tasks = new_tasks
319
+ return dataset_info
102
320
 
103
321
 
104
322
  class Sample(BaseModel):
105
- file_name: str
106
- height: int
107
- width: int
108
- split: str # Split name, e.g., "train", "val", "test"
109
- is_sample: bool # Indicates if this is a sample (True) or a metadata entry (False)
110
- collection_index: Optional[int] = None # Optional e.g. frame number for video datasets
111
- collection_id: Optional[str] = None # Optional e.g. video name for video datasets
112
- remote_path: Optional[str] = None # Optional remote path for the image, if applicable
113
- sample_index: Optional[int] = None # Don't manually set this, it is used for indexing samples in the dataset.
114
- classifications: Optional[List[Classification]] = None # Optional classification primitive
115
- objects: Optional[List[Bbox]] = None # List of coordinate primitives, e.g., Bbox, Bitmask, etc.
116
- bitmasks: Optional[List[Bitmask]] = None # List of bitmasks, if applicable
117
- polygons: Optional[List[Polygon]] = None # List of polygons, if applicable
118
-
119
- meta: Optional[Dict] = None # Additional metadata, e.g., camera settings, GPS data, etc.
323
+ file_path: str = Field(description="Path to the image file")
324
+ height: int = Field(description="Height of the image")
325
+ width: int = Field(description="Width of the image")
326
+ split: str = Field(description="Split name, e.g., 'train', 'val', 'test'")
327
+ tags: List[str] = Field(
328
+ default_factory=list,
329
+ description="Tags for a given sample. Used for creating subsets of the dataset.",
330
+ )
331
+ collection_index: Optional[int] = Field(default=None, description="Optional e.g. frame number for video datasets")
332
+ collection_id: Optional[str] = Field(default=None, description="Optional e.g. video name for video datasets")
333
+ remote_path: Optional[str] = Field(default=None, description="Optional remote path for the image, if applicable")
334
+ sample_index: Optional[int] = Field(
335
+ default=None,
336
+ description="Don't manually set this, it is used for indexing samples in the dataset.",
337
+ )
338
+ classifications: Optional[List[Classification]] = Field(
339
+ default=None, description="Optional list of classifications"
340
+ )
341
+ objects: Optional[List[Bbox]] = Field(default=None, description="Optional list of objects (bounding boxes)")
342
+ bitmasks: Optional[List[Bitmask]] = Field(default=None, description="Optional list of bitmasks")
343
+ polygons: Optional[List[Polygon]] = Field(default=None, description="Optional list of polygons")
344
+
345
+ attribution: Optional[Attribution] = Field(default=None, description="Attribution information for the image")
346
+ dataset_name: Optional[str] = Field(
347
+ default=None,
348
+ description=(
349
+ "Don't manually set this, it will be automatically defined during initialization. "
350
+ "Name of the dataset the sample belongs to. E.g. 'coco-2017' or 'midwest-vehicle-detection'."
351
+ ),
352
+ )
353
+ meta: Optional[Dict] = Field(
354
+ default=None,
355
+ description="Additional metadata, e.g., camera settings, GPS data, etc.",
356
+ )
120
357
 
121
358
  def get_annotations(self, primitive_types: Optional[List[Type[Primitive]]] = None) -> List[Primitive]:
122
359
  """
@@ -137,7 +374,7 @@ class Sample(BaseModel):
137
374
  Reads the image from the file path and returns it as a PIL Image.
138
375
  Raises FileNotFoundError if the image file does not exist.
139
376
  """
140
- path_image = Path(self.file_name)
377
+ path_image = Path(self.file_path)
141
378
  if not path_image.exists():
142
379
  raise FileNotFoundError(f"Image file {path_image} does not exist. Please check the file path.")
143
380
 
@@ -158,11 +395,93 @@ class Sample(BaseModel):
158
395
  return annotations_visualized
159
396
 
160
397
 
398
+ class License(BaseModel):
399
+ """License information"""
400
+
401
+ name: Optional[str] = Field(
402
+ default=None,
403
+ description="License name. E.g. 'Creative Commons: Attribution 2.0 Generic'",
404
+ max_length=100,
405
+ )
406
+ name_short: Optional[str] = Field(
407
+ default=None,
408
+ description="License short name or abbreviation. E.g. 'CC BY 4.0'",
409
+ max_length=100,
410
+ )
411
+ url: Optional[str] = Field(
412
+ default=None,
413
+ description="License URL e.g. https://creativecommons.org/licenses/by/4.0/",
414
+ )
415
+ description: Optional[str] = Field(
416
+ default=None,
417
+ description=(
418
+ "License description e.g. 'You must give appropriate credit, provide a "
419
+ "link to the license, and indicate if changes were made.'"
420
+ ),
421
+ )
422
+
423
+ valid_date: Optional[datetime] = Field(
424
+ default=None,
425
+ description="License valid date. E.g. '2023-01-01T00:00:00Z'",
426
+ )
427
+
428
+ permissions: Optional[List[str]] = Field(
429
+ default=None,
430
+ description="License permissions. Allowed to Access, Label, Distribute, Represent and Modify data.",
431
+ )
432
+ liability: Optional[str] = Field(
433
+ default=None,
434
+ description="License liability. Optional and not always applicable.",
435
+ )
436
+ location: Optional[str] = Field(
437
+ default=None,
438
+ description=(
439
+ "License Location. E.g. Iowa state. This is essential to understand the industry and "
440
+ "privacy location specific rules that applies to the data. Optional and not always applicable."
441
+ ),
442
+ )
443
+ notes: Optional[str] = Field(
444
+ default=None,
445
+ description="Additional license notes. Optional and not always applicable.",
446
+ )
447
+
448
+
449
+ class Attribution(BaseModel):
450
+ """Attribution information for the image: Giving source and credit to the original creator"""
451
+
452
+ title: Optional[str] = Field(default=None, description="Title of the image", max_length=255)
453
+ creator: Optional[str] = Field(default=None, description="Creator of the image", max_length=255)
454
+ creator_url: Optional[str] = Field(default=None, description="URL of the creator", max_length=255)
455
+ date_captured: Optional[datetime] = Field(default=None, description="Date when the image was captured")
456
+ copyright_notice: Optional[str] = Field(default=None, description="Copyright notice for the image", max_length=255)
457
+ licenses: Optional[List[License]] = Field(default=None, description="List of licenses for the image")
458
+ disclaimer: Optional[str] = Field(default=None, description="Disclaimer for the image", max_length=255)
459
+ changes: Optional[str] = Field(default=None, description="Changes made to the image", max_length=255)
460
+ source_url: Optional[str] = Field(default=None, description="Source URL for the image", max_length=255)
461
+
462
+
161
463
  @dataclass
162
464
  class HafniaDataset:
163
465
  info: DatasetInfo
164
466
  samples: pl.DataFrame
165
467
 
468
+ # Function mapping: Dataset stats
469
+ split_counts = dataset_stats.split_counts
470
+ class_counts_for_task = dataset_stats.class_counts_for_task
471
+ class_counts_all = dataset_stats.class_counts_all
472
+
473
+ # Function mapping: Print stats
474
+ print_stats = dataset_stats.print_stats
475
+ print_sample_and_task_counts = dataset_stats.print_sample_and_task_counts
476
+ print_class_distribution = dataset_stats.print_class_distribution
477
+
478
+ # Function mapping: Dataset checks
479
+ check_dataset = dataset_stats.check_dataset
480
+ check_dataset_tasks = dataset_stats.check_dataset_tasks
481
+
482
+ # Function mapping: Dataset transformations
483
+ transform_images = dataset_transformations.transform_images
484
+
166
485
  def __getitem__(self, item: int) -> Dict[str, Any]:
167
486
  return self.samples.row(index=item, named=True)
168
487
 
@@ -173,30 +492,36 @@ class HafniaDataset:
173
492
  for row in self.samples.iter_rows(named=True):
174
493
  yield row
175
494
 
495
+ def __post_init__(self):
496
+ self.samples, self.info = _dataset_corrections(self.samples, self.info)
497
+
176
498
  @staticmethod
177
499
  def from_path(path_folder: Path, check_for_images: bool = True) -> "HafniaDataset":
500
+ path_folder = Path(path_folder)
178
501
  HafniaDataset.check_dataset_path(path_folder, raise_error=True)
179
502
 
180
503
  dataset_info = DatasetInfo.from_json_file(path_folder / FILENAME_DATASET_INFO)
181
- table = read_table_from_path(path_folder)
504
+ samples = read_samples_from_path(path_folder)
505
+ samples, dataset_info = _dataset_corrections(samples, dataset_info)
182
506
 
183
507
  # Convert from relative paths to absolute paths
184
508
  dataset_root = path_folder.absolute().as_posix() + "/"
185
- table = table.with_columns((dataset_root + pl.col("file_name")).alias("file_name"))
509
+ samples = samples.with_columns((dataset_root + pl.col(ColumnName.FILE_PATH)).alias(ColumnName.FILE_PATH))
186
510
  if check_for_images:
187
- check_image_paths(table)
188
- return HafniaDataset(samples=table, info=dataset_info)
511
+ check_image_paths(samples)
512
+ return HafniaDataset(samples=samples, info=dataset_info)
189
513
 
190
514
  @staticmethod
191
515
  def from_name(name: str, force_redownload: bool = False, download_files: bool = True) -> "HafniaDataset":
192
516
  """
193
517
  Load a dataset by its name. The dataset must be registered in the Hafnia platform.
194
518
  """
195
- from hafnia.dataset.hafnia_dataset import HafniaDataset
196
519
  from hafnia.platform.datasets import download_or_get_dataset_path
197
520
 
198
521
  dataset_path = download_or_get_dataset_path(
199
- dataset_name=name, force_redownload=force_redownload, download_files=download_files
522
+ dataset_name=name,
523
+ force_redownload=force_redownload,
524
+ download_files=download_files,
200
525
  )
201
526
  return HafniaDataset.from_path(dataset_path, check_for_images=download_files)
202
527
 
@@ -210,9 +535,16 @@ class HafniaDataset:
210
535
  else:
211
536
  raise TypeError(f"Unsupported sample type: {type(sample)}. Expected Sample or dict.")
212
537
 
213
- table = pl.from_records(json_samples).drop(ColumnName.SAMPLE_INDEX)
214
- table = table.with_row_index(name=ColumnName.SAMPLE_INDEX) # Add sample index column
538
+ table = pl.from_records(json_samples)
539
+ table = table.drop(ColumnName.SAMPLE_INDEX).with_row_index(name=ColumnName.SAMPLE_INDEX)
215
540
 
541
+ # Add 'dataset_name' to samples
542
+ table = table.with_columns(
543
+ pl.when(pl.col(ColumnName.DATASET_NAME).is_null())
544
+ .then(pl.lit(info.dataset_name))
545
+ .otherwise(pl.col(ColumnName.DATASET_NAME))
546
+ .alias(ColumnName.DATASET_NAME)
547
+ )
216
548
  return HafniaDataset(info=info, samples=table)
217
549
 
218
550
  @staticmethod
@@ -241,7 +573,11 @@ class HafniaDataset:
241
573
  If the dataset is already cached, it will be loaded from the cache.
242
574
  """
243
575
 
244
- path_dataset = get_or_create_dataset_path_from_recipe(dataset_recipe, path_datasets=path_datasets)
576
+ path_dataset = get_or_create_dataset_path_from_recipe(
577
+ dataset_recipe,
578
+ path_datasets=path_datasets,
579
+ force_redownload=force_redownload,
580
+ )
245
581
  return HafniaDataset.from_path(path_dataset, check_for_images=False)
246
582
 
247
583
  @staticmethod
@@ -263,20 +599,46 @@ class HafniaDataset:
263
599
  merged_dataset = HafniaDataset.merge(merged_dataset, dataset)
264
600
  return merged_dataset
265
601
 
266
- # Dataset transformations
267
- transform_images = dataset_transformations.transform_images
602
+ @staticmethod
603
+ def from_name_public_dataset(
604
+ name: str,
605
+ force_redownload: bool = False,
606
+ n_samples: Optional[int] = None,
607
+ ) -> HafniaDataset:
608
+ from hafnia.dataset.format_conversions.torchvision_datasets import (
609
+ torchvision_to_hafnia_converters,
610
+ )
611
+
612
+ name_to_torchvision_function = torchvision_to_hafnia_converters()
613
+
614
+ if name not in name_to_torchvision_function:
615
+ raise ValueError(
616
+ f"Unknown torchvision dataset name: {name}. Supported: {list(name_to_torchvision_function.keys())}"
617
+ )
618
+ vision_dataset = name_to_torchvision_function[name]
619
+ return vision_dataset(
620
+ force_redownload=force_redownload,
621
+ n_samples=n_samples,
622
+ )
268
623
 
269
624
  def shuffle(dataset: HafniaDataset, seed: int = 42) -> HafniaDataset:
270
625
  table = dataset.samples.sample(n=len(dataset), with_replacement=False, seed=seed, shuffle=True)
271
- return dataset.update_table(table)
626
+ return dataset.update_samples(table)
272
627
 
273
628
  def select_samples(
274
- dataset: "HafniaDataset", n_samples: int, shuffle: bool = True, seed: int = 42, with_replacement: bool = False
629
+ dataset: "HafniaDataset",
630
+ n_samples: int,
631
+ shuffle: bool = True,
632
+ seed: int = 42,
633
+ with_replacement: bool = False,
275
634
  ) -> "HafniaDataset":
635
+ """
636
+ Create a new dataset with a subset of samples.
637
+ """
276
638
  if not with_replacement:
277
639
  n_samples = min(n_samples, len(dataset))
278
640
  table = dataset.samples.sample(n=n_samples, with_replacement=with_replacement, seed=seed, shuffle=shuffle)
279
- return dataset.update_table(table)
641
+ return dataset.update_samples(table)
280
642
 
281
643
  def splits_by_ratios(dataset: "HafniaDataset", split_ratios: Dict[str, float], seed: int = 42) -> "HafniaDataset":
282
644
  """
@@ -295,7 +657,7 @@ class HafniaDataset:
295
657
  split_ratios=split_ratios, n_items=n_items, seed=seed
296
658
  )
297
659
  table = dataset.samples.with_columns(pl.Series(split_name_column).alias("split"))
298
- return dataset.update_table(table)
660
+ return dataset.update_samples(table)
299
661
 
300
662
  def split_into_multiple_splits(
301
663
  dataset: "HafniaDataset",
@@ -323,33 +685,124 @@ class HafniaDataset:
323
685
 
324
686
  remaining_data = dataset.samples.filter(pl.col(ColumnName.SPLIT).is_in([split_name]).not_())
325
687
  new_table = pl.concat([remaining_data, dataset_split_to_be_divided.samples], how="vertical")
326
- dataset_new = dataset.update_table(new_table)
688
+ dataset_new = dataset.update_samples(new_table)
327
689
  return dataset_new
328
690
 
329
691
  def define_sample_set_by_size(dataset: "HafniaDataset", n_samples: int, seed: int = 42) -> "HafniaDataset":
692
+ """
693
+ Defines a sample set randomly by selecting 'n_samples' samples from the dataset.
694
+ """
695
+ samples = dataset.samples
696
+
697
+ # Remove any pre-existing "sample"-tags
698
+ samples = samples.with_columns(
699
+ pl.col(ColumnName.TAGS).list.eval(pl.element().filter(pl.element() != TAG_IS_SAMPLE)).alias(ColumnName.TAGS)
700
+ )
701
+
702
+ # Add "sample" to tags column for the selected samples
330
703
  is_sample_indices = Random(seed).sample(range(len(dataset)), n_samples)
331
- is_sample_column = [False for _ in range(len(dataset))]
332
- for idx in is_sample_indices:
333
- is_sample_column[idx] = True
704
+ samples = samples.with_columns(
705
+ pl.when(pl.int_range(len(samples)).is_in(is_sample_indices))
706
+ .then(pl.col(ColumnName.TAGS).list.concat(pl.lit([TAG_IS_SAMPLE])))
707
+ .otherwise(pl.col(ColumnName.TAGS))
708
+ )
709
+ return dataset.update_samples(samples)
710
+
711
+ def class_mapper(
712
+ dataset: "HafniaDataset",
713
+ class_mapping: Union[Dict[str, str], List[Tuple[str, str]]],
714
+ method: str = "strict",
715
+ primitive: Optional[Type[Primitive]] = None,
716
+ task_name: Optional[str] = None,
717
+ ) -> "HafniaDataset":
718
+ """
719
+ Map class names to new class names using a strict mapping.
720
+ A strict mapping means that all class names in the dataset must be mapped to a new class name.
721
+ If a class name is not mapped, an error is raised.
722
+
723
+ The class indices are determined by the order of appearance of the new class names in the mapping.
724
+ Duplicates in the new class names are removed, preserving the order of first appearance.
725
+
726
+ E.g.
727
+
728
+ mnist = HafniaDataset.from_name("mnist")
729
+ strict_class_mapping = {
730
+ "1 - one": "odd", # 'odd' appears first and becomes class index 0
731
+ "3 - three": "odd",
732
+ "5 - five": "odd",
733
+ "7 - seven": "odd",
734
+ "9 - nine": "odd",
735
+ "0 - zero": "even", # 'even' appears second and becomes class index 1
736
+ "2 - two": "even",
737
+ "4 - four": "even",
738
+ "6 - six": "even",
739
+ "8 - eight": "even",
740
+ }
741
+
742
+ dataset_new = class_mapper(dataset=mnist, class_mapping=strict_class_mapping)
743
+
744
+ """
745
+ return dataset_transformations.class_mapper(
746
+ dataset=dataset,
747
+ class_mapping=class_mapping,
748
+ method=method,
749
+ primitive=primitive,
750
+ task_name=task_name,
751
+ )
334
752
 
335
- table = dataset.samples.with_columns(pl.Series(is_sample_column).alias("is_sample"))
336
- return dataset.update_table(table)
753
+ def rename_task(
754
+ dataset: "HafniaDataset",
755
+ old_task_name: str,
756
+ new_task_name: str,
757
+ ) -> "HafniaDataset":
758
+ """
759
+ Rename a task in the dataset.
760
+ """
761
+ return dataset_transformations.rename_task(
762
+ dataset=dataset, old_task_name=old_task_name, new_task_name=new_task_name
763
+ )
764
+
765
+ def select_samples_by_class_name(
766
+ dataset: HafniaDataset,
767
+ name: Union[List[str], str],
768
+ task_name: Optional[str] = None,
769
+ primitive: Optional[Type[Primitive]] = None,
770
+ ) -> HafniaDataset:
771
+ """
772
+ Select samples that contain at least one annotation with the specified class name(s).
773
+ If 'task_name' and 'primitive' are not provided, the function will attempt to infer the task.
774
+ """
775
+ return dataset_transformations.select_samples_by_class_name(
776
+ dataset=dataset, name=name, task_name=task_name, primitive=primitive
777
+ )
337
778
 
338
779
  def merge(dataset0: "HafniaDataset", dataset1: "HafniaDataset") -> "HafniaDataset":
339
780
  """
340
781
  Merges two Hafnia datasets by concatenating their samples and updating the split names.
341
782
  """
342
- ## Currently, only a very naive merging is implemented.
343
- # In the future we need to verify that the class and tasks are compatible.
344
- # Do they have similar classes and tasks? What to do if they don't?
345
- # For now, we just concatenate the samples and keep the split names as they are.
346
- merged_samples = pl.concat([dataset0.samples, dataset1.samples], how="vertical")
347
- return dataset0.update_table(merged_samples)
348
783
 
349
- # Dataset stats
350
- split_counts = dataset_stats.split_counts
784
+ # Merges dataset info and checks for compatibility
785
+ merged_info = DatasetInfo.merge(dataset0.info, dataset1.info)
786
+
787
+ # Merges samples tables (removes incompatible columns)
788
+ merged_samples = table_transformations.merge_samples(samples0=dataset0.samples, samples1=dataset1.samples)
789
+
790
+ # Check if primitives have been removed during the merge_samples
791
+ for task in copy.deepcopy(merged_info.tasks):
792
+ if task.primitive.column_name() not in merged_samples.columns:
793
+ user_logger.warning(
794
+ f"Task '{task.name}' with primitive '{task.primitive.__name__}' has been removed during the merge. "
795
+ "This happens if the two datasets do not have the same primitives."
796
+ )
797
+ merged_info = merged_info.replace_task(old_task=task, new_task=None)
798
+
799
+ return HafniaDataset(info=merged_info, samples=merged_samples)
351
800
 
352
801
  def as_dict_dataset_splits(self) -> Dict[str, "HafniaDataset"]:
802
+ """
803
+ Splits the dataset into multiple datasets based on the 'split' column.
804
+ Returns a dictionary with split names as keys and HafniaDataset objects as values.
805
+ """
353
806
  if ColumnName.SPLIT not in self.samples.columns:
354
807
  raise ValueError(f"Dataset must contain a '{ColumnName.SPLIT}' column.")
355
808
 
@@ -360,10 +813,22 @@ class HafniaDataset:
360
813
  return splits
361
814
 
362
815
  def create_sample_dataset(self) -> "HafniaDataset":
363
- if ColumnName.IS_SAMPLE not in self.samples.columns:
364
- raise ValueError(f"Dataset must contain an '{ColumnName.IS_SAMPLE}' column.")
365
- table = self.samples.filter(pl.col(ColumnName.IS_SAMPLE))
366
- return self.update_table(table)
816
+ # Backwards compatibility. Remove in future versions when dataset have been updated
817
+ if "is_sample" in self.samples.columns:
818
+ user_logger.warning(
819
+ "'is_sample' column found in the dataset. This column is deprecated and will be removed in future versions. "
820
+ "Please use the 'tags' column with the tag 'sample' instead."
821
+ )
822
+ table = self.samples.filter(pl.col("is_sample") == True) # noqa: E712
823
+ return self.update_samples(table)
824
+
825
+ if ColumnName.TAGS not in self.samples.columns:
826
+ raise ValueError(f"Dataset must contain an '{ColumnName.TAGS}' column.")
827
+
828
+ table = self.samples.filter(
829
+ pl.col(ColumnName.TAGS).list.eval(pl.element().filter(pl.element() == TAG_IS_SAMPLE)).list.len() > 0
830
+ )
831
+ return self.update_samples(table)
367
832
 
368
833
  def create_split_dataset(self, split_name: Union[str | List[str]]) -> "HafniaDataset":
369
834
  if isinstance(split_name, str):
@@ -376,16 +841,12 @@ class HafniaDataset:
376
841
  raise ValueError(f"Invalid split name: {split_name}. Valid splits are: {SplitName.valid_splits()}")
377
842
 
378
843
  filtered_dataset = self.samples.filter(pl.col(ColumnName.SPLIT).is_in(split_names))
379
- return self.update_table(filtered_dataset)
380
-
381
- def get_task_by_name(self, task_name: str) -> TaskInfo:
382
- for task in self.info.tasks:
383
- if task.name == task_name:
384
- return task
385
- raise ValueError(f"Task with name {task_name} not found in dataset info.")
844
+ return self.update_samples(filtered_dataset)
386
845
 
387
- def update_table(self, table: pl.DataFrame) -> "HafniaDataset":
388
- return HafniaDataset(info=self.info.model_copy(), samples=table)
846
+ def update_samples(self, table: pl.DataFrame) -> "HafniaDataset":
847
+ dataset = HafniaDataset(info=self.info.model_copy(deep=True), samples=table)
848
+ dataset.check_dataset_tasks()
849
+ return dataset
389
850
 
390
851
  @staticmethod
391
852
  def check_dataset_path(path_dataset: Path, raise_error: bool = True) -> bool:
@@ -411,19 +872,27 @@ class HafniaDataset:
411
872
 
412
873
  return True
413
874
 
414
- def write(self, path_folder: Path, add_version: bool = False) -> None:
875
+ def copy(self) -> "HafniaDataset":
876
+ return HafniaDataset(info=self.info.model_copy(deep=True), samples=self.samples.clone())
877
+
878
+ def write(self, path_folder: Path, add_version: bool = False, drop_null_cols: bool = True) -> None:
415
879
  user_logger.info(f"Writing dataset to {path_folder}...")
416
880
  if not path_folder.exists():
417
881
  path_folder.mkdir(parents=True)
418
882
 
419
883
  new_relative_paths = []
420
- for org_path in tqdm(self.samples["file_name"].to_list(), desc="- Copy images"):
884
+ org_paths = self.samples[ColumnName.FILE_PATH].to_list()
885
+ for org_path in track(org_paths, description="- Copy images"):
421
886
  new_path = dataset_helpers.copy_and_rename_file_to_hash_value(
422
887
  path_source=Path(org_path),
423
888
  path_dataset_root=path_folder,
424
889
  )
425
890
  new_relative_paths.append(str(new_path.relative_to(path_folder)))
426
- table = self.samples.with_columns(pl.Series(new_relative_paths).alias("file_name"))
891
+ table = self.samples.with_columns(pl.Series(new_relative_paths).alias(ColumnName.FILE_PATH))
892
+
893
+ if drop_null_cols: # Drops all unused/Null columns
894
+ table = table.drop(pl.selectors.by_dtype(pl.Null))
895
+
427
896
  table.write_ndjson(path_folder / FILENAME_ANNOTATIONS_JSONL) # Json for readability
428
897
  table.write_parquet(path_folder / FILENAME_ANNOTATIONS_PARQUET) # Parquet for speed
429
898
  self.info.write_json(path_folder / FILENAME_DATASET_INFO)
@@ -448,51 +917,10 @@ class HafniaDataset:
448
917
  return False
449
918
  return True
450
919
 
451
- def print_stats(self) -> None:
452
- table_base = Table(title="Dataset Statistics", show_lines=True, box=rich.box.SIMPLE)
453
- table_base.add_column("Property", style="cyan")
454
- table_base.add_column("Value")
455
- table_base.add_row("Dataset Name", self.info.dataset_name)
456
- table_base.add_row("Version", self.info.version)
457
- table_base.add_row("Number of samples", str(len(self.samples)))
458
- rprint(table_base)
459
- rprint(self.info.tasks)
460
-
461
- splits_sets = {
462
- "All": SplitName.valid_splits(),
463
- "Train": [SplitName.TRAIN],
464
- "Validation": [SplitName.VAL],
465
- "Test": [SplitName.TEST],
466
- }
467
- rows = []
468
- for split_name, splits in splits_sets.items():
469
- dataset_split = self.create_split_dataset(splits)
470
- table = dataset_split.samples
471
- row = {}
472
- row["Split"] = split_name
473
- row["Sample "] = str(len(table))
474
- for PrimitiveType in PRIMITIVE_TYPES:
475
- column_name = PrimitiveType.column_name()
476
- objects_df = create_primitive_table(table, PrimitiveType=PrimitiveType, keep_sample_data=False)
477
- if objects_df is None:
478
- continue
479
- for (task_name,), object_group in objects_df.group_by(FieldName.TASK_NAME):
480
- count = len(object_group[FieldName.CLASS_NAME])
481
- row[f"{PrimitiveType.__name__}\n{task_name}"] = str(count)
482
- rows.append(row)
483
-
484
- rich_table = Table(title="Dataset Statistics", show_lines=True, box=rich.box.SIMPLE)
485
- for i_row, row in enumerate(rows):
486
- if i_row == 0:
487
- for column_name in row.keys():
488
- rich_table.add_column(column_name, justify="left", style="cyan")
489
- rich_table.add_row(*[str(value) for value in row.values()])
490
- rprint(rich_table)
491
-
492
920
 
493
921
  def check_hafnia_dataset_from_path(path_dataset: Path) -> None:
494
922
  dataset = HafniaDataset.from_path(path_dataset, check_for_images=True)
495
- check_hafnia_dataset(dataset)
923
+ dataset.check_dataset()
496
924
 
497
925
 
498
926
  def get_or_create_dataset_path_from_recipe(
@@ -524,87 +952,23 @@ def get_or_create_dataset_path_from_recipe(
524
952
  return path_dataset
525
953
 
526
954
 
527
- def check_hafnia_dataset(dataset: HafniaDataset):
528
- user_logger.info("Checking Hafnia dataset...")
529
- assert isinstance(dataset.info.version, str) and len(dataset.info.version) > 0
530
- assert isinstance(dataset.info.dataset_name, str) and len(dataset.info.dataset_name) > 0
531
-
532
- is_sample_list = set(dataset.samples.select(pl.col(ColumnName.IS_SAMPLE)).unique().to_series().to_list())
533
- if True not in is_sample_list:
534
- raise ValueError(f"The dataset should contain '{ColumnName.IS_SAMPLE}=True' samples")
535
-
536
- actual_splits = dataset.samples.select(pl.col(ColumnName.SPLIT)).unique().to_series().to_list()
537
- expected_splits = SplitName.valid_splits()
538
- if set(actual_splits) != set(expected_splits):
539
- raise ValueError(f"Expected all splits '{expected_splits}' in dataset, but got '{actual_splits}'. ")
540
-
541
- expected_tasks = dataset.info.tasks
542
- for task in expected_tasks:
543
- primitive = task.primitive.__name__
544
- column_name = task.primitive.column_name()
545
- primitive_column = dataset.samples[column_name]
546
- # msg_something_wrong = f"Something is wrong with the '{primtive_name}' task '{task.name}' in dataset '{dataset.name}'. "
547
- msg_something_wrong = (
548
- f"Something is wrong with the defined tasks ('info.tasks') in dataset '{dataset.info.dataset_name}'. \n"
549
- f"For '{primitive=}' and '{task.name=}' "
550
- )
551
- if primitive_column.dtype == pl.Null:
552
- raise ValueError(msg_something_wrong + "the column is 'Null'. Please check the dataset.")
955
+ def _dataset_corrections(samples: pl.DataFrame, dataset_info: DatasetInfo) -> Tuple[pl.DataFrame, DatasetInfo]:
956
+ format_version_of_dataset = Version(dataset_info.format_version)
553
957
 
554
- primitive_table = primitive_column.explode().struct.unnest().filter(pl.col(FieldName.TASK_NAME) == task.name)
555
- if primitive_table.is_empty():
556
- raise ValueError(
557
- msg_something_wrong
558
- + f"the column '{column_name}' has no {task.name=} objects. Please check the dataset."
559
- )
958
+ ## Backwards compatibility fixes for older dataset versions
959
+ if format_version_of_dataset <= Version("0.3.0"):
960
+ if ColumnName.DATASET_NAME not in samples.columns:
961
+ samples = samples.with_columns(pl.lit(dataset_info.dataset_name).alias(ColumnName.DATASET_NAME))
560
962
 
561
- actual_classes = set(primitive_table[FieldName.CLASS_NAME].unique().to_list())
562
- if task.class_names is None:
563
- raise ValueError(
564
- msg_something_wrong
565
- + f"the column '{column_name}' with {task.name=} has no defined classes. Please check the dataset."
566
- )
567
- defined_classes = set(task.class_names)
963
+ if "file_name" in samples.columns:
964
+ samples = samples.rename({"file_name": ColumnName.FILE_PATH})
568
965
 
569
- if not actual_classes.issubset(defined_classes):
570
- raise ValueError(
571
- msg_something_wrong
572
- + f"the column '{column_name}' with {task.name=} we expected the actual classes in the dataset to \n"
573
- f"to be a subset of the defined classes\n\t{actual_classes=} \n\t{defined_classes=}."
574
- )
575
- # Check class_indices
576
- mapped_indices = primitive_table[FieldName.CLASS_NAME].map_elements(
577
- lambda x: task.class_names.index(x), return_dtype=pl.Int64
578
- )
579
- table_indices = primitive_table[FieldName.CLASS_IDX]
966
+ if ColumnName.SAMPLE_INDEX not in samples.columns:
967
+ samples = samples.with_row_index(name=ColumnName.SAMPLE_INDEX)
580
968
 
581
- error_msg = msg_something_wrong + (
582
- f"class indices in '{FieldName.CLASS_IDX}' column does not match classes ordering in 'task.class_names'"
583
- )
584
- assert mapped_indices.equals(table_indices), error_msg
585
-
586
- distribution = dataset.info.distributions or []
587
- distribution_names = [task.name for task in distribution]
588
- # Check that tasks found in the 'dataset.table' matches the tasks defined in 'dataset.info.tasks'
589
- for PrimitiveType in PRIMITIVE_TYPES:
590
- column_name = PrimitiveType.column_name()
591
- if column_name not in dataset.samples.columns:
592
- continue
593
- objects_df = create_primitive_table(dataset.samples, PrimitiveType=PrimitiveType, keep_sample_data=False)
594
- if objects_df is None:
595
- continue
596
- for (task_name,), object_group in objects_df.group_by(FieldName.TASK_NAME):
597
- has_task = any([t for t in expected_tasks if t.name == task_name and t.primitive == PrimitiveType])
598
- if has_task:
599
- continue
600
- if task_name in distribution_names:
601
- continue
602
- class_names = object_group[FieldName.CLASS_NAME].unique().to_list()
603
- raise ValueError(
604
- f"Task name '{task_name}' for the '{PrimitiveType.__name__}' primitive is missing in "
605
- f"'dataset.info.tasks' for dataset '{task_name}'. Missing task has the following "
606
- f"classes: {class_names}. "
607
- )
969
+ # Backwards compatibility: If tags-column doesn't exist, create it with empty lists
970
+ if ColumnName.TAGS not in samples.columns:
971
+ tags_column: List[List[str]] = [[] for _ in range(len(samples))] # type: ignore[annotation-unchecked]
972
+ samples = samples.with_columns(pl.Series(tags_column, dtype=pl.List(pl.String)).alias(ColumnName.TAGS))
608
973
 
609
- for sample_dict in tqdm(dataset, desc="Checking samples in dataset"):
610
- sample = Sample(**sample_dict) # Checks format of all samples with pydantic validation # noqa: F841
974
+ return samples, dataset_info