hafnia 0.1.26__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. cli/__main__.py +2 -2
  2. cli/dataset_cmds.py +60 -0
  3. cli/runc_cmds.py +1 -1
  4. hafnia/data/__init__.py +2 -2
  5. hafnia/data/factory.py +9 -56
  6. hafnia/dataset/dataset_helpers.py +91 -0
  7. hafnia/dataset/dataset_names.py +71 -0
  8. hafnia/dataset/dataset_transformation.py +187 -0
  9. hafnia/dataset/dataset_upload_helper.py +468 -0
  10. hafnia/dataset/hafnia_dataset.py +453 -0
  11. hafnia/dataset/primitives/__init__.py +16 -0
  12. hafnia/dataset/primitives/bbox.py +137 -0
  13. hafnia/dataset/primitives/bitmask.py +182 -0
  14. hafnia/dataset/primitives/classification.py +56 -0
  15. hafnia/dataset/primitives/point.py +25 -0
  16. hafnia/dataset/primitives/polygon.py +100 -0
  17. hafnia/dataset/primitives/primitive.py +44 -0
  18. hafnia/dataset/primitives/segmentation.py +51 -0
  19. hafnia/dataset/primitives/utils.py +51 -0
  20. hafnia/dataset/table_transformations.py +183 -0
  21. hafnia/experiment/hafnia_logger.py +2 -2
  22. hafnia/helper_testing.py +63 -0
  23. hafnia/http.py +5 -3
  24. hafnia/platform/__init__.py +2 -2
  25. hafnia/platform/builder.py +25 -19
  26. hafnia/platform/datasets.py +184 -0
  27. hafnia/platform/download.py +85 -23
  28. hafnia/torch_helpers.py +180 -95
  29. hafnia/utils.py +1 -1
  30. hafnia/visualizations/colors.py +267 -0
  31. hafnia/visualizations/image_visualizations.py +202 -0
  32. {hafnia-0.1.26.dist-info → hafnia-0.2.0.dist-info}/METADATA +212 -99
  33. hafnia-0.2.0.dist-info/RECORD +46 -0
  34. cli/data_cmds.py +0 -53
  35. hafnia-0.1.26.dist-info/RECORD +0 -27
  36. {hafnia-0.1.26.dist-info → hafnia-0.2.0.dist-info}/WHEEL +0 -0
  37. {hafnia-0.1.26.dist-info → hafnia-0.2.0.dist-info}/entry_points.txt +0 -0
  38. {hafnia-0.1.26.dist-info → hafnia-0.2.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,453 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import shutil
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import Any, Dict, List, Optional, Type, Union
8
+
9
+ import more_itertools
10
+ import numpy as np
11
+ import polars as pl
12
+ import rich
13
+ from PIL import Image
14
+ from pydantic import BaseModel, field_serializer, field_validator
15
+ from rich import print as rprint
16
+ from rich.table import Table
17
+ from tqdm import tqdm
18
+
19
+ from hafnia.dataset import dataset_helpers, dataset_transformation
20
+ from hafnia.dataset.dataset_names import (
21
+ DATASET_FILENAMES,
22
+ FILENAME_ANNOTATIONS_JSONL,
23
+ FILENAME_ANNOTATIONS_PARQUET,
24
+ FILENAME_DATASET_INFO,
25
+ ColumnName,
26
+ FieldName,
27
+ SplitName,
28
+ )
29
+ from hafnia.dataset.primitives import (
30
+ PRIMITIVE_NAME_TO_TYPE,
31
+ PRIMITIVE_TYPES,
32
+ )
33
+ from hafnia.dataset.primitives.bbox import Bbox
34
+ from hafnia.dataset.primitives.bitmask import Bitmask
35
+ from hafnia.dataset.primitives.classification import Classification
36
+ from hafnia.dataset.primitives.polygon import Polygon
37
+ from hafnia.dataset.primitives.primitive import Primitive
38
+ from hafnia.dataset.table_transformations import (
39
+ check_image_paths,
40
+ create_primitive_table,
41
+ read_table_from_path,
42
+ )
43
+ from hafnia.log import user_logger
44
+
45
+
46
+ class TaskInfo(BaseModel):
47
+ primitive: Type[Primitive] # Primitive class or string name of the primitive, e.g. "Bbox" or "bitmask"
48
+ class_names: Optional[List[str]] # Class names for the tasks. To get consistent class indices specify class_names.
49
+ name: Optional[str] = (
50
+ None # None to use the default primitive task name Bbox ->"bboxes", Bitmask -> "bitmasks" etc.
51
+ )
52
+
53
+ def model_post_init(self, __context: Any) -> None:
54
+ if self.name is None:
55
+ self.name = self.primitive.default_task_name()
56
+
57
+ # The 'primitive'-field of type 'Type[Primitive]' is not supported by pydantic out-of-the-box as
58
+ # the 'Primitive' class is an abstract base class and for the actual primtives such as Bbox, Bitmask, Classification.
59
+ # Below magic functions ('ensure_primitive' and 'serialize_primitive') ensures that the 'primitive' field can
60
+ # correctly validate and serialize sub-classes (Bbox, Classification, ...).
61
+ @field_validator("primitive", mode="plain")
62
+ @classmethod
63
+ def ensure_primitive(cls, primitive: Any) -> Any:
64
+ if isinstance(primitive, str):
65
+ if primitive not in PRIMITIVE_NAME_TO_TYPE:
66
+ raise ValueError(
67
+ f"Primitive '{primitive}' is not recognized. Available primitives: {list(PRIMITIVE_NAME_TO_TYPE.keys())}"
68
+ )
69
+
70
+ return PRIMITIVE_NAME_TO_TYPE[primitive]
71
+
72
+ if issubclass(primitive, Primitive):
73
+ return primitive
74
+
75
+ raise ValueError(f"Primitive must be a string or a Primitive subclass, got {type(primitive)} instead.")
76
+
77
+ @field_serializer("primitive")
78
+ @classmethod
79
+ def serialize_primitive(cls, primitive: Type[Primitive]) -> str:
80
+ if not issubclass(primitive, Primitive):
81
+ raise ValueError(f"Primitive must be a subclass of Primitive, got {type(primitive)} instead.")
82
+ return primitive.__name__
83
+
84
+
85
+ class DatasetInfo(BaseModel):
86
+ dataset_name: str
87
+ version: str
88
+ tasks: list[TaskInfo]
89
+ distributions: Optional[List[TaskInfo]] = None # Distributions. TODO: FIX/REMOVE/CHANGE this
90
+ meta: Optional[Dict[str, Any]] = None # Metadata about the dataset, e.g. description, etc.
91
+
92
+ def write_json(self, path: Path, indent: Optional[int] = 4) -> None:
93
+ json_str = self.model_dump_json(indent=indent)
94
+ path.write_text(json_str)
95
+
96
+ @staticmethod
97
+ def from_json_file(path: Path) -> "DatasetInfo":
98
+ json_str = path.read_text()
99
+ return DatasetInfo.model_validate_json(json_str)
100
+
101
+
102
+ class Sample(BaseModel):
103
+ file_name: str
104
+ height: int
105
+ width: int
106
+ split: str # Split name, e.g., "train", "val", "test"
107
+ is_sample: bool # Indicates if this is a sample (True) or a metadata entry (False)
108
+ collection_index: Optional[int] = None # Optional e.g. frame number for video datasets
109
+ collection_id: Optional[str] = None # Optional e.g. video name for video datasets
110
+ remote_path: Optional[str] = None # Optional remote path for the image, if applicable
111
+ sample_index: Optional[int] = None # Don't manually set this, it is used for indexing samples in the dataset.
112
+ classifications: Optional[List[Classification]] = None # Optional classification primitive
113
+ objects: Optional[List[Bbox]] = None # List of coordinate primitives, e.g., Bbox, Bitmask, etc.
114
+ bitmasks: Optional[List[Bitmask]] = None # List of bitmasks, if applicable
115
+ polygons: Optional[List[Polygon]] = None # List of polygons, if applicable
116
+
117
+ meta: Optional[Dict] = None # Additional metadata, e.g., camera settings, GPS data, etc.
118
+
119
+ def get_annotations(self, primitive_types: Optional[List[Type[Primitive]]] = None) -> List[Primitive]:
120
+ """
121
+ Returns a list of all annotations (classifications, objects, bitmasks, polygons) for the sample.
122
+ """
123
+ primitive_types = primitive_types or PRIMITIVE_TYPES
124
+ annotations_primitives = [
125
+ getattr(self, primitive_type.column_name(), None) for primitive_type in primitive_types
126
+ ]
127
+ annotations = more_itertools.flatten(
128
+ [primitives for primitives in annotations_primitives if primitives is not None]
129
+ )
130
+
131
+ return list(annotations)
132
+
133
+ def read_image_pillow(self) -> Image.Image:
134
+ """
135
+ Reads the image from the file path and returns it as a PIL Image.
136
+ Raises FileNotFoundError if the image file does not exist.
137
+ """
138
+ path_image = Path(self.file_name)
139
+ if not path_image.exists():
140
+ raise FileNotFoundError(f"Image file {path_image} does not exist. Please check the file path.")
141
+
142
+ image = Image.open(str(path_image))
143
+ return image
144
+
145
+ def read_image(self) -> np.ndarray:
146
+ image_pil = self.read_image_pillow()
147
+ image = np.array(image_pil)
148
+ return image
149
+
150
+ def draw_annotations(self, image: Optional[np.ndarray] = None) -> np.ndarray:
151
+ from hafnia.visualizations import image_visualizations
152
+
153
+ image = image or self.read_image()
154
+ annotations = self.get_annotations()
155
+ annotations_visualized = image_visualizations.draw_annotations(image=image, primitives=annotations)
156
+ return annotations_visualized
157
+
158
+
159
+ @dataclass
160
+ class HafniaDataset:
161
+ info: DatasetInfo
162
+ samples: pl.DataFrame
163
+
164
+ def __getitem__(self, item: int) -> Dict[str, Any]:
165
+ return self.samples.row(index=item, named=True)
166
+
167
+ def __len__(self) -> int:
168
+ return len(self.samples)
169
+
170
+ def __iter__(self):
171
+ for row in self.samples.iter_rows(named=True):
172
+ yield row
173
+
174
+ # Dataset transformations
175
+ apply_image_transform = dataset_transformation.transform_images
176
+ sample = dataset_transformation.sample
177
+ shuffle = dataset_transformation.shuffle_dataset
178
+ split_by_ratios = dataset_transformation.splits_by_ratios
179
+ divide_split_into_multiple_splits = dataset_transformation.divide_split_into_multiple_splits
180
+ sample_set_by_size = dataset_transformation.define_sample_set_by_size
181
+
182
+ @staticmethod
183
+ def from_samples_list(samples_list: List, info: DatasetInfo) -> "HafniaDataset":
184
+ sample = samples_list[0]
185
+ if isinstance(sample, Sample):
186
+ json_samples = [sample.model_dump(mode="json") for sample in samples_list]
187
+ elif isinstance(sample, dict):
188
+ json_samples = samples_list
189
+ else:
190
+ raise TypeError(f"Unsupported sample type: {type(sample)}. Expected Sample or dict.")
191
+
192
+ table = pl.from_records(json_samples).drop(ColumnName.SAMPLE_INDEX)
193
+ table = table.with_row_index(name=ColumnName.SAMPLE_INDEX) # Add sample index column
194
+
195
+ return HafniaDataset(info=info, samples=table)
196
+
197
+ def as_dict_dataset_splits(self) -> Dict[str, "HafniaDataset"]:
198
+ if ColumnName.SPLIT not in self.samples.columns:
199
+ raise ValueError(f"Dataset must contain a '{ColumnName.SPLIT}' column.")
200
+
201
+ splits = {}
202
+ for split_name in SplitName.valid_splits():
203
+ splits[split_name] = self.create_split_dataset(split_name)
204
+
205
+ return splits
206
+
207
+ def create_sample_dataset(self) -> "HafniaDataset":
208
+ if ColumnName.IS_SAMPLE not in self.samples.columns:
209
+ raise ValueError(f"Dataset must contain an '{ColumnName.IS_SAMPLE}' column.")
210
+ table = self.samples.filter(pl.col(ColumnName.IS_SAMPLE))
211
+ return self.update_table(table)
212
+
213
+ def create_split_dataset(self, split_name: Union[str | List[str]]) -> "HafniaDataset":
214
+ if isinstance(split_name, str):
215
+ split_names = [split_name]
216
+ elif isinstance(split_name, list):
217
+ split_names = split_name
218
+
219
+ for name in split_names:
220
+ if name not in SplitName.valid_splits():
221
+ raise ValueError(f"Invalid split name: {split_name}. Valid splits are: {SplitName.valid_splits()}")
222
+
223
+ filtered_dataset = self.samples.filter(pl.col(ColumnName.SPLIT).is_in(split_names))
224
+ return self.update_table(filtered_dataset)
225
+
226
+ def get_task_by_name(self, task_name: str) -> TaskInfo:
227
+ for task in self.info.tasks:
228
+ if task.name == task_name:
229
+ return task
230
+ raise ValueError(f"Task with name {task_name} not found in dataset info.")
231
+
232
+ def update_table(self, table: pl.DataFrame) -> "HafniaDataset":
233
+ return HafniaDataset(info=self.info.model_copy(), samples=table)
234
+
235
+ @staticmethod
236
+ def check_dataset_path(path_dataset: Path, raise_error: bool = True) -> bool:
237
+ """
238
+ Checks if the dataset path exists and contains the required files.
239
+ Returns True if the dataset is valid, otherwise raises an error or returns False.
240
+ """
241
+ if not path_dataset.exists():
242
+ if raise_error:
243
+ raise FileNotFoundError(f"Dataset path {path_dataset} does not exist.")
244
+ return False
245
+
246
+ required_files = [
247
+ FILENAME_DATASET_INFO,
248
+ FILENAME_ANNOTATIONS_JSONL,
249
+ FILENAME_ANNOTATIONS_PARQUET,
250
+ ]
251
+ for filename in required_files:
252
+ if not (path_dataset / filename).exists():
253
+ if raise_error:
254
+ raise FileNotFoundError(f"Required file {filename} not found in {path_dataset}.")
255
+ return False
256
+
257
+ return True
258
+
259
+ @staticmethod
260
+ def read_from_path(path_folder: Path, check_for_images: bool = True) -> "HafniaDataset":
261
+ HafniaDataset.check_dataset_path(path_folder, raise_error=True)
262
+
263
+ dataset_info = DatasetInfo.from_json_file(path_folder / FILENAME_DATASET_INFO)
264
+ table = read_table_from_path(path_folder)
265
+
266
+ # Convert from relative paths to absolute paths
267
+ table = table.with_columns(
268
+ pl.concat_str([pl.lit(str(path_folder.absolute()) + os.sep), pl.col("file_name")]).alias("file_name")
269
+ )
270
+ if check_for_images:
271
+ check_image_paths(table)
272
+ return HafniaDataset(samples=table, info=dataset_info)
273
+
274
+ def write(self, path_folder: Path, name_by_hash: bool = True, add_version: bool = False) -> None:
275
+ user_logger.info(f"Writing dataset to {path_folder}...")
276
+ if not path_folder.exists():
277
+ path_folder.mkdir(parents=True)
278
+ path_folder_images = path_folder / "data"
279
+ path_folder_images.mkdir(parents=True, exist_ok=True)
280
+
281
+ new_relative_paths = []
282
+ for org_path in tqdm(self.samples["file_name"].to_list(), desc="- Copy images"):
283
+ org_path = Path(org_path)
284
+ if not org_path.exists():
285
+ raise FileNotFoundError(f"File {org_path} does not exist in the dataset.")
286
+ if name_by_hash:
287
+ filename = dataset_helpers.filename_as_hash_from_path(org_path)
288
+ else:
289
+ filename = Path(org_path).name
290
+ new_path = path_folder_images / filename
291
+ if not new_path.exists():
292
+ shutil.copy2(org_path, new_path)
293
+
294
+ if not new_path.exists():
295
+ raise FileNotFoundError(f"File {new_path} does not exist in the dataset.")
296
+ new_relative_paths.append(str(new_path.relative_to(path_folder)))
297
+
298
+ table = self.samples.with_columns(pl.Series(new_relative_paths).alias("file_name"))
299
+ table.write_ndjson(path_folder / FILENAME_ANNOTATIONS_JSONL) # Json for readability
300
+ table.write_parquet(path_folder / FILENAME_ANNOTATIONS_PARQUET) # Parquet for speed
301
+ self.info.write_json(path_folder / FILENAME_DATASET_INFO)
302
+
303
+ if add_version:
304
+ path_version = path_folder / "versions" / f"{self.info.version}"
305
+ path_version.mkdir(parents=True, exist_ok=True)
306
+ for filename in DATASET_FILENAMES:
307
+ shutil.copy2(path_folder / filename, path_version / filename)
308
+
309
+ def __eq__(self, value) -> bool:
310
+ if not isinstance(value, HafniaDataset):
311
+ return False
312
+
313
+ if self.info != value.info:
314
+ return False
315
+
316
+ if not isinstance(self.samples, pl.DataFrame) or not isinstance(value.samples, pl.DataFrame):
317
+ return False
318
+
319
+ if not self.samples.equals(value.samples):
320
+ return False
321
+ return True
322
+
323
+ def print_stats(self) -> None:
324
+ table_base = Table(title="Dataset Statistics", show_lines=True, box=rich.box.SIMPLE)
325
+ table_base.add_column("Property", style="cyan")
326
+ table_base.add_column("Value")
327
+ table_base.add_row("Dataset Name", self.info.dataset_name)
328
+ table_base.add_row("Version", self.info.version)
329
+ table_base.add_row("Number of samples", str(len(self.samples)))
330
+ rprint(table_base)
331
+ rprint(self.info.tasks)
332
+
333
+ splits_sets = {
334
+ "All": SplitName.valid_splits(),
335
+ "Train": [SplitName.TRAIN],
336
+ "Validation": [SplitName.VAL],
337
+ "Test": [SplitName.TEST],
338
+ }
339
+ rows = []
340
+ for split_name, splits in splits_sets.items():
341
+ dataset_split = self.create_split_dataset(splits)
342
+ table = dataset_split.samples
343
+ row = {}
344
+ row["Split"] = split_name
345
+ row["Sample "] = str(len(table))
346
+ for PrimitiveType in PRIMITIVE_TYPES:
347
+ column_name = PrimitiveType.column_name()
348
+ objects_df = create_primitive_table(table, PrimitiveType=PrimitiveType, keep_sample_data=False)
349
+ if objects_df is None:
350
+ continue
351
+ for (task_name,), object_group in objects_df.group_by(FieldName.TASK_NAME):
352
+ count = len(object_group[FieldName.CLASS_NAME])
353
+ row[f"{PrimitiveType.__name__}\n{task_name}"] = str(count)
354
+ rows.append(row)
355
+
356
+ rich_table = Table(title="Dataset Statistics", show_lines=True, box=rich.box.SIMPLE)
357
+ for i_row, row in enumerate(rows):
358
+ if i_row == 0:
359
+ for column_name in row.keys():
360
+ rich_table.add_column(column_name, justify="left", style="cyan")
361
+ rich_table.add_row(*[str(value) for value in row.values()])
362
+ rprint(rich_table)
363
+
364
+
365
+ def check_hafnia_dataset_from_path(path_dataset: Path) -> None:
366
+ dataset = HafniaDataset.read_from_path(path_dataset, check_for_images=True)
367
+ check_hafnia_dataset(dataset)
368
+
369
+
370
+ def check_hafnia_dataset(dataset: HafniaDataset):
371
+ user_logger.info("Checking Hafnia dataset...")
372
+ assert isinstance(dataset.info.version, str) and len(dataset.info.version) > 0
373
+ assert isinstance(dataset.info.dataset_name, str) and len(dataset.info.dataset_name) > 0
374
+
375
+ is_sample_list = set(dataset.samples.select(pl.col(ColumnName.IS_SAMPLE)).unique().to_series().to_list())
376
+ if True not in is_sample_list:
377
+ raise ValueError(f"The dataset should contain '{ColumnName.IS_SAMPLE}=True' samples")
378
+
379
+ actual_splits = dataset.samples.select(pl.col(ColumnName.SPLIT)).unique().to_series().to_list()
380
+ expected_splits = SplitName.valid_splits()
381
+ if set(actual_splits) != set(expected_splits):
382
+ raise ValueError(f"Expected all splits '{expected_splits}' in dataset, but got '{actual_splits}'. ")
383
+
384
+ expected_tasks = dataset.info.tasks
385
+ for task in expected_tasks:
386
+ primitive = task.primitive.__name__
387
+ column_name = task.primitive.column_name()
388
+ primitive_column = dataset.samples[column_name]
389
+ # msg_something_wrong = f"Something is wrong with the '{primtive_name}' task '{task.name}' in dataset '{dataset.name}'. "
390
+ msg_something_wrong = (
391
+ f"Something is wrong with the defined tasks ('info.tasks') in dataset '{dataset.info.dataset_name}'. \n"
392
+ f"For '{primitive=}' and '{task.name=}' "
393
+ )
394
+ if primitive_column.dtype == pl.Null:
395
+ raise ValueError(msg_something_wrong + "the column is 'Null'. Please check the dataset.")
396
+
397
+ primitive_table = primitive_column.explode().struct.unnest().filter(pl.col(FieldName.TASK_NAME) == task.name)
398
+ if primitive_table.is_empty():
399
+ raise ValueError(
400
+ msg_something_wrong
401
+ + f"the column '{column_name}' has no {task.name=} objects. Please check the dataset."
402
+ )
403
+
404
+ actual_classes = set(primitive_table[FieldName.CLASS_NAME].unique().to_list())
405
+ if task.class_names is None:
406
+ raise ValueError(
407
+ msg_something_wrong
408
+ + f"the column '{column_name}' with {task.name=} has no defined classes. Please check the dataset."
409
+ )
410
+ defined_classes = set(task.class_names)
411
+
412
+ if not actual_classes.issubset(defined_classes):
413
+ raise ValueError(
414
+ msg_something_wrong
415
+ + f"the column '{column_name}' with {task.name=} we expected the actual classes in the dataset to \n"
416
+ f"to be a subset of the defined classes\n\t{actual_classes=} \n\t{defined_classes=}."
417
+ )
418
+ # Check class_indices
419
+ mapped_indices = primitive_table[FieldName.CLASS_NAME].map_elements(
420
+ lambda x: task.class_names.index(x), return_dtype=pl.Int64
421
+ )
422
+ table_indices = primitive_table[FieldName.CLASS_IDX]
423
+
424
+ error_msg = msg_something_wrong + (
425
+ f"class indices in '{FieldName.CLASS_IDX}' column does not match classes ordering in 'task.class_names'"
426
+ )
427
+ assert mapped_indices.equals(table_indices), error_msg
428
+
429
+ distribution = dataset.info.distributions or []
430
+ distribution_names = [task.name for task in distribution]
431
+ # Check that tasks found in the 'dataset.table' matches the tasks defined in 'dataset.info.tasks'
432
+ for PrimitiveType in PRIMITIVE_TYPES:
433
+ column_name = PrimitiveType.column_name()
434
+ if column_name not in dataset.samples.columns:
435
+ continue
436
+ objects_df = create_primitive_table(dataset.samples, PrimitiveType=PrimitiveType, keep_sample_data=False)
437
+ if objects_df is None:
438
+ continue
439
+ for (task_name,), object_group in objects_df.group_by(FieldName.TASK_NAME):
440
+ has_task = any([t for t in expected_tasks if t.name == task_name and t.primitive == PrimitiveType])
441
+ if has_task:
442
+ continue
443
+ if task_name in distribution_names:
444
+ continue
445
+ class_names = object_group[FieldName.CLASS_NAME].unique().to_list()
446
+ raise ValueError(
447
+ f"Task name '{task_name}' for the '{PrimitiveType.__name__}' primitive is missing in "
448
+ f"'dataset.info.tasks' for dataset '{task_name}'. Missing task has the following "
449
+ f"classes: {class_names}. "
450
+ )
451
+
452
+ for sample_dict in tqdm(dataset, desc="Checking samples in dataset"):
453
+ sample = Sample(**sample_dict) # Checks format of all samples with pydantic validation # noqa: F841
@@ -0,0 +1,16 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import List, Type
4
+
5
+ from .bbox import Bbox
6
+ from .bitmask import Bitmask
7
+ from .classification import Classification
8
+ from .point import Point # noqa: F401
9
+ from .polygon import Polygon
10
+ from .primitive import Primitive
11
+ from .segmentation import Segmentation # noqa: F401
12
+ from .utils import class_color_by_name # noqa: F401
13
+
14
+ PRIMITIVE_TYPES: List[Type[Primitive]] = [Bbox, Classification, Polygon, Bitmask]
15
+ PRIMITIVE_NAME_TO_TYPE = {cls.__name__: cls for cls in PRIMITIVE_TYPES}
16
+ PRIMITIVE_COLUMN_NAMES: List[str] = [PrimitiveType.column_name() for PrimitiveType in PRIMITIVE_TYPES]
@@ -0,0 +1,137 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict, List, Optional, Tuple, Union
4
+
5
+ import cv2
6
+ import numpy as np
7
+
8
+ from hafnia.dataset.primitives.primitive import Primitive
9
+ from hafnia.dataset.primitives.utils import (
10
+ anonymize_by_resizing,
11
+ class_color_by_name,
12
+ clip,
13
+ get_class_name,
14
+ round_int_clip_value,
15
+ )
16
+
17
+
18
+ class Bbox(Primitive):
19
+ # Names should match names in FieldName
20
+ height: float # Height of the bounding box as a fraction of the image height, e.g. 0.1 for 10% of the image height
21
+ width: float # Width of the bounding box as a fraction of the image width, e.g. 0.1 for 10% of the image width
22
+ top_left_x: float # X coordinate of top-left corner of Bbox as a fraction of the image width, e.g. 0.1 for 10% of the image width
23
+ top_left_y: float # Y coordinate of top-left corner of Bbox as a fraction of the image height, e.g. 0.1 for 10% of the image height
24
+ class_name: Optional[str] = None # Class name, e.g. "car"
25
+ class_idx: Optional[int] = None # Class index, e.g. 0 for "car" if it is the first class
26
+ object_id: Optional[str] = None # Unique identifier for the object, e.g. "12345123"
27
+ confidence: Optional[float] = None # Confidence score (0-1.0) for the primitive, e.g. 0.95 for Bbox
28
+ ground_truth: bool = True # Whether this is ground truth or a prediction
29
+
30
+ task_name: str = "" # Task name to support multiple Bbox tasks in the same dataset. "" defaults to "bboxes"
31
+ meta: Optional[Dict[str, Any]] = None # This can be used to store additional information about the bitmask
32
+
33
+ @staticmethod
34
+ def default_task_name() -> str:
35
+ return "bboxes"
36
+
37
+ @staticmethod
38
+ def column_name() -> str:
39
+ return "objects"
40
+
41
+ def calculate_area(self) -> float:
42
+ return self.height * self.width
43
+
44
+ @staticmethod
45
+ def from_coco(bbox: List, height: int, width: int) -> Bbox:
46
+ """
47
+ Converts a COCO-style bounding box to a Bbox object.
48
+ The bbox is in the format [x_min, y_min, width, height].
49
+ """
50
+ x_min, y_min, bbox_width, bbox_height = bbox
51
+ return Bbox(
52
+ top_left_x=x_min / width,
53
+ top_left_y=y_min / height,
54
+ width=bbox_width / width,
55
+ height=bbox_height / height,
56
+ )
57
+
58
+ def to_bbox(self) -> Tuple[float, float, float, float]:
59
+ """
60
+ Converts Bbox to a tuple of (x_min, y_min, width, height) with normalized coordinates.
61
+ Values are floats in the range [0, 1].
62
+ """
63
+ return (self.top_left_x, self.top_left_y, self.width, self.height)
64
+
65
+ def to_coco(self, image_height: int, image_width: int) -> Tuple[int, int, int, int]:
66
+ xmin = round_int_clip_value(self.top_left_x * image_width, max_value=image_width)
67
+ bbox_width = round_int_clip_value(self.width * image_width, max_value=image_width)
68
+
69
+ ymin = round_int_clip_value(self.top_left_y * image_height, max_value=image_height)
70
+ bbox_height = round_int_clip_value(self.height * image_height, max_value=image_height)
71
+
72
+ return xmin, ymin, bbox_width, bbox_height
73
+
74
+ def to_pixel_coordinates(
75
+ self, image_shape: Tuple[int, int], as_int: bool = True, clip_values: bool = True
76
+ ) -> Union[Tuple[float, float, float, float], Tuple[int, int, int, int]]:
77
+ bb_height = self.height * image_shape[0]
78
+ bb_width = self.width * image_shape[1]
79
+ bb_top_left_x = self.top_left_x * image_shape[1]
80
+ bb_top_left_y = self.top_left_y * image_shape[0]
81
+ xmin, ymin, xmax, ymax = bb_top_left_x, bb_top_left_y, bb_top_left_x + bb_width, bb_top_left_y + bb_height
82
+
83
+ if as_int:
84
+ xmin, ymin, xmax, ymax = int(round(xmin)), int(round(ymin)), int(round(xmax)), int(round(ymax)) # noqa: RUF046
85
+
86
+ if clip_values:
87
+ xmin = clip(value=xmin, v_min=0, v_max=image_shape[1])
88
+ xmax = clip(value=xmax, v_min=0, v_max=image_shape[1])
89
+ ymin = clip(value=ymin, v_min=0, v_max=image_shape[0])
90
+ ymax = clip(value=ymax, v_min=0, v_max=image_shape[0])
91
+
92
+ return xmin, ymin, xmax, ymax
93
+
94
+ def draw(self, image: np.ndarray, inplace: bool = False, draw_label: bool = True) -> np.ndarray:
95
+ if not inplace:
96
+ image = image.copy()
97
+ xmin, ymin, xmax, ymax = self.to_pixel_coordinates(image_shape=image.shape[:2])
98
+
99
+ class_name = self.get_class_name()
100
+ color = class_color_by_name(class_name)
101
+ font = cv2.FONT_HERSHEY_SIMPLEX
102
+ margin = 5
103
+ bottom_left = (xmin + margin, ymax - margin)
104
+ if draw_label:
105
+ cv2.putText(
106
+ img=image, text=class_name, org=bottom_left, fontFace=font, fontScale=0.75, color=color, thickness=2
107
+ )
108
+ cv2.rectangle(image, pt1=(xmin, ymin), pt2=(xmax, ymax), color=color, thickness=2)
109
+
110
+ return image
111
+
112
+ def mask(
113
+ self, image: np.ndarray, inplace: bool = False, color: Optional[Tuple[np.uint8, np.uint8, np.uint8]] = None
114
+ ) -> np.ndarray:
115
+ if not inplace:
116
+ image = image.copy()
117
+ xmin, ymin, xmax, ymax = self.to_pixel_coordinates(image_shape=image.shape[:2])
118
+ xmin, ymin, xmax, ymax = int(xmin), int(ymin), int(xmax), int(ymax)
119
+
120
+ if color is None:
121
+ color = np.mean(image[ymin:ymax, xmin:xmax], axis=(0, 1)).astype(np.uint8)
122
+
123
+ image[ymin:ymax, xmin:xmax] = color
124
+ return image
125
+
126
+ def anonymize_by_blurring(self, image: np.ndarray, inplace: bool = False, max_resolution: int = 20) -> np.ndarray:
127
+ if not inplace:
128
+ image = image.copy()
129
+ xmin, ymin, xmax, ymax = self.to_pixel_coordinates(image_shape=image.shape[:2])
130
+ xmin, ymin, xmax, ymax = int(xmin), int(ymin), int(xmax), int(ymax)
131
+ blur_region = image[ymin:ymax, xmin:xmax]
132
+ blur_region_upsized = anonymize_by_resizing(blur_region, max_resolution=max_resolution)
133
+ image[ymin:ymax, xmin:xmax] = blur_region_upsized
134
+ return image
135
+
136
+ def get_class_name(self) -> str:
137
+ return get_class_name(self.class_name, self.class_idx)