hafnia 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. cli/__main__.py +13 -2
  2. cli/config.py +2 -1
  3. cli/consts.py +1 -1
  4. cli/dataset_cmds.py +6 -14
  5. cli/dataset_recipe_cmds.py +78 -0
  6. cli/experiment_cmds.py +226 -43
  7. cli/profile_cmds.py +6 -5
  8. cli/runc_cmds.py +5 -5
  9. cli/trainer_package_cmds.py +65 -0
  10. hafnia/__init__.py +2 -0
  11. hafnia/data/factory.py +1 -2
  12. hafnia/dataset/dataset_helpers.py +0 -12
  13. hafnia/dataset/dataset_names.py +8 -4
  14. hafnia/dataset/dataset_recipe/dataset_recipe.py +119 -33
  15. hafnia/dataset/dataset_recipe/recipe_transforms.py +32 -4
  16. hafnia/dataset/dataset_recipe/recipe_types.py +1 -1
  17. hafnia/dataset/dataset_upload_helper.py +206 -53
  18. hafnia/dataset/hafnia_dataset.py +432 -194
  19. hafnia/dataset/license_types.py +63 -0
  20. hafnia/dataset/operations/dataset_stats.py +260 -3
  21. hafnia/dataset/operations/dataset_transformations.py +325 -4
  22. hafnia/dataset/operations/table_transformations.py +39 -2
  23. hafnia/dataset/primitives/__init__.py +8 -0
  24. hafnia/dataset/primitives/classification.py +1 -1
  25. hafnia/experiment/hafnia_logger.py +112 -0
  26. hafnia/http.py +16 -2
  27. hafnia/platform/__init__.py +9 -3
  28. hafnia/platform/builder.py +12 -10
  29. hafnia/platform/dataset_recipe.py +99 -0
  30. hafnia/platform/datasets.py +44 -6
  31. hafnia/platform/download.py +2 -1
  32. hafnia/platform/experiment.py +51 -56
  33. hafnia/platform/trainer_package.py +57 -0
  34. hafnia/utils.py +64 -13
  35. hafnia/visualizations/image_visualizations.py +3 -3
  36. {hafnia-0.2.4.dist-info → hafnia-0.3.0.dist-info}/METADATA +34 -30
  37. hafnia-0.3.0.dist-info/RECORD +53 -0
  38. cli/recipe_cmds.py +0 -45
  39. hafnia-0.2.4.dist-info/RECORD +0 -49
  40. {hafnia-0.2.4.dist-info → hafnia-0.3.0.dist-info}/WHEEL +0 -0
  41. {hafnia-0.2.4.dist-info → hafnia-0.3.0.dist-info}/entry_points.txt +0 -0
  42. {hafnia-0.2.4.dist-info → hafnia-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,63 @@
1
+ from typing import List, Optional
2
+
3
+ from hafnia.dataset.hafnia_dataset import License
4
+
5
+ LICENSE_TYPES: List[License] = [
6
+ License(
7
+ name="Creative Commons: Attribution-NonCommercial-ShareAlike 2.0 Generic",
8
+ name_short="CC BY-NC-SA 2.0",
9
+ url="https://creativecommons.org/licenses/by-nc-sa/2.0/",
10
+ ),
11
+ License(
12
+ name="Creative Commons: Attribution-NonCommercial 2.0 Generic",
13
+ name_short="CC BY-NC 2.0",
14
+ url="https://creativecommons.org/licenses/by-nc/2.0/",
15
+ ),
16
+ License(
17
+ name="Creative Commons: Attribution-NonCommercial-NoDerivs 2.0 Generic",
18
+ name_short="CC BY-NC-ND 2.0",
19
+ url="https://creativecommons.org/licenses/by-nc-nd/2.0/",
20
+ ),
21
+ License(
22
+ name="Creative Commons: Attribution 2.0 Generic",
23
+ name_short="CC BY 2.0",
24
+ url="https://creativecommons.org/licenses/by/2.0/",
25
+ ),
26
+ License(
27
+ name="Creative Commons: Attribution-ShareAlike 2.0 Generic",
28
+ name_short="CC BY-SA 2.0",
29
+ url="https://creativecommons.org/licenses/by-sa/2.0/",
30
+ ),
31
+ License(
32
+ name="Creative Commons: Attribution-NoDerivs 2.0 Generic",
33
+ name_short="CC BY-ND 2.0",
34
+ url="https://creativecommons.org/licenses/by-nd/2.0/",
35
+ ),
36
+ License(
37
+ name="Flickr: No known copyright restrictions",
38
+ name_short="Flickr",
39
+ url="https://flickr.com/commons/usage/",
40
+ ),
41
+ License(
42
+ name="United States Government Work",
43
+ name_short="US Gov",
44
+ url="http://www.usa.gov/copyright.shtml",
45
+ ),
46
+ ]
47
+
48
+
49
+ def get_license_by_url(url: str) -> Optional[License]:
50
+ for license in LICENSE_TYPES:
51
+ # To handle http urls
52
+ license_url = (license.url or "").replace("http://", "https://")
53
+ url_https = url.replace("http://", "https://")
54
+ if license_url == url_https:
55
+ return license
56
+ raise ValueError(f"License with URL '{url}' not found.")
57
+
58
+
59
+ def get_license_by_short_name(short_name: str) -> Optional[License]:
60
+ for license in LICENSE_TYPES:
61
+ if license.name_short == short_name:
62
+ return license
63
+ raise ValueError(f"License with short name '{short_name}' not found.")
@@ -1,11 +1,21 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import TYPE_CHECKING, Dict
3
+ from typing import TYPE_CHECKING, Dict, Optional, Type
4
4
 
5
- from hafnia.dataset.dataset_names import ColumnName
5
+ import polars as pl
6
+ import rich
7
+ from rich import print as rprint
8
+ from rich.table import Table
9
+ from tqdm import tqdm
6
10
 
7
- if TYPE_CHECKING:
11
+ from hafnia.dataset.dataset_names import ColumnName, FieldName, SplitName
12
+ from hafnia.dataset.operations.table_transformations import create_primitive_table
13
+ from hafnia.dataset.primitives import PRIMITIVE_TYPES
14
+ from hafnia.log import user_logger
15
+
16
+ if TYPE_CHECKING: # Using 'TYPE_CHECKING' to avoid circular imports during type checking
8
17
  from hafnia.dataset.hafnia_dataset import HafniaDataset
18
+ from hafnia.dataset.primitives.primitive import Primitive
9
19
 
10
20
 
11
21
  def split_counts(dataset: HafniaDataset) -> Dict[str, int]:
@@ -13,3 +23,250 @@ def split_counts(dataset: HafniaDataset) -> Dict[str, int]:
13
23
  Returns a dictionary with the counts of samples in each split of the dataset.
14
24
  """
15
25
  return dict(dataset.samples[ColumnName.SPLIT].value_counts().iter_rows())
26
+
27
+
28
+ def class_counts_for_task(
29
+ dataset: HafniaDataset,
30
+ primitive: Optional[Type[Primitive]] = None,
31
+ task_name: Optional[str] = None,
32
+ ) -> Dict[str, int]:
33
+ """
34
+ Determines class name counts for a specific task in the dataset.
35
+
36
+ The counts are returned as a dictionary where keys are class names and values are their respective counts.
37
+ Note that class names with zero counts are included in the dictionary and that
38
+ the order of the dictionary matches the class idx.
39
+
40
+ >>> counts = dataset.class_counts_for_task(primitive=Bbox)
41
+ >>> print(counts)
42
+ {
43
+ 'person': 0, # Note: Zero count classes are included to maintain order with class idx
44
+ 'car': 500,
45
+ 'bicycle': 0,
46
+ 'bus': 100,
47
+ 'truck': 150,
48
+ 'motorcycle': 50,
49
+ }
50
+ """
51
+ task = dataset.info.get_task_by_task_name_and_primitive(task_name=task_name, primitive=primitive)
52
+ class_counts_df = (
53
+ dataset.samples[task.primitive.column_name()]
54
+ .explode()
55
+ .struct.unnest()
56
+ .filter(pl.col(FieldName.TASK_NAME) == task.name)[FieldName.CLASS_NAME]
57
+ .value_counts()
58
+ )
59
+
60
+ # Initialize counts with zero for all classes to ensure zero-count classes are included
61
+ # and to have class names in the order of class idx
62
+ class_counts = {name: 0 for name in task.class_names or []}
63
+ class_counts.update(dict(class_counts_df.iter_rows()))
64
+
65
+ return class_counts
66
+
67
+
68
+ def class_counts_all(dataset: HafniaDataset) -> Dict[str, int]:
69
+ """
70
+ Get class counts for all tasks in the dataset.
71
+ The counts are returned as a dictionary where keys are in the format
72
+ '{primitive_name}/{task_name}/{class_name}' and values are their respective counts.
73
+
74
+ Example:
75
+ >>> counts = dataset.class_counts_all()
76
+ >>> print(counts)
77
+ {
78
+ objects/bboxes/car: 500
79
+ objects/bboxes/person: 0
80
+ classifications/weather/sunny: 300
81
+ classifications/weather/rainy: 0
82
+ ...
83
+ }
84
+ """
85
+ class_counts = {}
86
+ for task in dataset.info.tasks:
87
+ if task.class_names is None:
88
+ raise ValueError(f"Task '{task.name}' does not have class names defined.")
89
+ class_counts_task = dataset.class_counts_for_task(primitive=task.primitive, task_name=task.name)
90
+
91
+ for class_idx, (class_name, count) in enumerate(class_counts_task.items()):
92
+ count_name = f"{task.primitive.__name__}/{task.name}/{class_name}"
93
+ class_counts[count_name] = count
94
+
95
+ return class_counts
96
+
97
+
98
+ def print_stats(dataset: HafniaDataset) -> None:
99
+ """
100
+ Prints verbose statistics about the dataset, including dataset name, version,
101
+ number of samples, and detailed counts of samples and tasks.
102
+ """
103
+ table_base = Table(title="Dataset Statistics", show_lines=True, box=rich.box.SIMPLE)
104
+ table_base.add_column("Property", style="cyan")
105
+ table_base.add_column("Value")
106
+ table_base.add_row("Dataset Name", dataset.info.dataset_name)
107
+ table_base.add_row("Version", dataset.info.version)
108
+ table_base.add_row("Number of samples", str(len(dataset.samples)))
109
+ rprint(table_base)
110
+ print_sample_and_task_counts(dataset)
111
+ print_class_distribution(dataset)
112
+
113
+
114
+ def print_class_distribution(dataset: HafniaDataset) -> None:
115
+ """
116
+ Prints the class distribution for each task in the dataset.
117
+ """
118
+ for task in dataset.info.tasks:
119
+ if task.class_names is None:
120
+ raise ValueError(f"Task '{task.name}' does not have class names defined.")
121
+ class_counts = dataset.class_counts_for_task(primitive=task.primitive, task_name=task.name)
122
+
123
+ # Print class distribution
124
+ rich_table = Table(title=f"Class Count for '{task.primitive.__name__}/{task.name}'", show_lines=False)
125
+ rich_table.add_column("Class Name", style="cyan")
126
+ rich_table.add_column("Class Idx", style="cyan")
127
+ rich_table.add_column("Count", justify="right")
128
+ for class_name, count in class_counts.items():
129
+ class_idx = task.class_names.index(class_name) # Get class idx from task info
130
+ rich_table.add_row(class_name, str(class_idx), str(count))
131
+ rprint(rich_table)
132
+
133
+
134
+ def print_sample_and_task_counts(dataset: HafniaDataset) -> None:
135
+ """
136
+ Prints a table with sample counts and task counts for each primitive type
137
+ in total and for each split (train, val, test).
138
+ """
139
+ from hafnia.dataset.operations.table_transformations import create_primitive_table
140
+ from hafnia.dataset.primitives import PRIMITIVE_TYPES
141
+
142
+ splits_sets = {
143
+ "All": SplitName.valid_splits(),
144
+ "Train": [SplitName.TRAIN],
145
+ "Validation": [SplitName.VAL],
146
+ "Test": [SplitName.TEST],
147
+ }
148
+ rows = []
149
+ for split_name, splits in splits_sets.items():
150
+ dataset_split = dataset.create_split_dataset(splits)
151
+ table = dataset_split.samples
152
+ row = {}
153
+ row["Split"] = split_name
154
+ row["Sample "] = str(len(table))
155
+ for PrimitiveType in PRIMITIVE_TYPES:
156
+ column_name = PrimitiveType.column_name()
157
+ objects_df = create_primitive_table(table, PrimitiveType=PrimitiveType, keep_sample_data=False)
158
+ if objects_df is None:
159
+ continue
160
+ for (task_name,), object_group in objects_df.group_by(FieldName.TASK_NAME):
161
+ count = len(object_group[FieldName.CLASS_NAME])
162
+ row[f"{PrimitiveType.__name__}\n{task_name}"] = str(count)
163
+ rows.append(row)
164
+
165
+ rich_table = Table(title="Dataset Statistics", show_lines=True, box=rich.box.SIMPLE)
166
+ for i_row, row in enumerate(rows):
167
+ if i_row == 0:
168
+ for column_name in row.keys():
169
+ rich_table.add_column(column_name, justify="left", style="cyan")
170
+ rich_table.add_row(*[str(value) for value in row.values()])
171
+ rprint(rich_table)
172
+
173
+
174
+ def check_dataset(dataset: HafniaDataset):
175
+ """
176
+ Performs various checks on the dataset to ensure its integrity and consistency.
177
+ Raises errors if any issues are found.
178
+ """
179
+ from hafnia.dataset.hafnia_dataset import Sample
180
+
181
+ user_logger.info("Checking Hafnia dataset...")
182
+ assert isinstance(dataset.info.version, str) and len(dataset.info.version) > 0
183
+ assert isinstance(dataset.info.dataset_name, str) and len(dataset.info.dataset_name) > 0
184
+
185
+ sample_dataset = dataset.create_sample_dataset()
186
+ if len(sample_dataset) == 0:
187
+ raise ValueError("The dataset does not include a sample dataset")
188
+
189
+ actual_splits = dataset.samples.select(pl.col(ColumnName.SPLIT)).unique().to_series().to_list()
190
+ expected_splits = SplitName.valid_splits()
191
+ if set(actual_splits) != set(expected_splits):
192
+ raise ValueError(f"Expected all splits '{expected_splits}' in dataset, but got '{actual_splits}'. ")
193
+
194
+ dataset.check_dataset_tasks()
195
+
196
+ expected_tasks = dataset.info.tasks
197
+ distribution = dataset.info.distributions or []
198
+ distribution_names = [task.name for task in distribution]
199
+ # Check that tasks found in the 'dataset.table' matches the tasks defined in 'dataset.info.tasks'
200
+ for PrimitiveType in PRIMITIVE_TYPES:
201
+ column_name = PrimitiveType.column_name()
202
+ if column_name not in dataset.samples.columns:
203
+ continue
204
+ objects_df = create_primitive_table(dataset.samples, PrimitiveType=PrimitiveType, keep_sample_data=False)
205
+ if objects_df is None:
206
+ continue
207
+ for (task_name,), object_group in objects_df.group_by(FieldName.TASK_NAME):
208
+ has_task = any([t for t in expected_tasks if t.name == task_name and t.primitive == PrimitiveType])
209
+ if has_task or (task_name in distribution_names):
210
+ continue
211
+ class_names = object_group[FieldName.CLASS_NAME].unique().to_list()
212
+ raise ValueError(
213
+ f"Task name '{task_name}' for the '{PrimitiveType.__name__}' primitive is missing in "
214
+ f"'dataset.info.tasks' for dataset '{task_name}'. Missing task has the following "
215
+ f"classes: {class_names}. "
216
+ )
217
+
218
+ for sample_dict in tqdm(dataset, desc="Checking samples in dataset"):
219
+ sample = Sample(**sample_dict) # noqa: F841
220
+
221
+
222
+ def check_dataset_tasks(dataset: HafniaDataset):
223
+ """
224
+ Checks that the tasks defined in 'dataset.info.tasks' are consistent with the data in 'dataset.samples'.
225
+ """
226
+ dataset.info.check_for_duplicate_task_names()
227
+
228
+ for task in dataset.info.tasks:
229
+ primitive = task.primitive.__name__
230
+ column_name = task.primitive.column_name()
231
+ primitive_column = dataset.samples[column_name]
232
+ msg_something_wrong = (
233
+ f"Something is wrong with the defined tasks ('info.tasks') in dataset '{dataset.info.dataset_name}'. \n"
234
+ f"For '{primitive=}' and '{task.name=}' "
235
+ )
236
+ if primitive_column.dtype == pl.Null:
237
+ raise ValueError(msg_something_wrong + "the column is 'Null'. Please check the dataset.")
238
+
239
+ if len(dataset) > 0: # Check only performed for non-empty datasets
240
+ primitive_table = (
241
+ primitive_column.explode().struct.unnest().filter(pl.col(FieldName.TASK_NAME) == task.name)
242
+ )
243
+ if primitive_table.is_empty():
244
+ raise ValueError(
245
+ msg_something_wrong
246
+ + f"the column '{column_name}' has no {task.name=} objects. Please check the dataset."
247
+ )
248
+
249
+ actual_classes = set(primitive_table[FieldName.CLASS_NAME].unique().to_list())
250
+ if task.class_names is None:
251
+ raise ValueError(
252
+ msg_something_wrong
253
+ + f"the column '{column_name}' with {task.name=} has no defined classes. Please check the dataset."
254
+ )
255
+ defined_classes = set(task.class_names)
256
+
257
+ if not actual_classes.issubset(defined_classes):
258
+ raise ValueError(
259
+ msg_something_wrong
260
+ + f"the column '{column_name}' with {task.name=} we expected the actual classes in the dataset to \n"
261
+ f"to be a subset of the defined classes\n\t{actual_classes=} \n\t{defined_classes=}."
262
+ )
263
+ # Check class_indices
264
+ mapped_indices = primitive_table[FieldName.CLASS_NAME].map_elements(
265
+ lambda x: task.class_names.index(x), return_dtype=pl.Int64
266
+ )
267
+ table_indices = primitive_table[FieldName.CLASS_IDX]
268
+
269
+ error_msg = msg_something_wrong + (
270
+ f"class indices in '{FieldName.CLASS_IDX}' column does not match classes ordering in 'task.class_names'"
271
+ )
272
+ assert mapped_indices.equals(table_indices), error_msg
@@ -29,19 +29,27 @@ HafniaDataset class and a RecipeTransform class in the `data_recipe/recipe_trans
29
29
  that the signatures match.
30
30
  """
31
31
 
32
+ import json
33
+ import re
34
+ import textwrap
32
35
  from pathlib import Path
33
- from typing import TYPE_CHECKING, Callable
36
+ from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple, Type, Union
34
37
 
35
38
  import cv2
39
+ import more_itertools
36
40
  import numpy as np
37
41
  import polars as pl
38
42
  from PIL import Image
39
43
  from tqdm import tqdm
40
44
 
41
45
  from hafnia.dataset import dataset_helpers
46
+ from hafnia.dataset.dataset_names import OPS_REMOVE_CLASS, FieldName
47
+ from hafnia.dataset.primitives import get_primitive_type_from_string
48
+ from hafnia.dataset.primitives.primitive import Primitive
49
+ from hafnia.utils import remove_duplicates_preserve_order
42
50
 
43
- if TYPE_CHECKING:
44
- from hafnia.dataset.hafnia_dataset import HafniaDataset
51
+ if TYPE_CHECKING: # Using 'TYPE_CHECKING' to avoid circular imports during type checking
52
+ from hafnia.dataset.hafnia_dataset import HafniaDataset, TaskInfo
45
53
 
46
54
 
47
55
  ### Image transformations ###
@@ -79,4 +87,317 @@ def transform_images(
79
87
  new_paths.append(str(new_path))
80
88
 
81
89
  table = dataset.samples.with_columns(pl.Series(new_paths).alias("file_name"))
82
- return dataset.update_table(table)
90
+ return dataset.update_samples(table)
91
+
92
+
93
+ def get_task_info_from_task_name_and_primitive(
94
+ tasks: List["TaskInfo"],
95
+ task_name: Optional[str] = None,
96
+ primitive: Union[None, str, Type[Primitive]] = None,
97
+ ) -> "TaskInfo":
98
+ if len(tasks) == 0:
99
+ raise ValueError("Dataset has no tasks defined.")
100
+
101
+ tasks_str = "\n".join([f"\t{task.__repr__()}" for task in tasks])
102
+ if task_name is None and primitive is None:
103
+ if len(tasks) == 1:
104
+ return tasks[0]
105
+ else:
106
+ raise ValueError(
107
+ "For multiple tasks, you will need to specify 'task_name' or 'type_primitive' "
108
+ "to return a unique task. The dataset contains the following tasks: \n" + tasks_str
109
+ )
110
+
111
+ if isinstance(primitive, str):
112
+ primitive = get_primitive_type_from_string(primitive)
113
+
114
+ tasks_filtered = tasks
115
+ if primitive is None:
116
+ tasks_filtered = [task for task in tasks if task.name == task_name]
117
+
118
+ if len(tasks_filtered) == 0:
119
+ raise ValueError(f"No task found with {task_name=}. Available tasks: \n {tasks_str}")
120
+
121
+ unique_primitives = set(task.primitive for task in tasks_filtered)
122
+ if len(unique_primitives) > 1:
123
+ raise ValueError(
124
+ f"Found multiple tasks with {task_name=} using different primitives {unique_primitives=}. "
125
+ "Please specify the primitive type to make it unique. "
126
+ f"The dataset contains the following tasks: \n {tasks_str}"
127
+ )
128
+ primitive = list(unique_primitives)[0]
129
+
130
+ if task_name is None:
131
+ tasks_filtered = [task for task in tasks if task.primitive == primitive]
132
+ if len(tasks_filtered) == 0:
133
+ raise ValueError(f"No task found with {primitive=}. Available tasks: \n {tasks_str}")
134
+
135
+ unique_task_names = set(task.name for task in tasks_filtered)
136
+ if len(unique_task_names) > 1:
137
+ raise ValueError(
138
+ f"Found multiple tasks with {primitive=} using different task names {unique_task_names=}. "
139
+ "Please specify the 'task_name' to make it unique."
140
+ f"The dataset contains the following tasks: \n {tasks_str}"
141
+ )
142
+ task_name = list(unique_task_names)[0]
143
+
144
+ tasks_filtered = [task for task in tasks_filtered if task.primitive == primitive and task.name == task_name]
145
+ if len(tasks_filtered) == 0:
146
+ raise ValueError(f"No task found with {task_name=} and {primitive=}. Available tasks: \n {tasks_str}")
147
+
148
+ if len(tasks_filtered) > 1:
149
+ raise ValueError(
150
+ f"Multiple tasks found with {task_name=} and {primitive=}. "
151
+ f"This should never happen. The dataset contains the following tasks: \n {tasks_str}"
152
+ )
153
+ task = tasks_filtered[0]
154
+ return task
155
+
156
+
157
+ def class_mapper(
158
+ dataset: "HafniaDataset",
159
+ class_mapping: Dict[str, str],
160
+ method: str = "strict",
161
+ primitive: Optional[Type[Primitive]] = None,
162
+ task_name: Optional[str] = None,
163
+ ) -> "HafniaDataset":
164
+ from hafnia.dataset.hafnia_dataset import HafniaDataset
165
+
166
+ allowed_methods = ("strict", "remove_undefined", "keep_undefined")
167
+ if method not in allowed_methods:
168
+ raise ValueError(f"Method '{method}' is not recognized. Allowed methods are: {allowed_methods}")
169
+
170
+ task = dataset.info.get_task_by_task_name_and_primitive(task_name=task_name, primitive=primitive)
171
+ current_names = task.class_names or []
172
+
173
+ # Expand wildcard mappings
174
+ class_mapping = expand_class_mapping(class_mapping, current_names)
175
+
176
+ non_existing_mapping_names = set(class_mapping) - set(current_names)
177
+ if len(non_existing_mapping_names) > 0:
178
+ raise ValueError(
179
+ f"The specified class mapping contains class names {list(non_existing_mapping_names)} "
180
+ f"that do not exist in the dataset task '{task.name}'. "
181
+ f"Available class names: {current_names}"
182
+ )
183
+
184
+ missing_class_names = [c for c in current_names if c not in class_mapping] # List-comprehension to preserve order
185
+ class_mapping = class_mapping.copy()
186
+ if method == "strict":
187
+ pass # Continue to strict mapping below
188
+ elif method == "remove_undefined":
189
+ for missing_class_name in missing_class_names:
190
+ class_mapping[missing_class_name] = OPS_REMOVE_CLASS
191
+ elif method == "keep_undefined":
192
+ for missing_class_name in missing_class_names:
193
+ class_mapping[missing_class_name] = missing_class_name
194
+ else:
195
+ raise ValueError(f"Method '{method}' is not recognized. Allowed methods are: {allowed_methods}")
196
+
197
+ missing_class_names = [c for c in current_names if c not in class_mapping]
198
+ if len(missing_class_names) > 0:
199
+ error_msg = f"""\
200
+ The specified class mapping is not a strict mapping - meaning that all class names have not
201
+ been mapped to a new class name.
202
+ In the current mapping, the following classes {list(missing_class_names)} have not been mapped.
203
+ The currently specified mapping is:
204
+ {json.dumps(class_mapping, indent=2)}
205
+ A strict mapping will replace all old class names (dictionary keys) to new class names (dictionary values).
206
+ Please update the mapping to include all class names from the dataset task '{task.name}'.
207
+ To keep class map to the same name e.g. 'person' = 'person'
208
+ or remove class by using the '__REMOVE__' key, e.g. 'person': '__REMOVE__'."""
209
+ raise ValueError(textwrap.dedent(error_msg))
210
+
211
+ new_class_names = remove_duplicates_preserve_order(class_mapping.values())
212
+
213
+ if OPS_REMOVE_CLASS in new_class_names:
214
+ # Move __REMOVE__ to the end of the list if it exists
215
+ new_class_names.append(new_class_names.pop(new_class_names.index(OPS_REMOVE_CLASS)))
216
+ name_2_idx_mapping: Dict[str, int] = {name: idx for idx, name in enumerate(new_class_names)}
217
+
218
+ samples = dataset.samples
219
+ samples_updated = samples.with_columns(
220
+ pl.col(task.primitive.column_name())
221
+ .list.eval(
222
+ pl.element().struct.with_fields(
223
+ pl.when(pl.field(FieldName.TASK_NAME) == task.name)
224
+ .then(pl.field(FieldName.CLASS_NAME).replace_strict(class_mapping))
225
+ .otherwise(pl.field(FieldName.CLASS_NAME))
226
+ .alias(FieldName.CLASS_NAME)
227
+ )
228
+ )
229
+ .alias(task.primitive.column_name())
230
+ )
231
+
232
+ # Update class indices too
233
+ samples_updated = samples_updated.with_columns(
234
+ pl.col(task.primitive.column_name())
235
+ .list.eval(
236
+ pl.element().struct.with_fields(
237
+ pl.when(pl.field(FieldName.TASK_NAME) == task.name)
238
+ .then(pl.field(FieldName.CLASS_NAME).replace_strict(name_2_idx_mapping))
239
+ .otherwise(pl.field(FieldName.CLASS_IDX))
240
+ .alias(FieldName.CLASS_IDX)
241
+ )
242
+ )
243
+ .alias(task.primitive.column_name())
244
+ )
245
+
246
+ if OPS_REMOVE_CLASS in new_class_names: # Remove class_names that are mapped to REMOVE_CLASS
247
+ samples_updated = samples_updated.with_columns(
248
+ pl.col(task.primitive.column_name())
249
+ .list.filter(pl.element().struct.field(FieldName.CLASS_NAME) != OPS_REMOVE_CLASS)
250
+ .alias(task.primitive.column_name())
251
+ )
252
+
253
+ new_class_names = [c for c in new_class_names if c != OPS_REMOVE_CLASS]
254
+
255
+ new_task = task.model_copy(deep=True)
256
+ new_task.class_names = new_class_names
257
+ dataset_info = dataset.info.replace_task(old_task=task, new_task=new_task)
258
+ return HafniaDataset(info=dataset_info, samples=samples_updated)
259
+
260
+
261
+ def expand_class_mapping(wildcard_mapping: Dict[str, str], class_names: List[str]) -> Dict[str, str]:
262
+ """
263
+ Expand a wildcard class mapping to a full explicit mapping.
264
+
265
+ This function takes a mapping that may contain wildcard patterns (using '*')
266
+ and expands them to match actual class names from a dataset. Exact matches
267
+ take precedence over wildcard patterns.
268
+
269
+ Examples:
270
+ >>> from hafnia.dataset.dataset_names import OPS_REMOVE_CLASS
271
+ >>> wildcard_mapping = {
272
+ ... "Person": "Person",
273
+ ... "Vehicle.*": "Vehicle",
274
+ ... "Vehicle.Trailer": OPS_REMOVE_CLASS
275
+ ... }
276
+ >>> class_names = [
277
+ ... "Person", "Vehicle.Car", "Vehicle.Trailer", "Vehicle.Bus", "Animal.Dog"
278
+ ... ]
279
+ >>> result = expand_wildcard_mapping(wildcard_mapping, class_names)
280
+ >>> print(result)
281
+ {
282
+ "Person": "Person",
283
+ "Vehicle.Car": "Vehicle",
284
+ "Vehicle.Trailer": OPS_REMOVE_CLASS, # Exact match overrides wildcard
285
+ "Vehicle.Bus": "Vehicle",
286
+ # Note: "Animal.Dog" is not included as it doesn't match any pattern
287
+ }
288
+ """
289
+ expanded_mapping = {}
290
+ for match_pattern, mapping_value in wildcard_mapping.items():
291
+ if "*" in match_pattern:
292
+ # Convert wildcard pattern to regex: Escape special regex characters except *, then replace * with .*
293
+ regex_pattern = re.escape(match_pattern).replace("\\*", ".*")
294
+ class_names_matched = [cn for cn in class_names if re.fullmatch(regex_pattern, cn)]
295
+ expanded_mapping.update({cn: mapping_value for cn in class_names_matched})
296
+ else:
297
+ expanded_mapping.pop(match_pattern, None)
298
+ expanded_mapping[match_pattern] = mapping_value
299
+ return expanded_mapping
300
+
301
+
302
+ def rename_task(
303
+ dataset: "HafniaDataset",
304
+ old_task_name: str,
305
+ new_task_name: str,
306
+ ) -> "HafniaDataset":
307
+ from hafnia.dataset.hafnia_dataset import HafniaDataset
308
+
309
+ old_task = dataset.info.get_task_by_name(task_name=old_task_name)
310
+ new_task = old_task.model_copy(deep=True)
311
+ new_task.name = new_task_name
312
+ samples = dataset.samples.with_columns(
313
+ pl.col(old_task.primitive.column_name())
314
+ .list.eval(
315
+ pl.element().struct.with_fields(
316
+ pl.field(FieldName.TASK_NAME).replace(old_task.name, new_task.name).alias(FieldName.TASK_NAME)
317
+ )
318
+ )
319
+ .alias(new_task.primitive.column_name())
320
+ )
321
+
322
+ dataset_info = dataset.info.replace_task(old_task=old_task, new_task=new_task)
323
+ return HafniaDataset(info=dataset_info, samples=samples)
324
+
325
+
326
+ def select_samples_by_class_name(
327
+ dataset: "HafniaDataset",
328
+ name: Union[List[str], str],
329
+ task_name: Optional[str] = None,
330
+ primitive: Optional[Type[Primitive]] = None,
331
+ ) -> "HafniaDataset":
332
+ task, class_names = _validate_inputs_select_samples_by_class_name(
333
+ dataset=dataset,
334
+ name=name,
335
+ task_name=task_name,
336
+ primitive=primitive,
337
+ )
338
+
339
+ samples = dataset.samples.filter(
340
+ pl.col(task.primitive.column_name())
341
+ .list.eval(
342
+ pl.element().struct.field(FieldName.CLASS_NAME).is_in(class_names)
343
+ & (pl.element().struct.field(FieldName.TASK_NAME) == task.name)
344
+ )
345
+ .list.any()
346
+ )
347
+
348
+ dataset_updated = dataset.update_samples(samples)
349
+ return dataset_updated
350
+
351
+
352
+ def _validate_inputs_select_samples_by_class_name(
353
+ dataset: "HafniaDataset",
354
+ name: Union[List[str], str],
355
+ task_name: Optional[str] = None,
356
+ primitive: Optional[Type[Primitive]] = None,
357
+ ) -> Tuple["TaskInfo", Set[str]]:
358
+ if isinstance(name, str):
359
+ name = [name]
360
+ names = set(name)
361
+
362
+ # Check that specified names are available in at least one of the tasks
363
+ available_names_across_tasks = set(more_itertools.flatten([t.class_names for t in dataset.info.tasks]))
364
+ missing_class_names_across_tasks = names - available_names_across_tasks
365
+ if len(missing_class_names_across_tasks) > 0:
366
+ raise ValueError(
367
+ f"The specified names {list(names)} have not been found in any of the tasks. "
368
+ f"Available class names: {available_names_across_tasks}"
369
+ )
370
+
371
+ # Auto infer task if task_name and primitive are not provided
372
+ if task_name is None and primitive is None:
373
+ tasks_with_names = [t for t in dataset.info.tasks if names.issubset(t.class_names or [])]
374
+ if len(tasks_with_names) == 0:
375
+ raise ValueError(
376
+ f"The specified names {list(names)} have not been found in any of the tasks. "
377
+ f"Available class names: {available_names_across_tasks}"
378
+ )
379
+ if len(tasks_with_names) > 1:
380
+ raise ValueError(
381
+ f"Found multiple tasks containing the specified names {list(names)}. "
382
+ f"Specify either 'task_name' or 'primitive' to only select from one task. "
383
+ f"Tasks containing all provided names: {[t.name for t in tasks_with_names]}"
384
+ )
385
+
386
+ task = tasks_with_names[0]
387
+
388
+ else:
389
+ task = get_task_info_from_task_name_and_primitive(
390
+ tasks=dataset.info.tasks,
391
+ task_name=task_name,
392
+ primitive=primitive,
393
+ )
394
+
395
+ task_class_names = set(task.class_names or [])
396
+ missing_class_names = names - task_class_names
397
+ if len(missing_class_names) > 0:
398
+ raise ValueError(
399
+ f"The specified names {list(missing_class_names)} have not been found for the '{task.name}' task. "
400
+ f"Available class names: {task_class_names}"
401
+ )
402
+
403
+ return task, names