hafnia 0.1.27__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. cli/__main__.py +2 -2
  2. cli/config.py +17 -4
  3. cli/dataset_cmds.py +60 -0
  4. cli/runc_cmds.py +1 -1
  5. hafnia/data/__init__.py +2 -2
  6. hafnia/data/factory.py +12 -56
  7. hafnia/dataset/dataset_helpers.py +91 -0
  8. hafnia/dataset/dataset_names.py +72 -0
  9. hafnia/dataset/dataset_recipe/dataset_recipe.py +327 -0
  10. hafnia/dataset/dataset_recipe/recipe_transforms.py +53 -0
  11. hafnia/dataset/dataset_recipe/recipe_types.py +140 -0
  12. hafnia/dataset/dataset_upload_helper.py +468 -0
  13. hafnia/dataset/hafnia_dataset.py +624 -0
  14. hafnia/dataset/operations/dataset_stats.py +15 -0
  15. hafnia/dataset/operations/dataset_transformations.py +82 -0
  16. hafnia/dataset/operations/table_transformations.py +183 -0
  17. hafnia/dataset/primitives/__init__.py +16 -0
  18. hafnia/dataset/primitives/bbox.py +137 -0
  19. hafnia/dataset/primitives/bitmask.py +182 -0
  20. hafnia/dataset/primitives/classification.py +56 -0
  21. hafnia/dataset/primitives/point.py +25 -0
  22. hafnia/dataset/primitives/polygon.py +100 -0
  23. hafnia/dataset/primitives/primitive.py +44 -0
  24. hafnia/dataset/primitives/segmentation.py +51 -0
  25. hafnia/dataset/primitives/utils.py +51 -0
  26. hafnia/experiment/hafnia_logger.py +7 -7
  27. hafnia/helper_testing.py +108 -0
  28. hafnia/http.py +5 -3
  29. hafnia/platform/__init__.py +2 -2
  30. hafnia/platform/datasets.py +197 -0
  31. hafnia/platform/download.py +85 -23
  32. hafnia/torch_helpers.py +180 -95
  33. hafnia/utils.py +21 -2
  34. hafnia/visualizations/colors.py +267 -0
  35. hafnia/visualizations/image_visualizations.py +202 -0
  36. {hafnia-0.1.27.dist-info → hafnia-0.2.1.dist-info}/METADATA +209 -99
  37. hafnia-0.2.1.dist-info/RECORD +50 -0
  38. cli/data_cmds.py +0 -53
  39. hafnia-0.1.27.dist-info/RECORD +0 -27
  40. {hafnia-0.1.27.dist-info → hafnia-0.2.1.dist-info}/WHEEL +0 -0
  41. {hafnia-0.1.27.dist-info → hafnia-0.2.1.dist-info}/entry_points.txt +0 -0
  42. {hafnia-0.1.27.dist-info → hafnia-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,624 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import shutil
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from random import Random
8
+ from typing import Any, Dict, List, Optional, Type, Union
9
+
10
+ import more_itertools
11
+ import numpy as np
12
+ import polars as pl
13
+ import rich
14
+ from PIL import Image
15
+ from pydantic import BaseModel, field_serializer, field_validator
16
+ from rich import print as rprint
17
+ from rich.table import Table
18
+ from tqdm import tqdm
19
+
20
+ from hafnia.dataset import dataset_helpers
21
+ from hafnia.dataset.dataset_names import (
22
+ DATASET_FILENAMES_REQUIRED,
23
+ FILENAME_ANNOTATIONS_JSONL,
24
+ FILENAME_ANNOTATIONS_PARQUET,
25
+ FILENAME_DATASET_INFO,
26
+ FILENAME_RECIPE_JSON,
27
+ ColumnName,
28
+ FieldName,
29
+ SplitName,
30
+ )
31
+ from hafnia.dataset.operations import dataset_stats, dataset_transformations
32
+ from hafnia.dataset.operations.table_transformations import (
33
+ check_image_paths,
34
+ create_primitive_table,
35
+ read_table_from_path,
36
+ )
37
+ from hafnia.dataset.primitives import (
38
+ PRIMITIVE_NAME_TO_TYPE,
39
+ PRIMITIVE_TYPES,
40
+ )
41
+ from hafnia.dataset.primitives.bbox import Bbox
42
+ from hafnia.dataset.primitives.bitmask import Bitmask
43
+ from hafnia.dataset.primitives.classification import Classification
44
+ from hafnia.dataset.primitives.polygon import Polygon
45
+ from hafnia.dataset.primitives.primitive import Primitive
46
+ from hafnia.log import user_logger
47
+
48
+
49
+ class TaskInfo(BaseModel):
50
+ primitive: Type[Primitive] # Primitive class or string name of the primitive, e.g. "Bbox" or "bitmask"
51
+ class_names: Optional[List[str]] # Class names for the tasks. To get consistent class indices specify class_names.
52
+ name: Optional[str] = (
53
+ None # None to use the default primitive task name Bbox ->"bboxes", Bitmask -> "bitmasks" etc.
54
+ )
55
+
56
+ def model_post_init(self, __context: Any) -> None:
57
+ if self.name is None:
58
+ self.name = self.primitive.default_task_name()
59
+
60
+ # The 'primitive'-field of type 'Type[Primitive]' is not supported by pydantic out-of-the-box as
61
+ # the 'Primitive' class is an abstract base class and for the actual primtives such as Bbox, Bitmask, Classification.
62
+ # Below magic functions ('ensure_primitive' and 'serialize_primitive') ensures that the 'primitive' field can
63
+ # correctly validate and serialize sub-classes (Bbox, Classification, ...).
64
+ @field_validator("primitive", mode="plain")
65
+ @classmethod
66
+ def ensure_primitive(cls, primitive: Any) -> Any:
67
+ if isinstance(primitive, str):
68
+ if primitive not in PRIMITIVE_NAME_TO_TYPE:
69
+ raise ValueError(
70
+ f"Primitive '{primitive}' is not recognized. Available primitives: {list(PRIMITIVE_NAME_TO_TYPE.keys())}"
71
+ )
72
+
73
+ return PRIMITIVE_NAME_TO_TYPE[primitive]
74
+
75
+ if issubclass(primitive, Primitive):
76
+ return primitive
77
+
78
+ raise ValueError(f"Primitive must be a string or a Primitive subclass, got {type(primitive)} instead.")
79
+
80
+ @field_serializer("primitive")
81
+ @classmethod
82
+ def serialize_primitive(cls, primitive: Type[Primitive]) -> str:
83
+ if not issubclass(primitive, Primitive):
84
+ raise ValueError(f"Primitive must be a subclass of Primitive, got {type(primitive)} instead.")
85
+ return primitive.__name__
86
+
87
+
88
+ class DatasetInfo(BaseModel):
89
+ dataset_name: str
90
+ version: str
91
+ tasks: list[TaskInfo]
92
+ distributions: Optional[List[TaskInfo]] = None # Distributions. TODO: FIX/REMOVE/CHANGE this
93
+ meta: Optional[Dict[str, Any]] = None # Metadata about the dataset, e.g. description, etc.
94
+
95
+ def write_json(self, path: Path, indent: Optional[int] = 4) -> None:
96
+ json_str = self.model_dump_json(indent=indent)
97
+ path.write_text(json_str)
98
+
99
+ @staticmethod
100
+ def from_json_file(path: Path) -> "DatasetInfo":
101
+ json_str = path.read_text()
102
+ return DatasetInfo.model_validate_json(json_str)
103
+
104
+
105
+ class Sample(BaseModel):
106
+ file_name: str
107
+ height: int
108
+ width: int
109
+ split: str # Split name, e.g., "train", "val", "test"
110
+ is_sample: bool # Indicates if this is a sample (True) or a metadata entry (False)
111
+ collection_index: Optional[int] = None # Optional e.g. frame number for video datasets
112
+ collection_id: Optional[str] = None # Optional e.g. video name for video datasets
113
+ remote_path: Optional[str] = None # Optional remote path for the image, if applicable
114
+ sample_index: Optional[int] = None # Don't manually set this, it is used for indexing samples in the dataset.
115
+ classifications: Optional[List[Classification]] = None # Optional classification primitive
116
+ objects: Optional[List[Bbox]] = None # List of coordinate primitives, e.g., Bbox, Bitmask, etc.
117
+ bitmasks: Optional[List[Bitmask]] = None # List of bitmasks, if applicable
118
+ polygons: Optional[List[Polygon]] = None # List of polygons, if applicable
119
+
120
+ meta: Optional[Dict] = None # Additional metadata, e.g., camera settings, GPS data, etc.
121
+
122
+ def get_annotations(self, primitive_types: Optional[List[Type[Primitive]]] = None) -> List[Primitive]:
123
+ """
124
+ Returns a list of all annotations (classifications, objects, bitmasks, polygons) for the sample.
125
+ """
126
+ primitive_types = primitive_types or PRIMITIVE_TYPES
127
+ annotations_primitives = [
128
+ getattr(self, primitive_type.column_name(), None) for primitive_type in primitive_types
129
+ ]
130
+ annotations = more_itertools.flatten(
131
+ [primitives for primitives in annotations_primitives if primitives is not None]
132
+ )
133
+
134
+ return list(annotations)
135
+
136
+ def read_image_pillow(self) -> Image.Image:
137
+ """
138
+ Reads the image from the file path and returns it as a PIL Image.
139
+ Raises FileNotFoundError if the image file does not exist.
140
+ """
141
+ path_image = Path(self.file_name)
142
+ if not path_image.exists():
143
+ raise FileNotFoundError(f"Image file {path_image} does not exist. Please check the file path.")
144
+
145
+ image = Image.open(str(path_image))
146
+ return image
147
+
148
+ def read_image(self) -> np.ndarray:
149
+ image_pil = self.read_image_pillow()
150
+ image = np.array(image_pil)
151
+ return image
152
+
153
+ def draw_annotations(self, image: Optional[np.ndarray] = None) -> np.ndarray:
154
+ from hafnia.visualizations import image_visualizations
155
+
156
+ image = image or self.read_image()
157
+ annotations = self.get_annotations()
158
+ annotations_visualized = image_visualizations.draw_annotations(image=image, primitives=annotations)
159
+ return annotations_visualized
160
+
161
+
162
+ @dataclass
163
+ class HafniaDataset:
164
+ info: DatasetInfo
165
+ samples: pl.DataFrame
166
+
167
+ def __getitem__(self, item: int) -> Dict[str, Any]:
168
+ return self.samples.row(index=item, named=True)
169
+
170
+ def __len__(self) -> int:
171
+ return len(self.samples)
172
+
173
+ def __iter__(self):
174
+ for row in self.samples.iter_rows(named=True):
175
+ yield row
176
+
177
+ @staticmethod
178
+ def from_path(path_folder: Path, check_for_images: bool = True) -> "HafniaDataset":
179
+ HafniaDataset.check_dataset_path(path_folder, raise_error=True)
180
+
181
+ dataset_info = DatasetInfo.from_json_file(path_folder / FILENAME_DATASET_INFO)
182
+ table = read_table_from_path(path_folder)
183
+
184
+ # Convert from relative paths to absolute paths
185
+ table = table.with_columns(
186
+ pl.concat_str([pl.lit(str(path_folder.absolute()) + os.sep), pl.col("file_name")]).alias("file_name")
187
+ )
188
+ if check_for_images:
189
+ check_image_paths(table)
190
+ return HafniaDataset(samples=table, info=dataset_info)
191
+
192
+ @staticmethod
193
+ def from_name(name: str, force_redownload: bool = False, download_files: bool = True) -> "HafniaDataset":
194
+ """
195
+ Load a dataset by its name. The dataset must be registered in the Hafnia platform.
196
+ """
197
+ from hafnia.dataset.hafnia_dataset import HafniaDataset
198
+ from hafnia.platform.datasets import download_or_get_dataset_path
199
+
200
+ dataset_path = download_or_get_dataset_path(
201
+ dataset_name=name, force_redownload=force_redownload, download_files=download_files
202
+ )
203
+ return HafniaDataset.from_path(dataset_path, check_for_images=download_files)
204
+
205
+ @staticmethod
206
+ def from_samples_list(samples_list: List, info: DatasetInfo) -> "HafniaDataset":
207
+ sample = samples_list[0]
208
+ if isinstance(sample, Sample):
209
+ json_samples = [sample.model_dump(mode="json") for sample in samples_list]
210
+ elif isinstance(sample, dict):
211
+ json_samples = samples_list
212
+ else:
213
+ raise TypeError(f"Unsupported sample type: {type(sample)}. Expected Sample or dict.")
214
+
215
+ table = pl.from_records(json_samples).drop(ColumnName.SAMPLE_INDEX)
216
+ table = table.with_row_index(name=ColumnName.SAMPLE_INDEX) # Add sample index column
217
+
218
+ return HafniaDataset(info=info, samples=table)
219
+
220
+ @staticmethod
221
+ def from_recipe(dataset_recipe: Any) -> "HafniaDataset":
222
+ """
223
+ Load a dataset from a recipe. The recipe can be a string (name of the dataset), a dictionary, or a DataRecipe object.
224
+ """
225
+ from hafnia.dataset.dataset_recipe.dataset_recipe import DatasetRecipe
226
+
227
+ recipe_explicit = DatasetRecipe.from_implicit_form(dataset_recipe)
228
+
229
+ return recipe_explicit.build() # Build dataset from the recipe
230
+
231
+ @staticmethod
232
+ def from_merge(dataset0: "HafniaDataset", dataset1: "HafniaDataset") -> "HafniaDataset":
233
+ return HafniaDataset.merge(dataset0, dataset1)
234
+
235
+ @staticmethod
236
+ def from_recipe_with_cache(
237
+ dataset_recipe: Any,
238
+ force_redownload: bool = False,
239
+ path_datasets: Optional[Union[Path, str]] = None,
240
+ ) -> "HafniaDataset":
241
+ """
242
+ Loads a dataset from a recipe and caches it to disk.
243
+ If the dataset is already cached, it will be loaded from the cache.
244
+ """
245
+
246
+ path_dataset = get_or_create_dataset_path_from_recipe(dataset_recipe, path_datasets=path_datasets)
247
+ return HafniaDataset.from_path(path_dataset, check_for_images=False)
248
+
249
+ @staticmethod
250
+ def from_merger(
251
+ datasets: List[HafniaDataset],
252
+ ) -> "HafniaDataset":
253
+ """
254
+ Merges multiple Hafnia datasets into one.
255
+ """
256
+ if len(datasets) == 0:
257
+ raise ValueError("No datasets to merge. Please provide at least one dataset.")
258
+
259
+ if len(datasets) == 1:
260
+ return datasets[0]
261
+
262
+ merged_dataset = datasets[0]
263
+ remaining_datasets = datasets[1:]
264
+ for dataset in remaining_datasets:
265
+ merged_dataset = HafniaDataset.merge(merged_dataset, dataset)
266
+ return merged_dataset
267
+
268
+ # Dataset transformations
269
+ transform_images = dataset_transformations.transform_images
270
+
271
+ def shuffle(dataset: HafniaDataset, seed: int = 42) -> HafniaDataset:
272
+ table = dataset.samples.sample(n=len(dataset), with_replacement=False, seed=seed, shuffle=True)
273
+ return dataset.update_table(table)
274
+
275
+ def select_samples(
276
+ dataset: "HafniaDataset", n_samples: int, shuffle: bool = True, seed: int = 42, with_replacement: bool = False
277
+ ) -> "HafniaDataset":
278
+ if not with_replacement:
279
+ n_samples = min(n_samples, len(dataset))
280
+ table = dataset.samples.sample(n=n_samples, with_replacement=with_replacement, seed=seed, shuffle=shuffle)
281
+ return dataset.update_table(table)
282
+
283
+ def splits_by_ratios(dataset: "HafniaDataset", split_ratios: Dict[str, float], seed: int = 42) -> "HafniaDataset":
284
+ """
285
+ Divides the dataset into splits based on the provided ratios.
286
+
287
+ Example: Defining split ratios and applying the transformation
288
+
289
+ >>> dataset = HafniaDataset.read_from_path(Path("path/to/dataset"))
290
+ >>> split_ratios = {SplitName.TRAIN: 0.8, SplitName.VAL: 0.1, SplitName.TEST: 0.1}
291
+ >>> dataset_with_splits = splits_by_ratios(dataset, split_ratios, seed=42)
292
+ Or use the function as a
293
+ >>> dataset_with_splits = dataset.splits_by_ratios(split_ratios, seed=42)
294
+ """
295
+ n_items = len(dataset)
296
+ split_name_column = dataset_helpers.create_split_name_list_from_ratios(
297
+ split_ratios=split_ratios, n_items=n_items, seed=seed
298
+ )
299
+ table = dataset.samples.with_columns(pl.Series(split_name_column).alias("split"))
300
+ return dataset.update_table(table)
301
+
302
+ def split_into_multiple_splits(
303
+ dataset: "HafniaDataset",
304
+ split_name: str,
305
+ split_ratios: Dict[str, float],
306
+ ) -> "HafniaDataset":
307
+ """
308
+ Divides a dataset split ('split_name') into multiple splits based on the provided split
309
+ ratios ('split_ratios'). This is especially useful for some open datasets where they have only provide
310
+ two splits or only provide annotations for two splits. This function allows you to create additional
311
+ splits based on the provided ratios.
312
+
313
+ Example: Defining split ratios and applying the transformation
314
+ >>> dataset = HafniaDataset.read_from_path(Path("path/to/dataset"))
315
+ >>> split_name = SplitName.TEST
316
+ >>> split_ratios = {SplitName.TEST: 0.8, SplitName.VAL: 0.2}
317
+ >>> dataset_with_splits = split_into_multiple_splits(dataset, split_name, split_ratios)
318
+ """
319
+ dataset_split_to_be_divided = dataset.create_split_dataset(split_name=split_name)
320
+ if len(dataset_split_to_be_divided) == 0:
321
+ split_counts = dict(dataset.samples.select(pl.col(ColumnName.SPLIT).value_counts()).iter_rows())
322
+ raise ValueError(f"No samples in the '{split_name}' split to divide into multiple splits. {split_counts=}")
323
+ assert len(dataset_split_to_be_divided) > 0, f"No samples in the '{split_name}' split!"
324
+ dataset_split_to_be_divided = dataset_split_to_be_divided.splits_by_ratios(split_ratios=split_ratios, seed=42)
325
+
326
+ remaining_data = dataset.samples.filter(pl.col(ColumnName.SPLIT).is_in([split_name]).not_())
327
+ new_table = pl.concat([remaining_data, dataset_split_to_be_divided.samples], how="vertical")
328
+ dataset_new = dataset.update_table(new_table)
329
+ return dataset_new
330
+
331
+ def define_sample_set_by_size(dataset: "HafniaDataset", n_samples: int, seed: int = 42) -> "HafniaDataset":
332
+ is_sample_indices = Random(seed).sample(range(len(dataset)), n_samples)
333
+ is_sample_column = [False for _ in range(len(dataset))]
334
+ for idx in is_sample_indices:
335
+ is_sample_column[idx] = True
336
+
337
+ table = dataset.samples.with_columns(pl.Series(is_sample_column).alias("is_sample"))
338
+ return dataset.update_table(table)
339
+
340
+ def merge(dataset0: "HafniaDataset", dataset1: "HafniaDataset") -> "HafniaDataset":
341
+ """
342
+ Merges two Hafnia datasets by concatenating their samples and updating the split names.
343
+ """
344
+ ## Currently, only a very naive merging is implemented.
345
+ # In the future we need to verify that the class and tasks are compatible.
346
+ # Do they have similar classes and tasks? What to do if they don't?
347
+ # For now, we just concatenate the samples and keep the split names as they are.
348
+ merged_samples = pl.concat([dataset0.samples, dataset1.samples], how="vertical")
349
+ return dataset0.update_table(merged_samples)
350
+
351
+ # Dataset stats
352
+ split_counts = dataset_stats.split_counts
353
+
354
+ def as_dict_dataset_splits(self) -> Dict[str, "HafniaDataset"]:
355
+ if ColumnName.SPLIT not in self.samples.columns:
356
+ raise ValueError(f"Dataset must contain a '{ColumnName.SPLIT}' column.")
357
+
358
+ splits = {}
359
+ for split_name in SplitName.valid_splits():
360
+ splits[split_name] = self.create_split_dataset(split_name)
361
+
362
+ return splits
363
+
364
+ def create_sample_dataset(self) -> "HafniaDataset":
365
+ if ColumnName.IS_SAMPLE not in self.samples.columns:
366
+ raise ValueError(f"Dataset must contain an '{ColumnName.IS_SAMPLE}' column.")
367
+ table = self.samples.filter(pl.col(ColumnName.IS_SAMPLE))
368
+ return self.update_table(table)
369
+
370
+ def create_split_dataset(self, split_name: Union[str | List[str]]) -> "HafniaDataset":
371
+ if isinstance(split_name, str):
372
+ split_names = [split_name]
373
+ elif isinstance(split_name, list):
374
+ split_names = split_name
375
+
376
+ for name in split_names:
377
+ if name not in SplitName.valid_splits():
378
+ raise ValueError(f"Invalid split name: {split_name}. Valid splits are: {SplitName.valid_splits()}")
379
+
380
+ filtered_dataset = self.samples.filter(pl.col(ColumnName.SPLIT).is_in(split_names))
381
+ return self.update_table(filtered_dataset)
382
+
383
+ def get_task_by_name(self, task_name: str) -> TaskInfo:
384
+ for task in self.info.tasks:
385
+ if task.name == task_name:
386
+ return task
387
+ raise ValueError(f"Task with name {task_name} not found in dataset info.")
388
+
389
+ def update_table(self, table: pl.DataFrame) -> "HafniaDataset":
390
+ return HafniaDataset(info=self.info.model_copy(), samples=table)
391
+
392
+ @staticmethod
393
+ def check_dataset_path(path_dataset: Path, raise_error: bool = True) -> bool:
394
+ """
395
+ Checks if the dataset path exists and contains the required files.
396
+ Returns True if the dataset is valid, otherwise raises an error or returns False.
397
+ """
398
+ if not path_dataset.exists():
399
+ if raise_error:
400
+ raise FileNotFoundError(f"Dataset path {path_dataset} does not exist.")
401
+ return False
402
+
403
+ required_files = [
404
+ FILENAME_DATASET_INFO,
405
+ FILENAME_ANNOTATIONS_JSONL,
406
+ FILENAME_ANNOTATIONS_PARQUET,
407
+ ]
408
+ for filename in required_files:
409
+ if not (path_dataset / filename).exists():
410
+ if raise_error:
411
+ raise FileNotFoundError(f"Required file {filename} not found in {path_dataset}.")
412
+ return False
413
+
414
+ return True
415
+
416
+ def write(self, path_folder: Path, name_by_hash: bool = True, add_version: bool = False) -> None:
417
+ user_logger.info(f"Writing dataset to {path_folder}...")
418
+ if not path_folder.exists():
419
+ path_folder.mkdir(parents=True)
420
+ path_folder_images = path_folder / "data"
421
+ path_folder_images.mkdir(parents=True, exist_ok=True)
422
+
423
+ new_relative_paths = []
424
+ for org_path in tqdm(self.samples["file_name"].to_list(), desc="- Copy images"):
425
+ org_path = Path(org_path)
426
+ if not org_path.exists():
427
+ raise FileNotFoundError(f"File {org_path} does not exist in the dataset.")
428
+ if name_by_hash:
429
+ filename = dataset_helpers.filename_as_hash_from_path(org_path)
430
+ else:
431
+ filename = Path(org_path).name
432
+ new_path = path_folder_images / filename
433
+ if not new_path.exists():
434
+ shutil.copy2(org_path, new_path)
435
+
436
+ if not new_path.exists():
437
+ raise FileNotFoundError(f"File {new_path} does not exist in the dataset.")
438
+ new_relative_paths.append(str(new_path.relative_to(path_folder)))
439
+
440
+ table = self.samples.with_columns(pl.Series(new_relative_paths).alias("file_name"))
441
+ table.write_ndjson(path_folder / FILENAME_ANNOTATIONS_JSONL) # Json for readability
442
+ table.write_parquet(path_folder / FILENAME_ANNOTATIONS_PARQUET) # Parquet for speed
443
+ self.info.write_json(path_folder / FILENAME_DATASET_INFO)
444
+
445
+ if add_version:
446
+ path_version = path_folder / "versions" / f"{self.info.version}"
447
+ path_version.mkdir(parents=True, exist_ok=True)
448
+ for filename in DATASET_FILENAMES_REQUIRED:
449
+ shutil.copy2(path_folder / filename, path_version / filename)
450
+
451
+ def __eq__(self, value) -> bool:
452
+ if not isinstance(value, HafniaDataset):
453
+ return False
454
+
455
+ if self.info != value.info:
456
+ return False
457
+
458
+ if not isinstance(self.samples, pl.DataFrame) or not isinstance(value.samples, pl.DataFrame):
459
+ return False
460
+
461
+ if not self.samples.equals(value.samples):
462
+ return False
463
+ return True
464
+
465
+ def print_stats(self) -> None:
466
+ table_base = Table(title="Dataset Statistics", show_lines=True, box=rich.box.SIMPLE)
467
+ table_base.add_column("Property", style="cyan")
468
+ table_base.add_column("Value")
469
+ table_base.add_row("Dataset Name", self.info.dataset_name)
470
+ table_base.add_row("Version", self.info.version)
471
+ table_base.add_row("Number of samples", str(len(self.samples)))
472
+ rprint(table_base)
473
+ rprint(self.info.tasks)
474
+
475
+ splits_sets = {
476
+ "All": SplitName.valid_splits(),
477
+ "Train": [SplitName.TRAIN],
478
+ "Validation": [SplitName.VAL],
479
+ "Test": [SplitName.TEST],
480
+ }
481
+ rows = []
482
+ for split_name, splits in splits_sets.items():
483
+ dataset_split = self.create_split_dataset(splits)
484
+ table = dataset_split.samples
485
+ row = {}
486
+ row["Split"] = split_name
487
+ row["Sample "] = str(len(table))
488
+ for PrimitiveType in PRIMITIVE_TYPES:
489
+ column_name = PrimitiveType.column_name()
490
+ objects_df = create_primitive_table(table, PrimitiveType=PrimitiveType, keep_sample_data=False)
491
+ if objects_df is None:
492
+ continue
493
+ for (task_name,), object_group in objects_df.group_by(FieldName.TASK_NAME):
494
+ count = len(object_group[FieldName.CLASS_NAME])
495
+ row[f"{PrimitiveType.__name__}\n{task_name}"] = str(count)
496
+ rows.append(row)
497
+
498
+ rich_table = Table(title="Dataset Statistics", show_lines=True, box=rich.box.SIMPLE)
499
+ for i_row, row in enumerate(rows):
500
+ if i_row == 0:
501
+ for column_name in row.keys():
502
+ rich_table.add_column(column_name, justify="left", style="cyan")
503
+ rich_table.add_row(*[str(value) for value in row.values()])
504
+ rprint(rich_table)
505
+
506
+
507
+ def check_hafnia_dataset_from_path(path_dataset: Path) -> None:
508
+ dataset = HafniaDataset.from_path(path_dataset, check_for_images=True)
509
+ check_hafnia_dataset(dataset)
510
+
511
+
512
+ def get_or_create_dataset_path_from_recipe(
513
+ dataset_recipe: Any,
514
+ force_redownload: bool = False,
515
+ path_datasets: Optional[Union[Path, str]] = None,
516
+ ) -> Path:
517
+ from hafnia.dataset.dataset_recipe.dataset_recipe import (
518
+ DatasetRecipe,
519
+ get_dataset_path_from_recipe,
520
+ )
521
+
522
+ recipe: DatasetRecipe = DatasetRecipe.from_implicit_form(dataset_recipe)
523
+ path_dataset = get_dataset_path_from_recipe(recipe, path_datasets=path_datasets)
524
+
525
+ if force_redownload:
526
+ shutil.rmtree(path_dataset, ignore_errors=True)
527
+
528
+ if HafniaDataset.check_dataset_path(path_dataset, raise_error=False):
529
+ return path_dataset
530
+
531
+ path_dataset.mkdir(parents=True, exist_ok=True)
532
+ path_recipe_json = path_dataset / FILENAME_RECIPE_JSON
533
+ path_recipe_json.write_text(recipe.model_dump_json(indent=4))
534
+
535
+ dataset: HafniaDataset = recipe.build()
536
+ dataset.write(path_dataset)
537
+
538
+ return path_dataset
539
+
540
+
541
+ def check_hafnia_dataset(dataset: HafniaDataset):
542
+ user_logger.info("Checking Hafnia dataset...")
543
+ assert isinstance(dataset.info.version, str) and len(dataset.info.version) > 0
544
+ assert isinstance(dataset.info.dataset_name, str) and len(dataset.info.dataset_name) > 0
545
+
546
+ is_sample_list = set(dataset.samples.select(pl.col(ColumnName.IS_SAMPLE)).unique().to_series().to_list())
547
+ if True not in is_sample_list:
548
+ raise ValueError(f"The dataset should contain '{ColumnName.IS_SAMPLE}=True' samples")
549
+
550
+ actual_splits = dataset.samples.select(pl.col(ColumnName.SPLIT)).unique().to_series().to_list()
551
+ expected_splits = SplitName.valid_splits()
552
+ if set(actual_splits) != set(expected_splits):
553
+ raise ValueError(f"Expected all splits '{expected_splits}' in dataset, but got '{actual_splits}'. ")
554
+
555
+ expected_tasks = dataset.info.tasks
556
+ for task in expected_tasks:
557
+ primitive = task.primitive.__name__
558
+ column_name = task.primitive.column_name()
559
+ primitive_column = dataset.samples[column_name]
560
+ # msg_something_wrong = f"Something is wrong with the '{primtive_name}' task '{task.name}' in dataset '{dataset.name}'. "
561
+ msg_something_wrong = (
562
+ f"Something is wrong with the defined tasks ('info.tasks') in dataset '{dataset.info.dataset_name}'. \n"
563
+ f"For '{primitive=}' and '{task.name=}' "
564
+ )
565
+ if primitive_column.dtype == pl.Null:
566
+ raise ValueError(msg_something_wrong + "the column is 'Null'. Please check the dataset.")
567
+
568
+ primitive_table = primitive_column.explode().struct.unnest().filter(pl.col(FieldName.TASK_NAME) == task.name)
569
+ if primitive_table.is_empty():
570
+ raise ValueError(
571
+ msg_something_wrong
572
+ + f"the column '{column_name}' has no {task.name=} objects. Please check the dataset."
573
+ )
574
+
575
+ actual_classes = set(primitive_table[FieldName.CLASS_NAME].unique().to_list())
576
+ if task.class_names is None:
577
+ raise ValueError(
578
+ msg_something_wrong
579
+ + f"the column '{column_name}' with {task.name=} has no defined classes. Please check the dataset."
580
+ )
581
+ defined_classes = set(task.class_names)
582
+
583
+ if not actual_classes.issubset(defined_classes):
584
+ raise ValueError(
585
+ msg_something_wrong
586
+ + f"the column '{column_name}' with {task.name=} we expected the actual classes in the dataset to \n"
587
+ f"to be a subset of the defined classes\n\t{actual_classes=} \n\t{defined_classes=}."
588
+ )
589
+ # Check class_indices
590
+ mapped_indices = primitive_table[FieldName.CLASS_NAME].map_elements(
591
+ lambda x: task.class_names.index(x), return_dtype=pl.Int64
592
+ )
593
+ table_indices = primitive_table[FieldName.CLASS_IDX]
594
+
595
+ error_msg = msg_something_wrong + (
596
+ f"class indices in '{FieldName.CLASS_IDX}' column does not match classes ordering in 'task.class_names'"
597
+ )
598
+ assert mapped_indices.equals(table_indices), error_msg
599
+
600
+ distribution = dataset.info.distributions or []
601
+ distribution_names = [task.name for task in distribution]
602
+ # Check that tasks found in the 'dataset.table' matches the tasks defined in 'dataset.info.tasks'
603
+ for PrimitiveType in PRIMITIVE_TYPES:
604
+ column_name = PrimitiveType.column_name()
605
+ if column_name not in dataset.samples.columns:
606
+ continue
607
+ objects_df = create_primitive_table(dataset.samples, PrimitiveType=PrimitiveType, keep_sample_data=False)
608
+ if objects_df is None:
609
+ continue
610
+ for (task_name,), object_group in objects_df.group_by(FieldName.TASK_NAME):
611
+ has_task = any([t for t in expected_tasks if t.name == task_name and t.primitive == PrimitiveType])
612
+ if has_task:
613
+ continue
614
+ if task_name in distribution_names:
615
+ continue
616
+ class_names = object_group[FieldName.CLASS_NAME].unique().to_list()
617
+ raise ValueError(
618
+ f"Task name '{task_name}' for the '{PrimitiveType.__name__}' primitive is missing in "
619
+ f"'dataset.info.tasks' for dataset '{task_name}'. Missing task has the following "
620
+ f"classes: {class_names}. "
621
+ )
622
+
623
+ for sample_dict in tqdm(dataset, desc="Checking samples in dataset"):
624
+ sample = Sample(**sample_dict) # Checks format of all samples with pydantic validation # noqa: F841
@@ -0,0 +1,15 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Dict
4
+
5
+ from hafnia.dataset.dataset_names import ColumnName
6
+
7
+ if TYPE_CHECKING:
8
+ from hafnia.dataset.hafnia_dataset import HafniaDataset
9
+
10
+
11
+ def split_counts(dataset: HafniaDataset) -> Dict[str, int]:
12
+ """
13
+ Returns a dictionary with the counts of samples in each split of the dataset.
14
+ """
15
+ return dict(dataset.samples[ColumnName.SPLIT].value_counts().iter_rows())