hafnia 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/__main__.py +13 -2
- cli/config.py +2 -1
- cli/consts.py +1 -1
- cli/dataset_cmds.py +6 -14
- cli/dataset_recipe_cmds.py +78 -0
- cli/experiment_cmds.py +226 -43
- cli/profile_cmds.py +6 -5
- cli/runc_cmds.py +5 -5
- cli/trainer_package_cmds.py +65 -0
- hafnia/__init__.py +2 -0
- hafnia/data/factory.py +1 -2
- hafnia/dataset/dataset_helpers.py +0 -12
- hafnia/dataset/dataset_names.py +8 -4
- hafnia/dataset/dataset_recipe/dataset_recipe.py +119 -33
- hafnia/dataset/dataset_recipe/recipe_transforms.py +32 -4
- hafnia/dataset/dataset_recipe/recipe_types.py +1 -1
- hafnia/dataset/dataset_upload_helper.py +206 -53
- hafnia/dataset/hafnia_dataset.py +432 -194
- hafnia/dataset/license_types.py +63 -0
- hafnia/dataset/operations/dataset_stats.py +260 -3
- hafnia/dataset/operations/dataset_transformations.py +325 -4
- hafnia/dataset/operations/table_transformations.py +39 -2
- hafnia/dataset/primitives/__init__.py +8 -0
- hafnia/dataset/primitives/classification.py +1 -1
- hafnia/experiment/hafnia_logger.py +112 -0
- hafnia/http.py +16 -2
- hafnia/platform/__init__.py +9 -3
- hafnia/platform/builder.py +12 -10
- hafnia/platform/dataset_recipe.py +99 -0
- hafnia/platform/datasets.py +44 -6
- hafnia/platform/download.py +2 -1
- hafnia/platform/experiment.py +51 -56
- hafnia/platform/trainer_package.py +57 -0
- hafnia/utils.py +64 -13
- hafnia/visualizations/image_visualizations.py +3 -3
- {hafnia-0.2.4.dist-info → hafnia-0.3.0.dist-info}/METADATA +34 -30
- hafnia-0.3.0.dist-info/RECORD +53 -0
- cli/recipe_cmds.py +0 -45
- hafnia-0.2.4.dist-info/RECORD +0 -49
- {hafnia-0.2.4.dist-info → hafnia-0.3.0.dist-info}/WHEEL +0 -0
- {hafnia-0.2.4.dist-info → hafnia-0.3.0.dist-info}/entry_points.txt +0 -0
- {hafnia-0.2.4.dist-info → hafnia-0.3.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
from typing import List, Optional
|
|
2
|
+
|
|
3
|
+
from hafnia.dataset.hafnia_dataset import License
|
|
4
|
+
|
|
5
|
+
LICENSE_TYPES: List[License] = [
|
|
6
|
+
License(
|
|
7
|
+
name="Creative Commons: Attribution-NonCommercial-ShareAlike 2.0 Generic",
|
|
8
|
+
name_short="CC BY-NC-SA 2.0",
|
|
9
|
+
url="https://creativecommons.org/licenses/by-nc-sa/2.0/",
|
|
10
|
+
),
|
|
11
|
+
License(
|
|
12
|
+
name="Creative Commons: Attribution-NonCommercial 2.0 Generic",
|
|
13
|
+
name_short="CC BY-NC 2.0",
|
|
14
|
+
url="https://creativecommons.org/licenses/by-nc/2.0/",
|
|
15
|
+
),
|
|
16
|
+
License(
|
|
17
|
+
name="Creative Commons: Attribution-NonCommercial-NoDerivs 2.0 Generic",
|
|
18
|
+
name_short="CC BY-NC-ND 2.0",
|
|
19
|
+
url="https://creativecommons.org/licenses/by-nc-nd/2.0/",
|
|
20
|
+
),
|
|
21
|
+
License(
|
|
22
|
+
name="Creative Commons: Attribution 2.0 Generic",
|
|
23
|
+
name_short="CC BY 2.0",
|
|
24
|
+
url="https://creativecommons.org/licenses/by/2.0/",
|
|
25
|
+
),
|
|
26
|
+
License(
|
|
27
|
+
name="Creative Commons: Attribution-ShareAlike 2.0 Generic",
|
|
28
|
+
name_short="CC BY-SA 2.0",
|
|
29
|
+
url="https://creativecommons.org/licenses/by-sa/2.0/",
|
|
30
|
+
),
|
|
31
|
+
License(
|
|
32
|
+
name="Creative Commons: Attribution-NoDerivs 2.0 Generic",
|
|
33
|
+
name_short="CC BY-ND 2.0",
|
|
34
|
+
url="https://creativecommons.org/licenses/by-nd/2.0/",
|
|
35
|
+
),
|
|
36
|
+
License(
|
|
37
|
+
name="Flickr: No known copyright restrictions",
|
|
38
|
+
name_short="Flickr",
|
|
39
|
+
url="https://flickr.com/commons/usage/",
|
|
40
|
+
),
|
|
41
|
+
License(
|
|
42
|
+
name="United States Government Work",
|
|
43
|
+
name_short="US Gov",
|
|
44
|
+
url="http://www.usa.gov/copyright.shtml",
|
|
45
|
+
),
|
|
46
|
+
]
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def get_license_by_url(url: str) -> Optional[License]:
|
|
50
|
+
for license in LICENSE_TYPES:
|
|
51
|
+
# To handle http urls
|
|
52
|
+
license_url = (license.url or "").replace("http://", "https://")
|
|
53
|
+
url_https = url.replace("http://", "https://")
|
|
54
|
+
if license_url == url_https:
|
|
55
|
+
return license
|
|
56
|
+
raise ValueError(f"License with URL '{url}' not found.")
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def get_license_by_short_name(short_name: str) -> Optional[License]:
|
|
60
|
+
for license in LICENSE_TYPES:
|
|
61
|
+
if license.name_short == short_name:
|
|
62
|
+
return license
|
|
63
|
+
raise ValueError(f"License with short name '{short_name}' not found.")
|
|
@@ -1,11 +1,21 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from typing import TYPE_CHECKING, Dict
|
|
3
|
+
from typing import TYPE_CHECKING, Dict, Optional, Type
|
|
4
4
|
|
|
5
|
-
|
|
5
|
+
import polars as pl
|
|
6
|
+
import rich
|
|
7
|
+
from rich import print as rprint
|
|
8
|
+
from rich.table import Table
|
|
9
|
+
from tqdm import tqdm
|
|
6
10
|
|
|
7
|
-
|
|
11
|
+
from hafnia.dataset.dataset_names import ColumnName, FieldName, SplitName
|
|
12
|
+
from hafnia.dataset.operations.table_transformations import create_primitive_table
|
|
13
|
+
from hafnia.dataset.primitives import PRIMITIVE_TYPES
|
|
14
|
+
from hafnia.log import user_logger
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING: # Using 'TYPE_CHECKING' to avoid circular imports during type checking
|
|
8
17
|
from hafnia.dataset.hafnia_dataset import HafniaDataset
|
|
18
|
+
from hafnia.dataset.primitives.primitive import Primitive
|
|
9
19
|
|
|
10
20
|
|
|
11
21
|
def split_counts(dataset: HafniaDataset) -> Dict[str, int]:
|
|
@@ -13,3 +23,250 @@ def split_counts(dataset: HafniaDataset) -> Dict[str, int]:
|
|
|
13
23
|
Returns a dictionary with the counts of samples in each split of the dataset.
|
|
14
24
|
"""
|
|
15
25
|
return dict(dataset.samples[ColumnName.SPLIT].value_counts().iter_rows())
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def class_counts_for_task(
|
|
29
|
+
dataset: HafniaDataset,
|
|
30
|
+
primitive: Optional[Type[Primitive]] = None,
|
|
31
|
+
task_name: Optional[str] = None,
|
|
32
|
+
) -> Dict[str, int]:
|
|
33
|
+
"""
|
|
34
|
+
Determines class name counts for a specific task in the dataset.
|
|
35
|
+
|
|
36
|
+
The counts are returned as a dictionary where keys are class names and values are their respective counts.
|
|
37
|
+
Note that class names with zero counts are included in the dictionary and that
|
|
38
|
+
the order of the dictionary matches the class idx.
|
|
39
|
+
|
|
40
|
+
>>> counts = dataset.class_counts_for_task(primitive=Bbox)
|
|
41
|
+
>>> print(counts)
|
|
42
|
+
{
|
|
43
|
+
'person': 0, # Note: Zero count classes are included to maintain order with class idx
|
|
44
|
+
'car': 500,
|
|
45
|
+
'bicycle': 0,
|
|
46
|
+
'bus': 100,
|
|
47
|
+
'truck': 150,
|
|
48
|
+
'motorcycle': 50,
|
|
49
|
+
}
|
|
50
|
+
"""
|
|
51
|
+
task = dataset.info.get_task_by_task_name_and_primitive(task_name=task_name, primitive=primitive)
|
|
52
|
+
class_counts_df = (
|
|
53
|
+
dataset.samples[task.primitive.column_name()]
|
|
54
|
+
.explode()
|
|
55
|
+
.struct.unnest()
|
|
56
|
+
.filter(pl.col(FieldName.TASK_NAME) == task.name)[FieldName.CLASS_NAME]
|
|
57
|
+
.value_counts()
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# Initialize counts with zero for all classes to ensure zero-count classes are included
|
|
61
|
+
# and to have class names in the order of class idx
|
|
62
|
+
class_counts = {name: 0 for name in task.class_names or []}
|
|
63
|
+
class_counts.update(dict(class_counts_df.iter_rows()))
|
|
64
|
+
|
|
65
|
+
return class_counts
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def class_counts_all(dataset: HafniaDataset) -> Dict[str, int]:
|
|
69
|
+
"""
|
|
70
|
+
Get class counts for all tasks in the dataset.
|
|
71
|
+
The counts are returned as a dictionary where keys are in the format
|
|
72
|
+
'{primitive_name}/{task_name}/{class_name}' and values are their respective counts.
|
|
73
|
+
|
|
74
|
+
Example:
|
|
75
|
+
>>> counts = dataset.class_counts_all()
|
|
76
|
+
>>> print(counts)
|
|
77
|
+
{
|
|
78
|
+
objects/bboxes/car: 500
|
|
79
|
+
objects/bboxes/person: 0
|
|
80
|
+
classifications/weather/sunny: 300
|
|
81
|
+
classifications/weather/rainy: 0
|
|
82
|
+
...
|
|
83
|
+
}
|
|
84
|
+
"""
|
|
85
|
+
class_counts = {}
|
|
86
|
+
for task in dataset.info.tasks:
|
|
87
|
+
if task.class_names is None:
|
|
88
|
+
raise ValueError(f"Task '{task.name}' does not have class names defined.")
|
|
89
|
+
class_counts_task = dataset.class_counts_for_task(primitive=task.primitive, task_name=task.name)
|
|
90
|
+
|
|
91
|
+
for class_idx, (class_name, count) in enumerate(class_counts_task.items()):
|
|
92
|
+
count_name = f"{task.primitive.__name__}/{task.name}/{class_name}"
|
|
93
|
+
class_counts[count_name] = count
|
|
94
|
+
|
|
95
|
+
return class_counts
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def print_stats(dataset: HafniaDataset) -> None:
|
|
99
|
+
"""
|
|
100
|
+
Prints verbose statistics about the dataset, including dataset name, version,
|
|
101
|
+
number of samples, and detailed counts of samples and tasks.
|
|
102
|
+
"""
|
|
103
|
+
table_base = Table(title="Dataset Statistics", show_lines=True, box=rich.box.SIMPLE)
|
|
104
|
+
table_base.add_column("Property", style="cyan")
|
|
105
|
+
table_base.add_column("Value")
|
|
106
|
+
table_base.add_row("Dataset Name", dataset.info.dataset_name)
|
|
107
|
+
table_base.add_row("Version", dataset.info.version)
|
|
108
|
+
table_base.add_row("Number of samples", str(len(dataset.samples)))
|
|
109
|
+
rprint(table_base)
|
|
110
|
+
print_sample_and_task_counts(dataset)
|
|
111
|
+
print_class_distribution(dataset)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def print_class_distribution(dataset: HafniaDataset) -> None:
|
|
115
|
+
"""
|
|
116
|
+
Prints the class distribution for each task in the dataset.
|
|
117
|
+
"""
|
|
118
|
+
for task in dataset.info.tasks:
|
|
119
|
+
if task.class_names is None:
|
|
120
|
+
raise ValueError(f"Task '{task.name}' does not have class names defined.")
|
|
121
|
+
class_counts = dataset.class_counts_for_task(primitive=task.primitive, task_name=task.name)
|
|
122
|
+
|
|
123
|
+
# Print class distribution
|
|
124
|
+
rich_table = Table(title=f"Class Count for '{task.primitive.__name__}/{task.name}'", show_lines=False)
|
|
125
|
+
rich_table.add_column("Class Name", style="cyan")
|
|
126
|
+
rich_table.add_column("Class Idx", style="cyan")
|
|
127
|
+
rich_table.add_column("Count", justify="right")
|
|
128
|
+
for class_name, count in class_counts.items():
|
|
129
|
+
class_idx = task.class_names.index(class_name) # Get class idx from task info
|
|
130
|
+
rich_table.add_row(class_name, str(class_idx), str(count))
|
|
131
|
+
rprint(rich_table)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def print_sample_and_task_counts(dataset: HafniaDataset) -> None:
|
|
135
|
+
"""
|
|
136
|
+
Prints a table with sample counts and task counts for each primitive type
|
|
137
|
+
in total and for each split (train, val, test).
|
|
138
|
+
"""
|
|
139
|
+
from hafnia.dataset.operations.table_transformations import create_primitive_table
|
|
140
|
+
from hafnia.dataset.primitives import PRIMITIVE_TYPES
|
|
141
|
+
|
|
142
|
+
splits_sets = {
|
|
143
|
+
"All": SplitName.valid_splits(),
|
|
144
|
+
"Train": [SplitName.TRAIN],
|
|
145
|
+
"Validation": [SplitName.VAL],
|
|
146
|
+
"Test": [SplitName.TEST],
|
|
147
|
+
}
|
|
148
|
+
rows = []
|
|
149
|
+
for split_name, splits in splits_sets.items():
|
|
150
|
+
dataset_split = dataset.create_split_dataset(splits)
|
|
151
|
+
table = dataset_split.samples
|
|
152
|
+
row = {}
|
|
153
|
+
row["Split"] = split_name
|
|
154
|
+
row["Sample "] = str(len(table))
|
|
155
|
+
for PrimitiveType in PRIMITIVE_TYPES:
|
|
156
|
+
column_name = PrimitiveType.column_name()
|
|
157
|
+
objects_df = create_primitive_table(table, PrimitiveType=PrimitiveType, keep_sample_data=False)
|
|
158
|
+
if objects_df is None:
|
|
159
|
+
continue
|
|
160
|
+
for (task_name,), object_group in objects_df.group_by(FieldName.TASK_NAME):
|
|
161
|
+
count = len(object_group[FieldName.CLASS_NAME])
|
|
162
|
+
row[f"{PrimitiveType.__name__}\n{task_name}"] = str(count)
|
|
163
|
+
rows.append(row)
|
|
164
|
+
|
|
165
|
+
rich_table = Table(title="Dataset Statistics", show_lines=True, box=rich.box.SIMPLE)
|
|
166
|
+
for i_row, row in enumerate(rows):
|
|
167
|
+
if i_row == 0:
|
|
168
|
+
for column_name in row.keys():
|
|
169
|
+
rich_table.add_column(column_name, justify="left", style="cyan")
|
|
170
|
+
rich_table.add_row(*[str(value) for value in row.values()])
|
|
171
|
+
rprint(rich_table)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def check_dataset(dataset: HafniaDataset):
|
|
175
|
+
"""
|
|
176
|
+
Performs various checks on the dataset to ensure its integrity and consistency.
|
|
177
|
+
Raises errors if any issues are found.
|
|
178
|
+
"""
|
|
179
|
+
from hafnia.dataset.hafnia_dataset import Sample
|
|
180
|
+
|
|
181
|
+
user_logger.info("Checking Hafnia dataset...")
|
|
182
|
+
assert isinstance(dataset.info.version, str) and len(dataset.info.version) > 0
|
|
183
|
+
assert isinstance(dataset.info.dataset_name, str) and len(dataset.info.dataset_name) > 0
|
|
184
|
+
|
|
185
|
+
sample_dataset = dataset.create_sample_dataset()
|
|
186
|
+
if len(sample_dataset) == 0:
|
|
187
|
+
raise ValueError("The dataset does not include a sample dataset")
|
|
188
|
+
|
|
189
|
+
actual_splits = dataset.samples.select(pl.col(ColumnName.SPLIT)).unique().to_series().to_list()
|
|
190
|
+
expected_splits = SplitName.valid_splits()
|
|
191
|
+
if set(actual_splits) != set(expected_splits):
|
|
192
|
+
raise ValueError(f"Expected all splits '{expected_splits}' in dataset, but got '{actual_splits}'. ")
|
|
193
|
+
|
|
194
|
+
dataset.check_dataset_tasks()
|
|
195
|
+
|
|
196
|
+
expected_tasks = dataset.info.tasks
|
|
197
|
+
distribution = dataset.info.distributions or []
|
|
198
|
+
distribution_names = [task.name for task in distribution]
|
|
199
|
+
# Check that tasks found in the 'dataset.table' matches the tasks defined in 'dataset.info.tasks'
|
|
200
|
+
for PrimitiveType in PRIMITIVE_TYPES:
|
|
201
|
+
column_name = PrimitiveType.column_name()
|
|
202
|
+
if column_name not in dataset.samples.columns:
|
|
203
|
+
continue
|
|
204
|
+
objects_df = create_primitive_table(dataset.samples, PrimitiveType=PrimitiveType, keep_sample_data=False)
|
|
205
|
+
if objects_df is None:
|
|
206
|
+
continue
|
|
207
|
+
for (task_name,), object_group in objects_df.group_by(FieldName.TASK_NAME):
|
|
208
|
+
has_task = any([t for t in expected_tasks if t.name == task_name and t.primitive == PrimitiveType])
|
|
209
|
+
if has_task or (task_name in distribution_names):
|
|
210
|
+
continue
|
|
211
|
+
class_names = object_group[FieldName.CLASS_NAME].unique().to_list()
|
|
212
|
+
raise ValueError(
|
|
213
|
+
f"Task name '{task_name}' for the '{PrimitiveType.__name__}' primitive is missing in "
|
|
214
|
+
f"'dataset.info.tasks' for dataset '{task_name}'. Missing task has the following "
|
|
215
|
+
f"classes: {class_names}. "
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
for sample_dict in tqdm(dataset, desc="Checking samples in dataset"):
|
|
219
|
+
sample = Sample(**sample_dict) # noqa: F841
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def check_dataset_tasks(dataset: HafniaDataset):
|
|
223
|
+
"""
|
|
224
|
+
Checks that the tasks defined in 'dataset.info.tasks' are consistent with the data in 'dataset.samples'.
|
|
225
|
+
"""
|
|
226
|
+
dataset.info.check_for_duplicate_task_names()
|
|
227
|
+
|
|
228
|
+
for task in dataset.info.tasks:
|
|
229
|
+
primitive = task.primitive.__name__
|
|
230
|
+
column_name = task.primitive.column_name()
|
|
231
|
+
primitive_column = dataset.samples[column_name]
|
|
232
|
+
msg_something_wrong = (
|
|
233
|
+
f"Something is wrong with the defined tasks ('info.tasks') in dataset '{dataset.info.dataset_name}'. \n"
|
|
234
|
+
f"For '{primitive=}' and '{task.name=}' "
|
|
235
|
+
)
|
|
236
|
+
if primitive_column.dtype == pl.Null:
|
|
237
|
+
raise ValueError(msg_something_wrong + "the column is 'Null'. Please check the dataset.")
|
|
238
|
+
|
|
239
|
+
if len(dataset) > 0: # Check only performed for non-empty datasets
|
|
240
|
+
primitive_table = (
|
|
241
|
+
primitive_column.explode().struct.unnest().filter(pl.col(FieldName.TASK_NAME) == task.name)
|
|
242
|
+
)
|
|
243
|
+
if primitive_table.is_empty():
|
|
244
|
+
raise ValueError(
|
|
245
|
+
msg_something_wrong
|
|
246
|
+
+ f"the column '{column_name}' has no {task.name=} objects. Please check the dataset."
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
actual_classes = set(primitive_table[FieldName.CLASS_NAME].unique().to_list())
|
|
250
|
+
if task.class_names is None:
|
|
251
|
+
raise ValueError(
|
|
252
|
+
msg_something_wrong
|
|
253
|
+
+ f"the column '{column_name}' with {task.name=} has no defined classes. Please check the dataset."
|
|
254
|
+
)
|
|
255
|
+
defined_classes = set(task.class_names)
|
|
256
|
+
|
|
257
|
+
if not actual_classes.issubset(defined_classes):
|
|
258
|
+
raise ValueError(
|
|
259
|
+
msg_something_wrong
|
|
260
|
+
+ f"the column '{column_name}' with {task.name=} we expected the actual classes in the dataset to \n"
|
|
261
|
+
f"to be a subset of the defined classes\n\t{actual_classes=} \n\t{defined_classes=}."
|
|
262
|
+
)
|
|
263
|
+
# Check class_indices
|
|
264
|
+
mapped_indices = primitive_table[FieldName.CLASS_NAME].map_elements(
|
|
265
|
+
lambda x: task.class_names.index(x), return_dtype=pl.Int64
|
|
266
|
+
)
|
|
267
|
+
table_indices = primitive_table[FieldName.CLASS_IDX]
|
|
268
|
+
|
|
269
|
+
error_msg = msg_something_wrong + (
|
|
270
|
+
f"class indices in '{FieldName.CLASS_IDX}' column does not match classes ordering in 'task.class_names'"
|
|
271
|
+
)
|
|
272
|
+
assert mapped_indices.equals(table_indices), error_msg
|
|
@@ -29,19 +29,27 @@ HafniaDataset class and a RecipeTransform class in the `data_recipe/recipe_trans
|
|
|
29
29
|
that the signatures match.
|
|
30
30
|
"""
|
|
31
31
|
|
|
32
|
+
import json
|
|
33
|
+
import re
|
|
34
|
+
import textwrap
|
|
32
35
|
from pathlib import Path
|
|
33
|
-
from typing import TYPE_CHECKING, Callable
|
|
36
|
+
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple, Type, Union
|
|
34
37
|
|
|
35
38
|
import cv2
|
|
39
|
+
import more_itertools
|
|
36
40
|
import numpy as np
|
|
37
41
|
import polars as pl
|
|
38
42
|
from PIL import Image
|
|
39
43
|
from tqdm import tqdm
|
|
40
44
|
|
|
41
45
|
from hafnia.dataset import dataset_helpers
|
|
46
|
+
from hafnia.dataset.dataset_names import OPS_REMOVE_CLASS, FieldName
|
|
47
|
+
from hafnia.dataset.primitives import get_primitive_type_from_string
|
|
48
|
+
from hafnia.dataset.primitives.primitive import Primitive
|
|
49
|
+
from hafnia.utils import remove_duplicates_preserve_order
|
|
42
50
|
|
|
43
|
-
if TYPE_CHECKING:
|
|
44
|
-
from hafnia.dataset.hafnia_dataset import HafniaDataset
|
|
51
|
+
if TYPE_CHECKING: # Using 'TYPE_CHECKING' to avoid circular imports during type checking
|
|
52
|
+
from hafnia.dataset.hafnia_dataset import HafniaDataset, TaskInfo
|
|
45
53
|
|
|
46
54
|
|
|
47
55
|
### Image transformations ###
|
|
@@ -79,4 +87,317 @@ def transform_images(
|
|
|
79
87
|
new_paths.append(str(new_path))
|
|
80
88
|
|
|
81
89
|
table = dataset.samples.with_columns(pl.Series(new_paths).alias("file_name"))
|
|
82
|
-
return dataset.
|
|
90
|
+
return dataset.update_samples(table)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def get_task_info_from_task_name_and_primitive(
|
|
94
|
+
tasks: List["TaskInfo"],
|
|
95
|
+
task_name: Optional[str] = None,
|
|
96
|
+
primitive: Union[None, str, Type[Primitive]] = None,
|
|
97
|
+
) -> "TaskInfo":
|
|
98
|
+
if len(tasks) == 0:
|
|
99
|
+
raise ValueError("Dataset has no tasks defined.")
|
|
100
|
+
|
|
101
|
+
tasks_str = "\n".join([f"\t{task.__repr__()}" for task in tasks])
|
|
102
|
+
if task_name is None and primitive is None:
|
|
103
|
+
if len(tasks) == 1:
|
|
104
|
+
return tasks[0]
|
|
105
|
+
else:
|
|
106
|
+
raise ValueError(
|
|
107
|
+
"For multiple tasks, you will need to specify 'task_name' or 'type_primitive' "
|
|
108
|
+
"to return a unique task. The dataset contains the following tasks: \n" + tasks_str
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
if isinstance(primitive, str):
|
|
112
|
+
primitive = get_primitive_type_from_string(primitive)
|
|
113
|
+
|
|
114
|
+
tasks_filtered = tasks
|
|
115
|
+
if primitive is None:
|
|
116
|
+
tasks_filtered = [task for task in tasks if task.name == task_name]
|
|
117
|
+
|
|
118
|
+
if len(tasks_filtered) == 0:
|
|
119
|
+
raise ValueError(f"No task found with {task_name=}. Available tasks: \n {tasks_str}")
|
|
120
|
+
|
|
121
|
+
unique_primitives = set(task.primitive for task in tasks_filtered)
|
|
122
|
+
if len(unique_primitives) > 1:
|
|
123
|
+
raise ValueError(
|
|
124
|
+
f"Found multiple tasks with {task_name=} using different primitives {unique_primitives=}. "
|
|
125
|
+
"Please specify the primitive type to make it unique. "
|
|
126
|
+
f"The dataset contains the following tasks: \n {tasks_str}"
|
|
127
|
+
)
|
|
128
|
+
primitive = list(unique_primitives)[0]
|
|
129
|
+
|
|
130
|
+
if task_name is None:
|
|
131
|
+
tasks_filtered = [task for task in tasks if task.primitive == primitive]
|
|
132
|
+
if len(tasks_filtered) == 0:
|
|
133
|
+
raise ValueError(f"No task found with {primitive=}. Available tasks: \n {tasks_str}")
|
|
134
|
+
|
|
135
|
+
unique_task_names = set(task.name for task in tasks_filtered)
|
|
136
|
+
if len(unique_task_names) > 1:
|
|
137
|
+
raise ValueError(
|
|
138
|
+
f"Found multiple tasks with {primitive=} using different task names {unique_task_names=}. "
|
|
139
|
+
"Please specify the 'task_name' to make it unique."
|
|
140
|
+
f"The dataset contains the following tasks: \n {tasks_str}"
|
|
141
|
+
)
|
|
142
|
+
task_name = list(unique_task_names)[0]
|
|
143
|
+
|
|
144
|
+
tasks_filtered = [task for task in tasks_filtered if task.primitive == primitive and task.name == task_name]
|
|
145
|
+
if len(tasks_filtered) == 0:
|
|
146
|
+
raise ValueError(f"No task found with {task_name=} and {primitive=}. Available tasks: \n {tasks_str}")
|
|
147
|
+
|
|
148
|
+
if len(tasks_filtered) > 1:
|
|
149
|
+
raise ValueError(
|
|
150
|
+
f"Multiple tasks found with {task_name=} and {primitive=}. "
|
|
151
|
+
f"This should never happen. The dataset contains the following tasks: \n {tasks_str}"
|
|
152
|
+
)
|
|
153
|
+
task = tasks_filtered[0]
|
|
154
|
+
return task
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def class_mapper(
|
|
158
|
+
dataset: "HafniaDataset",
|
|
159
|
+
class_mapping: Dict[str, str],
|
|
160
|
+
method: str = "strict",
|
|
161
|
+
primitive: Optional[Type[Primitive]] = None,
|
|
162
|
+
task_name: Optional[str] = None,
|
|
163
|
+
) -> "HafniaDataset":
|
|
164
|
+
from hafnia.dataset.hafnia_dataset import HafniaDataset
|
|
165
|
+
|
|
166
|
+
allowed_methods = ("strict", "remove_undefined", "keep_undefined")
|
|
167
|
+
if method not in allowed_methods:
|
|
168
|
+
raise ValueError(f"Method '{method}' is not recognized. Allowed methods are: {allowed_methods}")
|
|
169
|
+
|
|
170
|
+
task = dataset.info.get_task_by_task_name_and_primitive(task_name=task_name, primitive=primitive)
|
|
171
|
+
current_names = task.class_names or []
|
|
172
|
+
|
|
173
|
+
# Expand wildcard mappings
|
|
174
|
+
class_mapping = expand_class_mapping(class_mapping, current_names)
|
|
175
|
+
|
|
176
|
+
non_existing_mapping_names = set(class_mapping) - set(current_names)
|
|
177
|
+
if len(non_existing_mapping_names) > 0:
|
|
178
|
+
raise ValueError(
|
|
179
|
+
f"The specified class mapping contains class names {list(non_existing_mapping_names)} "
|
|
180
|
+
f"that do not exist in the dataset task '{task.name}'. "
|
|
181
|
+
f"Available class names: {current_names}"
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
missing_class_names = [c for c in current_names if c not in class_mapping] # List-comprehension to preserve order
|
|
185
|
+
class_mapping = class_mapping.copy()
|
|
186
|
+
if method == "strict":
|
|
187
|
+
pass # Continue to strict mapping below
|
|
188
|
+
elif method == "remove_undefined":
|
|
189
|
+
for missing_class_name in missing_class_names:
|
|
190
|
+
class_mapping[missing_class_name] = OPS_REMOVE_CLASS
|
|
191
|
+
elif method == "keep_undefined":
|
|
192
|
+
for missing_class_name in missing_class_names:
|
|
193
|
+
class_mapping[missing_class_name] = missing_class_name
|
|
194
|
+
else:
|
|
195
|
+
raise ValueError(f"Method '{method}' is not recognized. Allowed methods are: {allowed_methods}")
|
|
196
|
+
|
|
197
|
+
missing_class_names = [c for c in current_names if c not in class_mapping]
|
|
198
|
+
if len(missing_class_names) > 0:
|
|
199
|
+
error_msg = f"""\
|
|
200
|
+
The specified class mapping is not a strict mapping - meaning that all class names have not
|
|
201
|
+
been mapped to a new class name.
|
|
202
|
+
In the current mapping, the following classes {list(missing_class_names)} have not been mapped.
|
|
203
|
+
The currently specified mapping is:
|
|
204
|
+
{json.dumps(class_mapping, indent=2)}
|
|
205
|
+
A strict mapping will replace all old class names (dictionary keys) to new class names (dictionary values).
|
|
206
|
+
Please update the mapping to include all class names from the dataset task '{task.name}'.
|
|
207
|
+
To keep class map to the same name e.g. 'person' = 'person'
|
|
208
|
+
or remove class by using the '__REMOVE__' key, e.g. 'person': '__REMOVE__'."""
|
|
209
|
+
raise ValueError(textwrap.dedent(error_msg))
|
|
210
|
+
|
|
211
|
+
new_class_names = remove_duplicates_preserve_order(class_mapping.values())
|
|
212
|
+
|
|
213
|
+
if OPS_REMOVE_CLASS in new_class_names:
|
|
214
|
+
# Move __REMOVE__ to the end of the list if it exists
|
|
215
|
+
new_class_names.append(new_class_names.pop(new_class_names.index(OPS_REMOVE_CLASS)))
|
|
216
|
+
name_2_idx_mapping: Dict[str, int] = {name: idx for idx, name in enumerate(new_class_names)}
|
|
217
|
+
|
|
218
|
+
samples = dataset.samples
|
|
219
|
+
samples_updated = samples.with_columns(
|
|
220
|
+
pl.col(task.primitive.column_name())
|
|
221
|
+
.list.eval(
|
|
222
|
+
pl.element().struct.with_fields(
|
|
223
|
+
pl.when(pl.field(FieldName.TASK_NAME) == task.name)
|
|
224
|
+
.then(pl.field(FieldName.CLASS_NAME).replace_strict(class_mapping))
|
|
225
|
+
.otherwise(pl.field(FieldName.CLASS_NAME))
|
|
226
|
+
.alias(FieldName.CLASS_NAME)
|
|
227
|
+
)
|
|
228
|
+
)
|
|
229
|
+
.alias(task.primitive.column_name())
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
# Update class indices too
|
|
233
|
+
samples_updated = samples_updated.with_columns(
|
|
234
|
+
pl.col(task.primitive.column_name())
|
|
235
|
+
.list.eval(
|
|
236
|
+
pl.element().struct.with_fields(
|
|
237
|
+
pl.when(pl.field(FieldName.TASK_NAME) == task.name)
|
|
238
|
+
.then(pl.field(FieldName.CLASS_NAME).replace_strict(name_2_idx_mapping))
|
|
239
|
+
.otherwise(pl.field(FieldName.CLASS_IDX))
|
|
240
|
+
.alias(FieldName.CLASS_IDX)
|
|
241
|
+
)
|
|
242
|
+
)
|
|
243
|
+
.alias(task.primitive.column_name())
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
if OPS_REMOVE_CLASS in new_class_names: # Remove class_names that are mapped to REMOVE_CLASS
|
|
247
|
+
samples_updated = samples_updated.with_columns(
|
|
248
|
+
pl.col(task.primitive.column_name())
|
|
249
|
+
.list.filter(pl.element().struct.field(FieldName.CLASS_NAME) != OPS_REMOVE_CLASS)
|
|
250
|
+
.alias(task.primitive.column_name())
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
new_class_names = [c for c in new_class_names if c != OPS_REMOVE_CLASS]
|
|
254
|
+
|
|
255
|
+
new_task = task.model_copy(deep=True)
|
|
256
|
+
new_task.class_names = new_class_names
|
|
257
|
+
dataset_info = dataset.info.replace_task(old_task=task, new_task=new_task)
|
|
258
|
+
return HafniaDataset(info=dataset_info, samples=samples_updated)
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def expand_class_mapping(wildcard_mapping: Dict[str, str], class_names: List[str]) -> Dict[str, str]:
|
|
262
|
+
"""
|
|
263
|
+
Expand a wildcard class mapping to a full explicit mapping.
|
|
264
|
+
|
|
265
|
+
This function takes a mapping that may contain wildcard patterns (using '*')
|
|
266
|
+
and expands them to match actual class names from a dataset. Exact matches
|
|
267
|
+
take precedence over wildcard patterns.
|
|
268
|
+
|
|
269
|
+
Examples:
|
|
270
|
+
>>> from hafnia.dataset.dataset_names import OPS_REMOVE_CLASS
|
|
271
|
+
>>> wildcard_mapping = {
|
|
272
|
+
... "Person": "Person",
|
|
273
|
+
... "Vehicle.*": "Vehicle",
|
|
274
|
+
... "Vehicle.Trailer": OPS_REMOVE_CLASS
|
|
275
|
+
... }
|
|
276
|
+
>>> class_names = [
|
|
277
|
+
... "Person", "Vehicle.Car", "Vehicle.Trailer", "Vehicle.Bus", "Animal.Dog"
|
|
278
|
+
... ]
|
|
279
|
+
>>> result = expand_wildcard_mapping(wildcard_mapping, class_names)
|
|
280
|
+
>>> print(result)
|
|
281
|
+
{
|
|
282
|
+
"Person": "Person",
|
|
283
|
+
"Vehicle.Car": "Vehicle",
|
|
284
|
+
"Vehicle.Trailer": OPS_REMOVE_CLASS, # Exact match overrides wildcard
|
|
285
|
+
"Vehicle.Bus": "Vehicle",
|
|
286
|
+
# Note: "Animal.Dog" is not included as it doesn't match any pattern
|
|
287
|
+
}
|
|
288
|
+
"""
|
|
289
|
+
expanded_mapping = {}
|
|
290
|
+
for match_pattern, mapping_value in wildcard_mapping.items():
|
|
291
|
+
if "*" in match_pattern:
|
|
292
|
+
# Convert wildcard pattern to regex: Escape special regex characters except *, then replace * with .*
|
|
293
|
+
regex_pattern = re.escape(match_pattern).replace("\\*", ".*")
|
|
294
|
+
class_names_matched = [cn for cn in class_names if re.fullmatch(regex_pattern, cn)]
|
|
295
|
+
expanded_mapping.update({cn: mapping_value for cn in class_names_matched})
|
|
296
|
+
else:
|
|
297
|
+
expanded_mapping.pop(match_pattern, None)
|
|
298
|
+
expanded_mapping[match_pattern] = mapping_value
|
|
299
|
+
return expanded_mapping
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def rename_task(
|
|
303
|
+
dataset: "HafniaDataset",
|
|
304
|
+
old_task_name: str,
|
|
305
|
+
new_task_name: str,
|
|
306
|
+
) -> "HafniaDataset":
|
|
307
|
+
from hafnia.dataset.hafnia_dataset import HafniaDataset
|
|
308
|
+
|
|
309
|
+
old_task = dataset.info.get_task_by_name(task_name=old_task_name)
|
|
310
|
+
new_task = old_task.model_copy(deep=True)
|
|
311
|
+
new_task.name = new_task_name
|
|
312
|
+
samples = dataset.samples.with_columns(
|
|
313
|
+
pl.col(old_task.primitive.column_name())
|
|
314
|
+
.list.eval(
|
|
315
|
+
pl.element().struct.with_fields(
|
|
316
|
+
pl.field(FieldName.TASK_NAME).replace(old_task.name, new_task.name).alias(FieldName.TASK_NAME)
|
|
317
|
+
)
|
|
318
|
+
)
|
|
319
|
+
.alias(new_task.primitive.column_name())
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
dataset_info = dataset.info.replace_task(old_task=old_task, new_task=new_task)
|
|
323
|
+
return HafniaDataset(info=dataset_info, samples=samples)
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def select_samples_by_class_name(
|
|
327
|
+
dataset: "HafniaDataset",
|
|
328
|
+
name: Union[List[str], str],
|
|
329
|
+
task_name: Optional[str] = None,
|
|
330
|
+
primitive: Optional[Type[Primitive]] = None,
|
|
331
|
+
) -> "HafniaDataset":
|
|
332
|
+
task, class_names = _validate_inputs_select_samples_by_class_name(
|
|
333
|
+
dataset=dataset,
|
|
334
|
+
name=name,
|
|
335
|
+
task_name=task_name,
|
|
336
|
+
primitive=primitive,
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
samples = dataset.samples.filter(
|
|
340
|
+
pl.col(task.primitive.column_name())
|
|
341
|
+
.list.eval(
|
|
342
|
+
pl.element().struct.field(FieldName.CLASS_NAME).is_in(class_names)
|
|
343
|
+
& (pl.element().struct.field(FieldName.TASK_NAME) == task.name)
|
|
344
|
+
)
|
|
345
|
+
.list.any()
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
dataset_updated = dataset.update_samples(samples)
|
|
349
|
+
return dataset_updated
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def _validate_inputs_select_samples_by_class_name(
|
|
353
|
+
dataset: "HafniaDataset",
|
|
354
|
+
name: Union[List[str], str],
|
|
355
|
+
task_name: Optional[str] = None,
|
|
356
|
+
primitive: Optional[Type[Primitive]] = None,
|
|
357
|
+
) -> Tuple["TaskInfo", Set[str]]:
|
|
358
|
+
if isinstance(name, str):
|
|
359
|
+
name = [name]
|
|
360
|
+
names = set(name)
|
|
361
|
+
|
|
362
|
+
# Check that specified names are available in at least one of the tasks
|
|
363
|
+
available_names_across_tasks = set(more_itertools.flatten([t.class_names for t in dataset.info.tasks]))
|
|
364
|
+
missing_class_names_across_tasks = names - available_names_across_tasks
|
|
365
|
+
if len(missing_class_names_across_tasks) > 0:
|
|
366
|
+
raise ValueError(
|
|
367
|
+
f"The specified names {list(names)} have not been found in any of the tasks. "
|
|
368
|
+
f"Available class names: {available_names_across_tasks}"
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
# Auto infer task if task_name and primitive are not provided
|
|
372
|
+
if task_name is None and primitive is None:
|
|
373
|
+
tasks_with_names = [t for t in dataset.info.tasks if names.issubset(t.class_names or [])]
|
|
374
|
+
if len(tasks_with_names) == 0:
|
|
375
|
+
raise ValueError(
|
|
376
|
+
f"The specified names {list(names)} have not been found in any of the tasks. "
|
|
377
|
+
f"Available class names: {available_names_across_tasks}"
|
|
378
|
+
)
|
|
379
|
+
if len(tasks_with_names) > 1:
|
|
380
|
+
raise ValueError(
|
|
381
|
+
f"Found multiple tasks containing the specified names {list(names)}. "
|
|
382
|
+
f"Specify either 'task_name' or 'primitive' to only select from one task. "
|
|
383
|
+
f"Tasks containing all provided names: {[t.name for t in tasks_with_names]}"
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
task = tasks_with_names[0]
|
|
387
|
+
|
|
388
|
+
else:
|
|
389
|
+
task = get_task_info_from_task_name_and_primitive(
|
|
390
|
+
tasks=dataset.info.tasks,
|
|
391
|
+
task_name=task_name,
|
|
392
|
+
primitive=primitive,
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
task_class_names = set(task.class_names or [])
|
|
396
|
+
missing_class_names = names - task_class_names
|
|
397
|
+
if len(missing_class_names) > 0:
|
|
398
|
+
raise ValueError(
|
|
399
|
+
f"The specified names {list(missing_class_names)} have not been found for the '{task.name}' task. "
|
|
400
|
+
f"Available class names: {task_class_names}"
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
return task, names
|