hafnia 0.1.26__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/__main__.py +2 -2
- cli/dataset_cmds.py +60 -0
- cli/runc_cmds.py +1 -1
- hafnia/data/__init__.py +2 -2
- hafnia/data/factory.py +9 -56
- hafnia/dataset/dataset_helpers.py +91 -0
- hafnia/dataset/dataset_names.py +71 -0
- hafnia/dataset/dataset_transformation.py +187 -0
- hafnia/dataset/dataset_upload_helper.py +468 -0
- hafnia/dataset/hafnia_dataset.py +453 -0
- hafnia/dataset/primitives/__init__.py +16 -0
- hafnia/dataset/primitives/bbox.py +137 -0
- hafnia/dataset/primitives/bitmask.py +182 -0
- hafnia/dataset/primitives/classification.py +56 -0
- hafnia/dataset/primitives/point.py +25 -0
- hafnia/dataset/primitives/polygon.py +100 -0
- hafnia/dataset/primitives/primitive.py +44 -0
- hafnia/dataset/primitives/segmentation.py +51 -0
- hafnia/dataset/primitives/utils.py +51 -0
- hafnia/dataset/table_transformations.py +183 -0
- hafnia/experiment/hafnia_logger.py +2 -2
- hafnia/helper_testing.py +63 -0
- hafnia/http.py +5 -3
- hafnia/platform/__init__.py +2 -2
- hafnia/platform/builder.py +25 -19
- hafnia/platform/datasets.py +184 -0
- hafnia/platform/download.py +85 -23
- hafnia/torch_helpers.py +180 -95
- hafnia/utils.py +1 -1
- hafnia/visualizations/colors.py +267 -0
- hafnia/visualizations/image_visualizations.py +202 -0
- {hafnia-0.1.26.dist-info → hafnia-0.2.0.dist-info}/METADATA +212 -99
- hafnia-0.2.0.dist-info/RECORD +46 -0
- cli/data_cmds.py +0 -53
- hafnia-0.1.26.dist-info/RECORD +0 -27
- {hafnia-0.1.26.dist-info → hafnia-0.2.0.dist-info}/WHEEL +0 -0
- {hafnia-0.1.26.dist-info → hafnia-0.2.0.dist-info}/entry_points.txt +0 -0
- {hafnia-0.1.26.dist-info → hafnia-0.2.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,453 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import shutil
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Dict, List, Optional, Type, Union
|
|
8
|
+
|
|
9
|
+
import more_itertools
|
|
10
|
+
import numpy as np
|
|
11
|
+
import polars as pl
|
|
12
|
+
import rich
|
|
13
|
+
from PIL import Image
|
|
14
|
+
from pydantic import BaseModel, field_serializer, field_validator
|
|
15
|
+
from rich import print as rprint
|
|
16
|
+
from rich.table import Table
|
|
17
|
+
from tqdm import tqdm
|
|
18
|
+
|
|
19
|
+
from hafnia.dataset import dataset_helpers, dataset_transformation
|
|
20
|
+
from hafnia.dataset.dataset_names import (
|
|
21
|
+
DATASET_FILENAMES,
|
|
22
|
+
FILENAME_ANNOTATIONS_JSONL,
|
|
23
|
+
FILENAME_ANNOTATIONS_PARQUET,
|
|
24
|
+
FILENAME_DATASET_INFO,
|
|
25
|
+
ColumnName,
|
|
26
|
+
FieldName,
|
|
27
|
+
SplitName,
|
|
28
|
+
)
|
|
29
|
+
from hafnia.dataset.primitives import (
|
|
30
|
+
PRIMITIVE_NAME_TO_TYPE,
|
|
31
|
+
PRIMITIVE_TYPES,
|
|
32
|
+
)
|
|
33
|
+
from hafnia.dataset.primitives.bbox import Bbox
|
|
34
|
+
from hafnia.dataset.primitives.bitmask import Bitmask
|
|
35
|
+
from hafnia.dataset.primitives.classification import Classification
|
|
36
|
+
from hafnia.dataset.primitives.polygon import Polygon
|
|
37
|
+
from hafnia.dataset.primitives.primitive import Primitive
|
|
38
|
+
from hafnia.dataset.table_transformations import (
|
|
39
|
+
check_image_paths,
|
|
40
|
+
create_primitive_table,
|
|
41
|
+
read_table_from_path,
|
|
42
|
+
)
|
|
43
|
+
from hafnia.log import user_logger
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class TaskInfo(BaseModel):
|
|
47
|
+
primitive: Type[Primitive] # Primitive class or string name of the primitive, e.g. "Bbox" or "bitmask"
|
|
48
|
+
class_names: Optional[List[str]] # Class names for the tasks. To get consistent class indices specify class_names.
|
|
49
|
+
name: Optional[str] = (
|
|
50
|
+
None # None to use the default primitive task name Bbox ->"bboxes", Bitmask -> "bitmasks" etc.
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
def model_post_init(self, __context: Any) -> None:
|
|
54
|
+
if self.name is None:
|
|
55
|
+
self.name = self.primitive.default_task_name()
|
|
56
|
+
|
|
57
|
+
# The 'primitive'-field of type 'Type[Primitive]' is not supported by pydantic out-of-the-box as
|
|
58
|
+
# the 'Primitive' class is an abstract base class and for the actual primtives such as Bbox, Bitmask, Classification.
|
|
59
|
+
# Below magic functions ('ensure_primitive' and 'serialize_primitive') ensures that the 'primitive' field can
|
|
60
|
+
# correctly validate and serialize sub-classes (Bbox, Classification, ...).
|
|
61
|
+
@field_validator("primitive", mode="plain")
|
|
62
|
+
@classmethod
|
|
63
|
+
def ensure_primitive(cls, primitive: Any) -> Any:
|
|
64
|
+
if isinstance(primitive, str):
|
|
65
|
+
if primitive not in PRIMITIVE_NAME_TO_TYPE:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
f"Primitive '{primitive}' is not recognized. Available primitives: {list(PRIMITIVE_NAME_TO_TYPE.keys())}"
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
return PRIMITIVE_NAME_TO_TYPE[primitive]
|
|
71
|
+
|
|
72
|
+
if issubclass(primitive, Primitive):
|
|
73
|
+
return primitive
|
|
74
|
+
|
|
75
|
+
raise ValueError(f"Primitive must be a string or a Primitive subclass, got {type(primitive)} instead.")
|
|
76
|
+
|
|
77
|
+
@field_serializer("primitive")
|
|
78
|
+
@classmethod
|
|
79
|
+
def serialize_primitive(cls, primitive: Type[Primitive]) -> str:
|
|
80
|
+
if not issubclass(primitive, Primitive):
|
|
81
|
+
raise ValueError(f"Primitive must be a subclass of Primitive, got {type(primitive)} instead.")
|
|
82
|
+
return primitive.__name__
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class DatasetInfo(BaseModel):
|
|
86
|
+
dataset_name: str
|
|
87
|
+
version: str
|
|
88
|
+
tasks: list[TaskInfo]
|
|
89
|
+
distributions: Optional[List[TaskInfo]] = None # Distributions. TODO: FIX/REMOVE/CHANGE this
|
|
90
|
+
meta: Optional[Dict[str, Any]] = None # Metadata about the dataset, e.g. description, etc.
|
|
91
|
+
|
|
92
|
+
def write_json(self, path: Path, indent: Optional[int] = 4) -> None:
|
|
93
|
+
json_str = self.model_dump_json(indent=indent)
|
|
94
|
+
path.write_text(json_str)
|
|
95
|
+
|
|
96
|
+
@staticmethod
|
|
97
|
+
def from_json_file(path: Path) -> "DatasetInfo":
|
|
98
|
+
json_str = path.read_text()
|
|
99
|
+
return DatasetInfo.model_validate_json(json_str)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class Sample(BaseModel):
|
|
103
|
+
file_name: str
|
|
104
|
+
height: int
|
|
105
|
+
width: int
|
|
106
|
+
split: str # Split name, e.g., "train", "val", "test"
|
|
107
|
+
is_sample: bool # Indicates if this is a sample (True) or a metadata entry (False)
|
|
108
|
+
collection_index: Optional[int] = None # Optional e.g. frame number for video datasets
|
|
109
|
+
collection_id: Optional[str] = None # Optional e.g. video name for video datasets
|
|
110
|
+
remote_path: Optional[str] = None # Optional remote path for the image, if applicable
|
|
111
|
+
sample_index: Optional[int] = None # Don't manually set this, it is used for indexing samples in the dataset.
|
|
112
|
+
classifications: Optional[List[Classification]] = None # Optional classification primitive
|
|
113
|
+
objects: Optional[List[Bbox]] = None # List of coordinate primitives, e.g., Bbox, Bitmask, etc.
|
|
114
|
+
bitmasks: Optional[List[Bitmask]] = None # List of bitmasks, if applicable
|
|
115
|
+
polygons: Optional[List[Polygon]] = None # List of polygons, if applicable
|
|
116
|
+
|
|
117
|
+
meta: Optional[Dict] = None # Additional metadata, e.g., camera settings, GPS data, etc.
|
|
118
|
+
|
|
119
|
+
def get_annotations(self, primitive_types: Optional[List[Type[Primitive]]] = None) -> List[Primitive]:
|
|
120
|
+
"""
|
|
121
|
+
Returns a list of all annotations (classifications, objects, bitmasks, polygons) for the sample.
|
|
122
|
+
"""
|
|
123
|
+
primitive_types = primitive_types or PRIMITIVE_TYPES
|
|
124
|
+
annotations_primitives = [
|
|
125
|
+
getattr(self, primitive_type.column_name(), None) for primitive_type in primitive_types
|
|
126
|
+
]
|
|
127
|
+
annotations = more_itertools.flatten(
|
|
128
|
+
[primitives for primitives in annotations_primitives if primitives is not None]
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
return list(annotations)
|
|
132
|
+
|
|
133
|
+
def read_image_pillow(self) -> Image.Image:
|
|
134
|
+
"""
|
|
135
|
+
Reads the image from the file path and returns it as a PIL Image.
|
|
136
|
+
Raises FileNotFoundError if the image file does not exist.
|
|
137
|
+
"""
|
|
138
|
+
path_image = Path(self.file_name)
|
|
139
|
+
if not path_image.exists():
|
|
140
|
+
raise FileNotFoundError(f"Image file {path_image} does not exist. Please check the file path.")
|
|
141
|
+
|
|
142
|
+
image = Image.open(str(path_image))
|
|
143
|
+
return image
|
|
144
|
+
|
|
145
|
+
def read_image(self) -> np.ndarray:
|
|
146
|
+
image_pil = self.read_image_pillow()
|
|
147
|
+
image = np.array(image_pil)
|
|
148
|
+
return image
|
|
149
|
+
|
|
150
|
+
def draw_annotations(self, image: Optional[np.ndarray] = None) -> np.ndarray:
|
|
151
|
+
from hafnia.visualizations import image_visualizations
|
|
152
|
+
|
|
153
|
+
image = image or self.read_image()
|
|
154
|
+
annotations = self.get_annotations()
|
|
155
|
+
annotations_visualized = image_visualizations.draw_annotations(image=image, primitives=annotations)
|
|
156
|
+
return annotations_visualized
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
@dataclass
|
|
160
|
+
class HafniaDataset:
|
|
161
|
+
info: DatasetInfo
|
|
162
|
+
samples: pl.DataFrame
|
|
163
|
+
|
|
164
|
+
def __getitem__(self, item: int) -> Dict[str, Any]:
|
|
165
|
+
return self.samples.row(index=item, named=True)
|
|
166
|
+
|
|
167
|
+
def __len__(self) -> int:
|
|
168
|
+
return len(self.samples)
|
|
169
|
+
|
|
170
|
+
def __iter__(self):
|
|
171
|
+
for row in self.samples.iter_rows(named=True):
|
|
172
|
+
yield row
|
|
173
|
+
|
|
174
|
+
# Dataset transformations
|
|
175
|
+
apply_image_transform = dataset_transformation.transform_images
|
|
176
|
+
sample = dataset_transformation.sample
|
|
177
|
+
shuffle = dataset_transformation.shuffle_dataset
|
|
178
|
+
split_by_ratios = dataset_transformation.splits_by_ratios
|
|
179
|
+
divide_split_into_multiple_splits = dataset_transformation.divide_split_into_multiple_splits
|
|
180
|
+
sample_set_by_size = dataset_transformation.define_sample_set_by_size
|
|
181
|
+
|
|
182
|
+
@staticmethod
|
|
183
|
+
def from_samples_list(samples_list: List, info: DatasetInfo) -> "HafniaDataset":
|
|
184
|
+
sample = samples_list[0]
|
|
185
|
+
if isinstance(sample, Sample):
|
|
186
|
+
json_samples = [sample.model_dump(mode="json") for sample in samples_list]
|
|
187
|
+
elif isinstance(sample, dict):
|
|
188
|
+
json_samples = samples_list
|
|
189
|
+
else:
|
|
190
|
+
raise TypeError(f"Unsupported sample type: {type(sample)}. Expected Sample or dict.")
|
|
191
|
+
|
|
192
|
+
table = pl.from_records(json_samples).drop(ColumnName.SAMPLE_INDEX)
|
|
193
|
+
table = table.with_row_index(name=ColumnName.SAMPLE_INDEX) # Add sample index column
|
|
194
|
+
|
|
195
|
+
return HafniaDataset(info=info, samples=table)
|
|
196
|
+
|
|
197
|
+
def as_dict_dataset_splits(self) -> Dict[str, "HafniaDataset"]:
|
|
198
|
+
if ColumnName.SPLIT not in self.samples.columns:
|
|
199
|
+
raise ValueError(f"Dataset must contain a '{ColumnName.SPLIT}' column.")
|
|
200
|
+
|
|
201
|
+
splits = {}
|
|
202
|
+
for split_name in SplitName.valid_splits():
|
|
203
|
+
splits[split_name] = self.create_split_dataset(split_name)
|
|
204
|
+
|
|
205
|
+
return splits
|
|
206
|
+
|
|
207
|
+
def create_sample_dataset(self) -> "HafniaDataset":
|
|
208
|
+
if ColumnName.IS_SAMPLE not in self.samples.columns:
|
|
209
|
+
raise ValueError(f"Dataset must contain an '{ColumnName.IS_SAMPLE}' column.")
|
|
210
|
+
table = self.samples.filter(pl.col(ColumnName.IS_SAMPLE))
|
|
211
|
+
return self.update_table(table)
|
|
212
|
+
|
|
213
|
+
def create_split_dataset(self, split_name: Union[str | List[str]]) -> "HafniaDataset":
|
|
214
|
+
if isinstance(split_name, str):
|
|
215
|
+
split_names = [split_name]
|
|
216
|
+
elif isinstance(split_name, list):
|
|
217
|
+
split_names = split_name
|
|
218
|
+
|
|
219
|
+
for name in split_names:
|
|
220
|
+
if name not in SplitName.valid_splits():
|
|
221
|
+
raise ValueError(f"Invalid split name: {split_name}. Valid splits are: {SplitName.valid_splits()}")
|
|
222
|
+
|
|
223
|
+
filtered_dataset = self.samples.filter(pl.col(ColumnName.SPLIT).is_in(split_names))
|
|
224
|
+
return self.update_table(filtered_dataset)
|
|
225
|
+
|
|
226
|
+
def get_task_by_name(self, task_name: str) -> TaskInfo:
|
|
227
|
+
for task in self.info.tasks:
|
|
228
|
+
if task.name == task_name:
|
|
229
|
+
return task
|
|
230
|
+
raise ValueError(f"Task with name {task_name} not found in dataset info.")
|
|
231
|
+
|
|
232
|
+
def update_table(self, table: pl.DataFrame) -> "HafniaDataset":
|
|
233
|
+
return HafniaDataset(info=self.info.model_copy(), samples=table)
|
|
234
|
+
|
|
235
|
+
@staticmethod
|
|
236
|
+
def check_dataset_path(path_dataset: Path, raise_error: bool = True) -> bool:
|
|
237
|
+
"""
|
|
238
|
+
Checks if the dataset path exists and contains the required files.
|
|
239
|
+
Returns True if the dataset is valid, otherwise raises an error or returns False.
|
|
240
|
+
"""
|
|
241
|
+
if not path_dataset.exists():
|
|
242
|
+
if raise_error:
|
|
243
|
+
raise FileNotFoundError(f"Dataset path {path_dataset} does not exist.")
|
|
244
|
+
return False
|
|
245
|
+
|
|
246
|
+
required_files = [
|
|
247
|
+
FILENAME_DATASET_INFO,
|
|
248
|
+
FILENAME_ANNOTATIONS_JSONL,
|
|
249
|
+
FILENAME_ANNOTATIONS_PARQUET,
|
|
250
|
+
]
|
|
251
|
+
for filename in required_files:
|
|
252
|
+
if not (path_dataset / filename).exists():
|
|
253
|
+
if raise_error:
|
|
254
|
+
raise FileNotFoundError(f"Required file {filename} not found in {path_dataset}.")
|
|
255
|
+
return False
|
|
256
|
+
|
|
257
|
+
return True
|
|
258
|
+
|
|
259
|
+
@staticmethod
|
|
260
|
+
def read_from_path(path_folder: Path, check_for_images: bool = True) -> "HafniaDataset":
|
|
261
|
+
HafniaDataset.check_dataset_path(path_folder, raise_error=True)
|
|
262
|
+
|
|
263
|
+
dataset_info = DatasetInfo.from_json_file(path_folder / FILENAME_DATASET_INFO)
|
|
264
|
+
table = read_table_from_path(path_folder)
|
|
265
|
+
|
|
266
|
+
# Convert from relative paths to absolute paths
|
|
267
|
+
table = table.with_columns(
|
|
268
|
+
pl.concat_str([pl.lit(str(path_folder.absolute()) + os.sep), pl.col("file_name")]).alias("file_name")
|
|
269
|
+
)
|
|
270
|
+
if check_for_images:
|
|
271
|
+
check_image_paths(table)
|
|
272
|
+
return HafniaDataset(samples=table, info=dataset_info)
|
|
273
|
+
|
|
274
|
+
def write(self, path_folder: Path, name_by_hash: bool = True, add_version: bool = False) -> None:
|
|
275
|
+
user_logger.info(f"Writing dataset to {path_folder}...")
|
|
276
|
+
if not path_folder.exists():
|
|
277
|
+
path_folder.mkdir(parents=True)
|
|
278
|
+
path_folder_images = path_folder / "data"
|
|
279
|
+
path_folder_images.mkdir(parents=True, exist_ok=True)
|
|
280
|
+
|
|
281
|
+
new_relative_paths = []
|
|
282
|
+
for org_path in tqdm(self.samples["file_name"].to_list(), desc="- Copy images"):
|
|
283
|
+
org_path = Path(org_path)
|
|
284
|
+
if not org_path.exists():
|
|
285
|
+
raise FileNotFoundError(f"File {org_path} does not exist in the dataset.")
|
|
286
|
+
if name_by_hash:
|
|
287
|
+
filename = dataset_helpers.filename_as_hash_from_path(org_path)
|
|
288
|
+
else:
|
|
289
|
+
filename = Path(org_path).name
|
|
290
|
+
new_path = path_folder_images / filename
|
|
291
|
+
if not new_path.exists():
|
|
292
|
+
shutil.copy2(org_path, new_path)
|
|
293
|
+
|
|
294
|
+
if not new_path.exists():
|
|
295
|
+
raise FileNotFoundError(f"File {new_path} does not exist in the dataset.")
|
|
296
|
+
new_relative_paths.append(str(new_path.relative_to(path_folder)))
|
|
297
|
+
|
|
298
|
+
table = self.samples.with_columns(pl.Series(new_relative_paths).alias("file_name"))
|
|
299
|
+
table.write_ndjson(path_folder / FILENAME_ANNOTATIONS_JSONL) # Json for readability
|
|
300
|
+
table.write_parquet(path_folder / FILENAME_ANNOTATIONS_PARQUET) # Parquet for speed
|
|
301
|
+
self.info.write_json(path_folder / FILENAME_DATASET_INFO)
|
|
302
|
+
|
|
303
|
+
if add_version:
|
|
304
|
+
path_version = path_folder / "versions" / f"{self.info.version}"
|
|
305
|
+
path_version.mkdir(parents=True, exist_ok=True)
|
|
306
|
+
for filename in DATASET_FILENAMES:
|
|
307
|
+
shutil.copy2(path_folder / filename, path_version / filename)
|
|
308
|
+
|
|
309
|
+
def __eq__(self, value) -> bool:
|
|
310
|
+
if not isinstance(value, HafniaDataset):
|
|
311
|
+
return False
|
|
312
|
+
|
|
313
|
+
if self.info != value.info:
|
|
314
|
+
return False
|
|
315
|
+
|
|
316
|
+
if not isinstance(self.samples, pl.DataFrame) or not isinstance(value.samples, pl.DataFrame):
|
|
317
|
+
return False
|
|
318
|
+
|
|
319
|
+
if not self.samples.equals(value.samples):
|
|
320
|
+
return False
|
|
321
|
+
return True
|
|
322
|
+
|
|
323
|
+
def print_stats(self) -> None:
|
|
324
|
+
table_base = Table(title="Dataset Statistics", show_lines=True, box=rich.box.SIMPLE)
|
|
325
|
+
table_base.add_column("Property", style="cyan")
|
|
326
|
+
table_base.add_column("Value")
|
|
327
|
+
table_base.add_row("Dataset Name", self.info.dataset_name)
|
|
328
|
+
table_base.add_row("Version", self.info.version)
|
|
329
|
+
table_base.add_row("Number of samples", str(len(self.samples)))
|
|
330
|
+
rprint(table_base)
|
|
331
|
+
rprint(self.info.tasks)
|
|
332
|
+
|
|
333
|
+
splits_sets = {
|
|
334
|
+
"All": SplitName.valid_splits(),
|
|
335
|
+
"Train": [SplitName.TRAIN],
|
|
336
|
+
"Validation": [SplitName.VAL],
|
|
337
|
+
"Test": [SplitName.TEST],
|
|
338
|
+
}
|
|
339
|
+
rows = []
|
|
340
|
+
for split_name, splits in splits_sets.items():
|
|
341
|
+
dataset_split = self.create_split_dataset(splits)
|
|
342
|
+
table = dataset_split.samples
|
|
343
|
+
row = {}
|
|
344
|
+
row["Split"] = split_name
|
|
345
|
+
row["Sample "] = str(len(table))
|
|
346
|
+
for PrimitiveType in PRIMITIVE_TYPES:
|
|
347
|
+
column_name = PrimitiveType.column_name()
|
|
348
|
+
objects_df = create_primitive_table(table, PrimitiveType=PrimitiveType, keep_sample_data=False)
|
|
349
|
+
if objects_df is None:
|
|
350
|
+
continue
|
|
351
|
+
for (task_name,), object_group in objects_df.group_by(FieldName.TASK_NAME):
|
|
352
|
+
count = len(object_group[FieldName.CLASS_NAME])
|
|
353
|
+
row[f"{PrimitiveType.__name__}\n{task_name}"] = str(count)
|
|
354
|
+
rows.append(row)
|
|
355
|
+
|
|
356
|
+
rich_table = Table(title="Dataset Statistics", show_lines=True, box=rich.box.SIMPLE)
|
|
357
|
+
for i_row, row in enumerate(rows):
|
|
358
|
+
if i_row == 0:
|
|
359
|
+
for column_name in row.keys():
|
|
360
|
+
rich_table.add_column(column_name, justify="left", style="cyan")
|
|
361
|
+
rich_table.add_row(*[str(value) for value in row.values()])
|
|
362
|
+
rprint(rich_table)
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
def check_hafnia_dataset_from_path(path_dataset: Path) -> None:
|
|
366
|
+
dataset = HafniaDataset.read_from_path(path_dataset, check_for_images=True)
|
|
367
|
+
check_hafnia_dataset(dataset)
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
def check_hafnia_dataset(dataset: HafniaDataset):
|
|
371
|
+
user_logger.info("Checking Hafnia dataset...")
|
|
372
|
+
assert isinstance(dataset.info.version, str) and len(dataset.info.version) > 0
|
|
373
|
+
assert isinstance(dataset.info.dataset_name, str) and len(dataset.info.dataset_name) > 0
|
|
374
|
+
|
|
375
|
+
is_sample_list = set(dataset.samples.select(pl.col(ColumnName.IS_SAMPLE)).unique().to_series().to_list())
|
|
376
|
+
if True not in is_sample_list:
|
|
377
|
+
raise ValueError(f"The dataset should contain '{ColumnName.IS_SAMPLE}=True' samples")
|
|
378
|
+
|
|
379
|
+
actual_splits = dataset.samples.select(pl.col(ColumnName.SPLIT)).unique().to_series().to_list()
|
|
380
|
+
expected_splits = SplitName.valid_splits()
|
|
381
|
+
if set(actual_splits) != set(expected_splits):
|
|
382
|
+
raise ValueError(f"Expected all splits '{expected_splits}' in dataset, but got '{actual_splits}'. ")
|
|
383
|
+
|
|
384
|
+
expected_tasks = dataset.info.tasks
|
|
385
|
+
for task in expected_tasks:
|
|
386
|
+
primitive = task.primitive.__name__
|
|
387
|
+
column_name = task.primitive.column_name()
|
|
388
|
+
primitive_column = dataset.samples[column_name]
|
|
389
|
+
# msg_something_wrong = f"Something is wrong with the '{primtive_name}' task '{task.name}' in dataset '{dataset.name}'. "
|
|
390
|
+
msg_something_wrong = (
|
|
391
|
+
f"Something is wrong with the defined tasks ('info.tasks') in dataset '{dataset.info.dataset_name}'. \n"
|
|
392
|
+
f"For '{primitive=}' and '{task.name=}' "
|
|
393
|
+
)
|
|
394
|
+
if primitive_column.dtype == pl.Null:
|
|
395
|
+
raise ValueError(msg_something_wrong + "the column is 'Null'. Please check the dataset.")
|
|
396
|
+
|
|
397
|
+
primitive_table = primitive_column.explode().struct.unnest().filter(pl.col(FieldName.TASK_NAME) == task.name)
|
|
398
|
+
if primitive_table.is_empty():
|
|
399
|
+
raise ValueError(
|
|
400
|
+
msg_something_wrong
|
|
401
|
+
+ f"the column '{column_name}' has no {task.name=} objects. Please check the dataset."
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
actual_classes = set(primitive_table[FieldName.CLASS_NAME].unique().to_list())
|
|
405
|
+
if task.class_names is None:
|
|
406
|
+
raise ValueError(
|
|
407
|
+
msg_something_wrong
|
|
408
|
+
+ f"the column '{column_name}' with {task.name=} has no defined classes. Please check the dataset."
|
|
409
|
+
)
|
|
410
|
+
defined_classes = set(task.class_names)
|
|
411
|
+
|
|
412
|
+
if not actual_classes.issubset(defined_classes):
|
|
413
|
+
raise ValueError(
|
|
414
|
+
msg_something_wrong
|
|
415
|
+
+ f"the column '{column_name}' with {task.name=} we expected the actual classes in the dataset to \n"
|
|
416
|
+
f"to be a subset of the defined classes\n\t{actual_classes=} \n\t{defined_classes=}."
|
|
417
|
+
)
|
|
418
|
+
# Check class_indices
|
|
419
|
+
mapped_indices = primitive_table[FieldName.CLASS_NAME].map_elements(
|
|
420
|
+
lambda x: task.class_names.index(x), return_dtype=pl.Int64
|
|
421
|
+
)
|
|
422
|
+
table_indices = primitive_table[FieldName.CLASS_IDX]
|
|
423
|
+
|
|
424
|
+
error_msg = msg_something_wrong + (
|
|
425
|
+
f"class indices in '{FieldName.CLASS_IDX}' column does not match classes ordering in 'task.class_names'"
|
|
426
|
+
)
|
|
427
|
+
assert mapped_indices.equals(table_indices), error_msg
|
|
428
|
+
|
|
429
|
+
distribution = dataset.info.distributions or []
|
|
430
|
+
distribution_names = [task.name for task in distribution]
|
|
431
|
+
# Check that tasks found in the 'dataset.table' matches the tasks defined in 'dataset.info.tasks'
|
|
432
|
+
for PrimitiveType in PRIMITIVE_TYPES:
|
|
433
|
+
column_name = PrimitiveType.column_name()
|
|
434
|
+
if column_name not in dataset.samples.columns:
|
|
435
|
+
continue
|
|
436
|
+
objects_df = create_primitive_table(dataset.samples, PrimitiveType=PrimitiveType, keep_sample_data=False)
|
|
437
|
+
if objects_df is None:
|
|
438
|
+
continue
|
|
439
|
+
for (task_name,), object_group in objects_df.group_by(FieldName.TASK_NAME):
|
|
440
|
+
has_task = any([t for t in expected_tasks if t.name == task_name and t.primitive == PrimitiveType])
|
|
441
|
+
if has_task:
|
|
442
|
+
continue
|
|
443
|
+
if task_name in distribution_names:
|
|
444
|
+
continue
|
|
445
|
+
class_names = object_group[FieldName.CLASS_NAME].unique().to_list()
|
|
446
|
+
raise ValueError(
|
|
447
|
+
f"Task name '{task_name}' for the '{PrimitiveType.__name__}' primitive is missing in "
|
|
448
|
+
f"'dataset.info.tasks' for dataset '{task_name}'. Missing task has the following "
|
|
449
|
+
f"classes: {class_names}. "
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
for sample_dict in tqdm(dataset, desc="Checking samples in dataset"):
|
|
453
|
+
sample = Sample(**sample_dict) # Checks format of all samples with pydantic validation # noqa: F841
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import List, Type
|
|
4
|
+
|
|
5
|
+
from .bbox import Bbox
|
|
6
|
+
from .bitmask import Bitmask
|
|
7
|
+
from .classification import Classification
|
|
8
|
+
from .point import Point # noqa: F401
|
|
9
|
+
from .polygon import Polygon
|
|
10
|
+
from .primitive import Primitive
|
|
11
|
+
from .segmentation import Segmentation # noqa: F401
|
|
12
|
+
from .utils import class_color_by_name # noqa: F401
|
|
13
|
+
|
|
14
|
+
PRIMITIVE_TYPES: List[Type[Primitive]] = [Bbox, Classification, Polygon, Bitmask]
|
|
15
|
+
PRIMITIVE_NAME_TO_TYPE = {cls.__name__: cls for cls in PRIMITIVE_TYPES}
|
|
16
|
+
PRIMITIVE_COLUMN_NAMES: List[str] = [PrimitiveType.column_name() for PrimitiveType in PRIMITIVE_TYPES]
|
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
4
|
+
|
|
5
|
+
import cv2
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from hafnia.dataset.primitives.primitive import Primitive
|
|
9
|
+
from hafnia.dataset.primitives.utils import (
|
|
10
|
+
anonymize_by_resizing,
|
|
11
|
+
class_color_by_name,
|
|
12
|
+
clip,
|
|
13
|
+
get_class_name,
|
|
14
|
+
round_int_clip_value,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class Bbox(Primitive):
|
|
19
|
+
# Names should match names in FieldName
|
|
20
|
+
height: float # Height of the bounding box as a fraction of the image height, e.g. 0.1 for 10% of the image height
|
|
21
|
+
width: float # Width of the bounding box as a fraction of the image width, e.g. 0.1 for 10% of the image width
|
|
22
|
+
top_left_x: float # X coordinate of top-left corner of Bbox as a fraction of the image width, e.g. 0.1 for 10% of the image width
|
|
23
|
+
top_left_y: float # Y coordinate of top-left corner of Bbox as a fraction of the image height, e.g. 0.1 for 10% of the image height
|
|
24
|
+
class_name: Optional[str] = None # Class name, e.g. "car"
|
|
25
|
+
class_idx: Optional[int] = None # Class index, e.g. 0 for "car" if it is the first class
|
|
26
|
+
object_id: Optional[str] = None # Unique identifier for the object, e.g. "12345123"
|
|
27
|
+
confidence: Optional[float] = None # Confidence score (0-1.0) for the primitive, e.g. 0.95 for Bbox
|
|
28
|
+
ground_truth: bool = True # Whether this is ground truth or a prediction
|
|
29
|
+
|
|
30
|
+
task_name: str = "" # Task name to support multiple Bbox tasks in the same dataset. "" defaults to "bboxes"
|
|
31
|
+
meta: Optional[Dict[str, Any]] = None # This can be used to store additional information about the bitmask
|
|
32
|
+
|
|
33
|
+
@staticmethod
|
|
34
|
+
def default_task_name() -> str:
|
|
35
|
+
return "bboxes"
|
|
36
|
+
|
|
37
|
+
@staticmethod
|
|
38
|
+
def column_name() -> str:
|
|
39
|
+
return "objects"
|
|
40
|
+
|
|
41
|
+
def calculate_area(self) -> float:
|
|
42
|
+
return self.height * self.width
|
|
43
|
+
|
|
44
|
+
@staticmethod
|
|
45
|
+
def from_coco(bbox: List, height: int, width: int) -> Bbox:
|
|
46
|
+
"""
|
|
47
|
+
Converts a COCO-style bounding box to a Bbox object.
|
|
48
|
+
The bbox is in the format [x_min, y_min, width, height].
|
|
49
|
+
"""
|
|
50
|
+
x_min, y_min, bbox_width, bbox_height = bbox
|
|
51
|
+
return Bbox(
|
|
52
|
+
top_left_x=x_min / width,
|
|
53
|
+
top_left_y=y_min / height,
|
|
54
|
+
width=bbox_width / width,
|
|
55
|
+
height=bbox_height / height,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
def to_bbox(self) -> Tuple[float, float, float, float]:
|
|
59
|
+
"""
|
|
60
|
+
Converts Bbox to a tuple of (x_min, y_min, width, height) with normalized coordinates.
|
|
61
|
+
Values are floats in the range [0, 1].
|
|
62
|
+
"""
|
|
63
|
+
return (self.top_left_x, self.top_left_y, self.width, self.height)
|
|
64
|
+
|
|
65
|
+
def to_coco(self, image_height: int, image_width: int) -> Tuple[int, int, int, int]:
|
|
66
|
+
xmin = round_int_clip_value(self.top_left_x * image_width, max_value=image_width)
|
|
67
|
+
bbox_width = round_int_clip_value(self.width * image_width, max_value=image_width)
|
|
68
|
+
|
|
69
|
+
ymin = round_int_clip_value(self.top_left_y * image_height, max_value=image_height)
|
|
70
|
+
bbox_height = round_int_clip_value(self.height * image_height, max_value=image_height)
|
|
71
|
+
|
|
72
|
+
return xmin, ymin, bbox_width, bbox_height
|
|
73
|
+
|
|
74
|
+
def to_pixel_coordinates(
|
|
75
|
+
self, image_shape: Tuple[int, int], as_int: bool = True, clip_values: bool = True
|
|
76
|
+
) -> Union[Tuple[float, float, float, float], Tuple[int, int, int, int]]:
|
|
77
|
+
bb_height = self.height * image_shape[0]
|
|
78
|
+
bb_width = self.width * image_shape[1]
|
|
79
|
+
bb_top_left_x = self.top_left_x * image_shape[1]
|
|
80
|
+
bb_top_left_y = self.top_left_y * image_shape[0]
|
|
81
|
+
xmin, ymin, xmax, ymax = bb_top_left_x, bb_top_left_y, bb_top_left_x + bb_width, bb_top_left_y + bb_height
|
|
82
|
+
|
|
83
|
+
if as_int:
|
|
84
|
+
xmin, ymin, xmax, ymax = int(round(xmin)), int(round(ymin)), int(round(xmax)), int(round(ymax)) # noqa: RUF046
|
|
85
|
+
|
|
86
|
+
if clip_values:
|
|
87
|
+
xmin = clip(value=xmin, v_min=0, v_max=image_shape[1])
|
|
88
|
+
xmax = clip(value=xmax, v_min=0, v_max=image_shape[1])
|
|
89
|
+
ymin = clip(value=ymin, v_min=0, v_max=image_shape[0])
|
|
90
|
+
ymax = clip(value=ymax, v_min=0, v_max=image_shape[0])
|
|
91
|
+
|
|
92
|
+
return xmin, ymin, xmax, ymax
|
|
93
|
+
|
|
94
|
+
def draw(self, image: np.ndarray, inplace: bool = False, draw_label: bool = True) -> np.ndarray:
|
|
95
|
+
if not inplace:
|
|
96
|
+
image = image.copy()
|
|
97
|
+
xmin, ymin, xmax, ymax = self.to_pixel_coordinates(image_shape=image.shape[:2])
|
|
98
|
+
|
|
99
|
+
class_name = self.get_class_name()
|
|
100
|
+
color = class_color_by_name(class_name)
|
|
101
|
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
|
102
|
+
margin = 5
|
|
103
|
+
bottom_left = (xmin + margin, ymax - margin)
|
|
104
|
+
if draw_label:
|
|
105
|
+
cv2.putText(
|
|
106
|
+
img=image, text=class_name, org=bottom_left, fontFace=font, fontScale=0.75, color=color, thickness=2
|
|
107
|
+
)
|
|
108
|
+
cv2.rectangle(image, pt1=(xmin, ymin), pt2=(xmax, ymax), color=color, thickness=2)
|
|
109
|
+
|
|
110
|
+
return image
|
|
111
|
+
|
|
112
|
+
def mask(
|
|
113
|
+
self, image: np.ndarray, inplace: bool = False, color: Optional[Tuple[np.uint8, np.uint8, np.uint8]] = None
|
|
114
|
+
) -> np.ndarray:
|
|
115
|
+
if not inplace:
|
|
116
|
+
image = image.copy()
|
|
117
|
+
xmin, ymin, xmax, ymax = self.to_pixel_coordinates(image_shape=image.shape[:2])
|
|
118
|
+
xmin, ymin, xmax, ymax = int(xmin), int(ymin), int(xmax), int(ymax)
|
|
119
|
+
|
|
120
|
+
if color is None:
|
|
121
|
+
color = np.mean(image[ymin:ymax, xmin:xmax], axis=(0, 1)).astype(np.uint8)
|
|
122
|
+
|
|
123
|
+
image[ymin:ymax, xmin:xmax] = color
|
|
124
|
+
return image
|
|
125
|
+
|
|
126
|
+
def anonymize_by_blurring(self, image: np.ndarray, inplace: bool = False, max_resolution: int = 20) -> np.ndarray:
|
|
127
|
+
if not inplace:
|
|
128
|
+
image = image.copy()
|
|
129
|
+
xmin, ymin, xmax, ymax = self.to_pixel_coordinates(image_shape=image.shape[:2])
|
|
130
|
+
xmin, ymin, xmax, ymax = int(xmin), int(ymin), int(xmax), int(ymax)
|
|
131
|
+
blur_region = image[ymin:ymax, xmin:xmax]
|
|
132
|
+
blur_region_upsized = anonymize_by_resizing(blur_region, max_resolution=max_resolution)
|
|
133
|
+
image[ymin:ymax, xmin:xmax] = blur_region_upsized
|
|
134
|
+
return image
|
|
135
|
+
|
|
136
|
+
def get_class_name(self) -> str:
|
|
137
|
+
return get_class_name(self.class_name, self.class_idx)
|