hafnia 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. cli/__main__.py +13 -2
  2. cli/config.py +2 -1
  3. cli/consts.py +1 -1
  4. cli/dataset_cmds.py +6 -14
  5. cli/dataset_recipe_cmds.py +78 -0
  6. cli/experiment_cmds.py +226 -43
  7. cli/profile_cmds.py +6 -5
  8. cli/runc_cmds.py +5 -5
  9. cli/trainer_package_cmds.py +65 -0
  10. hafnia/__init__.py +2 -0
  11. hafnia/data/factory.py +1 -2
  12. hafnia/dataset/dataset_helpers.py +0 -12
  13. hafnia/dataset/dataset_names.py +8 -4
  14. hafnia/dataset/dataset_recipe/dataset_recipe.py +119 -33
  15. hafnia/dataset/dataset_recipe/recipe_transforms.py +32 -4
  16. hafnia/dataset/dataset_recipe/recipe_types.py +1 -1
  17. hafnia/dataset/dataset_upload_helper.py +206 -53
  18. hafnia/dataset/hafnia_dataset.py +432 -194
  19. hafnia/dataset/license_types.py +63 -0
  20. hafnia/dataset/operations/dataset_stats.py +260 -3
  21. hafnia/dataset/operations/dataset_transformations.py +325 -4
  22. hafnia/dataset/operations/table_transformations.py +39 -2
  23. hafnia/dataset/primitives/__init__.py +8 -0
  24. hafnia/dataset/primitives/classification.py +1 -1
  25. hafnia/experiment/hafnia_logger.py +112 -0
  26. hafnia/http.py +16 -2
  27. hafnia/platform/__init__.py +9 -3
  28. hafnia/platform/builder.py +12 -10
  29. hafnia/platform/dataset_recipe.py +99 -0
  30. hafnia/platform/datasets.py +44 -6
  31. hafnia/platform/download.py +2 -1
  32. hafnia/platform/experiment.py +51 -56
  33. hafnia/platform/trainer_package.py +57 -0
  34. hafnia/utils.py +64 -13
  35. hafnia/visualizations/image_visualizations.py +3 -3
  36. {hafnia-0.2.4.dist-info → hafnia-0.3.0.dist-info}/METADATA +34 -30
  37. hafnia-0.3.0.dist-info/RECORD +53 -0
  38. cli/recipe_cmds.py +0 -45
  39. hafnia-0.2.4.dist-info/RECORD +0 -49
  40. {hafnia-0.2.4.dist-info → hafnia-0.3.0.dist-info}/WHEEL +0 -0
  41. {hafnia-0.2.4.dist-info → hafnia-0.3.0.dist-info}/entry_points.txt +0 -0
  42. {hafnia-0.2.4.dist-info → hafnia-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,7 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import collections
4
+ import copy
5
+ import json
3
6
  import shutil
4
7
  from dataclasses import dataclass
8
+ from datetime import datetime
5
9
  from pathlib import Path
6
10
  from random import Random
7
11
  from typing import Any, Dict, List, Optional, Type, Union
@@ -9,13 +13,11 @@ from typing import Any, Dict, List, Optional, Type, Union
9
13
  import more_itertools
10
14
  import numpy as np
11
15
  import polars as pl
12
- import rich
13
16
  from PIL import Image
14
- from pydantic import BaseModel, field_serializer, field_validator
15
- from rich import print as rprint
16
- from rich.table import Table
17
+ from pydantic import BaseModel, Field, field_serializer, field_validator
17
18
  from tqdm import tqdm
18
19
 
20
+ import hafnia
19
21
  from hafnia.dataset import dataset_helpers
20
22
  from hafnia.dataset.dataset_names import (
21
23
  DATASET_FILENAMES_REQUIRED,
@@ -23,20 +25,16 @@ from hafnia.dataset.dataset_names import (
23
25
  FILENAME_ANNOTATIONS_PARQUET,
24
26
  FILENAME_DATASET_INFO,
25
27
  FILENAME_RECIPE_JSON,
28
+ TAG_IS_SAMPLE,
26
29
  ColumnName,
27
- FieldName,
28
30
  SplitName,
29
31
  )
30
- from hafnia.dataset.operations import dataset_stats, dataset_transformations
32
+ from hafnia.dataset.operations import dataset_stats, dataset_transformations, table_transformations
31
33
  from hafnia.dataset.operations.table_transformations import (
32
34
  check_image_paths,
33
- create_primitive_table,
34
35
  read_table_from_path,
35
36
  )
36
- from hafnia.dataset.primitives import (
37
- PRIMITIVE_NAME_TO_TYPE,
38
- PRIMITIVE_TYPES,
39
- )
37
+ from hafnia.dataset.primitives import PRIMITIVE_TYPES, get_primitive_type_from_string
40
38
  from hafnia.dataset.primitives.bbox import Bbox
41
39
  from hafnia.dataset.primitives.bitmask import Bitmask
42
40
  from hafnia.dataset.primitives.classification import Classification
@@ -48,9 +46,7 @@ from hafnia.log import user_logger
48
46
  class TaskInfo(BaseModel):
49
47
  primitive: Type[Primitive] # Primitive class or string name of the primitive, e.g. "Bbox" or "bitmask"
50
48
  class_names: Optional[List[str]] # Class names for the tasks. To get consistent class indices specify class_names.
51
- name: Optional[str] = (
52
- None # None to use the default primitive task name Bbox ->"bboxes", Bitmask -> "bitmasks" etc.
53
- )
49
+ name: Optional[str] = None # Use 'None' to use default name Bbox ->"bboxes", Bitmask -> "bitmasks" etc.
54
50
 
55
51
  def model_post_init(self, __context: Any) -> None:
56
52
  if self.name is None:
@@ -64,12 +60,7 @@ class TaskInfo(BaseModel):
64
60
  @classmethod
65
61
  def ensure_primitive(cls, primitive: Any) -> Any:
66
62
  if isinstance(primitive, str):
67
- if primitive not in PRIMITIVE_NAME_TO_TYPE:
68
- raise ValueError(
69
- f"Primitive '{primitive}' is not recognized. Available primitives: {list(PRIMITIVE_NAME_TO_TYPE.keys())}"
70
- )
71
-
72
- return PRIMITIVE_NAME_TO_TYPE[primitive]
63
+ return get_primitive_type_from_string(primitive)
73
64
 
74
65
  if issubclass(primitive, Primitive):
75
66
  return primitive
@@ -83,22 +74,187 @@ class TaskInfo(BaseModel):
83
74
  raise ValueError(f"Primitive must be a subclass of Primitive, got {type(primitive)} instead.")
84
75
  return primitive.__name__
85
76
 
77
+ @field_validator("class_names", mode="after")
78
+ @classmethod
79
+ def validate_unique_class_names(cls, class_names: Optional[List[str]]) -> Optional[List[str]]:
80
+ """Validate that class names are unique"""
81
+ if class_names is None:
82
+ return None
83
+ duplicate_class_names = set([name for name in class_names if class_names.count(name) > 1])
84
+ if duplicate_class_names:
85
+ raise ValueError(
86
+ f"Class names must be unique. The following class names appear multiple times: {duplicate_class_names}."
87
+ )
88
+ return class_names
89
+
90
+ # To get unique hash value for TaskInfo objects
91
+ def __hash__(self) -> int:
92
+ class_names = self.class_names or []
93
+ return hash((self.name, self.primitive.__name__, tuple(class_names)))
94
+
95
+ def __eq__(self, other: Any) -> bool:
96
+ if not isinstance(other, TaskInfo):
97
+ return False
98
+ return self.name == other.name and self.primitive == other.primitive and self.class_names == other.class_names
99
+
86
100
 
87
101
  class DatasetInfo(BaseModel):
88
102
  dataset_name: str
89
- version: str
90
- tasks: list[TaskInfo]
103
+ version: str # Dataset version. This is not the same as the Hafnia dataset format version.
104
+ tasks: List[TaskInfo]
91
105
  distributions: Optional[List[TaskInfo]] = None # Distributions. TODO: FIX/REMOVE/CHANGE this
92
106
  meta: Optional[Dict[str, Any]] = None # Metadata about the dataset, e.g. description, etc.
107
+ format_version: str = hafnia.__dataset_format_version__ # Version of the Hafnia dataset format
108
+ updated_at: datetime = datetime.now()
109
+
110
+ @field_validator("tasks", mode="after")
111
+ @classmethod
112
+ def _validate_check_for_duplicate_tasks(cls, tasks: List[TaskInfo]) -> List[TaskInfo]:
113
+ task_name_counts = collections.Counter(task.name for task in tasks)
114
+ duplicate_task_names = [name for name, count in task_name_counts.items() if count > 1]
115
+ if duplicate_task_names:
116
+ raise ValueError(
117
+ f"Tasks must be unique. The following tasks appear multiple times: {duplicate_task_names}."
118
+ )
119
+ return tasks
120
+
121
+ def check_for_duplicate_task_names(self) -> List[TaskInfo]:
122
+ return self._validate_check_for_duplicate_tasks(self.tasks)
93
123
 
94
124
  def write_json(self, path: Path, indent: Optional[int] = 4) -> None:
95
125
  json_str = self.model_dump_json(indent=indent)
96
126
  path.write_text(json_str)
97
127
 
98
128
  @staticmethod
99
- def from_json_file(path: Path) -> "DatasetInfo":
129
+ def from_json_file(path: Path) -> DatasetInfo:
100
130
  json_str = path.read_text()
101
- return DatasetInfo.model_validate_json(json_str)
131
+
132
+ # TODO: Deprecated support for old dataset info without format_version
133
+ # Below 4 lines can be replaced by 'dataset_info = DatasetInfo.model_validate_json(json_str)'
134
+ # when all datasets include a 'format_version' field
135
+ json_dict = json.loads(json_str)
136
+ if "format_version" not in json_dict:
137
+ json_dict["format_version"] = "0.0.0"
138
+
139
+ if "updated_at" not in json_dict:
140
+ json_dict["updated_at"] = datetime.min.isoformat()
141
+ dataset_info = DatasetInfo.model_validate(json_dict)
142
+
143
+ return dataset_info
144
+
145
+ @staticmethod
146
+ def merge(info0: DatasetInfo, info1: DatasetInfo) -> DatasetInfo:
147
+ """
148
+ Merges two DatasetInfo objects into one and validates if they are compatible.
149
+ """
150
+ for task_ds0 in info0.tasks:
151
+ for task_ds1 in info1.tasks:
152
+ same_name = task_ds0.name == task_ds1.name
153
+ same_primitive = task_ds0.primitive == task_ds1.primitive
154
+ same_name_different_primitive = same_name and not same_primitive
155
+ if same_name_different_primitive:
156
+ raise ValueError(
157
+ f"Cannot merge datasets with different primitives for the same task name: "
158
+ f"'{task_ds0.name}' has primitive '{task_ds0.primitive}' in dataset0 and "
159
+ f"'{task_ds1.primitive}' in dataset1."
160
+ )
161
+
162
+ is_same_name_and_primitive = same_name and same_primitive
163
+ if is_same_name_and_primitive:
164
+ task_ds0_class_names = task_ds0.class_names or []
165
+ task_ds1_class_names = task_ds1.class_names or []
166
+ if task_ds0_class_names != task_ds1_class_names:
167
+ raise ValueError(
168
+ f"Cannot merge datasets with different class names for the same task name and primitive: "
169
+ f"'{task_ds0.name}' with primitive '{task_ds0.primitive}' has class names "
170
+ f"{task_ds0_class_names} in dataset0 and {task_ds1_class_names} in dataset1."
171
+ )
172
+
173
+ if info1.format_version != info0.format_version:
174
+ user_logger.warning(
175
+ "Dataset format version of the two datasets do not match. "
176
+ f"'{info1.format_version}' vs '{info0.format_version}'."
177
+ )
178
+ dataset_format_version = info0.format_version
179
+ if hafnia.__dataset_format_version__ != dataset_format_version:
180
+ user_logger.warning(
181
+ f"Dataset format version '{dataset_format_version}' does not match the current "
182
+ f"Hafnia format version '{hafnia.__dataset_format_version__}'."
183
+ )
184
+ unique_tasks = set(info0.tasks + info1.tasks)
185
+ distributions = set((info0.distributions or []) + (info1.distributions or []))
186
+ meta = (info0.meta or {}).copy()
187
+ meta.update(info1.meta or {})
188
+ return DatasetInfo(
189
+ dataset_name=info0.dataset_name + "+" + info1.dataset_name,
190
+ version="merged",
191
+ tasks=list(unique_tasks),
192
+ distributions=list(distributions),
193
+ meta=meta,
194
+ format_version=dataset_format_version,
195
+ )
196
+
197
+ def get_task_by_name(self, task_name: str) -> TaskInfo:
198
+ """
199
+ Get task by its name. Raises an error if the task name is not found or if multiple tasks have the same name.
200
+ """
201
+ tasks_with_name = [task for task in self.tasks if task.name == task_name]
202
+ if not tasks_with_name:
203
+ raise ValueError(f"Task with name '{task_name}' not found in dataset info.")
204
+ if len(tasks_with_name) > 1:
205
+ raise ValueError(f"Multiple tasks found with name '{task_name}'. This should not happen!")
206
+ return tasks_with_name[0]
207
+
208
+ def get_task_by_primitive(self, primitive: Union[Type[Primitive], str]) -> TaskInfo:
209
+ """
210
+ Get task by its primitive type. Raises an error if the primitive type is not found or if multiple tasks
211
+ have the same primitive type.
212
+ """
213
+ if isinstance(primitive, str):
214
+ primitive = get_primitive_type_from_string(primitive)
215
+
216
+ tasks_with_primitive = [task for task in self.tasks if task.primitive == primitive]
217
+ if not tasks_with_primitive:
218
+ raise ValueError(f"Task with primitive {primitive} not found in dataset info.")
219
+ if len(tasks_with_primitive) > 1:
220
+ raise ValueError(
221
+ f"Multiple tasks found with primitive {primitive}. Use '{self.get_task_by_name.__name__}' instead."
222
+ )
223
+ return tasks_with_primitive[0]
224
+
225
+ def get_task_by_task_name_and_primitive(
226
+ self,
227
+ task_name: Optional[str],
228
+ primitive: Optional[Union[Type[Primitive], str]],
229
+ ) -> TaskInfo:
230
+ """
231
+ Logic to get a unique task based on the provided 'task_name' and/or 'primitive'.
232
+ If both 'task_name' and 'primitive' are None, the dataset must have only one task.
233
+ """
234
+ task = dataset_transformations.get_task_info_from_task_name_and_primitive(
235
+ tasks=self.tasks,
236
+ primitive=primitive,
237
+ task_name=task_name,
238
+ )
239
+ return task
240
+
241
+ def replace_task(self, old_task: TaskInfo, new_task: Optional[TaskInfo]) -> DatasetInfo:
242
+ dataset_info = self.model_copy(deep=True)
243
+ has_task = any(t for t in dataset_info.tasks if t.name == old_task.name and t.primitive == old_task.primitive)
244
+ if not has_task:
245
+ raise ValueError(f"Task '{old_task.__repr__()}' not found in dataset info.")
246
+
247
+ new_tasks = []
248
+ for task in dataset_info.tasks:
249
+ if task.name == old_task.name and task.primitive == old_task.primitive:
250
+ if new_task is None:
251
+ continue # Remove the task
252
+ new_tasks.append(new_task)
253
+ else:
254
+ new_tasks.append(task)
255
+
256
+ dataset_info.tasks = new_tasks
257
+ return dataset_info
102
258
 
103
259
 
104
260
  class Sample(BaseModel):
@@ -106,7 +262,7 @@ class Sample(BaseModel):
106
262
  height: int
107
263
  width: int
108
264
  split: str # Split name, e.g., "train", "val", "test"
109
- is_sample: bool # Indicates if this is a sample (True) or a metadata entry (False)
265
+ tags: List[str] = [] # tags for a given sample. Used for creating subsets of the dataset.
110
266
  collection_index: Optional[int] = None # Optional e.g. frame number for video datasets
111
267
  collection_id: Optional[str] = None # Optional e.g. video name for video datasets
112
268
  remote_path: Optional[str] = None # Optional remote path for the image, if applicable
@@ -116,6 +272,7 @@ class Sample(BaseModel):
116
272
  bitmasks: Optional[List[Bitmask]] = None # List of bitmasks, if applicable
117
273
  polygons: Optional[List[Polygon]] = None # List of polygons, if applicable
118
274
 
275
+ attribution: Optional[Attribution] = None # Attribution information for the image
119
276
  meta: Optional[Dict] = None # Additional metadata, e.g., camera settings, GPS data, etc.
120
277
 
121
278
  def get_annotations(self, primitive_types: Optional[List[Type[Primitive]]] = None) -> List[Primitive]:
@@ -158,11 +315,93 @@ class Sample(BaseModel):
158
315
  return annotations_visualized
159
316
 
160
317
 
318
+ class License(BaseModel):
319
+ """License information"""
320
+
321
+ name: Optional[str] = Field(
322
+ default=None,
323
+ description="License name. E.g. 'Creative Commons: Attribution 2.0 Generic'",
324
+ max_length=100,
325
+ )
326
+ name_short: Optional[str] = Field(
327
+ default=None,
328
+ description="License short name or abbreviation. E.g. 'CC BY 4.0'",
329
+ max_length=100,
330
+ )
331
+ url: Optional[str] = Field(
332
+ default=None,
333
+ description="License URL e.g. https://creativecommons.org/licenses/by/4.0/",
334
+ )
335
+ description: Optional[str] = Field(
336
+ default=None,
337
+ description=(
338
+ "License description e.g. 'You must give appropriate credit, provide a "
339
+ "link to the license, and indicate if changes were made.'"
340
+ ),
341
+ )
342
+
343
+ valid_date: Optional[datetime] = Field(
344
+ default=None,
345
+ description="License valid date. E.g. '2023-01-01T00:00:00Z'",
346
+ )
347
+
348
+ permissions: Optional[List[str]] = Field(
349
+ default=None,
350
+ description="License permissions. Allowed to Access, Label, Distribute, Represent and Modify data.",
351
+ )
352
+ liability: Optional[str] = Field(
353
+ default=None,
354
+ description="License liability. Optional and not always applicable.",
355
+ )
356
+ location: Optional[str] = Field(
357
+ default=None,
358
+ description=(
359
+ "License Location. E.g. Iowa state. This is essential to understand the industry and "
360
+ "privacy location specific rules that applies to the data. Optional and not always applicable."
361
+ ),
362
+ )
363
+ notes: Optional[str] = Field(
364
+ default=None,
365
+ description="Additional license notes. Optional and not always applicable.",
366
+ )
367
+
368
+
369
+ class Attribution(BaseModel):
370
+ """Attribution information for the image: Giving source and credit to the original creator"""
371
+
372
+ title: Optional[str] = Field(default=None, description="Title of the image", max_length=255)
373
+ creator: Optional[str] = Field(default=None, description="Creator of the image", max_length=255)
374
+ creator_url: Optional[str] = Field(default=None, description="URL of the creator", max_length=255)
375
+ date_captured: Optional[datetime] = Field(default=None, description="Date when the image was captured")
376
+ copyright_notice: Optional[str] = Field(default=None, description="Copyright notice for the image", max_length=255)
377
+ licenses: Optional[List[License]] = Field(default=None, description="List of licenses for the image")
378
+ disclaimer: Optional[str] = Field(default=None, description="Disclaimer for the image", max_length=255)
379
+ changes: Optional[str] = Field(default=None, description="Changes made to the image", max_length=255)
380
+ source_url: Optional[str] = Field(default=None, description="Source URL for the image", max_length=255)
381
+
382
+
161
383
  @dataclass
162
384
  class HafniaDataset:
163
385
  info: DatasetInfo
164
386
  samples: pl.DataFrame
165
387
 
388
+ # Function mapping: Dataset stats
389
+ split_counts = dataset_stats.split_counts
390
+ class_counts_for_task = dataset_stats.class_counts_for_task
391
+ class_counts_all = dataset_stats.class_counts_all
392
+
393
+ # Function mapping: Print stats
394
+ print_stats = dataset_stats.print_stats
395
+ print_sample_and_task_counts = dataset_stats.print_sample_and_task_counts
396
+ print_class_distribution = dataset_stats.print_class_distribution
397
+
398
+ # Function mapping: Dataset checks
399
+ check_dataset = dataset_stats.check_dataset
400
+ check_dataset_tasks = dataset_stats.check_dataset_tasks
401
+
402
+ # Function mapping: Dataset transformations
403
+ transform_images = dataset_transformations.transform_images
404
+
166
405
  def __getitem__(self, item: int) -> Dict[str, Any]:
167
406
  return self.samples.row(index=item, named=True)
168
407
 
@@ -173,6 +412,18 @@ class HafniaDataset:
173
412
  for row in self.samples.iter_rows(named=True):
174
413
  yield row
175
414
 
415
+ def __post_init__(self):
416
+ samples = self.samples
417
+ if ColumnName.SAMPLE_INDEX not in samples.columns:
418
+ samples = samples.with_row_index(name=ColumnName.SAMPLE_INDEX)
419
+
420
+ # Backwards compatibility: If tags-column doesn't exist, create it with empty lists
421
+ if ColumnName.TAGS not in samples.columns:
422
+ tags_column: List[List[str]] = [[] for _ in range(len(self))] # type: ignore[annotation-unchecked]
423
+ samples = samples.with_columns(pl.Series(tags_column, dtype=pl.List(pl.String)).alias(ColumnName.TAGS))
424
+
425
+ self.samples = samples
426
+
176
427
  @staticmethod
177
428
  def from_path(path_folder: Path, check_for_images: bool = True) -> "HafniaDataset":
178
429
  HafniaDataset.check_dataset_path(path_folder, raise_error=True)
@@ -192,11 +443,12 @@ class HafniaDataset:
192
443
  """
193
444
  Load a dataset by its name. The dataset must be registered in the Hafnia platform.
194
445
  """
195
- from hafnia.dataset.hafnia_dataset import HafniaDataset
196
446
  from hafnia.platform.datasets import download_or_get_dataset_path
197
447
 
198
448
  dataset_path = download_or_get_dataset_path(
199
- dataset_name=name, force_redownload=force_redownload, download_files=download_files
449
+ dataset_name=name,
450
+ force_redownload=force_redownload,
451
+ download_files=download_files,
200
452
  )
201
453
  return HafniaDataset.from_path(dataset_path, check_for_images=download_files)
202
454
 
@@ -210,9 +462,8 @@ class HafniaDataset:
210
462
  else:
211
463
  raise TypeError(f"Unsupported sample type: {type(sample)}. Expected Sample or dict.")
212
464
 
213
- table = pl.from_records(json_samples).drop(ColumnName.SAMPLE_INDEX)
214
- table = table.with_row_index(name=ColumnName.SAMPLE_INDEX) # Add sample index column
215
-
465
+ table = pl.from_records(json_samples)
466
+ table = table.drop(ColumnName.SAMPLE_INDEX).with_row_index(name=ColumnName.SAMPLE_INDEX)
216
467
  return HafniaDataset(info=info, samples=table)
217
468
 
218
469
  @staticmethod
@@ -241,7 +492,11 @@ class HafniaDataset:
241
492
  If the dataset is already cached, it will be loaded from the cache.
242
493
  """
243
494
 
244
- path_dataset = get_or_create_dataset_path_from_recipe(dataset_recipe, path_datasets=path_datasets)
495
+ path_dataset = get_or_create_dataset_path_from_recipe(
496
+ dataset_recipe,
497
+ path_datasets=path_datasets,
498
+ force_redownload=force_redownload,
499
+ )
245
500
  return HafniaDataset.from_path(path_dataset, check_for_images=False)
246
501
 
247
502
  @staticmethod
@@ -263,20 +518,24 @@ class HafniaDataset:
263
518
  merged_dataset = HafniaDataset.merge(merged_dataset, dataset)
264
519
  return merged_dataset
265
520
 
266
- # Dataset transformations
267
- transform_images = dataset_transformations.transform_images
268
-
269
521
  def shuffle(dataset: HafniaDataset, seed: int = 42) -> HafniaDataset:
270
522
  table = dataset.samples.sample(n=len(dataset), with_replacement=False, seed=seed, shuffle=True)
271
- return dataset.update_table(table)
523
+ return dataset.update_samples(table)
272
524
 
273
525
  def select_samples(
274
- dataset: "HafniaDataset", n_samples: int, shuffle: bool = True, seed: int = 42, with_replacement: bool = False
526
+ dataset: "HafniaDataset",
527
+ n_samples: int,
528
+ shuffle: bool = True,
529
+ seed: int = 42,
530
+ with_replacement: bool = False,
275
531
  ) -> "HafniaDataset":
532
+ """
533
+ Create a new dataset with a subset of samples.
534
+ """
276
535
  if not with_replacement:
277
536
  n_samples = min(n_samples, len(dataset))
278
537
  table = dataset.samples.sample(n=n_samples, with_replacement=with_replacement, seed=seed, shuffle=shuffle)
279
- return dataset.update_table(table)
538
+ return dataset.update_samples(table)
280
539
 
281
540
  def splits_by_ratios(dataset: "HafniaDataset", split_ratios: Dict[str, float], seed: int = 42) -> "HafniaDataset":
282
541
  """
@@ -295,7 +554,7 @@ class HafniaDataset:
295
554
  split_ratios=split_ratios, n_items=n_items, seed=seed
296
555
  )
297
556
  table = dataset.samples.with_columns(pl.Series(split_name_column).alias("split"))
298
- return dataset.update_table(table)
557
+ return dataset.update_samples(table)
299
558
 
300
559
  def split_into_multiple_splits(
301
560
  dataset: "HafniaDataset",
@@ -323,33 +582,124 @@ class HafniaDataset:
323
582
 
324
583
  remaining_data = dataset.samples.filter(pl.col(ColumnName.SPLIT).is_in([split_name]).not_())
325
584
  new_table = pl.concat([remaining_data, dataset_split_to_be_divided.samples], how="vertical")
326
- dataset_new = dataset.update_table(new_table)
585
+ dataset_new = dataset.update_samples(new_table)
327
586
  return dataset_new
328
587
 
329
588
  def define_sample_set_by_size(dataset: "HafniaDataset", n_samples: int, seed: int = 42) -> "HafniaDataset":
589
+ """
590
+ Defines a sample set randomly by selecting 'n_samples' samples from the dataset.
591
+ """
592
+ samples = dataset.samples
593
+
594
+ # Remove any pre-existing "sample"-tags
595
+ samples = samples.with_columns(
596
+ pl.col(ColumnName.TAGS).list.eval(pl.element().filter(pl.element() != TAG_IS_SAMPLE)).alias(ColumnName.TAGS)
597
+ )
598
+
599
+ # Add "sample" to tags column for the selected samples
330
600
  is_sample_indices = Random(seed).sample(range(len(dataset)), n_samples)
331
- is_sample_column = [False for _ in range(len(dataset))]
332
- for idx in is_sample_indices:
333
- is_sample_column[idx] = True
601
+ samples = samples.with_columns(
602
+ pl.when(pl.int_range(len(samples)).is_in(is_sample_indices))
603
+ .then(pl.col(ColumnName.TAGS).list.concat(pl.lit([TAG_IS_SAMPLE])))
604
+ .otherwise(pl.col(ColumnName.TAGS))
605
+ )
606
+ return dataset.update_samples(samples)
607
+
608
+ def class_mapper(
609
+ dataset: "HafniaDataset",
610
+ class_mapping: Dict[str, str],
611
+ method: str = "strict",
612
+ primitive: Optional[Type[Primitive]] = None,
613
+ task_name: Optional[str] = None,
614
+ ) -> "HafniaDataset":
615
+ """
616
+ Map class names to new class names using a strict mapping.
617
+ A strict mapping means that all class names in the dataset must be mapped to a new class name.
618
+ If a class name is not mapped, an error is raised.
619
+
620
+ The class indices are determined by the order of appearance of the new class names in the mapping.
621
+ Duplicates in the new class names are removed, preserving the order of first appearance.
622
+
623
+ E.g.
624
+
625
+ mnist = HafniaDataset.from_name("mnist")
626
+ strict_class_mapping = {
627
+ "1 - one": "odd", # 'odd' appears first and becomes class index 0
628
+ "3 - three": "odd",
629
+ "5 - five": "odd",
630
+ "7 - seven": "odd",
631
+ "9 - nine": "odd",
632
+ "0 - zero": "even", # 'even' appears second and becomes class index 1
633
+ "2 - two": "even",
634
+ "4 - four": "even",
635
+ "6 - six": "even",
636
+ "8 - eight": "even",
637
+ }
638
+
639
+ dataset_new = class_mapper(dataset=mnist, class_mapping=strict_class_mapping)
334
640
 
335
- table = dataset.samples.with_columns(pl.Series(is_sample_column).alias("is_sample"))
336
- return dataset.update_table(table)
641
+ """
642
+ return dataset_transformations.class_mapper(
643
+ dataset=dataset,
644
+ class_mapping=class_mapping,
645
+ method=method,
646
+ primitive=primitive,
647
+ task_name=task_name,
648
+ )
649
+
650
+ def rename_task(
651
+ dataset: "HafniaDataset",
652
+ old_task_name: str,
653
+ new_task_name: str,
654
+ ) -> "HafniaDataset":
655
+ """
656
+ Rename a task in the dataset.
657
+ """
658
+ return dataset_transformations.rename_task(
659
+ dataset=dataset, old_task_name=old_task_name, new_task_name=new_task_name
660
+ )
661
+
662
+ def select_samples_by_class_name(
663
+ dataset: HafniaDataset,
664
+ name: Union[List[str], str],
665
+ task_name: Optional[str] = None,
666
+ primitive: Optional[Type[Primitive]] = None,
667
+ ) -> HafniaDataset:
668
+ """
669
+ Select samples that contain at least one annotation with the specified class name(s).
670
+ If 'task_name' and 'primitive' are not provided, the function will attempt to infer the task.
671
+ """
672
+ return dataset_transformations.select_samples_by_class_name(
673
+ dataset=dataset, name=name, task_name=task_name, primitive=primitive
674
+ )
337
675
 
338
676
  def merge(dataset0: "HafniaDataset", dataset1: "HafniaDataset") -> "HafniaDataset":
339
677
  """
340
678
  Merges two Hafnia datasets by concatenating their samples and updating the split names.
341
679
  """
342
- ## Currently, only a very naive merging is implemented.
343
- # In the future we need to verify that the class and tasks are compatible.
344
- # Do they have similar classes and tasks? What to do if they don't?
345
- # For now, we just concatenate the samples and keep the split names as they are.
346
- merged_samples = pl.concat([dataset0.samples, dataset1.samples], how="vertical")
347
- return dataset0.update_table(merged_samples)
348
680
 
349
- # Dataset stats
350
- split_counts = dataset_stats.split_counts
681
+ # Merges dataset info and checks for compatibility
682
+ merged_info = DatasetInfo.merge(dataset0.info, dataset1.info)
683
+
684
+ # Merges samples tables (removes incompatible columns)
685
+ merged_samples = table_transformations.merge_samples(samples0=dataset0.samples, samples1=dataset1.samples)
686
+
687
+ # Check if primitives have been removed during the merge_samples
688
+ for task in copy.deepcopy(merged_info.tasks):
689
+ if task.primitive.column_name() not in merged_samples.columns:
690
+ user_logger.warning(
691
+ f"Task '{task.name}' with primitive '{task.primitive.__name__}' has been removed during the merge. "
692
+ "This happens if the two datasets do not have the same primitives."
693
+ )
694
+ merged_info = merged_info.replace_task(old_task=task, new_task=None)
695
+
696
+ return HafniaDataset(info=merged_info, samples=merged_samples)
351
697
 
352
698
  def as_dict_dataset_splits(self) -> Dict[str, "HafniaDataset"]:
699
+ """
700
+ Splits the dataset into multiple datasets based on the 'split' column.
701
+ Returns a dictionary with split names as keys and HafniaDataset objects as values.
702
+ """
353
703
  if ColumnName.SPLIT not in self.samples.columns:
354
704
  raise ValueError(f"Dataset must contain a '{ColumnName.SPLIT}' column.")
355
705
 
@@ -360,10 +710,22 @@ class HafniaDataset:
360
710
  return splits
361
711
 
362
712
  def create_sample_dataset(self) -> "HafniaDataset":
363
- if ColumnName.IS_SAMPLE not in self.samples.columns:
364
- raise ValueError(f"Dataset must contain an '{ColumnName.IS_SAMPLE}' column.")
365
- table = self.samples.filter(pl.col(ColumnName.IS_SAMPLE))
366
- return self.update_table(table)
713
+ # Backwards compatibility. Remove in future versions when dataset have been updated
714
+ if "is_sample" in self.samples.columns:
715
+ user_logger.warning(
716
+ "'is_sample' column found in the dataset. This column is deprecated and will be removed in future versions. "
717
+ "Please use the 'tags' column with the tag 'sample' instead."
718
+ )
719
+ table = self.samples.filter(pl.col("is_sample") == True) # noqa: E712
720
+ return self.update_samples(table)
721
+
722
+ if ColumnName.TAGS not in self.samples.columns:
723
+ raise ValueError(f"Dataset must contain an '{ColumnName.TAGS}' column.")
724
+
725
+ table = self.samples.filter(
726
+ pl.col(ColumnName.TAGS).list.eval(pl.element().filter(pl.element() == TAG_IS_SAMPLE)).list.len() > 0
727
+ )
728
+ return self.update_samples(table)
367
729
 
368
730
  def create_split_dataset(self, split_name: Union[str | List[str]]) -> "HafniaDataset":
369
731
  if isinstance(split_name, str):
@@ -376,16 +738,12 @@ class HafniaDataset:
376
738
  raise ValueError(f"Invalid split name: {split_name}. Valid splits are: {SplitName.valid_splits()}")
377
739
 
378
740
  filtered_dataset = self.samples.filter(pl.col(ColumnName.SPLIT).is_in(split_names))
379
- return self.update_table(filtered_dataset)
741
+ return self.update_samples(filtered_dataset)
380
742
 
381
- def get_task_by_name(self, task_name: str) -> TaskInfo:
382
- for task in self.info.tasks:
383
- if task.name == task_name:
384
- return task
385
- raise ValueError(f"Task with name {task_name} not found in dataset info.")
386
-
387
- def update_table(self, table: pl.DataFrame) -> "HafniaDataset":
388
- return HafniaDataset(info=self.info.model_copy(), samples=table)
743
+ def update_samples(self, table: pl.DataFrame) -> "HafniaDataset":
744
+ dataset = HafniaDataset(info=self.info.model_copy(deep=True), samples=table)
745
+ dataset.check_dataset_tasks()
746
+ return dataset
389
747
 
390
748
  @staticmethod
391
749
  def check_dataset_path(path_dataset: Path, raise_error: bool = True) -> bool:
@@ -411,7 +769,10 @@ class HafniaDataset:
411
769
 
412
770
  return True
413
771
 
414
- def write(self, path_folder: Path, add_version: bool = False) -> None:
772
+ def copy(self) -> "HafniaDataset":
773
+ return HafniaDataset(info=self.info.model_copy(deep=True), samples=self.samples.clone())
774
+
775
+ def write(self, path_folder: Path, add_version: bool = False, drop_null_cols: bool = True) -> None:
415
776
  user_logger.info(f"Writing dataset to {path_folder}...")
416
777
  if not path_folder.exists():
417
778
  path_folder.mkdir(parents=True)
@@ -424,6 +785,10 @@ class HafniaDataset:
424
785
  )
425
786
  new_relative_paths.append(str(new_path.relative_to(path_folder)))
426
787
  table = self.samples.with_columns(pl.Series(new_relative_paths).alias("file_name"))
788
+
789
+ if drop_null_cols: # Drops all unused/Null columns
790
+ table = table.drop(pl.selectors.by_dtype(pl.Null))
791
+
427
792
  table.write_ndjson(path_folder / FILENAME_ANNOTATIONS_JSONL) # Json for readability
428
793
  table.write_parquet(path_folder / FILENAME_ANNOTATIONS_PARQUET) # Parquet for speed
429
794
  self.info.write_json(path_folder / FILENAME_DATASET_INFO)
@@ -448,51 +813,10 @@ class HafniaDataset:
448
813
  return False
449
814
  return True
450
815
 
451
- def print_stats(self) -> None:
452
- table_base = Table(title="Dataset Statistics", show_lines=True, box=rich.box.SIMPLE)
453
- table_base.add_column("Property", style="cyan")
454
- table_base.add_column("Value")
455
- table_base.add_row("Dataset Name", self.info.dataset_name)
456
- table_base.add_row("Version", self.info.version)
457
- table_base.add_row("Number of samples", str(len(self.samples)))
458
- rprint(table_base)
459
- rprint(self.info.tasks)
460
-
461
- splits_sets = {
462
- "All": SplitName.valid_splits(),
463
- "Train": [SplitName.TRAIN],
464
- "Validation": [SplitName.VAL],
465
- "Test": [SplitName.TEST],
466
- }
467
- rows = []
468
- for split_name, splits in splits_sets.items():
469
- dataset_split = self.create_split_dataset(splits)
470
- table = dataset_split.samples
471
- row = {}
472
- row["Split"] = split_name
473
- row["Sample "] = str(len(table))
474
- for PrimitiveType in PRIMITIVE_TYPES:
475
- column_name = PrimitiveType.column_name()
476
- objects_df = create_primitive_table(table, PrimitiveType=PrimitiveType, keep_sample_data=False)
477
- if objects_df is None:
478
- continue
479
- for (task_name,), object_group in objects_df.group_by(FieldName.TASK_NAME):
480
- count = len(object_group[FieldName.CLASS_NAME])
481
- row[f"{PrimitiveType.__name__}\n{task_name}"] = str(count)
482
- rows.append(row)
483
-
484
- rich_table = Table(title="Dataset Statistics", show_lines=True, box=rich.box.SIMPLE)
485
- for i_row, row in enumerate(rows):
486
- if i_row == 0:
487
- for column_name in row.keys():
488
- rich_table.add_column(column_name, justify="left", style="cyan")
489
- rich_table.add_row(*[str(value) for value in row.values()])
490
- rprint(rich_table)
491
-
492
816
 
493
817
  def check_hafnia_dataset_from_path(path_dataset: Path) -> None:
494
818
  dataset = HafniaDataset.from_path(path_dataset, check_for_images=True)
495
- check_hafnia_dataset(dataset)
819
+ dataset.check_dataset()
496
820
 
497
821
 
498
822
  def get_or_create_dataset_path_from_recipe(
@@ -522,89 +846,3 @@ def get_or_create_dataset_path_from_recipe(
522
846
  dataset.write(path_dataset)
523
847
 
524
848
  return path_dataset
525
-
526
-
527
- def check_hafnia_dataset(dataset: HafniaDataset):
528
- user_logger.info("Checking Hafnia dataset...")
529
- assert isinstance(dataset.info.version, str) and len(dataset.info.version) > 0
530
- assert isinstance(dataset.info.dataset_name, str) and len(dataset.info.dataset_name) > 0
531
-
532
- is_sample_list = set(dataset.samples.select(pl.col(ColumnName.IS_SAMPLE)).unique().to_series().to_list())
533
- if True not in is_sample_list:
534
- raise ValueError(f"The dataset should contain '{ColumnName.IS_SAMPLE}=True' samples")
535
-
536
- actual_splits = dataset.samples.select(pl.col(ColumnName.SPLIT)).unique().to_series().to_list()
537
- expected_splits = SplitName.valid_splits()
538
- if set(actual_splits) != set(expected_splits):
539
- raise ValueError(f"Expected all splits '{expected_splits}' in dataset, but got '{actual_splits}'. ")
540
-
541
- expected_tasks = dataset.info.tasks
542
- for task in expected_tasks:
543
- primitive = task.primitive.__name__
544
- column_name = task.primitive.column_name()
545
- primitive_column = dataset.samples[column_name]
546
- # msg_something_wrong = f"Something is wrong with the '{primtive_name}' task '{task.name}' in dataset '{dataset.name}'. "
547
- msg_something_wrong = (
548
- f"Something is wrong with the defined tasks ('info.tasks') in dataset '{dataset.info.dataset_name}'. \n"
549
- f"For '{primitive=}' and '{task.name=}' "
550
- )
551
- if primitive_column.dtype == pl.Null:
552
- raise ValueError(msg_something_wrong + "the column is 'Null'. Please check the dataset.")
553
-
554
- primitive_table = primitive_column.explode().struct.unnest().filter(pl.col(FieldName.TASK_NAME) == task.name)
555
- if primitive_table.is_empty():
556
- raise ValueError(
557
- msg_something_wrong
558
- + f"the column '{column_name}' has no {task.name=} objects. Please check the dataset."
559
- )
560
-
561
- actual_classes = set(primitive_table[FieldName.CLASS_NAME].unique().to_list())
562
- if task.class_names is None:
563
- raise ValueError(
564
- msg_something_wrong
565
- + f"the column '{column_name}' with {task.name=} has no defined classes. Please check the dataset."
566
- )
567
- defined_classes = set(task.class_names)
568
-
569
- if not actual_classes.issubset(defined_classes):
570
- raise ValueError(
571
- msg_something_wrong
572
- + f"the column '{column_name}' with {task.name=} we expected the actual classes in the dataset to \n"
573
- f"to be a subset of the defined classes\n\t{actual_classes=} \n\t{defined_classes=}."
574
- )
575
- # Check class_indices
576
- mapped_indices = primitive_table[FieldName.CLASS_NAME].map_elements(
577
- lambda x: task.class_names.index(x), return_dtype=pl.Int64
578
- )
579
- table_indices = primitive_table[FieldName.CLASS_IDX]
580
-
581
- error_msg = msg_something_wrong + (
582
- f"class indices in '{FieldName.CLASS_IDX}' column does not match classes ordering in 'task.class_names'"
583
- )
584
- assert mapped_indices.equals(table_indices), error_msg
585
-
586
- distribution = dataset.info.distributions or []
587
- distribution_names = [task.name for task in distribution]
588
- # Check that tasks found in the 'dataset.table' matches the tasks defined in 'dataset.info.tasks'
589
- for PrimitiveType in PRIMITIVE_TYPES:
590
- column_name = PrimitiveType.column_name()
591
- if column_name not in dataset.samples.columns:
592
- continue
593
- objects_df = create_primitive_table(dataset.samples, PrimitiveType=PrimitiveType, keep_sample_data=False)
594
- if objects_df is None:
595
- continue
596
- for (task_name,), object_group in objects_df.group_by(FieldName.TASK_NAME):
597
- has_task = any([t for t in expected_tasks if t.name == task_name and t.primitive == PrimitiveType])
598
- if has_task:
599
- continue
600
- if task_name in distribution_names:
601
- continue
602
- class_names = object_group[FieldName.CLASS_NAME].unique().to_list()
603
- raise ValueError(
604
- f"Task name '{task_name}' for the '{PrimitiveType.__name__}' primitive is missing in "
605
- f"'dataset.info.tasks' for dataset '{task_name}'. Missing task has the following "
606
- f"classes: {class_names}. "
607
- )
608
-
609
- for sample_dict in tqdm(dataset, desc="Checking samples in dataset"):
610
- sample = Sample(**sample_dict) # Checks format of all samples with pydantic validation # noqa: F841