hafnia 0.3.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. cli/__main__.py +3 -1
  2. cli/config.py +43 -3
  3. cli/keychain.py +88 -0
  4. cli/profile_cmds.py +5 -2
  5. hafnia/__init__.py +1 -1
  6. hafnia/dataset/dataset_helpers.py +9 -2
  7. hafnia/dataset/dataset_names.py +130 -16
  8. hafnia/dataset/dataset_recipe/dataset_recipe.py +49 -37
  9. hafnia/dataset/dataset_recipe/recipe_transforms.py +18 -2
  10. hafnia/dataset/dataset_upload_helper.py +83 -22
  11. hafnia/dataset/format_conversions/format_image_classification_folder.py +110 -0
  12. hafnia/dataset/format_conversions/format_yolo.py +164 -0
  13. hafnia/dataset/format_conversions/torchvision_datasets.py +287 -0
  14. hafnia/dataset/hafnia_dataset.py +396 -96
  15. hafnia/dataset/operations/dataset_stats.py +84 -73
  16. hafnia/dataset/operations/dataset_transformations.py +116 -47
  17. hafnia/dataset/operations/table_transformations.py +135 -17
  18. hafnia/dataset/primitives/bbox.py +25 -14
  19. hafnia/dataset/primitives/bitmask.py +22 -15
  20. hafnia/dataset/primitives/classification.py +16 -8
  21. hafnia/dataset/primitives/point.py +7 -3
  22. hafnia/dataset/primitives/polygon.py +15 -10
  23. hafnia/dataset/primitives/primitive.py +1 -1
  24. hafnia/dataset/primitives/segmentation.py +12 -9
  25. hafnia/experiment/hafnia_logger.py +0 -9
  26. hafnia/platform/dataset_recipe.py +7 -2
  27. hafnia/platform/datasets.py +5 -9
  28. hafnia/platform/download.py +24 -90
  29. hafnia/torch_helpers.py +12 -12
  30. hafnia/utils.py +17 -0
  31. hafnia/visualizations/image_visualizations.py +3 -1
  32. {hafnia-0.3.0.dist-info → hafnia-0.4.1.dist-info}/METADATA +11 -9
  33. hafnia-0.4.1.dist-info/RECORD +57 -0
  34. hafnia-0.3.0.dist-info/RECORD +0 -53
  35. {hafnia-0.3.0.dist-info → hafnia-0.4.1.dist-info}/WHEEL +0 -0
  36. {hafnia-0.3.0.dist-info → hafnia-0.4.1.dist-info}/entry_points.txt +0 -0
  37. {hafnia-0.3.0.dist-info → hafnia-0.4.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,14 +1,14 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import TYPE_CHECKING, Dict, Optional, Type
3
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
4
4
 
5
5
  import polars as pl
6
6
  import rich
7
7
  from rich import print as rprint
8
+ from rich.progress import track
8
9
  from rich.table import Table
9
- from tqdm import tqdm
10
10
 
11
- from hafnia.dataset.dataset_names import ColumnName, FieldName, SplitName
11
+ from hafnia.dataset.dataset_names import PrimitiveField, SampleField, SplitName
12
12
  from hafnia.dataset.operations.table_transformations import create_primitive_table
13
13
  from hafnia.dataset.primitives import PRIMITIVE_TYPES
14
14
  from hafnia.log import user_logger
@@ -18,14 +18,14 @@ if TYPE_CHECKING: # Using 'TYPE_CHECKING' to avoid circular imports during type
18
18
  from hafnia.dataset.primitives.primitive import Primitive
19
19
 
20
20
 
21
- def split_counts(dataset: HafniaDataset) -> Dict[str, int]:
21
+ def calculate_split_counts(dataset: HafniaDataset) -> Dict[str, int]:
22
22
  """
23
23
  Returns a dictionary with the counts of samples in each split of the dataset.
24
24
  """
25
- return dict(dataset.samples[ColumnName.SPLIT].value_counts().iter_rows())
25
+ return dict(dataset.samples[SampleField.SPLIT].value_counts().iter_rows())
26
26
 
27
27
 
28
- def class_counts_for_task(
28
+ def calculate_task_class_counts(
29
29
  dataset: HafniaDataset,
30
30
  primitive: Optional[Type[Primitive]] = None,
31
31
  task_name: Optional[str] = None,
@@ -53,7 +53,7 @@ def class_counts_for_task(
53
53
  dataset.samples[task.primitive.column_name()]
54
54
  .explode()
55
55
  .struct.unnest()
56
- .filter(pl.col(FieldName.TASK_NAME) == task.name)[FieldName.CLASS_NAME]
56
+ .filter(pl.col(PrimitiveField.TASK_NAME) == task.name)[PrimitiveField.CLASS_NAME]
57
57
  .value_counts()
58
58
  )
59
59
 
@@ -65,7 +65,7 @@ def class_counts_for_task(
65
65
  return class_counts
66
66
 
67
67
 
68
- def class_counts_all(dataset: HafniaDataset) -> Dict[str, int]:
68
+ def calculate_class_counts(dataset: HafniaDataset) -> List[Dict[str, Any]]:
69
69
  """
70
70
  Get class counts for all tasks in the dataset.
71
71
  The counts are returned as a dictionary where keys are in the format
@@ -74,25 +74,59 @@ def class_counts_all(dataset: HafniaDataset) -> Dict[str, int]:
74
74
  Example:
75
75
  >>> counts = dataset.class_counts_all()
76
76
  >>> print(counts)
77
- {
78
- objects/bboxes/car: 500
79
- objects/bboxes/person: 0
80
- classifications/weather/sunny: 300
81
- classifications/weather/rainy: 0
82
- ...
83
- }
77
+ [
78
+ {'Primitive': 'Bbox', 'Task Name': 'detection', 'Class Name': 'car', 'Count': 500},
79
+ {'Primitive': 'Bbox', 'Task Name': 'detection', 'Class Name': 'bus', 'Count': 100},
80
+ {'Primitive': 'Classification', 'Task Name': 'scene', 'Class Name': 'indoor', 'Count': 300},
81
+ {'Primitive': 'Classification', 'Task Name': 'scene', 'Class Name': 'outdoor', 'Count': 700},
82
+ ]
84
83
  """
85
- class_counts = {}
84
+ count_info = []
86
85
  for task in dataset.info.tasks:
87
- if task.class_names is None:
88
- raise ValueError(f"Task '{task.name}' does not have class names defined.")
89
- class_counts_task = dataset.class_counts_for_task(primitive=task.primitive, task_name=task.name)
86
+ class_name_counts = dataset.calculate_task_class_counts(task_name=task.name)
87
+ for name, counts in class_name_counts.items():
88
+ count_info.append(
89
+ {
90
+ "Primitive": task.primitive.__name__,
91
+ "Task Name": task.name,
92
+ "Class Name": name,
93
+ "Count": counts,
94
+ }
95
+ )
96
+ return count_info
90
97
 
91
- for class_idx, (class_name, count) in enumerate(class_counts_task.items()):
92
- count_name = f"{task.primitive.__name__}/{task.name}/{class_name}"
93
- class_counts[count_name] = count
94
98
 
95
- return class_counts
99
+ def calculate_primitive_counts(dataset: HafniaDataset) -> Dict[str, int]:
100
+ annotation_counts = {}
101
+ for task in dataset.info.tasks:
102
+ objects = dataset.create_primitive_table(task.primitive, task_name=task.name)
103
+ name = task.primitive.__name__
104
+ if task.name != task.primitive.default_task_name():
105
+ name = f"{name}.{task.name}"
106
+ annotation_counts[name] = len(objects)
107
+ return annotation_counts
108
+
109
+
110
+ def calculate_split_counts_extended(dataset: HafniaDataset) -> List[Dict[str, Any]]:
111
+ splits_sets = {
112
+ "All": SplitName.valid_splits(),
113
+ "Train": [SplitName.TRAIN],
114
+ "Validation": [SplitName.VAL],
115
+ "Test": [SplitName.TEST],
116
+ }
117
+ rows = []
118
+ for split_name, splits in splits_sets.items():
119
+ dataset_split = dataset.create_split_dataset(splits)
120
+ table = dataset_split.samples
121
+ row: Dict[str, Any] = {}
122
+ row["Split"] = split_name
123
+ row["Samples "] = str(len(table))
124
+
125
+ primitive_counts = calculate_primitive_counts(dataset_split)
126
+ row.update(primitive_counts)
127
+ rows.append(row)
128
+
129
+ return rows
96
130
 
97
131
 
98
132
  def print_stats(dataset: HafniaDataset) -> None:
@@ -118,10 +152,13 @@ def print_class_distribution(dataset: HafniaDataset) -> None:
118
152
  for task in dataset.info.tasks:
119
153
  if task.class_names is None:
120
154
  raise ValueError(f"Task '{task.name}' does not have class names defined.")
121
- class_counts = dataset.class_counts_for_task(primitive=task.primitive, task_name=task.name)
155
+ class_counts = dataset.calculate_task_class_counts(primitive=task.primitive, task_name=task.name)
122
156
 
123
157
  # Print class distribution
124
- rich_table = Table(title=f"Class Count for '{task.primitive.__name__}/{task.name}'", show_lines=False)
158
+ rich_table = Table(
159
+ title=f"Class Count for '{task.primitive.__name__}/{task.name}'",
160
+ show_lines=False,
161
+ )
125
162
  rich_table.add_column("Class Name", style="cyan")
126
163
  rich_table.add_column("Class Idx", style="cyan")
127
164
  rich_table.add_column("Count", justify="right")
@@ -136,32 +173,7 @@ def print_sample_and_task_counts(dataset: HafniaDataset) -> None:
136
173
  Prints a table with sample counts and task counts for each primitive type
137
174
  in total and for each split (train, val, test).
138
175
  """
139
- from hafnia.dataset.operations.table_transformations import create_primitive_table
140
- from hafnia.dataset.primitives import PRIMITIVE_TYPES
141
-
142
- splits_sets = {
143
- "All": SplitName.valid_splits(),
144
- "Train": [SplitName.TRAIN],
145
- "Validation": [SplitName.VAL],
146
- "Test": [SplitName.TEST],
147
- }
148
- rows = []
149
- for split_name, splits in splits_sets.items():
150
- dataset_split = dataset.create_split_dataset(splits)
151
- table = dataset_split.samples
152
- row = {}
153
- row["Split"] = split_name
154
- row["Sample "] = str(len(table))
155
- for PrimitiveType in PRIMITIVE_TYPES:
156
- column_name = PrimitiveType.column_name()
157
- objects_df = create_primitive_table(table, PrimitiveType=PrimitiveType, keep_sample_data=False)
158
- if objects_df is None:
159
- continue
160
- for (task_name,), object_group in objects_df.group_by(FieldName.TASK_NAME):
161
- count = len(object_group[FieldName.CLASS_NAME])
162
- row[f"{PrimitiveType.__name__}\n{task_name}"] = str(count)
163
- rows.append(row)
164
-
176
+ rows = calculate_split_counts_extended(dataset)
165
177
  rich_table = Table(title="Dataset Statistics", show_lines=True, box=rich.box.SIMPLE)
166
178
  for i_row, row in enumerate(rows):
167
179
  if i_row == 0:
@@ -171,7 +183,7 @@ def print_sample_and_task_counts(dataset: HafniaDataset) -> None:
171
183
  rprint(rich_table)
172
184
 
173
185
 
174
- def check_dataset(dataset: HafniaDataset):
186
+ def check_dataset(dataset: HafniaDataset, check_splits: bool = True):
175
187
  """
176
188
  Performs various checks on the dataset to ensure its integrity and consistency.
177
189
  Raises errors if any issues are found.
@@ -179,24 +191,23 @@ def check_dataset(dataset: HafniaDataset):
179
191
  from hafnia.dataset.hafnia_dataset import Sample
180
192
 
181
193
  user_logger.info("Checking Hafnia dataset...")
182
- assert isinstance(dataset.info.version, str) and len(dataset.info.version) > 0
183
194
  assert isinstance(dataset.info.dataset_name, str) and len(dataset.info.dataset_name) > 0
184
195
 
185
- sample_dataset = dataset.create_sample_dataset()
186
- if len(sample_dataset) == 0:
187
- raise ValueError("The dataset does not include a sample dataset")
196
+ if check_splits:
197
+ sample_dataset = dataset.create_sample_dataset()
198
+ if len(sample_dataset) == 0:
199
+ raise ValueError("The dataset does not include a sample dataset")
200
+
201
+ actual_splits = dataset.samples.select(pl.col(SampleField.SPLIT)).unique().to_series().to_list()
202
+ required_splits = SplitName.valid_splits()
188
203
 
189
- actual_splits = dataset.samples.select(pl.col(ColumnName.SPLIT)).unique().to_series().to_list()
190
- expected_splits = SplitName.valid_splits()
191
- if set(actual_splits) != set(expected_splits):
192
- raise ValueError(f"Expected all splits '{expected_splits}' in dataset, but got '{actual_splits}'. ")
204
+ if not set(required_splits).issubset(set(actual_splits)):
205
+ raise ValueError(f"Expected all splits '{required_splits}' in dataset, but got '{actual_splits}'. ")
193
206
 
194
207
  dataset.check_dataset_tasks()
195
208
 
196
209
  expected_tasks = dataset.info.tasks
197
- distribution = dataset.info.distributions or []
198
- distribution_names = [task.name for task in distribution]
199
- # Check that tasks found in the 'dataset.table' matches the tasks defined in 'dataset.info.tasks'
210
+ # Check that tasks found in the 'dataset.samples' matches the tasks defined in 'dataset.info.tasks'
200
211
  for PrimitiveType in PRIMITIVE_TYPES:
201
212
  column_name = PrimitiveType.column_name()
202
213
  if column_name not in dataset.samples.columns:
@@ -204,18 +215,18 @@ def check_dataset(dataset: HafniaDataset):
204
215
  objects_df = create_primitive_table(dataset.samples, PrimitiveType=PrimitiveType, keep_sample_data=False)
205
216
  if objects_df is None:
206
217
  continue
207
- for (task_name,), object_group in objects_df.group_by(FieldName.TASK_NAME):
218
+ for (task_name,), object_group in objects_df.group_by(PrimitiveField.TASK_NAME):
208
219
  has_task = any([t for t in expected_tasks if t.name == task_name and t.primitive == PrimitiveType])
209
- if has_task or (task_name in distribution_names):
220
+ if has_task:
210
221
  continue
211
- class_names = object_group[FieldName.CLASS_NAME].unique().to_list()
222
+ class_names = object_group[PrimitiveField.CLASS_NAME].unique().to_list()
212
223
  raise ValueError(
213
224
  f"Task name '{task_name}' for the '{PrimitiveType.__name__}' primitive is missing in "
214
- f"'dataset.info.tasks' for dataset '{task_name}'. Missing task has the following "
225
+ f"'dataset.info.tasks' for dataset '{dataset.info.dataset_name}'. Missing task has the following "
215
226
  f"classes: {class_names}. "
216
227
  )
217
228
 
218
- for sample_dict in tqdm(dataset, desc="Checking samples in dataset"):
229
+ for sample_dict in track(dataset, description="Checking samples in dataset"):
219
230
  sample = Sample(**sample_dict) # noqa: F841
220
231
 
221
232
 
@@ -238,7 +249,7 @@ def check_dataset_tasks(dataset: HafniaDataset):
238
249
 
239
250
  if len(dataset) > 0: # Check only performed for non-empty datasets
240
251
  primitive_table = (
241
- primitive_column.explode().struct.unnest().filter(pl.col(FieldName.TASK_NAME) == task.name)
252
+ primitive_column.explode().struct.unnest().filter(pl.col(PrimitiveField.TASK_NAME) == task.name)
242
253
  )
243
254
  if primitive_table.is_empty():
244
255
  raise ValueError(
@@ -246,7 +257,7 @@ def check_dataset_tasks(dataset: HafniaDataset):
246
257
  + f"the column '{column_name}' has no {task.name=} objects. Please check the dataset."
247
258
  )
248
259
 
249
- actual_classes = set(primitive_table[FieldName.CLASS_NAME].unique().to_list())
260
+ actual_classes = set(primitive_table[PrimitiveField.CLASS_NAME].unique().to_list())
250
261
  if task.class_names is None:
251
262
  raise ValueError(
252
263
  msg_something_wrong
@@ -261,12 +272,12 @@ def check_dataset_tasks(dataset: HafniaDataset):
261
272
  f"to be a subset of the defined classes\n\t{actual_classes=} \n\t{defined_classes=}."
262
273
  )
263
274
  # Check class_indices
264
- mapped_indices = primitive_table[FieldName.CLASS_NAME].map_elements(
275
+ mapped_indices = primitive_table[PrimitiveField.CLASS_NAME].map_elements(
265
276
  lambda x: task.class_names.index(x), return_dtype=pl.Int64
266
277
  )
267
- table_indices = primitive_table[FieldName.CLASS_IDX]
278
+ table_indices = primitive_table[PrimitiveField.CLASS_IDX]
268
279
 
269
280
  error_msg = msg_something_wrong + (
270
- f"class indices in '{FieldName.CLASS_IDX}' column does not match classes ordering in 'task.class_names'"
281
+ f"class indices in '{PrimitiveField.CLASS_IDX}' column does not match classes ordering in 'task.class_names'"
271
282
  )
272
283
  assert mapped_indices.equals(table_indices), error_msg
@@ -31,25 +31,32 @@ that the signatures match.
31
31
 
32
32
  import json
33
33
  import re
34
+ import shutil
34
35
  import textwrap
35
36
  from pathlib import Path
36
- from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple, Type, Union
37
+ from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Type, Union
37
38
 
38
39
  import cv2
39
40
  import more_itertools
40
41
  import numpy as np
41
42
  import polars as pl
42
- from PIL import Image
43
- from tqdm import tqdm
43
+ from rich.progress import track
44
44
 
45
45
  from hafnia.dataset import dataset_helpers
46
- from hafnia.dataset.dataset_names import OPS_REMOVE_CLASS, FieldName
46
+ from hafnia.dataset.dataset_names import (
47
+ OPS_REMOVE_CLASS,
48
+ PrimitiveField,
49
+ SampleField,
50
+ StorageFormat,
51
+ )
52
+ from hafnia.dataset.operations.table_transformations import update_class_indices
47
53
  from hafnia.dataset.primitives import get_primitive_type_from_string
48
54
  from hafnia.dataset.primitives.primitive import Primitive
55
+ from hafnia.log import user_logger
49
56
  from hafnia.utils import remove_duplicates_preserve_order
50
57
 
51
58
  if TYPE_CHECKING: # Using 'TYPE_CHECKING' to avoid circular imports during type checking
52
- from hafnia.dataset.hafnia_dataset import HafniaDataset, TaskInfo
59
+ from hafnia.dataset.hafnia_dataset import HafniaDataset, Sample, TaskInfo
53
60
 
54
61
 
55
62
  ### Image transformations ###
@@ -57,7 +64,7 @@ class AnonymizeByPixelation:
57
64
  def __init__(self, resize_factor: float = 0.10):
58
65
  self.resize_factor = resize_factor
59
66
 
60
- def __call__(self, frame: np.ndarray) -> np.ndarray:
67
+ def __call__(self, frame: np.ndarray, sample: "Sample") -> np.ndarray:
61
68
  org_size = frame.shape[:2]
62
69
  frame = cv2.resize(frame, (0, 0), fx=self.resize_factor, fy=self.resize_factor)
63
70
  frame = cv2.resize(frame, org_size[::-1], interpolation=cv2.INTER_NEAREST)
@@ -66,30 +73,100 @@ class AnonymizeByPixelation:
66
73
 
67
74
  def transform_images(
68
75
  dataset: "HafniaDataset",
69
- transform: Callable[[np.ndarray], np.ndarray],
76
+ transform: Callable[[np.ndarray, "Sample"], np.ndarray],
70
77
  path_output: Path,
78
+ description: str = "Transform images",
71
79
  ) -> "HafniaDataset":
80
+ from hafnia.dataset.hafnia_dataset import Sample
81
+
72
82
  new_paths = []
73
83
  path_image_folder = path_output / "data"
74
84
  path_image_folder.mkdir(parents=True, exist_ok=True)
75
85
 
76
- for org_path in tqdm(dataset.samples["file_name"].to_list(), desc="Transform images"):
77
- org_path = Path(org_path)
78
- if not org_path.exists():
79
- raise FileNotFoundError(f"File {org_path} does not exist in the dataset.")
80
-
81
- image = np.array(Image.open(org_path))
82
- image_transformed = transform(image)
86
+ for sample_dict in track(dataset, description=description):
87
+ sample = Sample(**sample_dict)
88
+ image = sample.read_image()
89
+ image_transformed = transform(image, sample)
83
90
  new_path = dataset_helpers.save_image_with_hash_name(image_transformed, path_image_folder)
84
91
 
85
92
  if not new_path.exists():
86
93
  raise FileNotFoundError(f"Transformed file {new_path} does not exist in the dataset.")
87
94
  new_paths.append(str(new_path))
88
95
 
89
- table = dataset.samples.with_columns(pl.Series(new_paths).alias("file_name"))
96
+ table = dataset.samples.with_columns(pl.Series(new_paths).alias(SampleField.FILE_PATH))
90
97
  return dataset.update_samples(table)
91
98
 
92
99
 
100
+ def convert_to_image_storage_format(
101
+ dataset: "HafniaDataset",
102
+ path_output_folder: Path,
103
+ reextract_frames: bool,
104
+ image_format: str = "png",
105
+ transform: Optional[Callable[[np.ndarray, "Sample"], np.ndarray]] = None,
106
+ ) -> "HafniaDataset":
107
+ """
108
+ Convert a video-based dataset ("storage_format" == "video", FieldName.STORAGE_FORMAT == StorageFormat.VIDEO)
109
+ to an image-based dataset by extracting frames.
110
+ """
111
+ from hafnia.dataset.hafnia_dataset import HafniaDataset, Sample
112
+
113
+ path_images = path_output_folder / "data"
114
+ path_images.mkdir(parents=True, exist_ok=True)
115
+
116
+ # Only video format dataset samples are processed
117
+ video_based_samples = dataset.samples.filter(pl.col(SampleField.STORAGE_FORMAT) == StorageFormat.VIDEO)
118
+
119
+ if video_based_samples.is_empty():
120
+ user_logger.info("Dataset has no video-based samples. Returning dataset unchanged.")
121
+ return dataset
122
+
123
+ update_list = []
124
+ for (path_video,), video_samples in video_based_samples.group_by(SampleField.FILE_PATH):
125
+ assert Path(path_video).exists(), (
126
+ f"'{path_video}' not found. We expect the video to be downloaded to '{path_output_folder}'"
127
+ )
128
+ video = cv2.VideoCapture(str(path_video))
129
+
130
+ video_samples = video_samples.sort(SampleField.COLLECTION_INDEX)
131
+ for sample_dict in track(
132
+ video_samples.iter_rows(named=True),
133
+ total=video_samples.height,
134
+ description=f"Extracting frames from '{Path(path_video).name}'",
135
+ ):
136
+ frame_number = sample_dict[SampleField.COLLECTION_INDEX]
137
+ image_name = f"{Path(path_video).stem}_F{frame_number:06d}.{image_format}"
138
+ path_image = path_images / image_name
139
+
140
+ update_list.append(
141
+ {
142
+ SampleField.SAMPLE_INDEX: sample_dict[SampleField.SAMPLE_INDEX],
143
+ SampleField.COLLECTION_ID: sample_dict[SampleField.COLLECTION_ID],
144
+ SampleField.COLLECTION_INDEX: frame_number,
145
+ SampleField.FILE_PATH: path_image.as_posix(),
146
+ SampleField.STORAGE_FORMAT: StorageFormat.IMAGE,
147
+ }
148
+ )
149
+ if reextract_frames:
150
+ shutil.rmtree(path_image, ignore_errors=True)
151
+ if path_image.exists():
152
+ continue
153
+
154
+ video.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
155
+ ret, frame_org = video.read()
156
+ if not ret:
157
+ raise RuntimeError(f"Could not read frame {frame_number} from video '{path_video}'")
158
+
159
+ if transform is not None:
160
+ frame_org = transform(frame_org, Sample(**sample_dict))
161
+
162
+ cv2.imwrite(str(path_image), frame_org)
163
+ df_updates = pl.DataFrame(update_list)
164
+ samples_as_images = dataset.samples.update(df_updates, on=[SampleField.COLLECTION_ID, SampleField.COLLECTION_INDEX])
165
+ hafnia_dataset = HafniaDataset(samples=samples_as_images, info=dataset.info)
166
+
167
+ return hafnia_dataset
168
+
169
+
93
170
  def get_task_info_from_task_name_and_primitive(
94
171
  tasks: List["TaskInfo"],
95
172
  task_name: Optional[str] = None,
@@ -156,13 +233,16 @@ def get_task_info_from_task_name_and_primitive(
156
233
 
157
234
  def class_mapper(
158
235
  dataset: "HafniaDataset",
159
- class_mapping: Dict[str, str],
236
+ class_mapping: Union[Dict[str, str], List[Tuple[str, str]]],
160
237
  method: str = "strict",
161
238
  primitive: Optional[Type[Primitive]] = None,
162
239
  task_name: Optional[str] = None,
163
240
  ) -> "HafniaDataset":
164
241
  from hafnia.dataset.hafnia_dataset import HafniaDataset
165
242
 
243
+ if isinstance(class_mapping, list):
244
+ class_mapping = dict(class_mapping)
245
+
166
246
  allowed_methods = ("strict", "remove_undefined", "keep_undefined")
167
247
  if method not in allowed_methods:
168
248
  raise ValueError(f"Method '{method}' is not recognized. Allowed methods are: {allowed_methods}")
@@ -170,7 +250,7 @@ def class_mapper(
170
250
  task = dataset.info.get_task_by_task_name_and_primitive(task_name=task_name, primitive=primitive)
171
251
  current_names = task.class_names or []
172
252
 
173
- # Expand wildcard mappings
253
+ # Expand wildcard mappings e.g. {"Vehicle.*": "Vehicle"} to {"Vehicle.Car": "Vehicle", "Vehicle.Bus": "Vehicle"}
174
254
  class_mapping = expand_class_mapping(class_mapping, current_names)
175
255
 
176
256
  non_existing_mapping_names = set(class_mapping) - set(current_names)
@@ -213,31 +293,16 @@ def class_mapper(
213
293
  if OPS_REMOVE_CLASS in new_class_names:
214
294
  # Move __REMOVE__ to the end of the list if it exists
215
295
  new_class_names.append(new_class_names.pop(new_class_names.index(OPS_REMOVE_CLASS)))
216
- name_2_idx_mapping: Dict[str, int] = {name: idx for idx, name in enumerate(new_class_names)}
217
296
 
218
297
  samples = dataset.samples
219
298
  samples_updated = samples.with_columns(
220
299
  pl.col(task.primitive.column_name())
221
300
  .list.eval(
222
301
  pl.element().struct.with_fields(
223
- pl.when(pl.field(FieldName.TASK_NAME) == task.name)
224
- .then(pl.field(FieldName.CLASS_NAME).replace_strict(class_mapping))
225
- .otherwise(pl.field(FieldName.CLASS_NAME))
226
- .alias(FieldName.CLASS_NAME)
227
- )
228
- )
229
- .alias(task.primitive.column_name())
230
- )
231
-
232
- # Update class indices too
233
- samples_updated = samples_updated.with_columns(
234
- pl.col(task.primitive.column_name())
235
- .list.eval(
236
- pl.element().struct.with_fields(
237
- pl.when(pl.field(FieldName.TASK_NAME) == task.name)
238
- .then(pl.field(FieldName.CLASS_NAME).replace_strict(name_2_idx_mapping))
239
- .otherwise(pl.field(FieldName.CLASS_IDX))
240
- .alias(FieldName.CLASS_IDX)
302
+ pl.when(pl.field(PrimitiveField.TASK_NAME) == task.name)
303
+ .then(pl.field(PrimitiveField.CLASS_NAME).replace_strict(class_mapping, default="Missing"))
304
+ .otherwise(pl.field(PrimitiveField.CLASS_NAME))
305
+ .alias(PrimitiveField.CLASS_NAME)
241
306
  )
242
307
  )
243
308
  .alias(task.primitive.column_name())
@@ -246,7 +311,7 @@ def class_mapper(
246
311
  if OPS_REMOVE_CLASS in new_class_names: # Remove class_names that are mapped to REMOVE_CLASS
247
312
  samples_updated = samples_updated.with_columns(
248
313
  pl.col(task.primitive.column_name())
249
- .list.filter(pl.element().struct.field(FieldName.CLASS_NAME) != OPS_REMOVE_CLASS)
314
+ .list.filter(pl.element().struct.field(PrimitiveField.CLASS_NAME) != OPS_REMOVE_CLASS)
250
315
  .alias(task.primitive.column_name())
251
316
  )
252
317
 
@@ -255,6 +320,10 @@ def class_mapper(
255
320
  new_task = task.model_copy(deep=True)
256
321
  new_task.class_names = new_class_names
257
322
  dataset_info = dataset.info.replace_task(old_task=task, new_task=new_task)
323
+
324
+ # Update class indices to match new class names
325
+ samples_updated = update_class_indices(samples_updated, new_task)
326
+
258
327
  return HafniaDataset(info=dataset_info, samples=samples_updated)
259
328
 
260
329
 
@@ -313,7 +382,7 @@ def rename_task(
313
382
  pl.col(old_task.primitive.column_name())
314
383
  .list.eval(
315
384
  pl.element().struct.with_fields(
316
- pl.field(FieldName.TASK_NAME).replace(old_task.name, new_task.name).alias(FieldName.TASK_NAME)
385
+ pl.field(PrimitiveField.TASK_NAME).replace(old_task.name, new_task.name).alias(PrimitiveField.TASK_NAME)
317
386
  )
318
387
  )
319
388
  .alias(new_task.primitive.column_name())
@@ -339,8 +408,8 @@ def select_samples_by_class_name(
339
408
  samples = dataset.samples.filter(
340
409
  pl.col(task.primitive.column_name())
341
410
  .list.eval(
342
- pl.element().struct.field(FieldName.CLASS_NAME).is_in(class_names)
343
- & (pl.element().struct.field(FieldName.TASK_NAME) == task.name)
411
+ pl.element().struct.field(PrimitiveField.CLASS_NAME).is_in(class_names)
412
+ & (pl.element().struct.field(PrimitiveField.TASK_NAME) == task.name)
344
413
  )
345
414
  .list.any()
346
415
  )
@@ -354,14 +423,14 @@ def _validate_inputs_select_samples_by_class_name(
354
423
  name: Union[List[str], str],
355
424
  task_name: Optional[str] = None,
356
425
  primitive: Optional[Type[Primitive]] = None,
357
- ) -> Tuple["TaskInfo", Set[str]]:
426
+ ) -> Tuple["TaskInfo", List[str]]:
358
427
  if isinstance(name, str):
359
428
  name = [name]
360
- names = set(name)
429
+ names = list(name)
361
430
 
362
431
  # Check that specified names are available in at least one of the tasks
363
432
  available_names_across_tasks = set(more_itertools.flatten([t.class_names for t in dataset.info.tasks]))
364
- missing_class_names_across_tasks = names - available_names_across_tasks
433
+ missing_class_names_across_tasks = set(names) - available_names_across_tasks
365
434
  if len(missing_class_names_across_tasks) > 0:
366
435
  raise ValueError(
367
436
  f"The specified names {list(names)} have not been found in any of the tasks. "
@@ -370,15 +439,15 @@ def _validate_inputs_select_samples_by_class_name(
370
439
 
371
440
  # Auto infer task if task_name and primitive are not provided
372
441
  if task_name is None and primitive is None:
373
- tasks_with_names = [t for t in dataset.info.tasks if names.issubset(t.class_names or [])]
442
+ tasks_with_names = [t for t in dataset.info.tasks if set(names).issubset(t.class_names or [])]
374
443
  if len(tasks_with_names) == 0:
375
444
  raise ValueError(
376
- f"The specified names {list(names)} have not been found in any of the tasks. "
445
+ f"The specified names {names} have not been found in any of the tasks. "
377
446
  f"Available class names: {available_names_across_tasks}"
378
447
  )
379
448
  if len(tasks_with_names) > 1:
380
449
  raise ValueError(
381
- f"Found multiple tasks containing the specified names {list(names)}. "
450
+ f"Found multiple tasks containing the specified names {names}. "
382
451
  f"Specify either 'task_name' or 'primitive' to only select from one task. "
383
452
  f"Tasks containing all provided names: {[t.name for t in tasks_with_names]}"
384
453
  )
@@ -393,7 +462,7 @@ def _validate_inputs_select_samples_by_class_name(
393
462
  )
394
463
 
395
464
  task_class_names = set(task.class_names or [])
396
- missing_class_names = names - task_class_names
465
+ missing_class_names = set(names) - task_class_names
397
466
  if len(missing_class_names) > 0:
398
467
  raise ValueError(
399
468
  f"The specified names {list(missing_class_names)} have not been found for the '{task.name}' task. "