hafnia 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hafnia/__init__.py +1 -1
- hafnia/dataset/dataset_names.py +128 -15
- hafnia/dataset/dataset_upload_helper.py +30 -25
- hafnia/dataset/format_conversions/{image_classification_from_directory.py → format_image_classification_folder.py} +14 -10
- hafnia/dataset/format_conversions/format_yolo.py +164 -0
- hafnia/dataset/format_conversions/torchvision_datasets.py +10 -4
- hafnia/dataset/hafnia_dataset.py +246 -72
- hafnia/dataset/operations/dataset_stats.py +82 -70
- hafnia/dataset/operations/dataset_transformations.py +102 -37
- hafnia/dataset/operations/table_transformations.py +132 -15
- hafnia/dataset/primitives/bbox.py +3 -5
- hafnia/dataset/primitives/bitmask.py +2 -7
- hafnia/dataset/primitives/classification.py +3 -3
- hafnia/dataset/primitives/polygon.py +2 -4
- hafnia/dataset/primitives/primitive.py +1 -1
- hafnia/dataset/primitives/segmentation.py +2 -2
- hafnia/platform/datasets.py +3 -7
- hafnia/platform/download.py +1 -72
- hafnia/torch_helpers.py +12 -12
- hafnia/visualizations/image_visualizations.py +2 -0
- {hafnia-0.4.0.dist-info → hafnia-0.4.1.dist-info}/METADATA +4 -4
- {hafnia-0.4.0.dist-info → hafnia-0.4.1.dist-info}/RECORD +25 -24
- {hafnia-0.4.0.dist-info → hafnia-0.4.1.dist-info}/WHEEL +0 -0
- {hafnia-0.4.0.dist-info → hafnia-0.4.1.dist-info}/entry_points.txt +0 -0
- {hafnia-0.4.0.dist-info → hafnia-0.4.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from pathlib import Path
|
|
2
|
-
from typing import List, Optional, Type
|
|
2
|
+
from typing import TYPE_CHECKING, List, Optional, Tuple, Type
|
|
3
3
|
|
|
4
4
|
import polars as pl
|
|
5
5
|
from rich.progress import track
|
|
@@ -7,8 +7,8 @@ from rich.progress import track
|
|
|
7
7
|
from hafnia.dataset.dataset_names import (
|
|
8
8
|
FILENAME_ANNOTATIONS_JSONL,
|
|
9
9
|
FILENAME_ANNOTATIONS_PARQUET,
|
|
10
|
-
|
|
11
|
-
|
|
10
|
+
PrimitiveField,
|
|
11
|
+
SampleField,
|
|
12
12
|
)
|
|
13
13
|
from hafnia.dataset.operations import table_transformations
|
|
14
14
|
from hafnia.dataset.primitives import PRIMITIVE_TYPES
|
|
@@ -16,9 +16,15 @@ from hafnia.dataset.primitives.classification import Classification
|
|
|
16
16
|
from hafnia.dataset.primitives.primitive import Primitive
|
|
17
17
|
from hafnia.log import user_logger
|
|
18
18
|
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from hafnia.dataset.hafnia_dataset import TaskInfo
|
|
21
|
+
|
|
19
22
|
|
|
20
23
|
def create_primitive_table(
|
|
21
|
-
samples_table: pl.DataFrame,
|
|
24
|
+
samples_table: pl.DataFrame,
|
|
25
|
+
PrimitiveType: Type[Primitive],
|
|
26
|
+
keep_sample_data: bool = False,
|
|
27
|
+
task_name: Optional[str] = None,
|
|
22
28
|
) -> Optional[pl.DataFrame]:
|
|
23
29
|
"""
|
|
24
30
|
Returns a DataFrame with objects of the specified primitive type.
|
|
@@ -48,6 +54,9 @@ def create_primitive_table(
|
|
|
48
54
|
objects_df = remove_no_object_frames.explode(column_name).unnest(column_name)
|
|
49
55
|
else:
|
|
50
56
|
objects_df = remove_no_object_frames.select(pl.col(column_name).explode().struct.unnest())
|
|
57
|
+
|
|
58
|
+
if task_name is not None:
|
|
59
|
+
objects_df = objects_df.filter(pl.col(PrimitiveField.TASK_NAME) == task_name)
|
|
51
60
|
return objects_df
|
|
52
61
|
|
|
53
62
|
|
|
@@ -55,11 +64,12 @@ def merge_samples(samples0: pl.DataFrame, samples1: pl.DataFrame) -> pl.DataFram
|
|
|
55
64
|
has_same_schema = samples0.schema == samples1.schema
|
|
56
65
|
if not has_same_schema:
|
|
57
66
|
shared_columns = []
|
|
58
|
-
for column_name,
|
|
67
|
+
for column_name, s0_column_type in samples0.schema.items():
|
|
59
68
|
if column_name not in samples1.schema:
|
|
60
69
|
continue
|
|
70
|
+
samples0, samples1 = correction_of_list_struct_primitives(samples0, samples1, column_name)
|
|
61
71
|
|
|
62
|
-
if
|
|
72
|
+
if samples0.schema[column_name] != samples1.schema[column_name]:
|
|
63
73
|
continue
|
|
64
74
|
shared_columns.append(column_name)
|
|
65
75
|
|
|
@@ -79,16 +89,58 @@ def merge_samples(samples0: pl.DataFrame, samples1: pl.DataFrame) -> pl.DataFram
|
|
|
79
89
|
samples0 = samples0.select(list(shared_columns))
|
|
80
90
|
samples1 = samples1.select(list(shared_columns))
|
|
81
91
|
merged_samples = pl.concat([samples0, samples1], how="vertical")
|
|
82
|
-
merged_samples = merged_samples
|
|
92
|
+
merged_samples = add_sample_index(merged_samples)
|
|
83
93
|
return merged_samples
|
|
84
94
|
|
|
85
95
|
|
|
96
|
+
def correction_of_list_struct_primitives(
|
|
97
|
+
samples0: pl.DataFrame,
|
|
98
|
+
samples1: pl.DataFrame,
|
|
99
|
+
column_name: str,
|
|
100
|
+
) -> Tuple[pl.DataFrame, pl.DataFrame]:
|
|
101
|
+
"""
|
|
102
|
+
Corrects primitive columns (bboxes, polygons etc of type 'list[struct]') by removing non-matching struct fields
|
|
103
|
+
between two datasets. This is useful when merging two datasets with the same primitive (e.g. Bbox), where
|
|
104
|
+
some (less important) field types in the struct differ between the two datasets.
|
|
105
|
+
This issue often occurs with the 'meta' field as different dataset formats may store different metadata information.
|
|
106
|
+
"""
|
|
107
|
+
s0_column_type = samples0.schema[column_name]
|
|
108
|
+
s1_column_type = samples1.schema[column_name]
|
|
109
|
+
is_list_structs = s1_column_type == pl.List(pl.Struct) and s0_column_type == pl.List(pl.Struct)
|
|
110
|
+
is_non_matching_types = s1_column_type != s0_column_type
|
|
111
|
+
if is_list_structs and is_non_matching_types: # Only perform correction for list[struct] types that do not match
|
|
112
|
+
s0_fields = set(s0_column_type.inner.fields)
|
|
113
|
+
s1_fields = set(s1_column_type.inner.fields)
|
|
114
|
+
similar_fields = s0_fields.intersection(s1_fields)
|
|
115
|
+
s0_dropped_fields = s0_fields - similar_fields
|
|
116
|
+
if len(s0_dropped_fields) > 0:
|
|
117
|
+
samples0 = samples0.with_columns(
|
|
118
|
+
pl.col(column_name)
|
|
119
|
+
.list.eval(pl.struct([pl.element().struct.field(k.name) for k in similar_fields]))
|
|
120
|
+
.alias(column_name)
|
|
121
|
+
)
|
|
122
|
+
s1_dropped_fields = s1_fields - similar_fields
|
|
123
|
+
if len(s1_dropped_fields) > 0:
|
|
124
|
+
samples1 = samples1.with_columns(
|
|
125
|
+
pl.col(column_name)
|
|
126
|
+
.list.eval(pl.struct([pl.element().struct.field(k.name) for k in similar_fields]))
|
|
127
|
+
.alias(column_name)
|
|
128
|
+
)
|
|
129
|
+
user_logger.warning(
|
|
130
|
+
f"Primitive column '{column_name}' has none-matching fields in the two datasets. "
|
|
131
|
+
f"Dropping fields in samples0: {[f.name for f in s0_dropped_fields]}. "
|
|
132
|
+
f"Dropping fields in samples1: {[f.name for f in s1_dropped_fields]}."
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
return samples0, samples1
|
|
136
|
+
|
|
137
|
+
|
|
86
138
|
def filter_table_for_class_names(
|
|
87
139
|
samples_table: pl.DataFrame, class_names: List[str], PrimitiveType: Type[Primitive]
|
|
88
140
|
) -> Optional[pl.DataFrame]:
|
|
89
141
|
table_with_selected_class_names = samples_table.filter(
|
|
90
142
|
pl.col(PrimitiveType.column_name())
|
|
91
|
-
.list.eval(pl.element().struct.field(
|
|
143
|
+
.list.eval(pl.element().struct.field(PrimitiveField.CLASS_NAME).is_in(class_names))
|
|
92
144
|
.list.any()
|
|
93
145
|
)
|
|
94
146
|
|
|
@@ -100,20 +152,20 @@ def split_primitive_columns_by_task_name(
|
|
|
100
152
|
coordinate_types: Optional[List[Type[Primitive]]] = None,
|
|
101
153
|
) -> pl.DataFrame:
|
|
102
154
|
"""
|
|
103
|
-
Convert Primitive columns such as "
|
|
104
|
-
For example, if the "
|
|
155
|
+
Convert Primitive columns such as "bboxes" (Bbox) into a column for each task name.
|
|
156
|
+
For example, if the "bboxes" column (containing Bbox objects) has tasks "task1" and "task2".
|
|
105
157
|
|
|
106
158
|
|
|
107
159
|
This:
|
|
108
160
|
─┬────────────┬─
|
|
109
|
-
┆
|
|
161
|
+
┆ bboxes ┆
|
|
110
162
|
┆ --- ┆
|
|
111
163
|
┆ list[struc ┆
|
|
112
164
|
┆ t[11]] ┆
|
|
113
165
|
═╪════════════╪═
|
|
114
166
|
becomes this:
|
|
115
167
|
─┬────────────┬────────────┬─
|
|
116
|
-
┆
|
|
168
|
+
┆ bboxes. ┆ bboxes. ┆
|
|
117
169
|
┆ task1 ┆ task2 ┆
|
|
118
170
|
┆ --- ┆ --- ┆
|
|
119
171
|
┆ list[struc ┆ list[struc ┆
|
|
@@ -131,11 +183,11 @@ def split_primitive_columns_by_task_name(
|
|
|
131
183
|
if samples_table[col_name].dtype != pl.List(pl.Struct):
|
|
132
184
|
continue
|
|
133
185
|
|
|
134
|
-
task_names = samples_table[col_name].explode().struct.field(
|
|
186
|
+
task_names = samples_table[col_name].explode().struct.field(PrimitiveField.TASK_NAME).unique().to_list()
|
|
135
187
|
samples_table = samples_table.with_columns(
|
|
136
188
|
[
|
|
137
189
|
pl.col(col_name)
|
|
138
|
-
.list.filter(pl.element().struct.field(
|
|
190
|
+
.list.filter(pl.element().struct.field(PrimitiveField.TASK_NAME).eq(task_name))
|
|
139
191
|
.alias(f"{col_name}.{task_name}")
|
|
140
192
|
for task_name in task_names
|
|
141
193
|
]
|
|
@@ -162,7 +214,7 @@ def read_samples_from_path(path: Path) -> pl.DataFrame:
|
|
|
162
214
|
|
|
163
215
|
def check_image_paths(table: pl.DataFrame) -> bool:
|
|
164
216
|
missing_files = []
|
|
165
|
-
org_paths = table[
|
|
217
|
+
org_paths = table[SampleField.FILE_PATH].to_list()
|
|
166
218
|
for org_path in track(org_paths, description="Check image paths"):
|
|
167
219
|
org_path = Path(org_path)
|
|
168
220
|
if not org_path.exists():
|
|
@@ -219,3 +271,68 @@ def unnest_classification_tasks(table: pl.DataFrame, strict: bool = True) -> pl.
|
|
|
219
271
|
|
|
220
272
|
table_out = table_out.with_columns([pl.col(c).list.first() for c in classification_columns])
|
|
221
273
|
return table_out
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def update_class_indices(samples: pl.DataFrame, task: "TaskInfo") -> pl.DataFrame:
|
|
277
|
+
if task.class_names is None or len(task.class_names) == 0:
|
|
278
|
+
raise ValueError(f"Task '{task.name}' does not have defined class names to update class indices.")
|
|
279
|
+
|
|
280
|
+
objs = (
|
|
281
|
+
samples[task.primitive.column_name()]
|
|
282
|
+
.explode()
|
|
283
|
+
.struct.unnest()
|
|
284
|
+
.filter(pl.col(PrimitiveField.TASK_NAME) == task.name)
|
|
285
|
+
)
|
|
286
|
+
expected_class_names = set(objs[PrimitiveField.CLASS_NAME].unique())
|
|
287
|
+
missing_class_names = expected_class_names - set(task.class_names)
|
|
288
|
+
if len(missing_class_names) > 0:
|
|
289
|
+
raise ValueError(
|
|
290
|
+
f"Task '{task.name}' is missing class names: {missing_class_names}. Cannot update class indices."
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
name_2_idx_mapping = {name: idx for idx, name in enumerate(task.class_names)}
|
|
294
|
+
|
|
295
|
+
samples_updated = samples.with_columns(
|
|
296
|
+
pl.col(task.primitive.column_name())
|
|
297
|
+
.list.eval(
|
|
298
|
+
pl.element().struct.with_fields(
|
|
299
|
+
pl.when(pl.field(PrimitiveField.TASK_NAME) == task.name)
|
|
300
|
+
.then(pl.field(PrimitiveField.CLASS_NAME).replace_strict(name_2_idx_mapping, default=-1))
|
|
301
|
+
.otherwise(pl.field(PrimitiveField.CLASS_IDX))
|
|
302
|
+
.alias(PrimitiveField.CLASS_IDX)
|
|
303
|
+
)
|
|
304
|
+
)
|
|
305
|
+
.alias(task.primitive.column_name())
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
return samples_updated
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
def add_sample_index(samples: pl.DataFrame) -> pl.DataFrame:
|
|
312
|
+
"""
|
|
313
|
+
Adds a sample index column to the samples DataFrame.
|
|
314
|
+
|
|
315
|
+
Note: Unlike the built-in 'polars.DataFrame.with_row_count', this function
|
|
316
|
+
always guarantees 'pl.UInt64' type for the index column.
|
|
317
|
+
"""
|
|
318
|
+
if SampleField.SAMPLE_INDEX in samples.columns:
|
|
319
|
+
samples = samples.drop(SampleField.SAMPLE_INDEX)
|
|
320
|
+
samples = samples.select(
|
|
321
|
+
pl.int_range(0, pl.count(), dtype=pl.UInt64).alias(SampleField.SAMPLE_INDEX),
|
|
322
|
+
pl.all(),
|
|
323
|
+
)
|
|
324
|
+
return samples
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def add_dataset_name_if_missing(table: pl.DataFrame, dataset_name: str) -> pl.DataFrame:
|
|
328
|
+
if SampleField.DATASET_NAME not in table.columns:
|
|
329
|
+
table = table.with_columns(pl.lit(dataset_name).alias(SampleField.DATASET_NAME))
|
|
330
|
+
else:
|
|
331
|
+
table = table.with_columns(
|
|
332
|
+
pl.when(pl.col(SampleField.DATASET_NAME).is_null())
|
|
333
|
+
.then(pl.lit(dataset_name))
|
|
334
|
+
.otherwise(pl.col(SampleField.DATASET_NAME))
|
|
335
|
+
.alias(SampleField.DATASET_NAME)
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
return table
|
|
@@ -33,9 +33,7 @@ class Bbox(Primitive):
|
|
|
33
33
|
class_name: Optional[str] = Field(default=None, description="Class name, e.g. 'car'")
|
|
34
34
|
class_idx: Optional[int] = Field(default=None, description="Class index, e.g. 0 for 'car' if it is the first class")
|
|
35
35
|
object_id: Optional[str] = Field(default=None, description="Unique identifier for the object, e.g. '12345123'")
|
|
36
|
-
confidence:
|
|
37
|
-
default=None, description="Confidence score (0-1.0) for the primitive, e.g. 0.95 for Bbox"
|
|
38
|
-
)
|
|
36
|
+
confidence: float = Field(default=1.0, description="Confidence score (0-1.0) for the primitive, e.g. 0.95 for Bbox")
|
|
39
37
|
ground_truth: bool = Field(default=True, description="Whether this is ground truth or a prediction")
|
|
40
38
|
|
|
41
39
|
task_name: str = Field(
|
|
@@ -45,11 +43,11 @@ class Bbox(Primitive):
|
|
|
45
43
|
|
|
46
44
|
@staticmethod
|
|
47
45
|
def default_task_name() -> str:
|
|
48
|
-
return "
|
|
46
|
+
return "object_detection"
|
|
49
47
|
|
|
50
48
|
@staticmethod
|
|
51
49
|
def column_name() -> str:
|
|
52
|
-
return "
|
|
50
|
+
return "bboxes"
|
|
53
51
|
|
|
54
52
|
def calculate_area(self) -> float:
|
|
55
53
|
return self.height * self.width
|
|
@@ -7,7 +7,6 @@ import numpy as np
|
|
|
7
7
|
import pycocotools.mask as coco_mask
|
|
8
8
|
from pydantic import Field
|
|
9
9
|
|
|
10
|
-
from hafnia.dataset.dataset_names import FieldName
|
|
11
10
|
from hafnia.dataset.primitives.primitive import Primitive
|
|
12
11
|
from hafnia.dataset.primitives.utils import (
|
|
13
12
|
anonymize_by_resizing,
|
|
@@ -16,8 +15,6 @@ from hafnia.dataset.primitives.utils import (
|
|
|
16
15
|
text_org_from_left_bottom_to_centered,
|
|
17
16
|
)
|
|
18
17
|
|
|
19
|
-
FieldName
|
|
20
|
-
|
|
21
18
|
|
|
22
19
|
class Bitmask(Primitive):
|
|
23
20
|
# Names should match names in FieldName
|
|
@@ -34,9 +31,7 @@ class Bitmask(Primitive):
|
|
|
34
31
|
class_name: Optional[str] = Field(default=None, description="Class name of the object represented by the bitmask")
|
|
35
32
|
class_idx: Optional[int] = Field(default=None, description="Class index of the object represented by the bitmask")
|
|
36
33
|
object_id: Optional[str] = Field(default=None, description="Object ID of the instance represented by the bitmask")
|
|
37
|
-
confidence:
|
|
38
|
-
default=None, description="Confidence score (0-1.0) for the primitive, e.g. 0.95 for Bbox"
|
|
39
|
-
)
|
|
34
|
+
confidence: float = Field(default=1.0, description="Confidence score (0-1.0) for the primitive, e.g. 0.95 for Bbox")
|
|
40
35
|
ground_truth: bool = Field(default=True, description="Whether this is ground truth or a prediction")
|
|
41
36
|
|
|
42
37
|
task_name: str = Field(
|
|
@@ -46,7 +41,7 @@ class Bitmask(Primitive):
|
|
|
46
41
|
|
|
47
42
|
@staticmethod
|
|
48
43
|
def default_task_name() -> str:
|
|
49
|
-
return "
|
|
44
|
+
return "mask_detection"
|
|
50
45
|
|
|
51
46
|
@staticmethod
|
|
52
47
|
def column_name() -> str:
|
|
@@ -12,8 +12,8 @@ class Classification(Primitive):
|
|
|
12
12
|
class_name: Optional[str] = Field(default=None, description="Class name, e.g. 'car'")
|
|
13
13
|
class_idx: Optional[int] = Field(default=None, description="Class index, e.g. 0 for 'car' if it is the first class")
|
|
14
14
|
object_id: Optional[str] = Field(default=None, description="Unique identifier for the object, e.g. '12345123'")
|
|
15
|
-
confidence:
|
|
16
|
-
default=
|
|
15
|
+
confidence: float = Field(
|
|
16
|
+
default=1.0, description="Confidence score (0-1.0) for the primitive, e.g. 0.95 for Classification"
|
|
17
17
|
)
|
|
18
18
|
ground_truth: bool = Field(default=True, description="Whether this is ground truth or a prediction")
|
|
19
19
|
|
|
@@ -27,7 +27,7 @@ class Classification(Primitive):
|
|
|
27
27
|
|
|
28
28
|
@staticmethod
|
|
29
29
|
def default_task_name() -> str:
|
|
30
|
-
return "
|
|
30
|
+
return "image_classification"
|
|
31
31
|
|
|
32
32
|
@staticmethod
|
|
33
33
|
def column_name() -> str:
|
|
@@ -16,9 +16,7 @@ class Polygon(Primitive):
|
|
|
16
16
|
class_name: Optional[str] = Field(default=None, description="Class name of the polygon")
|
|
17
17
|
class_idx: Optional[int] = Field(default=None, description="Class index of the polygon")
|
|
18
18
|
object_id: Optional[str] = Field(default=None, description="Object ID of the polygon")
|
|
19
|
-
confidence:
|
|
20
|
-
default=None, description="Confidence score (0-1.0) for the primitive, e.g. 0.95 for Bbox"
|
|
21
|
-
)
|
|
19
|
+
confidence: float = Field(default=1.0, description="Confidence score (0-1.0) for the primitive, e.g. 0.95 for Bbox")
|
|
22
20
|
ground_truth: bool = Field(default=True, description="Whether this is ground truth or a prediction")
|
|
23
21
|
|
|
24
22
|
task_name: str = Field(
|
|
@@ -40,7 +38,7 @@ class Polygon(Primitive):
|
|
|
40
38
|
|
|
41
39
|
@staticmethod
|
|
42
40
|
def default_task_name() -> str:
|
|
43
|
-
return "
|
|
41
|
+
return "polygon_detection"
|
|
44
42
|
|
|
45
43
|
@staticmethod
|
|
46
44
|
def column_name() -> str:
|
|
@@ -24,11 +24,11 @@ class Segmentation(Primitive):
|
|
|
24
24
|
|
|
25
25
|
@staticmethod
|
|
26
26
|
def default_task_name() -> str:
|
|
27
|
-
return "
|
|
27
|
+
return "semantic_segmentation"
|
|
28
28
|
|
|
29
29
|
@staticmethod
|
|
30
30
|
def column_name() -> str:
|
|
31
|
-
return "
|
|
31
|
+
return "segmentations"
|
|
32
32
|
|
|
33
33
|
def calculate_area(self) -> float:
|
|
34
34
|
raise NotImplementedError()
|
hafnia/platform/datasets.py
CHANGED
|
@@ -13,7 +13,7 @@ from rich.progress import track
|
|
|
13
13
|
|
|
14
14
|
from cli.config import Config
|
|
15
15
|
from hafnia import http, utils
|
|
16
|
-
from hafnia.dataset.dataset_names import DATASET_FILENAMES_REQUIRED
|
|
16
|
+
from hafnia.dataset.dataset_names import DATASET_FILENAMES_REQUIRED
|
|
17
17
|
from hafnia.dataset.dataset_recipe.dataset_recipe import (
|
|
18
18
|
DatasetRecipe,
|
|
19
19
|
get_dataset_path_from_recipe,
|
|
@@ -120,15 +120,11 @@ def download_dataset_from_access_endpoint(
|
|
|
120
120
|
return
|
|
121
121
|
dataset = HafniaDataset.from_path(path_dataset, check_for_images=False)
|
|
122
122
|
try:
|
|
123
|
-
|
|
124
|
-
src_paths=dataset.samples[ColumnName.REMOTE_PATH].to_list(),
|
|
125
|
-
dst_paths=dataset.samples[ColumnName.FILE_PATH].to_list(),
|
|
126
|
-
append_envs=envs,
|
|
127
|
-
description="Downloading images",
|
|
128
|
-
)
|
|
123
|
+
dataset = dataset.download_files_aws(path_dataset, aws_credentials=resource_credentials, force_redownload=True)
|
|
129
124
|
except ValueError as e:
|
|
130
125
|
user_logger.error(f"Failed to download images: {e}")
|
|
131
126
|
return
|
|
127
|
+
dataset.write_annotations(path_folder=path_dataset) # Overwrite annotations as files have been re-downloaded
|
|
132
128
|
|
|
133
129
|
|
|
134
130
|
def fast_copy_files_s3(
|
hafnia/platform/download.py
CHANGED
|
@@ -3,83 +3,12 @@ from typing import Dict, Optional
|
|
|
3
3
|
|
|
4
4
|
import boto3
|
|
5
5
|
from botocore.exceptions import ClientError
|
|
6
|
-
from pydantic import BaseModel, field_validator
|
|
7
6
|
from rich.progress import Progress
|
|
8
7
|
|
|
8
|
+
from hafnia.dataset.dataset_names import ResourceCredentials
|
|
9
9
|
from hafnia.http import fetch
|
|
10
10
|
from hafnia.log import sys_logger, user_logger
|
|
11
11
|
|
|
12
|
-
ARN_PREFIX = "arn:aws:s3:::"
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
class ResourceCredentials(BaseModel):
|
|
16
|
-
access_key: str
|
|
17
|
-
secret_key: str
|
|
18
|
-
session_token: str
|
|
19
|
-
s3_arn: str
|
|
20
|
-
region: str
|
|
21
|
-
|
|
22
|
-
@staticmethod
|
|
23
|
-
def fix_naming(payload: Dict[str, str]) -> "ResourceCredentials":
|
|
24
|
-
"""
|
|
25
|
-
The endpoint returns a payload with a key called 's3_path', but it
|
|
26
|
-
is actually an ARN path (starts with arn:aws:s3::). This method renames it to 's3_arn' for consistency.
|
|
27
|
-
"""
|
|
28
|
-
if "s3_path" in payload and payload["s3_path"].startswith(ARN_PREFIX):
|
|
29
|
-
payload["s3_arn"] = payload.pop("s3_path")
|
|
30
|
-
|
|
31
|
-
if "region" not in payload:
|
|
32
|
-
payload["region"] = "eu-west-1"
|
|
33
|
-
return ResourceCredentials(**payload)
|
|
34
|
-
|
|
35
|
-
@field_validator("s3_arn")
|
|
36
|
-
@classmethod
|
|
37
|
-
def validate_s3_arn(cls, value: str) -> str:
|
|
38
|
-
"""Validate s3_arn to ensure it starts with 'arn:aws:s3:::'"""
|
|
39
|
-
if not value.startswith("arn:aws:s3:::"):
|
|
40
|
-
raise ValueError(f"Invalid S3 ARN: {value}. It should start with 'arn:aws:s3:::'")
|
|
41
|
-
return value
|
|
42
|
-
|
|
43
|
-
def s3_path(self) -> str:
|
|
44
|
-
"""
|
|
45
|
-
Extracts the S3 path from the ARN.
|
|
46
|
-
Example: arn:aws:s3:::my-bucket/my-prefix -> my-bucket/my-prefix
|
|
47
|
-
"""
|
|
48
|
-
return self.s3_arn[len(ARN_PREFIX) :]
|
|
49
|
-
|
|
50
|
-
def s3_uri(self) -> str:
|
|
51
|
-
"""
|
|
52
|
-
Converts the S3 ARN to a URI format.
|
|
53
|
-
Example: arn:aws:s3:::my-bucket/my-prefix -> s3://my-bucket/my-prefix
|
|
54
|
-
"""
|
|
55
|
-
return f"s3://{self.s3_path()}"
|
|
56
|
-
|
|
57
|
-
def bucket_name(self) -> str:
|
|
58
|
-
"""
|
|
59
|
-
Extracts the bucket name from the S3 ARN.
|
|
60
|
-
Example: arn:aws:s3:::my-bucket/my-prefix -> my-bucket
|
|
61
|
-
"""
|
|
62
|
-
return self.s3_path().split("/")[0]
|
|
63
|
-
|
|
64
|
-
def object_key(self) -> str:
|
|
65
|
-
"""
|
|
66
|
-
Extracts the object key from the S3 ARN.
|
|
67
|
-
Example: arn:aws:s3:::my-bucket/my-prefix -> my-prefix
|
|
68
|
-
"""
|
|
69
|
-
return "/".join(self.s3_path().split("/")[1:])
|
|
70
|
-
|
|
71
|
-
def aws_credentials(self) -> Dict[str, str]:
|
|
72
|
-
"""
|
|
73
|
-
Returns the AWS credentials as a dictionary.
|
|
74
|
-
"""
|
|
75
|
-
environment_vars = {
|
|
76
|
-
"AWS_ACCESS_KEY_ID": self.access_key,
|
|
77
|
-
"AWS_SECRET_ACCESS_KEY": self.secret_key,
|
|
78
|
-
"AWS_SESSION_TOKEN": self.session_token,
|
|
79
|
-
"AWS_REGION": self.region,
|
|
80
|
-
}
|
|
81
|
-
return environment_vars
|
|
82
|
-
|
|
83
12
|
|
|
84
13
|
def get_resource_credentials(endpoint: str, api_key: str) -> ResourceCredentials:
|
|
85
14
|
"""
|
hafnia/torch_helpers.py
CHANGED
|
@@ -9,7 +9,7 @@ from torchvision import tv_tensors
|
|
|
9
9
|
from torchvision import utils as tv_utils
|
|
10
10
|
from torchvision.transforms import v2
|
|
11
11
|
|
|
12
|
-
from hafnia.dataset.dataset_names import
|
|
12
|
+
from hafnia.dataset.dataset_names import PrimitiveField
|
|
13
13
|
from hafnia.dataset.hafnia_dataset import HafniaDataset, Sample
|
|
14
14
|
from hafnia.dataset.primitives import (
|
|
15
15
|
PRIMITIVE_COLUMN_NAMES,
|
|
@@ -68,8 +68,8 @@ class TorchvisionDataset(torch.utils.data.Dataset):
|
|
|
68
68
|
for task_name, classifications in class_tasks.items():
|
|
69
69
|
assert len(classifications) == 1, "Expected exactly one classification task per sample"
|
|
70
70
|
target_flat[f"{Classification.column_name()}.{task_name}"] = {
|
|
71
|
-
|
|
72
|
-
|
|
71
|
+
PrimitiveField.CLASS_IDX: classifications[0].class_idx,
|
|
72
|
+
PrimitiveField.CLASS_NAME: classifications[0].class_name,
|
|
73
73
|
}
|
|
74
74
|
|
|
75
75
|
bbox_tasks: Dict[str, List[Bbox]] = get_primitives_per_task_name_for_primitive(sample, Bbox)
|
|
@@ -77,8 +77,8 @@ class TorchvisionDataset(torch.utils.data.Dataset):
|
|
|
77
77
|
bboxes_list = [bbox.to_coco(image_height=h, image_width=w) for bbox in bboxes]
|
|
78
78
|
bboxes_tensor = torch.as_tensor(bboxes_list).reshape(-1, 4)
|
|
79
79
|
target_flat[f"{Bbox.column_name()}.{task_name}"] = {
|
|
80
|
-
|
|
81
|
-
|
|
80
|
+
PrimitiveField.CLASS_IDX: [bbox.class_idx for bbox in bboxes],
|
|
81
|
+
PrimitiveField.CLASS_NAME: [bbox.class_name for bbox in bboxes],
|
|
82
82
|
"bbox": tv_tensors.BoundingBoxes(bboxes_tensor, format="XYWH", canvas_size=(h, w)),
|
|
83
83
|
}
|
|
84
84
|
|
|
@@ -86,8 +86,8 @@ class TorchvisionDataset(torch.utils.data.Dataset):
|
|
|
86
86
|
for task_name, bitmasks in bitmask_tasks.items():
|
|
87
87
|
bitmasks_np = np.array([bitmask.to_mask(img_height=h, img_width=w) for bitmask in bitmasks])
|
|
88
88
|
target_flat[f"{Bitmask.column_name()}.{task_name}"] = {
|
|
89
|
-
|
|
90
|
-
|
|
89
|
+
PrimitiveField.CLASS_IDX: [bitmask.class_idx for bitmask in bitmasks],
|
|
90
|
+
PrimitiveField.CLASS_NAME: [bitmask.class_name for bitmask in bitmasks],
|
|
91
91
|
"mask": tv_tensors.Mask(bitmasks_np),
|
|
92
92
|
}
|
|
93
93
|
|
|
@@ -161,7 +161,7 @@ def draw_image_and_targets(
|
|
|
161
161
|
if Bitmask.column_name() in targets:
|
|
162
162
|
primitive_annotations = targets[Bitmask.column_name()]
|
|
163
163
|
for task_name, task_annotations in primitive_annotations.items():
|
|
164
|
-
colors = [class_color_by_name(class_name) for class_name in task_annotations[
|
|
164
|
+
colors = [class_color_by_name(class_name) for class_name in task_annotations[PrimitiveField.CLASS_NAME]]
|
|
165
165
|
visualize_image = tv_utils.draw_segmentation_masks(
|
|
166
166
|
image=visualize_image,
|
|
167
167
|
masks=task_annotations["mask"],
|
|
@@ -172,11 +172,11 @@ def draw_image_and_targets(
|
|
|
172
172
|
primitive_annotations = targets[Bbox.column_name()]
|
|
173
173
|
for task_name, task_annotations in primitive_annotations.items():
|
|
174
174
|
bboxes = torchvision.ops.box_convert(task_annotations["bbox"], in_fmt="xywh", out_fmt="xyxy")
|
|
175
|
-
colors = [class_color_by_name(class_name) for class_name in task_annotations[
|
|
175
|
+
colors = [class_color_by_name(class_name) for class_name in task_annotations[PrimitiveField.CLASS_NAME]]
|
|
176
176
|
visualize_image = tv_utils.draw_bounding_boxes(
|
|
177
177
|
image=visualize_image,
|
|
178
178
|
boxes=bboxes,
|
|
179
|
-
labels=task_annotations[
|
|
179
|
+
labels=task_annotations[PrimitiveField.CLASS_NAME],
|
|
180
180
|
width=2,
|
|
181
181
|
colors=colors,
|
|
182
182
|
)
|
|
@@ -187,9 +187,9 @@ def draw_image_and_targets(
|
|
|
187
187
|
text_labels = []
|
|
188
188
|
for task_name, task_annotations in primitive_annotations.items():
|
|
189
189
|
if task_name == Classification.default_task_name():
|
|
190
|
-
text_label = task_annotations[
|
|
190
|
+
text_label = task_annotations[PrimitiveField.CLASS_NAME]
|
|
191
191
|
else:
|
|
192
|
-
text_label = f"{task_name}: {task_annotations[
|
|
192
|
+
text_label = f"{task_name}: {task_annotations[PrimitiveField.CLASS_NAME]}"
|
|
193
193
|
text_labels.append(text_label)
|
|
194
194
|
visualize_image = draw_image_classification(visualize_image, text_labels)
|
|
195
195
|
return visualize_image
|
|
@@ -193,6 +193,8 @@ def save_dataset_sample_set_visualizations(
|
|
|
193
193
|
image = draw_annotations(image, annotations, draw_settings=draw_settings)
|
|
194
194
|
|
|
195
195
|
pil_image = Image.fromarray(image)
|
|
196
|
+
if sample.file_path is None:
|
|
197
|
+
raise ValueError("Sample has no file_path defined.")
|
|
196
198
|
path_image = path_output_folder / Path(sample.file_path).name
|
|
197
199
|
pil_image.save(path_image)
|
|
198
200
|
paths.append(path_image)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: hafnia
|
|
3
|
-
Version: 0.4.
|
|
3
|
+
Version: 0.4.1
|
|
4
4
|
Summary: Python SDK for communication with Hafnia platform.
|
|
5
5
|
Author-email: Milestone Systems <hafniaplatform@milestone.dk>
|
|
6
6
|
License-File: LICENSE
|
|
@@ -158,7 +158,7 @@ and `dataset.samples` with annotations as a polars DataFrame
|
|
|
158
158
|
print(dataset.samples.head(2))
|
|
159
159
|
shape: (2, 14)
|
|
160
160
|
┌──────────────┬─────────────────────────────────┬────────┬───────┬───┬─────────────────────────────────┬──────────┬──────────┬─────────────────────────────────┐
|
|
161
|
-
│ sample_index ┆ file_name ┆ height ┆ width ┆ … ┆
|
|
161
|
+
│ sample_index ┆ file_name ┆ height ┆ width ┆ … ┆ bboxes ┆ bitmasks ┆ polygons ┆ meta │
|
|
162
162
|
│ --- ┆ --- ┆ --- ┆ --- ┆ ┆ --- ┆ --- ┆ --- ┆ --- │
|
|
163
163
|
│ u32 ┆ str ┆ i64 ┆ i64 ┆ ┆ list[struct[11]] ┆ null ┆ null ┆ struct[5] │
|
|
164
164
|
╞══════════════╪═════════════════════════════════╪════════╪═══════╪═══╪═════════════════════════════════╪══════════╪══════════╪═════════════════════════════════╡
|
|
@@ -218,7 +218,7 @@ sample_dict = dataset[0]
|
|
|
218
218
|
|
|
219
219
|
for sample_dict in dataset:
|
|
220
220
|
sample = Sample(**sample_dict)
|
|
221
|
-
print(sample.sample_id, sample.
|
|
221
|
+
print(sample.sample_id, sample.bboxes)
|
|
222
222
|
break
|
|
223
223
|
```
|
|
224
224
|
Not that it is possible to create a `Sample` object from the sample dictionary.
|
|
@@ -421,7 +421,7 @@ pil_image.save("visualized_labels.png")
|
|
|
421
421
|
|
|
422
422
|
# Create DataLoaders - using TorchVisionCollateFn
|
|
423
423
|
collate_fn = torch_helpers.TorchVisionCollateFn(
|
|
424
|
-
skip_stacking=["
|
|
424
|
+
skip_stacking=["bboxes.bbox", "bboxes.class_idx"]
|
|
425
425
|
)
|
|
426
426
|
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)
|
|
427
427
|
```
|