hafnia 0.3.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/__main__.py +3 -1
- cli/config.py +43 -3
- cli/keychain.py +88 -0
- cli/profile_cmds.py +5 -2
- hafnia/__init__.py +1 -1
- hafnia/dataset/dataset_helpers.py +9 -2
- hafnia/dataset/dataset_names.py +130 -16
- hafnia/dataset/dataset_recipe/dataset_recipe.py +49 -37
- hafnia/dataset/dataset_recipe/recipe_transforms.py +18 -2
- hafnia/dataset/dataset_upload_helper.py +83 -22
- hafnia/dataset/format_conversions/format_image_classification_folder.py +110 -0
- hafnia/dataset/format_conversions/format_yolo.py +164 -0
- hafnia/dataset/format_conversions/torchvision_datasets.py +287 -0
- hafnia/dataset/hafnia_dataset.py +396 -96
- hafnia/dataset/operations/dataset_stats.py +84 -73
- hafnia/dataset/operations/dataset_transformations.py +116 -47
- hafnia/dataset/operations/table_transformations.py +135 -17
- hafnia/dataset/primitives/bbox.py +25 -14
- hafnia/dataset/primitives/bitmask.py +22 -15
- hafnia/dataset/primitives/classification.py +16 -8
- hafnia/dataset/primitives/point.py +7 -3
- hafnia/dataset/primitives/polygon.py +15 -10
- hafnia/dataset/primitives/primitive.py +1 -1
- hafnia/dataset/primitives/segmentation.py +12 -9
- hafnia/experiment/hafnia_logger.py +0 -9
- hafnia/platform/dataset_recipe.py +7 -2
- hafnia/platform/datasets.py +5 -9
- hafnia/platform/download.py +24 -90
- hafnia/torch_helpers.py +12 -12
- hafnia/utils.py +17 -0
- hafnia/visualizations/image_visualizations.py +3 -1
- {hafnia-0.3.0.dist-info → hafnia-0.4.1.dist-info}/METADATA +11 -9
- hafnia-0.4.1.dist-info/RECORD +57 -0
- hafnia-0.3.0.dist-info/RECORD +0 -53
- {hafnia-0.3.0.dist-info → hafnia-0.4.1.dist-info}/WHEEL +0 -0
- {hafnia-0.3.0.dist-info → hafnia-0.4.1.dist-info}/entry_points.txt +0 -0
- {hafnia-0.3.0.dist-info → hafnia-0.4.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from typing import TYPE_CHECKING, Dict, Optional, Type
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
|
|
4
4
|
|
|
5
5
|
import polars as pl
|
|
6
6
|
import rich
|
|
7
7
|
from rich import print as rprint
|
|
8
|
+
from rich.progress import track
|
|
8
9
|
from rich.table import Table
|
|
9
|
-
from tqdm import tqdm
|
|
10
10
|
|
|
11
|
-
from hafnia.dataset.dataset_names import
|
|
11
|
+
from hafnia.dataset.dataset_names import PrimitiveField, SampleField, SplitName
|
|
12
12
|
from hafnia.dataset.operations.table_transformations import create_primitive_table
|
|
13
13
|
from hafnia.dataset.primitives import PRIMITIVE_TYPES
|
|
14
14
|
from hafnia.log import user_logger
|
|
@@ -18,14 +18,14 @@ if TYPE_CHECKING: # Using 'TYPE_CHECKING' to avoid circular imports during type
|
|
|
18
18
|
from hafnia.dataset.primitives.primitive import Primitive
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
def
|
|
21
|
+
def calculate_split_counts(dataset: HafniaDataset) -> Dict[str, int]:
|
|
22
22
|
"""
|
|
23
23
|
Returns a dictionary with the counts of samples in each split of the dataset.
|
|
24
24
|
"""
|
|
25
|
-
return dict(dataset.samples[
|
|
25
|
+
return dict(dataset.samples[SampleField.SPLIT].value_counts().iter_rows())
|
|
26
26
|
|
|
27
27
|
|
|
28
|
-
def
|
|
28
|
+
def calculate_task_class_counts(
|
|
29
29
|
dataset: HafniaDataset,
|
|
30
30
|
primitive: Optional[Type[Primitive]] = None,
|
|
31
31
|
task_name: Optional[str] = None,
|
|
@@ -53,7 +53,7 @@ def class_counts_for_task(
|
|
|
53
53
|
dataset.samples[task.primitive.column_name()]
|
|
54
54
|
.explode()
|
|
55
55
|
.struct.unnest()
|
|
56
|
-
.filter(pl.col(
|
|
56
|
+
.filter(pl.col(PrimitiveField.TASK_NAME) == task.name)[PrimitiveField.CLASS_NAME]
|
|
57
57
|
.value_counts()
|
|
58
58
|
)
|
|
59
59
|
|
|
@@ -65,7 +65,7 @@ def class_counts_for_task(
|
|
|
65
65
|
return class_counts
|
|
66
66
|
|
|
67
67
|
|
|
68
|
-
def
|
|
68
|
+
def calculate_class_counts(dataset: HafniaDataset) -> List[Dict[str, Any]]:
|
|
69
69
|
"""
|
|
70
70
|
Get class counts for all tasks in the dataset.
|
|
71
71
|
The counts are returned as a dictionary where keys are in the format
|
|
@@ -74,25 +74,59 @@ def class_counts_all(dataset: HafniaDataset) -> Dict[str, int]:
|
|
|
74
74
|
Example:
|
|
75
75
|
>>> counts = dataset.class_counts_all()
|
|
76
76
|
>>> print(counts)
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
}
|
|
77
|
+
[
|
|
78
|
+
{'Primitive': 'Bbox', 'Task Name': 'detection', 'Class Name': 'car', 'Count': 500},
|
|
79
|
+
{'Primitive': 'Bbox', 'Task Name': 'detection', 'Class Name': 'bus', 'Count': 100},
|
|
80
|
+
{'Primitive': 'Classification', 'Task Name': 'scene', 'Class Name': 'indoor', 'Count': 300},
|
|
81
|
+
{'Primitive': 'Classification', 'Task Name': 'scene', 'Class Name': 'outdoor', 'Count': 700},
|
|
82
|
+
]
|
|
84
83
|
"""
|
|
85
|
-
|
|
84
|
+
count_info = []
|
|
86
85
|
for task in dataset.info.tasks:
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
86
|
+
class_name_counts = dataset.calculate_task_class_counts(task_name=task.name)
|
|
87
|
+
for name, counts in class_name_counts.items():
|
|
88
|
+
count_info.append(
|
|
89
|
+
{
|
|
90
|
+
"Primitive": task.primitive.__name__,
|
|
91
|
+
"Task Name": task.name,
|
|
92
|
+
"Class Name": name,
|
|
93
|
+
"Count": counts,
|
|
94
|
+
}
|
|
95
|
+
)
|
|
96
|
+
return count_info
|
|
90
97
|
|
|
91
|
-
for class_idx, (class_name, count) in enumerate(class_counts_task.items()):
|
|
92
|
-
count_name = f"{task.primitive.__name__}/{task.name}/{class_name}"
|
|
93
|
-
class_counts[count_name] = count
|
|
94
98
|
|
|
95
|
-
|
|
99
|
+
def calculate_primitive_counts(dataset: HafniaDataset) -> Dict[str, int]:
|
|
100
|
+
annotation_counts = {}
|
|
101
|
+
for task in dataset.info.tasks:
|
|
102
|
+
objects = dataset.create_primitive_table(task.primitive, task_name=task.name)
|
|
103
|
+
name = task.primitive.__name__
|
|
104
|
+
if task.name != task.primitive.default_task_name():
|
|
105
|
+
name = f"{name}.{task.name}"
|
|
106
|
+
annotation_counts[name] = len(objects)
|
|
107
|
+
return annotation_counts
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def calculate_split_counts_extended(dataset: HafniaDataset) -> List[Dict[str, Any]]:
|
|
111
|
+
splits_sets = {
|
|
112
|
+
"All": SplitName.valid_splits(),
|
|
113
|
+
"Train": [SplitName.TRAIN],
|
|
114
|
+
"Validation": [SplitName.VAL],
|
|
115
|
+
"Test": [SplitName.TEST],
|
|
116
|
+
}
|
|
117
|
+
rows = []
|
|
118
|
+
for split_name, splits in splits_sets.items():
|
|
119
|
+
dataset_split = dataset.create_split_dataset(splits)
|
|
120
|
+
table = dataset_split.samples
|
|
121
|
+
row: Dict[str, Any] = {}
|
|
122
|
+
row["Split"] = split_name
|
|
123
|
+
row["Samples "] = str(len(table))
|
|
124
|
+
|
|
125
|
+
primitive_counts = calculate_primitive_counts(dataset_split)
|
|
126
|
+
row.update(primitive_counts)
|
|
127
|
+
rows.append(row)
|
|
128
|
+
|
|
129
|
+
return rows
|
|
96
130
|
|
|
97
131
|
|
|
98
132
|
def print_stats(dataset: HafniaDataset) -> None:
|
|
@@ -118,10 +152,13 @@ def print_class_distribution(dataset: HafniaDataset) -> None:
|
|
|
118
152
|
for task in dataset.info.tasks:
|
|
119
153
|
if task.class_names is None:
|
|
120
154
|
raise ValueError(f"Task '{task.name}' does not have class names defined.")
|
|
121
|
-
class_counts = dataset.
|
|
155
|
+
class_counts = dataset.calculate_task_class_counts(primitive=task.primitive, task_name=task.name)
|
|
122
156
|
|
|
123
157
|
# Print class distribution
|
|
124
|
-
rich_table = Table(
|
|
158
|
+
rich_table = Table(
|
|
159
|
+
title=f"Class Count for '{task.primitive.__name__}/{task.name}'",
|
|
160
|
+
show_lines=False,
|
|
161
|
+
)
|
|
125
162
|
rich_table.add_column("Class Name", style="cyan")
|
|
126
163
|
rich_table.add_column("Class Idx", style="cyan")
|
|
127
164
|
rich_table.add_column("Count", justify="right")
|
|
@@ -136,32 +173,7 @@ def print_sample_and_task_counts(dataset: HafniaDataset) -> None:
|
|
|
136
173
|
Prints a table with sample counts and task counts for each primitive type
|
|
137
174
|
in total and for each split (train, val, test).
|
|
138
175
|
"""
|
|
139
|
-
|
|
140
|
-
from hafnia.dataset.primitives import PRIMITIVE_TYPES
|
|
141
|
-
|
|
142
|
-
splits_sets = {
|
|
143
|
-
"All": SplitName.valid_splits(),
|
|
144
|
-
"Train": [SplitName.TRAIN],
|
|
145
|
-
"Validation": [SplitName.VAL],
|
|
146
|
-
"Test": [SplitName.TEST],
|
|
147
|
-
}
|
|
148
|
-
rows = []
|
|
149
|
-
for split_name, splits in splits_sets.items():
|
|
150
|
-
dataset_split = dataset.create_split_dataset(splits)
|
|
151
|
-
table = dataset_split.samples
|
|
152
|
-
row = {}
|
|
153
|
-
row["Split"] = split_name
|
|
154
|
-
row["Sample "] = str(len(table))
|
|
155
|
-
for PrimitiveType in PRIMITIVE_TYPES:
|
|
156
|
-
column_name = PrimitiveType.column_name()
|
|
157
|
-
objects_df = create_primitive_table(table, PrimitiveType=PrimitiveType, keep_sample_data=False)
|
|
158
|
-
if objects_df is None:
|
|
159
|
-
continue
|
|
160
|
-
for (task_name,), object_group in objects_df.group_by(FieldName.TASK_NAME):
|
|
161
|
-
count = len(object_group[FieldName.CLASS_NAME])
|
|
162
|
-
row[f"{PrimitiveType.__name__}\n{task_name}"] = str(count)
|
|
163
|
-
rows.append(row)
|
|
164
|
-
|
|
176
|
+
rows = calculate_split_counts_extended(dataset)
|
|
165
177
|
rich_table = Table(title="Dataset Statistics", show_lines=True, box=rich.box.SIMPLE)
|
|
166
178
|
for i_row, row in enumerate(rows):
|
|
167
179
|
if i_row == 0:
|
|
@@ -171,7 +183,7 @@ def print_sample_and_task_counts(dataset: HafniaDataset) -> None:
|
|
|
171
183
|
rprint(rich_table)
|
|
172
184
|
|
|
173
185
|
|
|
174
|
-
def check_dataset(dataset: HafniaDataset):
|
|
186
|
+
def check_dataset(dataset: HafniaDataset, check_splits: bool = True):
|
|
175
187
|
"""
|
|
176
188
|
Performs various checks on the dataset to ensure its integrity and consistency.
|
|
177
189
|
Raises errors if any issues are found.
|
|
@@ -179,24 +191,23 @@ def check_dataset(dataset: HafniaDataset):
|
|
|
179
191
|
from hafnia.dataset.hafnia_dataset import Sample
|
|
180
192
|
|
|
181
193
|
user_logger.info("Checking Hafnia dataset...")
|
|
182
|
-
assert isinstance(dataset.info.version, str) and len(dataset.info.version) > 0
|
|
183
194
|
assert isinstance(dataset.info.dataset_name, str) and len(dataset.info.dataset_name) > 0
|
|
184
195
|
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
196
|
+
if check_splits:
|
|
197
|
+
sample_dataset = dataset.create_sample_dataset()
|
|
198
|
+
if len(sample_dataset) == 0:
|
|
199
|
+
raise ValueError("The dataset does not include a sample dataset")
|
|
200
|
+
|
|
201
|
+
actual_splits = dataset.samples.select(pl.col(SampleField.SPLIT)).unique().to_series().to_list()
|
|
202
|
+
required_splits = SplitName.valid_splits()
|
|
188
203
|
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
if set(actual_splits) != set(expected_splits):
|
|
192
|
-
raise ValueError(f"Expected all splits '{expected_splits}' in dataset, but got '{actual_splits}'. ")
|
|
204
|
+
if not set(required_splits).issubset(set(actual_splits)):
|
|
205
|
+
raise ValueError(f"Expected all splits '{required_splits}' in dataset, but got '{actual_splits}'. ")
|
|
193
206
|
|
|
194
207
|
dataset.check_dataset_tasks()
|
|
195
208
|
|
|
196
209
|
expected_tasks = dataset.info.tasks
|
|
197
|
-
|
|
198
|
-
distribution_names = [task.name for task in distribution]
|
|
199
|
-
# Check that tasks found in the 'dataset.table' matches the tasks defined in 'dataset.info.tasks'
|
|
210
|
+
# Check that tasks found in the 'dataset.samples' matches the tasks defined in 'dataset.info.tasks'
|
|
200
211
|
for PrimitiveType in PRIMITIVE_TYPES:
|
|
201
212
|
column_name = PrimitiveType.column_name()
|
|
202
213
|
if column_name not in dataset.samples.columns:
|
|
@@ -204,18 +215,18 @@ def check_dataset(dataset: HafniaDataset):
|
|
|
204
215
|
objects_df = create_primitive_table(dataset.samples, PrimitiveType=PrimitiveType, keep_sample_data=False)
|
|
205
216
|
if objects_df is None:
|
|
206
217
|
continue
|
|
207
|
-
for (task_name,), object_group in objects_df.group_by(
|
|
218
|
+
for (task_name,), object_group in objects_df.group_by(PrimitiveField.TASK_NAME):
|
|
208
219
|
has_task = any([t for t in expected_tasks if t.name == task_name and t.primitive == PrimitiveType])
|
|
209
|
-
if has_task
|
|
220
|
+
if has_task:
|
|
210
221
|
continue
|
|
211
|
-
class_names = object_group[
|
|
222
|
+
class_names = object_group[PrimitiveField.CLASS_NAME].unique().to_list()
|
|
212
223
|
raise ValueError(
|
|
213
224
|
f"Task name '{task_name}' for the '{PrimitiveType.__name__}' primitive is missing in "
|
|
214
|
-
f"'dataset.info.tasks' for dataset '{
|
|
225
|
+
f"'dataset.info.tasks' for dataset '{dataset.info.dataset_name}'. Missing task has the following "
|
|
215
226
|
f"classes: {class_names}. "
|
|
216
227
|
)
|
|
217
228
|
|
|
218
|
-
for sample_dict in
|
|
229
|
+
for sample_dict in track(dataset, description="Checking samples in dataset"):
|
|
219
230
|
sample = Sample(**sample_dict) # noqa: F841
|
|
220
231
|
|
|
221
232
|
|
|
@@ -238,7 +249,7 @@ def check_dataset_tasks(dataset: HafniaDataset):
|
|
|
238
249
|
|
|
239
250
|
if len(dataset) > 0: # Check only performed for non-empty datasets
|
|
240
251
|
primitive_table = (
|
|
241
|
-
primitive_column.explode().struct.unnest().filter(pl.col(
|
|
252
|
+
primitive_column.explode().struct.unnest().filter(pl.col(PrimitiveField.TASK_NAME) == task.name)
|
|
242
253
|
)
|
|
243
254
|
if primitive_table.is_empty():
|
|
244
255
|
raise ValueError(
|
|
@@ -246,7 +257,7 @@ def check_dataset_tasks(dataset: HafniaDataset):
|
|
|
246
257
|
+ f"the column '{column_name}' has no {task.name=} objects. Please check the dataset."
|
|
247
258
|
)
|
|
248
259
|
|
|
249
|
-
actual_classes = set(primitive_table[
|
|
260
|
+
actual_classes = set(primitive_table[PrimitiveField.CLASS_NAME].unique().to_list())
|
|
250
261
|
if task.class_names is None:
|
|
251
262
|
raise ValueError(
|
|
252
263
|
msg_something_wrong
|
|
@@ -261,12 +272,12 @@ def check_dataset_tasks(dataset: HafniaDataset):
|
|
|
261
272
|
f"to be a subset of the defined classes\n\t{actual_classes=} \n\t{defined_classes=}."
|
|
262
273
|
)
|
|
263
274
|
# Check class_indices
|
|
264
|
-
mapped_indices = primitive_table[
|
|
275
|
+
mapped_indices = primitive_table[PrimitiveField.CLASS_NAME].map_elements(
|
|
265
276
|
lambda x: task.class_names.index(x), return_dtype=pl.Int64
|
|
266
277
|
)
|
|
267
|
-
table_indices = primitive_table[
|
|
278
|
+
table_indices = primitive_table[PrimitiveField.CLASS_IDX]
|
|
268
279
|
|
|
269
280
|
error_msg = msg_something_wrong + (
|
|
270
|
-
f"class indices in '{
|
|
281
|
+
f"class indices in '{PrimitiveField.CLASS_IDX}' column does not match classes ordering in 'task.class_names'"
|
|
271
282
|
)
|
|
272
283
|
assert mapped_indices.equals(table_indices), error_msg
|
|
@@ -31,25 +31,32 @@ that the signatures match.
|
|
|
31
31
|
|
|
32
32
|
import json
|
|
33
33
|
import re
|
|
34
|
+
import shutil
|
|
34
35
|
import textwrap
|
|
35
36
|
from pathlib import Path
|
|
36
|
-
from typing import TYPE_CHECKING, Callable, Dict, List, Optional,
|
|
37
|
+
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Type, Union
|
|
37
38
|
|
|
38
39
|
import cv2
|
|
39
40
|
import more_itertools
|
|
40
41
|
import numpy as np
|
|
41
42
|
import polars as pl
|
|
42
|
-
from
|
|
43
|
-
from tqdm import tqdm
|
|
43
|
+
from rich.progress import track
|
|
44
44
|
|
|
45
45
|
from hafnia.dataset import dataset_helpers
|
|
46
|
-
from hafnia.dataset.dataset_names import
|
|
46
|
+
from hafnia.dataset.dataset_names import (
|
|
47
|
+
OPS_REMOVE_CLASS,
|
|
48
|
+
PrimitiveField,
|
|
49
|
+
SampleField,
|
|
50
|
+
StorageFormat,
|
|
51
|
+
)
|
|
52
|
+
from hafnia.dataset.operations.table_transformations import update_class_indices
|
|
47
53
|
from hafnia.dataset.primitives import get_primitive_type_from_string
|
|
48
54
|
from hafnia.dataset.primitives.primitive import Primitive
|
|
55
|
+
from hafnia.log import user_logger
|
|
49
56
|
from hafnia.utils import remove_duplicates_preserve_order
|
|
50
57
|
|
|
51
58
|
if TYPE_CHECKING: # Using 'TYPE_CHECKING' to avoid circular imports during type checking
|
|
52
|
-
from hafnia.dataset.hafnia_dataset import HafniaDataset, TaskInfo
|
|
59
|
+
from hafnia.dataset.hafnia_dataset import HafniaDataset, Sample, TaskInfo
|
|
53
60
|
|
|
54
61
|
|
|
55
62
|
### Image transformations ###
|
|
@@ -57,7 +64,7 @@ class AnonymizeByPixelation:
|
|
|
57
64
|
def __init__(self, resize_factor: float = 0.10):
|
|
58
65
|
self.resize_factor = resize_factor
|
|
59
66
|
|
|
60
|
-
def __call__(self, frame: np.ndarray) -> np.ndarray:
|
|
67
|
+
def __call__(self, frame: np.ndarray, sample: "Sample") -> np.ndarray:
|
|
61
68
|
org_size = frame.shape[:2]
|
|
62
69
|
frame = cv2.resize(frame, (0, 0), fx=self.resize_factor, fy=self.resize_factor)
|
|
63
70
|
frame = cv2.resize(frame, org_size[::-1], interpolation=cv2.INTER_NEAREST)
|
|
@@ -66,30 +73,100 @@ class AnonymizeByPixelation:
|
|
|
66
73
|
|
|
67
74
|
def transform_images(
|
|
68
75
|
dataset: "HafniaDataset",
|
|
69
|
-
transform: Callable[[np.ndarray], np.ndarray],
|
|
76
|
+
transform: Callable[[np.ndarray, "Sample"], np.ndarray],
|
|
70
77
|
path_output: Path,
|
|
78
|
+
description: str = "Transform images",
|
|
71
79
|
) -> "HafniaDataset":
|
|
80
|
+
from hafnia.dataset.hafnia_dataset import Sample
|
|
81
|
+
|
|
72
82
|
new_paths = []
|
|
73
83
|
path_image_folder = path_output / "data"
|
|
74
84
|
path_image_folder.mkdir(parents=True, exist_ok=True)
|
|
75
85
|
|
|
76
|
-
for
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
image = np.array(Image.open(org_path))
|
|
82
|
-
image_transformed = transform(image)
|
|
86
|
+
for sample_dict in track(dataset, description=description):
|
|
87
|
+
sample = Sample(**sample_dict)
|
|
88
|
+
image = sample.read_image()
|
|
89
|
+
image_transformed = transform(image, sample)
|
|
83
90
|
new_path = dataset_helpers.save_image_with_hash_name(image_transformed, path_image_folder)
|
|
84
91
|
|
|
85
92
|
if not new_path.exists():
|
|
86
93
|
raise FileNotFoundError(f"Transformed file {new_path} does not exist in the dataset.")
|
|
87
94
|
new_paths.append(str(new_path))
|
|
88
95
|
|
|
89
|
-
table = dataset.samples.with_columns(pl.Series(new_paths).alias(
|
|
96
|
+
table = dataset.samples.with_columns(pl.Series(new_paths).alias(SampleField.FILE_PATH))
|
|
90
97
|
return dataset.update_samples(table)
|
|
91
98
|
|
|
92
99
|
|
|
100
|
+
def convert_to_image_storage_format(
|
|
101
|
+
dataset: "HafniaDataset",
|
|
102
|
+
path_output_folder: Path,
|
|
103
|
+
reextract_frames: bool,
|
|
104
|
+
image_format: str = "png",
|
|
105
|
+
transform: Optional[Callable[[np.ndarray, "Sample"], np.ndarray]] = None,
|
|
106
|
+
) -> "HafniaDataset":
|
|
107
|
+
"""
|
|
108
|
+
Convert a video-based dataset ("storage_format" == "video", FieldName.STORAGE_FORMAT == StorageFormat.VIDEO)
|
|
109
|
+
to an image-based dataset by extracting frames.
|
|
110
|
+
"""
|
|
111
|
+
from hafnia.dataset.hafnia_dataset import HafniaDataset, Sample
|
|
112
|
+
|
|
113
|
+
path_images = path_output_folder / "data"
|
|
114
|
+
path_images.mkdir(parents=True, exist_ok=True)
|
|
115
|
+
|
|
116
|
+
# Only video format dataset samples are processed
|
|
117
|
+
video_based_samples = dataset.samples.filter(pl.col(SampleField.STORAGE_FORMAT) == StorageFormat.VIDEO)
|
|
118
|
+
|
|
119
|
+
if video_based_samples.is_empty():
|
|
120
|
+
user_logger.info("Dataset has no video-based samples. Returning dataset unchanged.")
|
|
121
|
+
return dataset
|
|
122
|
+
|
|
123
|
+
update_list = []
|
|
124
|
+
for (path_video,), video_samples in video_based_samples.group_by(SampleField.FILE_PATH):
|
|
125
|
+
assert Path(path_video).exists(), (
|
|
126
|
+
f"'{path_video}' not found. We expect the video to be downloaded to '{path_output_folder}'"
|
|
127
|
+
)
|
|
128
|
+
video = cv2.VideoCapture(str(path_video))
|
|
129
|
+
|
|
130
|
+
video_samples = video_samples.sort(SampleField.COLLECTION_INDEX)
|
|
131
|
+
for sample_dict in track(
|
|
132
|
+
video_samples.iter_rows(named=True),
|
|
133
|
+
total=video_samples.height,
|
|
134
|
+
description=f"Extracting frames from '{Path(path_video).name}'",
|
|
135
|
+
):
|
|
136
|
+
frame_number = sample_dict[SampleField.COLLECTION_INDEX]
|
|
137
|
+
image_name = f"{Path(path_video).stem}_F{frame_number:06d}.{image_format}"
|
|
138
|
+
path_image = path_images / image_name
|
|
139
|
+
|
|
140
|
+
update_list.append(
|
|
141
|
+
{
|
|
142
|
+
SampleField.SAMPLE_INDEX: sample_dict[SampleField.SAMPLE_INDEX],
|
|
143
|
+
SampleField.COLLECTION_ID: sample_dict[SampleField.COLLECTION_ID],
|
|
144
|
+
SampleField.COLLECTION_INDEX: frame_number,
|
|
145
|
+
SampleField.FILE_PATH: path_image.as_posix(),
|
|
146
|
+
SampleField.STORAGE_FORMAT: StorageFormat.IMAGE,
|
|
147
|
+
}
|
|
148
|
+
)
|
|
149
|
+
if reextract_frames:
|
|
150
|
+
shutil.rmtree(path_image, ignore_errors=True)
|
|
151
|
+
if path_image.exists():
|
|
152
|
+
continue
|
|
153
|
+
|
|
154
|
+
video.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
|
|
155
|
+
ret, frame_org = video.read()
|
|
156
|
+
if not ret:
|
|
157
|
+
raise RuntimeError(f"Could not read frame {frame_number} from video '{path_video}'")
|
|
158
|
+
|
|
159
|
+
if transform is not None:
|
|
160
|
+
frame_org = transform(frame_org, Sample(**sample_dict))
|
|
161
|
+
|
|
162
|
+
cv2.imwrite(str(path_image), frame_org)
|
|
163
|
+
df_updates = pl.DataFrame(update_list)
|
|
164
|
+
samples_as_images = dataset.samples.update(df_updates, on=[SampleField.COLLECTION_ID, SampleField.COLLECTION_INDEX])
|
|
165
|
+
hafnia_dataset = HafniaDataset(samples=samples_as_images, info=dataset.info)
|
|
166
|
+
|
|
167
|
+
return hafnia_dataset
|
|
168
|
+
|
|
169
|
+
|
|
93
170
|
def get_task_info_from_task_name_and_primitive(
|
|
94
171
|
tasks: List["TaskInfo"],
|
|
95
172
|
task_name: Optional[str] = None,
|
|
@@ -156,13 +233,16 @@ def get_task_info_from_task_name_and_primitive(
|
|
|
156
233
|
|
|
157
234
|
def class_mapper(
|
|
158
235
|
dataset: "HafniaDataset",
|
|
159
|
-
class_mapping: Dict[str, str],
|
|
236
|
+
class_mapping: Union[Dict[str, str], List[Tuple[str, str]]],
|
|
160
237
|
method: str = "strict",
|
|
161
238
|
primitive: Optional[Type[Primitive]] = None,
|
|
162
239
|
task_name: Optional[str] = None,
|
|
163
240
|
) -> "HafniaDataset":
|
|
164
241
|
from hafnia.dataset.hafnia_dataset import HafniaDataset
|
|
165
242
|
|
|
243
|
+
if isinstance(class_mapping, list):
|
|
244
|
+
class_mapping = dict(class_mapping)
|
|
245
|
+
|
|
166
246
|
allowed_methods = ("strict", "remove_undefined", "keep_undefined")
|
|
167
247
|
if method not in allowed_methods:
|
|
168
248
|
raise ValueError(f"Method '{method}' is not recognized. Allowed methods are: {allowed_methods}")
|
|
@@ -170,7 +250,7 @@ def class_mapper(
|
|
|
170
250
|
task = dataset.info.get_task_by_task_name_and_primitive(task_name=task_name, primitive=primitive)
|
|
171
251
|
current_names = task.class_names or []
|
|
172
252
|
|
|
173
|
-
# Expand wildcard mappings
|
|
253
|
+
# Expand wildcard mappings e.g. {"Vehicle.*": "Vehicle"} to {"Vehicle.Car": "Vehicle", "Vehicle.Bus": "Vehicle"}
|
|
174
254
|
class_mapping = expand_class_mapping(class_mapping, current_names)
|
|
175
255
|
|
|
176
256
|
non_existing_mapping_names = set(class_mapping) - set(current_names)
|
|
@@ -213,31 +293,16 @@ def class_mapper(
|
|
|
213
293
|
if OPS_REMOVE_CLASS in new_class_names:
|
|
214
294
|
# Move __REMOVE__ to the end of the list if it exists
|
|
215
295
|
new_class_names.append(new_class_names.pop(new_class_names.index(OPS_REMOVE_CLASS)))
|
|
216
|
-
name_2_idx_mapping: Dict[str, int] = {name: idx for idx, name in enumerate(new_class_names)}
|
|
217
296
|
|
|
218
297
|
samples = dataset.samples
|
|
219
298
|
samples_updated = samples.with_columns(
|
|
220
299
|
pl.col(task.primitive.column_name())
|
|
221
300
|
.list.eval(
|
|
222
301
|
pl.element().struct.with_fields(
|
|
223
|
-
pl.when(pl.field(
|
|
224
|
-
.then(pl.field(
|
|
225
|
-
.otherwise(pl.field(
|
|
226
|
-
.alias(
|
|
227
|
-
)
|
|
228
|
-
)
|
|
229
|
-
.alias(task.primitive.column_name())
|
|
230
|
-
)
|
|
231
|
-
|
|
232
|
-
# Update class indices too
|
|
233
|
-
samples_updated = samples_updated.with_columns(
|
|
234
|
-
pl.col(task.primitive.column_name())
|
|
235
|
-
.list.eval(
|
|
236
|
-
pl.element().struct.with_fields(
|
|
237
|
-
pl.when(pl.field(FieldName.TASK_NAME) == task.name)
|
|
238
|
-
.then(pl.field(FieldName.CLASS_NAME).replace_strict(name_2_idx_mapping))
|
|
239
|
-
.otherwise(pl.field(FieldName.CLASS_IDX))
|
|
240
|
-
.alias(FieldName.CLASS_IDX)
|
|
302
|
+
pl.when(pl.field(PrimitiveField.TASK_NAME) == task.name)
|
|
303
|
+
.then(pl.field(PrimitiveField.CLASS_NAME).replace_strict(class_mapping, default="Missing"))
|
|
304
|
+
.otherwise(pl.field(PrimitiveField.CLASS_NAME))
|
|
305
|
+
.alias(PrimitiveField.CLASS_NAME)
|
|
241
306
|
)
|
|
242
307
|
)
|
|
243
308
|
.alias(task.primitive.column_name())
|
|
@@ -246,7 +311,7 @@ def class_mapper(
|
|
|
246
311
|
if OPS_REMOVE_CLASS in new_class_names: # Remove class_names that are mapped to REMOVE_CLASS
|
|
247
312
|
samples_updated = samples_updated.with_columns(
|
|
248
313
|
pl.col(task.primitive.column_name())
|
|
249
|
-
.list.filter(pl.element().struct.field(
|
|
314
|
+
.list.filter(pl.element().struct.field(PrimitiveField.CLASS_NAME) != OPS_REMOVE_CLASS)
|
|
250
315
|
.alias(task.primitive.column_name())
|
|
251
316
|
)
|
|
252
317
|
|
|
@@ -255,6 +320,10 @@ def class_mapper(
|
|
|
255
320
|
new_task = task.model_copy(deep=True)
|
|
256
321
|
new_task.class_names = new_class_names
|
|
257
322
|
dataset_info = dataset.info.replace_task(old_task=task, new_task=new_task)
|
|
323
|
+
|
|
324
|
+
# Update class indices to match new class names
|
|
325
|
+
samples_updated = update_class_indices(samples_updated, new_task)
|
|
326
|
+
|
|
258
327
|
return HafniaDataset(info=dataset_info, samples=samples_updated)
|
|
259
328
|
|
|
260
329
|
|
|
@@ -313,7 +382,7 @@ def rename_task(
|
|
|
313
382
|
pl.col(old_task.primitive.column_name())
|
|
314
383
|
.list.eval(
|
|
315
384
|
pl.element().struct.with_fields(
|
|
316
|
-
pl.field(
|
|
385
|
+
pl.field(PrimitiveField.TASK_NAME).replace(old_task.name, new_task.name).alias(PrimitiveField.TASK_NAME)
|
|
317
386
|
)
|
|
318
387
|
)
|
|
319
388
|
.alias(new_task.primitive.column_name())
|
|
@@ -339,8 +408,8 @@ def select_samples_by_class_name(
|
|
|
339
408
|
samples = dataset.samples.filter(
|
|
340
409
|
pl.col(task.primitive.column_name())
|
|
341
410
|
.list.eval(
|
|
342
|
-
pl.element().struct.field(
|
|
343
|
-
& (pl.element().struct.field(
|
|
411
|
+
pl.element().struct.field(PrimitiveField.CLASS_NAME).is_in(class_names)
|
|
412
|
+
& (pl.element().struct.field(PrimitiveField.TASK_NAME) == task.name)
|
|
344
413
|
)
|
|
345
414
|
.list.any()
|
|
346
415
|
)
|
|
@@ -354,14 +423,14 @@ def _validate_inputs_select_samples_by_class_name(
|
|
|
354
423
|
name: Union[List[str], str],
|
|
355
424
|
task_name: Optional[str] = None,
|
|
356
425
|
primitive: Optional[Type[Primitive]] = None,
|
|
357
|
-
) -> Tuple["TaskInfo",
|
|
426
|
+
) -> Tuple["TaskInfo", List[str]]:
|
|
358
427
|
if isinstance(name, str):
|
|
359
428
|
name = [name]
|
|
360
|
-
names =
|
|
429
|
+
names = list(name)
|
|
361
430
|
|
|
362
431
|
# Check that specified names are available in at least one of the tasks
|
|
363
432
|
available_names_across_tasks = set(more_itertools.flatten([t.class_names for t in dataset.info.tasks]))
|
|
364
|
-
missing_class_names_across_tasks = names - available_names_across_tasks
|
|
433
|
+
missing_class_names_across_tasks = set(names) - available_names_across_tasks
|
|
365
434
|
if len(missing_class_names_across_tasks) > 0:
|
|
366
435
|
raise ValueError(
|
|
367
436
|
f"The specified names {list(names)} have not been found in any of the tasks. "
|
|
@@ -370,15 +439,15 @@ def _validate_inputs_select_samples_by_class_name(
|
|
|
370
439
|
|
|
371
440
|
# Auto infer task if task_name and primitive are not provided
|
|
372
441
|
if task_name is None and primitive is None:
|
|
373
|
-
tasks_with_names = [t for t in dataset.info.tasks if names.issubset(t.class_names or [])]
|
|
442
|
+
tasks_with_names = [t for t in dataset.info.tasks if set(names).issubset(t.class_names or [])]
|
|
374
443
|
if len(tasks_with_names) == 0:
|
|
375
444
|
raise ValueError(
|
|
376
|
-
f"The specified names {
|
|
445
|
+
f"The specified names {names} have not been found in any of the tasks. "
|
|
377
446
|
f"Available class names: {available_names_across_tasks}"
|
|
378
447
|
)
|
|
379
448
|
if len(tasks_with_names) > 1:
|
|
380
449
|
raise ValueError(
|
|
381
|
-
f"Found multiple tasks containing the specified names {
|
|
450
|
+
f"Found multiple tasks containing the specified names {names}. "
|
|
382
451
|
f"Specify either 'task_name' or 'primitive' to only select from one task. "
|
|
383
452
|
f"Tasks containing all provided names: {[t.name for t in tasks_with_names]}"
|
|
384
453
|
)
|
|
@@ -393,7 +462,7 @@ def _validate_inputs_select_samples_by_class_name(
|
|
|
393
462
|
)
|
|
394
463
|
|
|
395
464
|
task_class_names = set(task.class_names or [])
|
|
396
|
-
missing_class_names = names - task_class_names
|
|
465
|
+
missing_class_names = set(names) - task_class_names
|
|
397
466
|
if len(missing_class_names) > 0:
|
|
398
467
|
raise ValueError(
|
|
399
468
|
f"The specified names {list(missing_class_names)} have not been found for the '{task.name}' task. "
|