hafnia 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hafnia/__init__.py +1 -1
- hafnia/dataset/dataset_names.py +128 -15
- hafnia/dataset/dataset_upload_helper.py +30 -25
- hafnia/dataset/format_conversions/{image_classification_from_directory.py → format_image_classification_folder.py} +14 -10
- hafnia/dataset/format_conversions/format_yolo.py +164 -0
- hafnia/dataset/format_conversions/torchvision_datasets.py +10 -4
- hafnia/dataset/hafnia_dataset.py +246 -72
- hafnia/dataset/operations/dataset_stats.py +82 -70
- hafnia/dataset/operations/dataset_transformations.py +102 -37
- hafnia/dataset/operations/table_transformations.py +132 -15
- hafnia/dataset/primitives/bbox.py +3 -5
- hafnia/dataset/primitives/bitmask.py +2 -7
- hafnia/dataset/primitives/classification.py +3 -3
- hafnia/dataset/primitives/polygon.py +2 -4
- hafnia/dataset/primitives/primitive.py +1 -1
- hafnia/dataset/primitives/segmentation.py +2 -2
- hafnia/platform/datasets.py +3 -7
- hafnia/platform/download.py +1 -72
- hafnia/torch_helpers.py +12 -12
- hafnia/visualizations/image_visualizations.py +2 -0
- {hafnia-0.4.0.dist-info → hafnia-0.4.1.dist-info}/METADATA +4 -4
- {hafnia-0.4.0.dist-info → hafnia-0.4.1.dist-info}/RECORD +25 -24
- {hafnia-0.4.0.dist-info → hafnia-0.4.1.dist-info}/WHEEL +0 -0
- {hafnia-0.4.0.dist-info → hafnia-0.4.1.dist-info}/entry_points.txt +0 -0
- {hafnia-0.4.0.dist-info → hafnia-0.4.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from typing import TYPE_CHECKING, Dict, Optional, Type
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
|
|
4
4
|
|
|
5
5
|
import polars as pl
|
|
6
6
|
import rich
|
|
@@ -8,7 +8,7 @@ from rich import print as rprint
|
|
|
8
8
|
from rich.progress import track
|
|
9
9
|
from rich.table import Table
|
|
10
10
|
|
|
11
|
-
from hafnia.dataset.dataset_names import
|
|
11
|
+
from hafnia.dataset.dataset_names import PrimitiveField, SampleField, SplitName
|
|
12
12
|
from hafnia.dataset.operations.table_transformations import create_primitive_table
|
|
13
13
|
from hafnia.dataset.primitives import PRIMITIVE_TYPES
|
|
14
14
|
from hafnia.log import user_logger
|
|
@@ -18,14 +18,14 @@ if TYPE_CHECKING: # Using 'TYPE_CHECKING' to avoid circular imports during type
|
|
|
18
18
|
from hafnia.dataset.primitives.primitive import Primitive
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
def
|
|
21
|
+
def calculate_split_counts(dataset: HafniaDataset) -> Dict[str, int]:
|
|
22
22
|
"""
|
|
23
23
|
Returns a dictionary with the counts of samples in each split of the dataset.
|
|
24
24
|
"""
|
|
25
|
-
return dict(dataset.samples[
|
|
25
|
+
return dict(dataset.samples[SampleField.SPLIT].value_counts().iter_rows())
|
|
26
26
|
|
|
27
27
|
|
|
28
|
-
def
|
|
28
|
+
def calculate_task_class_counts(
|
|
29
29
|
dataset: HafniaDataset,
|
|
30
30
|
primitive: Optional[Type[Primitive]] = None,
|
|
31
31
|
task_name: Optional[str] = None,
|
|
@@ -53,7 +53,7 @@ def class_counts_for_task(
|
|
|
53
53
|
dataset.samples[task.primitive.column_name()]
|
|
54
54
|
.explode()
|
|
55
55
|
.struct.unnest()
|
|
56
|
-
.filter(pl.col(
|
|
56
|
+
.filter(pl.col(PrimitiveField.TASK_NAME) == task.name)[PrimitiveField.CLASS_NAME]
|
|
57
57
|
.value_counts()
|
|
58
58
|
)
|
|
59
59
|
|
|
@@ -65,7 +65,7 @@ def class_counts_for_task(
|
|
|
65
65
|
return class_counts
|
|
66
66
|
|
|
67
67
|
|
|
68
|
-
def
|
|
68
|
+
def calculate_class_counts(dataset: HafniaDataset) -> List[Dict[str, Any]]:
|
|
69
69
|
"""
|
|
70
70
|
Get class counts for all tasks in the dataset.
|
|
71
71
|
The counts are returned as a dictionary where keys are in the format
|
|
@@ -74,25 +74,59 @@ def class_counts_all(dataset: HafniaDataset) -> Dict[str, int]:
|
|
|
74
74
|
Example:
|
|
75
75
|
>>> counts = dataset.class_counts_all()
|
|
76
76
|
>>> print(counts)
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
}
|
|
77
|
+
[
|
|
78
|
+
{'Primitive': 'Bbox', 'Task Name': 'detection', 'Class Name': 'car', 'Count': 500},
|
|
79
|
+
{'Primitive': 'Bbox', 'Task Name': 'detection', 'Class Name': 'bus', 'Count': 100},
|
|
80
|
+
{'Primitive': 'Classification', 'Task Name': 'scene', 'Class Name': 'indoor', 'Count': 300},
|
|
81
|
+
{'Primitive': 'Classification', 'Task Name': 'scene', 'Class Name': 'outdoor', 'Count': 700},
|
|
82
|
+
]
|
|
84
83
|
"""
|
|
85
|
-
|
|
84
|
+
count_info = []
|
|
86
85
|
for task in dataset.info.tasks:
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
86
|
+
class_name_counts = dataset.calculate_task_class_counts(task_name=task.name)
|
|
87
|
+
for name, counts in class_name_counts.items():
|
|
88
|
+
count_info.append(
|
|
89
|
+
{
|
|
90
|
+
"Primitive": task.primitive.__name__,
|
|
91
|
+
"Task Name": task.name,
|
|
92
|
+
"Class Name": name,
|
|
93
|
+
"Count": counts,
|
|
94
|
+
}
|
|
95
|
+
)
|
|
96
|
+
return count_info
|
|
90
97
|
|
|
91
|
-
for class_idx, (class_name, count) in enumerate(class_counts_task.items()):
|
|
92
|
-
count_name = f"{task.primitive.__name__}/{task.name}/{class_name}"
|
|
93
|
-
class_counts[count_name] = count
|
|
94
98
|
|
|
95
|
-
|
|
99
|
+
def calculate_primitive_counts(dataset: HafniaDataset) -> Dict[str, int]:
|
|
100
|
+
annotation_counts = {}
|
|
101
|
+
for task in dataset.info.tasks:
|
|
102
|
+
objects = dataset.create_primitive_table(task.primitive, task_name=task.name)
|
|
103
|
+
name = task.primitive.__name__
|
|
104
|
+
if task.name != task.primitive.default_task_name():
|
|
105
|
+
name = f"{name}.{task.name}"
|
|
106
|
+
annotation_counts[name] = len(objects)
|
|
107
|
+
return annotation_counts
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def calculate_split_counts_extended(dataset: HafniaDataset) -> List[Dict[str, Any]]:
|
|
111
|
+
splits_sets = {
|
|
112
|
+
"All": SplitName.valid_splits(),
|
|
113
|
+
"Train": [SplitName.TRAIN],
|
|
114
|
+
"Validation": [SplitName.VAL],
|
|
115
|
+
"Test": [SplitName.TEST],
|
|
116
|
+
}
|
|
117
|
+
rows = []
|
|
118
|
+
for split_name, splits in splits_sets.items():
|
|
119
|
+
dataset_split = dataset.create_split_dataset(splits)
|
|
120
|
+
table = dataset_split.samples
|
|
121
|
+
row: Dict[str, Any] = {}
|
|
122
|
+
row["Split"] = split_name
|
|
123
|
+
row["Samples "] = str(len(table))
|
|
124
|
+
|
|
125
|
+
primitive_counts = calculate_primitive_counts(dataset_split)
|
|
126
|
+
row.update(primitive_counts)
|
|
127
|
+
rows.append(row)
|
|
128
|
+
|
|
129
|
+
return rows
|
|
96
130
|
|
|
97
131
|
|
|
98
132
|
def print_stats(dataset: HafniaDataset) -> None:
|
|
@@ -118,10 +152,13 @@ def print_class_distribution(dataset: HafniaDataset) -> None:
|
|
|
118
152
|
for task in dataset.info.tasks:
|
|
119
153
|
if task.class_names is None:
|
|
120
154
|
raise ValueError(f"Task '{task.name}' does not have class names defined.")
|
|
121
|
-
class_counts = dataset.
|
|
155
|
+
class_counts = dataset.calculate_task_class_counts(primitive=task.primitive, task_name=task.name)
|
|
122
156
|
|
|
123
157
|
# Print class distribution
|
|
124
|
-
rich_table = Table(
|
|
158
|
+
rich_table = Table(
|
|
159
|
+
title=f"Class Count for '{task.primitive.__name__}/{task.name}'",
|
|
160
|
+
show_lines=False,
|
|
161
|
+
)
|
|
125
162
|
rich_table.add_column("Class Name", style="cyan")
|
|
126
163
|
rich_table.add_column("Class Idx", style="cyan")
|
|
127
164
|
rich_table.add_column("Count", justify="right")
|
|
@@ -136,32 +173,7 @@ def print_sample_and_task_counts(dataset: HafniaDataset) -> None:
|
|
|
136
173
|
Prints a table with sample counts and task counts for each primitive type
|
|
137
174
|
in total and for each split (train, val, test).
|
|
138
175
|
"""
|
|
139
|
-
|
|
140
|
-
from hafnia.dataset.primitives import PRIMITIVE_TYPES
|
|
141
|
-
|
|
142
|
-
splits_sets = {
|
|
143
|
-
"All": SplitName.valid_splits(),
|
|
144
|
-
"Train": [SplitName.TRAIN],
|
|
145
|
-
"Validation": [SplitName.VAL],
|
|
146
|
-
"Test": [SplitName.TEST],
|
|
147
|
-
}
|
|
148
|
-
rows = []
|
|
149
|
-
for split_name, splits in splits_sets.items():
|
|
150
|
-
dataset_split = dataset.create_split_dataset(splits)
|
|
151
|
-
table = dataset_split.samples
|
|
152
|
-
row = {}
|
|
153
|
-
row["Split"] = split_name
|
|
154
|
-
row["Sample "] = str(len(table))
|
|
155
|
-
for PrimitiveType in PRIMITIVE_TYPES:
|
|
156
|
-
column_name = PrimitiveType.column_name()
|
|
157
|
-
objects_df = create_primitive_table(table, PrimitiveType=PrimitiveType, keep_sample_data=False)
|
|
158
|
-
if objects_df is None:
|
|
159
|
-
continue
|
|
160
|
-
for (task_name,), object_group in objects_df.group_by(FieldName.TASK_NAME):
|
|
161
|
-
count = len(object_group[FieldName.CLASS_NAME])
|
|
162
|
-
row[f"{PrimitiveType.__name__}\n{task_name}"] = str(count)
|
|
163
|
-
rows.append(row)
|
|
164
|
-
|
|
176
|
+
rows = calculate_split_counts_extended(dataset)
|
|
165
177
|
rich_table = Table(title="Dataset Statistics", show_lines=True, box=rich.box.SIMPLE)
|
|
166
178
|
for i_row, row in enumerate(rows):
|
|
167
179
|
if i_row == 0:
|
|
@@ -171,7 +183,7 @@ def print_sample_and_task_counts(dataset: HafniaDataset) -> None:
|
|
|
171
183
|
rprint(rich_table)
|
|
172
184
|
|
|
173
185
|
|
|
174
|
-
def check_dataset(dataset: HafniaDataset):
|
|
186
|
+
def check_dataset(dataset: HafniaDataset, check_splits: bool = True):
|
|
175
187
|
"""
|
|
176
188
|
Performs various checks on the dataset to ensure its integrity and consistency.
|
|
177
189
|
Raises errors if any issues are found.
|
|
@@ -181,21 +193,21 @@ def check_dataset(dataset: HafniaDataset):
|
|
|
181
193
|
user_logger.info("Checking Hafnia dataset...")
|
|
182
194
|
assert isinstance(dataset.info.dataset_name, str) and len(dataset.info.dataset_name) > 0
|
|
183
195
|
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
196
|
+
if check_splits:
|
|
197
|
+
sample_dataset = dataset.create_sample_dataset()
|
|
198
|
+
if len(sample_dataset) == 0:
|
|
199
|
+
raise ValueError("The dataset does not include a sample dataset")
|
|
200
|
+
|
|
201
|
+
actual_splits = dataset.samples.select(pl.col(SampleField.SPLIT)).unique().to_series().to_list()
|
|
202
|
+
required_splits = SplitName.valid_splits()
|
|
187
203
|
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
if set(actual_splits) != set(expected_splits):
|
|
191
|
-
raise ValueError(f"Expected all splits '{expected_splits}' in dataset, but got '{actual_splits}'. ")
|
|
204
|
+
if not set(required_splits).issubset(set(actual_splits)):
|
|
205
|
+
raise ValueError(f"Expected all splits '{required_splits}' in dataset, but got '{actual_splits}'. ")
|
|
192
206
|
|
|
193
207
|
dataset.check_dataset_tasks()
|
|
194
208
|
|
|
195
209
|
expected_tasks = dataset.info.tasks
|
|
196
|
-
|
|
197
|
-
distribution_names = [task.name for task in distribution]
|
|
198
|
-
# Check that tasks found in the 'dataset.table' matches the tasks defined in 'dataset.info.tasks'
|
|
210
|
+
# Check that tasks found in the 'dataset.samples' matches the tasks defined in 'dataset.info.tasks'
|
|
199
211
|
for PrimitiveType in PRIMITIVE_TYPES:
|
|
200
212
|
column_name = PrimitiveType.column_name()
|
|
201
213
|
if column_name not in dataset.samples.columns:
|
|
@@ -203,14 +215,14 @@ def check_dataset(dataset: HafniaDataset):
|
|
|
203
215
|
objects_df = create_primitive_table(dataset.samples, PrimitiveType=PrimitiveType, keep_sample_data=False)
|
|
204
216
|
if objects_df is None:
|
|
205
217
|
continue
|
|
206
|
-
for (task_name,), object_group in objects_df.group_by(
|
|
218
|
+
for (task_name,), object_group in objects_df.group_by(PrimitiveField.TASK_NAME):
|
|
207
219
|
has_task = any([t for t in expected_tasks if t.name == task_name and t.primitive == PrimitiveType])
|
|
208
|
-
if has_task
|
|
220
|
+
if has_task:
|
|
209
221
|
continue
|
|
210
|
-
class_names = object_group[
|
|
222
|
+
class_names = object_group[PrimitiveField.CLASS_NAME].unique().to_list()
|
|
211
223
|
raise ValueError(
|
|
212
224
|
f"Task name '{task_name}' for the '{PrimitiveType.__name__}' primitive is missing in "
|
|
213
|
-
f"'dataset.info.tasks' for dataset '{
|
|
225
|
+
f"'dataset.info.tasks' for dataset '{dataset.info.dataset_name}'. Missing task has the following "
|
|
214
226
|
f"classes: {class_names}. "
|
|
215
227
|
)
|
|
216
228
|
|
|
@@ -237,7 +249,7 @@ def check_dataset_tasks(dataset: HafniaDataset):
|
|
|
237
249
|
|
|
238
250
|
if len(dataset) > 0: # Check only performed for non-empty datasets
|
|
239
251
|
primitive_table = (
|
|
240
|
-
primitive_column.explode().struct.unnest().filter(pl.col(
|
|
252
|
+
primitive_column.explode().struct.unnest().filter(pl.col(PrimitiveField.TASK_NAME) == task.name)
|
|
241
253
|
)
|
|
242
254
|
if primitive_table.is_empty():
|
|
243
255
|
raise ValueError(
|
|
@@ -245,7 +257,7 @@ def check_dataset_tasks(dataset: HafniaDataset):
|
|
|
245
257
|
+ f"the column '{column_name}' has no {task.name=} objects. Please check the dataset."
|
|
246
258
|
)
|
|
247
259
|
|
|
248
|
-
actual_classes = set(primitive_table[
|
|
260
|
+
actual_classes = set(primitive_table[PrimitiveField.CLASS_NAME].unique().to_list())
|
|
249
261
|
if task.class_names is None:
|
|
250
262
|
raise ValueError(
|
|
251
263
|
msg_something_wrong
|
|
@@ -260,12 +272,12 @@ def check_dataset_tasks(dataset: HafniaDataset):
|
|
|
260
272
|
f"to be a subset of the defined classes\n\t{actual_classes=} \n\t{defined_classes=}."
|
|
261
273
|
)
|
|
262
274
|
# Check class_indices
|
|
263
|
-
mapped_indices = primitive_table[
|
|
275
|
+
mapped_indices = primitive_table[PrimitiveField.CLASS_NAME].map_elements(
|
|
264
276
|
lambda x: task.class_names.index(x), return_dtype=pl.Int64
|
|
265
277
|
)
|
|
266
|
-
table_indices = primitive_table[
|
|
278
|
+
table_indices = primitive_table[PrimitiveField.CLASS_IDX]
|
|
267
279
|
|
|
268
280
|
error_msg = msg_something_wrong + (
|
|
269
|
-
f"class indices in '{
|
|
281
|
+
f"class indices in '{PrimitiveField.CLASS_IDX}' column does not match classes ordering in 'task.class_names'"
|
|
270
282
|
)
|
|
271
283
|
assert mapped_indices.equals(table_indices), error_msg
|
|
@@ -31,6 +31,7 @@ that the signatures match.
|
|
|
31
31
|
|
|
32
32
|
import json
|
|
33
33
|
import re
|
|
34
|
+
import shutil
|
|
34
35
|
import textwrap
|
|
35
36
|
from pathlib import Path
|
|
36
37
|
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Type, Union
|
|
@@ -39,17 +40,23 @@ import cv2
|
|
|
39
40
|
import more_itertools
|
|
40
41
|
import numpy as np
|
|
41
42
|
import polars as pl
|
|
42
|
-
from PIL import Image
|
|
43
43
|
from rich.progress import track
|
|
44
44
|
|
|
45
45
|
from hafnia.dataset import dataset_helpers
|
|
46
|
-
from hafnia.dataset.dataset_names import
|
|
46
|
+
from hafnia.dataset.dataset_names import (
|
|
47
|
+
OPS_REMOVE_CLASS,
|
|
48
|
+
PrimitiveField,
|
|
49
|
+
SampleField,
|
|
50
|
+
StorageFormat,
|
|
51
|
+
)
|
|
52
|
+
from hafnia.dataset.operations.table_transformations import update_class_indices
|
|
47
53
|
from hafnia.dataset.primitives import get_primitive_type_from_string
|
|
48
54
|
from hafnia.dataset.primitives.primitive import Primitive
|
|
55
|
+
from hafnia.log import user_logger
|
|
49
56
|
from hafnia.utils import remove_duplicates_preserve_order
|
|
50
57
|
|
|
51
58
|
if TYPE_CHECKING: # Using 'TYPE_CHECKING' to avoid circular imports during type checking
|
|
52
|
-
from hafnia.dataset.hafnia_dataset import HafniaDataset, TaskInfo
|
|
59
|
+
from hafnia.dataset.hafnia_dataset import HafniaDataset, Sample, TaskInfo
|
|
53
60
|
|
|
54
61
|
|
|
55
62
|
### Image transformations ###
|
|
@@ -57,7 +64,7 @@ class AnonymizeByPixelation:
|
|
|
57
64
|
def __init__(self, resize_factor: float = 0.10):
|
|
58
65
|
self.resize_factor = resize_factor
|
|
59
66
|
|
|
60
|
-
def __call__(self, frame: np.ndarray) -> np.ndarray:
|
|
67
|
+
def __call__(self, frame: np.ndarray, sample: "Sample") -> np.ndarray:
|
|
61
68
|
org_size = frame.shape[:2]
|
|
62
69
|
frame = cv2.resize(frame, (0, 0), fx=self.resize_factor, fy=self.resize_factor)
|
|
63
70
|
frame = cv2.resize(frame, org_size[::-1], interpolation=cv2.INTER_NEAREST)
|
|
@@ -66,31 +73,100 @@ class AnonymizeByPixelation:
|
|
|
66
73
|
|
|
67
74
|
def transform_images(
|
|
68
75
|
dataset: "HafniaDataset",
|
|
69
|
-
transform: Callable[[np.ndarray], np.ndarray],
|
|
76
|
+
transform: Callable[[np.ndarray, "Sample"], np.ndarray],
|
|
70
77
|
path_output: Path,
|
|
78
|
+
description: str = "Transform images",
|
|
71
79
|
) -> "HafniaDataset":
|
|
80
|
+
from hafnia.dataset.hafnia_dataset import Sample
|
|
81
|
+
|
|
72
82
|
new_paths = []
|
|
73
83
|
path_image_folder = path_output / "data"
|
|
74
84
|
path_image_folder.mkdir(parents=True, exist_ok=True)
|
|
75
85
|
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
raise FileNotFoundError(f"File {org_path} does not exist in the dataset.")
|
|
81
|
-
|
|
82
|
-
image = np.array(Image.open(org_path))
|
|
83
|
-
image_transformed = transform(image)
|
|
86
|
+
for sample_dict in track(dataset, description=description):
|
|
87
|
+
sample = Sample(**sample_dict)
|
|
88
|
+
image = sample.read_image()
|
|
89
|
+
image_transformed = transform(image, sample)
|
|
84
90
|
new_path = dataset_helpers.save_image_with_hash_name(image_transformed, path_image_folder)
|
|
85
91
|
|
|
86
92
|
if not new_path.exists():
|
|
87
93
|
raise FileNotFoundError(f"Transformed file {new_path} does not exist in the dataset.")
|
|
88
94
|
new_paths.append(str(new_path))
|
|
89
95
|
|
|
90
|
-
table = dataset.samples.with_columns(pl.Series(new_paths).alias(
|
|
96
|
+
table = dataset.samples.with_columns(pl.Series(new_paths).alias(SampleField.FILE_PATH))
|
|
91
97
|
return dataset.update_samples(table)
|
|
92
98
|
|
|
93
99
|
|
|
100
|
+
def convert_to_image_storage_format(
|
|
101
|
+
dataset: "HafniaDataset",
|
|
102
|
+
path_output_folder: Path,
|
|
103
|
+
reextract_frames: bool,
|
|
104
|
+
image_format: str = "png",
|
|
105
|
+
transform: Optional[Callable[[np.ndarray, "Sample"], np.ndarray]] = None,
|
|
106
|
+
) -> "HafniaDataset":
|
|
107
|
+
"""
|
|
108
|
+
Convert a video-based dataset ("storage_format" == "video", FieldName.STORAGE_FORMAT == StorageFormat.VIDEO)
|
|
109
|
+
to an image-based dataset by extracting frames.
|
|
110
|
+
"""
|
|
111
|
+
from hafnia.dataset.hafnia_dataset import HafniaDataset, Sample
|
|
112
|
+
|
|
113
|
+
path_images = path_output_folder / "data"
|
|
114
|
+
path_images.mkdir(parents=True, exist_ok=True)
|
|
115
|
+
|
|
116
|
+
# Only video format dataset samples are processed
|
|
117
|
+
video_based_samples = dataset.samples.filter(pl.col(SampleField.STORAGE_FORMAT) == StorageFormat.VIDEO)
|
|
118
|
+
|
|
119
|
+
if video_based_samples.is_empty():
|
|
120
|
+
user_logger.info("Dataset has no video-based samples. Returning dataset unchanged.")
|
|
121
|
+
return dataset
|
|
122
|
+
|
|
123
|
+
update_list = []
|
|
124
|
+
for (path_video,), video_samples in video_based_samples.group_by(SampleField.FILE_PATH):
|
|
125
|
+
assert Path(path_video).exists(), (
|
|
126
|
+
f"'{path_video}' not found. We expect the video to be downloaded to '{path_output_folder}'"
|
|
127
|
+
)
|
|
128
|
+
video = cv2.VideoCapture(str(path_video))
|
|
129
|
+
|
|
130
|
+
video_samples = video_samples.sort(SampleField.COLLECTION_INDEX)
|
|
131
|
+
for sample_dict in track(
|
|
132
|
+
video_samples.iter_rows(named=True),
|
|
133
|
+
total=video_samples.height,
|
|
134
|
+
description=f"Extracting frames from '{Path(path_video).name}'",
|
|
135
|
+
):
|
|
136
|
+
frame_number = sample_dict[SampleField.COLLECTION_INDEX]
|
|
137
|
+
image_name = f"{Path(path_video).stem}_F{frame_number:06d}.{image_format}"
|
|
138
|
+
path_image = path_images / image_name
|
|
139
|
+
|
|
140
|
+
update_list.append(
|
|
141
|
+
{
|
|
142
|
+
SampleField.SAMPLE_INDEX: sample_dict[SampleField.SAMPLE_INDEX],
|
|
143
|
+
SampleField.COLLECTION_ID: sample_dict[SampleField.COLLECTION_ID],
|
|
144
|
+
SampleField.COLLECTION_INDEX: frame_number,
|
|
145
|
+
SampleField.FILE_PATH: path_image.as_posix(),
|
|
146
|
+
SampleField.STORAGE_FORMAT: StorageFormat.IMAGE,
|
|
147
|
+
}
|
|
148
|
+
)
|
|
149
|
+
if reextract_frames:
|
|
150
|
+
shutil.rmtree(path_image, ignore_errors=True)
|
|
151
|
+
if path_image.exists():
|
|
152
|
+
continue
|
|
153
|
+
|
|
154
|
+
video.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
|
|
155
|
+
ret, frame_org = video.read()
|
|
156
|
+
if not ret:
|
|
157
|
+
raise RuntimeError(f"Could not read frame {frame_number} from video '{path_video}'")
|
|
158
|
+
|
|
159
|
+
if transform is not None:
|
|
160
|
+
frame_org = transform(frame_org, Sample(**sample_dict))
|
|
161
|
+
|
|
162
|
+
cv2.imwrite(str(path_image), frame_org)
|
|
163
|
+
df_updates = pl.DataFrame(update_list)
|
|
164
|
+
samples_as_images = dataset.samples.update(df_updates, on=[SampleField.COLLECTION_ID, SampleField.COLLECTION_INDEX])
|
|
165
|
+
hafnia_dataset = HafniaDataset(samples=samples_as_images, info=dataset.info)
|
|
166
|
+
|
|
167
|
+
return hafnia_dataset
|
|
168
|
+
|
|
169
|
+
|
|
94
170
|
def get_task_info_from_task_name_and_primitive(
|
|
95
171
|
tasks: List["TaskInfo"],
|
|
96
172
|
task_name: Optional[str] = None,
|
|
@@ -223,25 +299,10 @@ def class_mapper(
|
|
|
223
299
|
pl.col(task.primitive.column_name())
|
|
224
300
|
.list.eval(
|
|
225
301
|
pl.element().struct.with_fields(
|
|
226
|
-
pl.when(pl.field(
|
|
227
|
-
.then(pl.field(
|
|
228
|
-
.otherwise(pl.field(
|
|
229
|
-
.alias(
|
|
230
|
-
)
|
|
231
|
-
)
|
|
232
|
-
.alias(task.primitive.column_name())
|
|
233
|
-
)
|
|
234
|
-
|
|
235
|
-
# Update class indices too
|
|
236
|
-
name_2_idx_mapping: Dict[str, int] = {name: idx for idx, name in enumerate(new_class_names)}
|
|
237
|
-
samples_updated = samples_updated.with_columns(
|
|
238
|
-
pl.col(task.primitive.column_name())
|
|
239
|
-
.list.eval(
|
|
240
|
-
pl.element().struct.with_fields(
|
|
241
|
-
pl.when(pl.field(FieldName.TASK_NAME) == task.name)
|
|
242
|
-
.then(pl.field(FieldName.CLASS_NAME).replace_strict(name_2_idx_mapping))
|
|
243
|
-
.otherwise(pl.field(FieldName.CLASS_IDX))
|
|
244
|
-
.alias(FieldName.CLASS_IDX)
|
|
302
|
+
pl.when(pl.field(PrimitiveField.TASK_NAME) == task.name)
|
|
303
|
+
.then(pl.field(PrimitiveField.CLASS_NAME).replace_strict(class_mapping, default="Missing"))
|
|
304
|
+
.otherwise(pl.field(PrimitiveField.CLASS_NAME))
|
|
305
|
+
.alias(PrimitiveField.CLASS_NAME)
|
|
245
306
|
)
|
|
246
307
|
)
|
|
247
308
|
.alias(task.primitive.column_name())
|
|
@@ -250,7 +311,7 @@ def class_mapper(
|
|
|
250
311
|
if OPS_REMOVE_CLASS in new_class_names: # Remove class_names that are mapped to REMOVE_CLASS
|
|
251
312
|
samples_updated = samples_updated.with_columns(
|
|
252
313
|
pl.col(task.primitive.column_name())
|
|
253
|
-
.list.filter(pl.element().struct.field(
|
|
314
|
+
.list.filter(pl.element().struct.field(PrimitiveField.CLASS_NAME) != OPS_REMOVE_CLASS)
|
|
254
315
|
.alias(task.primitive.column_name())
|
|
255
316
|
)
|
|
256
317
|
|
|
@@ -259,6 +320,10 @@ def class_mapper(
|
|
|
259
320
|
new_task = task.model_copy(deep=True)
|
|
260
321
|
new_task.class_names = new_class_names
|
|
261
322
|
dataset_info = dataset.info.replace_task(old_task=task, new_task=new_task)
|
|
323
|
+
|
|
324
|
+
# Update class indices to match new class names
|
|
325
|
+
samples_updated = update_class_indices(samples_updated, new_task)
|
|
326
|
+
|
|
262
327
|
return HafniaDataset(info=dataset_info, samples=samples_updated)
|
|
263
328
|
|
|
264
329
|
|
|
@@ -317,7 +382,7 @@ def rename_task(
|
|
|
317
382
|
pl.col(old_task.primitive.column_name())
|
|
318
383
|
.list.eval(
|
|
319
384
|
pl.element().struct.with_fields(
|
|
320
|
-
pl.field(
|
|
385
|
+
pl.field(PrimitiveField.TASK_NAME).replace(old_task.name, new_task.name).alias(PrimitiveField.TASK_NAME)
|
|
321
386
|
)
|
|
322
387
|
)
|
|
323
388
|
.alias(new_task.primitive.column_name())
|
|
@@ -343,8 +408,8 @@ def select_samples_by_class_name(
|
|
|
343
408
|
samples = dataset.samples.filter(
|
|
344
409
|
pl.col(task.primitive.column_name())
|
|
345
410
|
.list.eval(
|
|
346
|
-
pl.element().struct.field(
|
|
347
|
-
& (pl.element().struct.field(
|
|
411
|
+
pl.element().struct.field(PrimitiveField.CLASS_NAME).is_in(class_names)
|
|
412
|
+
& (pl.element().struct.field(PrimitiveField.TASK_NAME) == task.name)
|
|
348
413
|
)
|
|
349
414
|
.list.any()
|
|
350
415
|
)
|