hafnia 0.1.27__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/__main__.py +2 -2
- cli/config.py +17 -4
- cli/dataset_cmds.py +60 -0
- cli/runc_cmds.py +1 -1
- hafnia/data/__init__.py +2 -2
- hafnia/data/factory.py +12 -56
- hafnia/dataset/dataset_helpers.py +91 -0
- hafnia/dataset/dataset_names.py +72 -0
- hafnia/dataset/dataset_recipe/dataset_recipe.py +327 -0
- hafnia/dataset/dataset_recipe/recipe_transforms.py +53 -0
- hafnia/dataset/dataset_recipe/recipe_types.py +140 -0
- hafnia/dataset/dataset_upload_helper.py +468 -0
- hafnia/dataset/hafnia_dataset.py +624 -0
- hafnia/dataset/operations/dataset_stats.py +15 -0
- hafnia/dataset/operations/dataset_transformations.py +82 -0
- hafnia/dataset/operations/table_transformations.py +183 -0
- hafnia/dataset/primitives/__init__.py +16 -0
- hafnia/dataset/primitives/bbox.py +137 -0
- hafnia/dataset/primitives/bitmask.py +182 -0
- hafnia/dataset/primitives/classification.py +56 -0
- hafnia/dataset/primitives/point.py +25 -0
- hafnia/dataset/primitives/polygon.py +100 -0
- hafnia/dataset/primitives/primitive.py +44 -0
- hafnia/dataset/primitives/segmentation.py +51 -0
- hafnia/dataset/primitives/utils.py +51 -0
- hafnia/experiment/hafnia_logger.py +7 -7
- hafnia/helper_testing.py +108 -0
- hafnia/http.py +5 -3
- hafnia/platform/__init__.py +2 -2
- hafnia/platform/datasets.py +197 -0
- hafnia/platform/download.py +85 -23
- hafnia/torch_helpers.py +180 -95
- hafnia/utils.py +21 -2
- hafnia/visualizations/colors.py +267 -0
- hafnia/visualizations/image_visualizations.py +202 -0
- {hafnia-0.1.27.dist-info → hafnia-0.2.1.dist-info}/METADATA +209 -99
- hafnia-0.2.1.dist-info/RECORD +50 -0
- cli/data_cmds.py +0 -53
- hafnia-0.1.27.dist-info/RECORD +0 -27
- {hafnia-0.1.27.dist-info → hafnia-0.2.1.dist-info}/WHEEL +0 -0
- {hafnia-0.1.27.dist-info → hafnia-0.2.1.dist-info}/entry_points.txt +0 -0
- {hafnia-0.1.27.dist-info → hafnia-0.2.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,624 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import shutil
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from random import Random
|
|
8
|
+
from typing import Any, Dict, List, Optional, Type, Union
|
|
9
|
+
|
|
10
|
+
import more_itertools
|
|
11
|
+
import numpy as np
|
|
12
|
+
import polars as pl
|
|
13
|
+
import rich
|
|
14
|
+
from PIL import Image
|
|
15
|
+
from pydantic import BaseModel, field_serializer, field_validator
|
|
16
|
+
from rich import print as rprint
|
|
17
|
+
from rich.table import Table
|
|
18
|
+
from tqdm import tqdm
|
|
19
|
+
|
|
20
|
+
from hafnia.dataset import dataset_helpers
|
|
21
|
+
from hafnia.dataset.dataset_names import (
|
|
22
|
+
DATASET_FILENAMES_REQUIRED,
|
|
23
|
+
FILENAME_ANNOTATIONS_JSONL,
|
|
24
|
+
FILENAME_ANNOTATIONS_PARQUET,
|
|
25
|
+
FILENAME_DATASET_INFO,
|
|
26
|
+
FILENAME_RECIPE_JSON,
|
|
27
|
+
ColumnName,
|
|
28
|
+
FieldName,
|
|
29
|
+
SplitName,
|
|
30
|
+
)
|
|
31
|
+
from hafnia.dataset.operations import dataset_stats, dataset_transformations
|
|
32
|
+
from hafnia.dataset.operations.table_transformations import (
|
|
33
|
+
check_image_paths,
|
|
34
|
+
create_primitive_table,
|
|
35
|
+
read_table_from_path,
|
|
36
|
+
)
|
|
37
|
+
from hafnia.dataset.primitives import (
|
|
38
|
+
PRIMITIVE_NAME_TO_TYPE,
|
|
39
|
+
PRIMITIVE_TYPES,
|
|
40
|
+
)
|
|
41
|
+
from hafnia.dataset.primitives.bbox import Bbox
|
|
42
|
+
from hafnia.dataset.primitives.bitmask import Bitmask
|
|
43
|
+
from hafnia.dataset.primitives.classification import Classification
|
|
44
|
+
from hafnia.dataset.primitives.polygon import Polygon
|
|
45
|
+
from hafnia.dataset.primitives.primitive import Primitive
|
|
46
|
+
from hafnia.log import user_logger
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class TaskInfo(BaseModel):
|
|
50
|
+
primitive: Type[Primitive] # Primitive class or string name of the primitive, e.g. "Bbox" or "bitmask"
|
|
51
|
+
class_names: Optional[List[str]] # Class names for the tasks. To get consistent class indices specify class_names.
|
|
52
|
+
name: Optional[str] = (
|
|
53
|
+
None # None to use the default primitive task name Bbox ->"bboxes", Bitmask -> "bitmasks" etc.
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
def model_post_init(self, __context: Any) -> None:
|
|
57
|
+
if self.name is None:
|
|
58
|
+
self.name = self.primitive.default_task_name()
|
|
59
|
+
|
|
60
|
+
# The 'primitive'-field of type 'Type[Primitive]' is not supported by pydantic out-of-the-box as
|
|
61
|
+
# the 'Primitive' class is an abstract base class and for the actual primtives such as Bbox, Bitmask, Classification.
|
|
62
|
+
# Below magic functions ('ensure_primitive' and 'serialize_primitive') ensures that the 'primitive' field can
|
|
63
|
+
# correctly validate and serialize sub-classes (Bbox, Classification, ...).
|
|
64
|
+
@field_validator("primitive", mode="plain")
|
|
65
|
+
@classmethod
|
|
66
|
+
def ensure_primitive(cls, primitive: Any) -> Any:
|
|
67
|
+
if isinstance(primitive, str):
|
|
68
|
+
if primitive not in PRIMITIVE_NAME_TO_TYPE:
|
|
69
|
+
raise ValueError(
|
|
70
|
+
f"Primitive '{primitive}' is not recognized. Available primitives: {list(PRIMITIVE_NAME_TO_TYPE.keys())}"
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
return PRIMITIVE_NAME_TO_TYPE[primitive]
|
|
74
|
+
|
|
75
|
+
if issubclass(primitive, Primitive):
|
|
76
|
+
return primitive
|
|
77
|
+
|
|
78
|
+
raise ValueError(f"Primitive must be a string or a Primitive subclass, got {type(primitive)} instead.")
|
|
79
|
+
|
|
80
|
+
@field_serializer("primitive")
|
|
81
|
+
@classmethod
|
|
82
|
+
def serialize_primitive(cls, primitive: Type[Primitive]) -> str:
|
|
83
|
+
if not issubclass(primitive, Primitive):
|
|
84
|
+
raise ValueError(f"Primitive must be a subclass of Primitive, got {type(primitive)} instead.")
|
|
85
|
+
return primitive.__name__
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class DatasetInfo(BaseModel):
|
|
89
|
+
dataset_name: str
|
|
90
|
+
version: str
|
|
91
|
+
tasks: list[TaskInfo]
|
|
92
|
+
distributions: Optional[List[TaskInfo]] = None # Distributions. TODO: FIX/REMOVE/CHANGE this
|
|
93
|
+
meta: Optional[Dict[str, Any]] = None # Metadata about the dataset, e.g. description, etc.
|
|
94
|
+
|
|
95
|
+
def write_json(self, path: Path, indent: Optional[int] = 4) -> None:
|
|
96
|
+
json_str = self.model_dump_json(indent=indent)
|
|
97
|
+
path.write_text(json_str)
|
|
98
|
+
|
|
99
|
+
@staticmethod
|
|
100
|
+
def from_json_file(path: Path) -> "DatasetInfo":
|
|
101
|
+
json_str = path.read_text()
|
|
102
|
+
return DatasetInfo.model_validate_json(json_str)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class Sample(BaseModel):
|
|
106
|
+
file_name: str
|
|
107
|
+
height: int
|
|
108
|
+
width: int
|
|
109
|
+
split: str # Split name, e.g., "train", "val", "test"
|
|
110
|
+
is_sample: bool # Indicates if this is a sample (True) or a metadata entry (False)
|
|
111
|
+
collection_index: Optional[int] = None # Optional e.g. frame number for video datasets
|
|
112
|
+
collection_id: Optional[str] = None # Optional e.g. video name for video datasets
|
|
113
|
+
remote_path: Optional[str] = None # Optional remote path for the image, if applicable
|
|
114
|
+
sample_index: Optional[int] = None # Don't manually set this, it is used for indexing samples in the dataset.
|
|
115
|
+
classifications: Optional[List[Classification]] = None # Optional classification primitive
|
|
116
|
+
objects: Optional[List[Bbox]] = None # List of coordinate primitives, e.g., Bbox, Bitmask, etc.
|
|
117
|
+
bitmasks: Optional[List[Bitmask]] = None # List of bitmasks, if applicable
|
|
118
|
+
polygons: Optional[List[Polygon]] = None # List of polygons, if applicable
|
|
119
|
+
|
|
120
|
+
meta: Optional[Dict] = None # Additional metadata, e.g., camera settings, GPS data, etc.
|
|
121
|
+
|
|
122
|
+
def get_annotations(self, primitive_types: Optional[List[Type[Primitive]]] = None) -> List[Primitive]:
|
|
123
|
+
"""
|
|
124
|
+
Returns a list of all annotations (classifications, objects, bitmasks, polygons) for the sample.
|
|
125
|
+
"""
|
|
126
|
+
primitive_types = primitive_types or PRIMITIVE_TYPES
|
|
127
|
+
annotations_primitives = [
|
|
128
|
+
getattr(self, primitive_type.column_name(), None) for primitive_type in primitive_types
|
|
129
|
+
]
|
|
130
|
+
annotations = more_itertools.flatten(
|
|
131
|
+
[primitives for primitives in annotations_primitives if primitives is not None]
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
return list(annotations)
|
|
135
|
+
|
|
136
|
+
def read_image_pillow(self) -> Image.Image:
|
|
137
|
+
"""
|
|
138
|
+
Reads the image from the file path and returns it as a PIL Image.
|
|
139
|
+
Raises FileNotFoundError if the image file does not exist.
|
|
140
|
+
"""
|
|
141
|
+
path_image = Path(self.file_name)
|
|
142
|
+
if not path_image.exists():
|
|
143
|
+
raise FileNotFoundError(f"Image file {path_image} does not exist. Please check the file path.")
|
|
144
|
+
|
|
145
|
+
image = Image.open(str(path_image))
|
|
146
|
+
return image
|
|
147
|
+
|
|
148
|
+
def read_image(self) -> np.ndarray:
|
|
149
|
+
image_pil = self.read_image_pillow()
|
|
150
|
+
image = np.array(image_pil)
|
|
151
|
+
return image
|
|
152
|
+
|
|
153
|
+
def draw_annotations(self, image: Optional[np.ndarray] = None) -> np.ndarray:
|
|
154
|
+
from hafnia.visualizations import image_visualizations
|
|
155
|
+
|
|
156
|
+
image = image or self.read_image()
|
|
157
|
+
annotations = self.get_annotations()
|
|
158
|
+
annotations_visualized = image_visualizations.draw_annotations(image=image, primitives=annotations)
|
|
159
|
+
return annotations_visualized
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
@dataclass
|
|
163
|
+
class HafniaDataset:
|
|
164
|
+
info: DatasetInfo
|
|
165
|
+
samples: pl.DataFrame
|
|
166
|
+
|
|
167
|
+
def __getitem__(self, item: int) -> Dict[str, Any]:
|
|
168
|
+
return self.samples.row(index=item, named=True)
|
|
169
|
+
|
|
170
|
+
def __len__(self) -> int:
|
|
171
|
+
return len(self.samples)
|
|
172
|
+
|
|
173
|
+
def __iter__(self):
|
|
174
|
+
for row in self.samples.iter_rows(named=True):
|
|
175
|
+
yield row
|
|
176
|
+
|
|
177
|
+
@staticmethod
|
|
178
|
+
def from_path(path_folder: Path, check_for_images: bool = True) -> "HafniaDataset":
|
|
179
|
+
HafniaDataset.check_dataset_path(path_folder, raise_error=True)
|
|
180
|
+
|
|
181
|
+
dataset_info = DatasetInfo.from_json_file(path_folder / FILENAME_DATASET_INFO)
|
|
182
|
+
table = read_table_from_path(path_folder)
|
|
183
|
+
|
|
184
|
+
# Convert from relative paths to absolute paths
|
|
185
|
+
table = table.with_columns(
|
|
186
|
+
pl.concat_str([pl.lit(str(path_folder.absolute()) + os.sep), pl.col("file_name")]).alias("file_name")
|
|
187
|
+
)
|
|
188
|
+
if check_for_images:
|
|
189
|
+
check_image_paths(table)
|
|
190
|
+
return HafniaDataset(samples=table, info=dataset_info)
|
|
191
|
+
|
|
192
|
+
@staticmethod
|
|
193
|
+
def from_name(name: str, force_redownload: bool = False, download_files: bool = True) -> "HafniaDataset":
|
|
194
|
+
"""
|
|
195
|
+
Load a dataset by its name. The dataset must be registered in the Hafnia platform.
|
|
196
|
+
"""
|
|
197
|
+
from hafnia.dataset.hafnia_dataset import HafniaDataset
|
|
198
|
+
from hafnia.platform.datasets import download_or_get_dataset_path
|
|
199
|
+
|
|
200
|
+
dataset_path = download_or_get_dataset_path(
|
|
201
|
+
dataset_name=name, force_redownload=force_redownload, download_files=download_files
|
|
202
|
+
)
|
|
203
|
+
return HafniaDataset.from_path(dataset_path, check_for_images=download_files)
|
|
204
|
+
|
|
205
|
+
@staticmethod
|
|
206
|
+
def from_samples_list(samples_list: List, info: DatasetInfo) -> "HafniaDataset":
|
|
207
|
+
sample = samples_list[0]
|
|
208
|
+
if isinstance(sample, Sample):
|
|
209
|
+
json_samples = [sample.model_dump(mode="json") for sample in samples_list]
|
|
210
|
+
elif isinstance(sample, dict):
|
|
211
|
+
json_samples = samples_list
|
|
212
|
+
else:
|
|
213
|
+
raise TypeError(f"Unsupported sample type: {type(sample)}. Expected Sample or dict.")
|
|
214
|
+
|
|
215
|
+
table = pl.from_records(json_samples).drop(ColumnName.SAMPLE_INDEX)
|
|
216
|
+
table = table.with_row_index(name=ColumnName.SAMPLE_INDEX) # Add sample index column
|
|
217
|
+
|
|
218
|
+
return HafniaDataset(info=info, samples=table)
|
|
219
|
+
|
|
220
|
+
@staticmethod
|
|
221
|
+
def from_recipe(dataset_recipe: Any) -> "HafniaDataset":
|
|
222
|
+
"""
|
|
223
|
+
Load a dataset from a recipe. The recipe can be a string (name of the dataset), a dictionary, or a DataRecipe object.
|
|
224
|
+
"""
|
|
225
|
+
from hafnia.dataset.dataset_recipe.dataset_recipe import DatasetRecipe
|
|
226
|
+
|
|
227
|
+
recipe_explicit = DatasetRecipe.from_implicit_form(dataset_recipe)
|
|
228
|
+
|
|
229
|
+
return recipe_explicit.build() # Build dataset from the recipe
|
|
230
|
+
|
|
231
|
+
@staticmethod
|
|
232
|
+
def from_merge(dataset0: "HafniaDataset", dataset1: "HafniaDataset") -> "HafniaDataset":
|
|
233
|
+
return HafniaDataset.merge(dataset0, dataset1)
|
|
234
|
+
|
|
235
|
+
@staticmethod
|
|
236
|
+
def from_recipe_with_cache(
|
|
237
|
+
dataset_recipe: Any,
|
|
238
|
+
force_redownload: bool = False,
|
|
239
|
+
path_datasets: Optional[Union[Path, str]] = None,
|
|
240
|
+
) -> "HafniaDataset":
|
|
241
|
+
"""
|
|
242
|
+
Loads a dataset from a recipe and caches it to disk.
|
|
243
|
+
If the dataset is already cached, it will be loaded from the cache.
|
|
244
|
+
"""
|
|
245
|
+
|
|
246
|
+
path_dataset = get_or_create_dataset_path_from_recipe(dataset_recipe, path_datasets=path_datasets)
|
|
247
|
+
return HafniaDataset.from_path(path_dataset, check_for_images=False)
|
|
248
|
+
|
|
249
|
+
@staticmethod
|
|
250
|
+
def from_merger(
|
|
251
|
+
datasets: List[HafniaDataset],
|
|
252
|
+
) -> "HafniaDataset":
|
|
253
|
+
"""
|
|
254
|
+
Merges multiple Hafnia datasets into one.
|
|
255
|
+
"""
|
|
256
|
+
if len(datasets) == 0:
|
|
257
|
+
raise ValueError("No datasets to merge. Please provide at least one dataset.")
|
|
258
|
+
|
|
259
|
+
if len(datasets) == 1:
|
|
260
|
+
return datasets[0]
|
|
261
|
+
|
|
262
|
+
merged_dataset = datasets[0]
|
|
263
|
+
remaining_datasets = datasets[1:]
|
|
264
|
+
for dataset in remaining_datasets:
|
|
265
|
+
merged_dataset = HafniaDataset.merge(merged_dataset, dataset)
|
|
266
|
+
return merged_dataset
|
|
267
|
+
|
|
268
|
+
# Dataset transformations
|
|
269
|
+
transform_images = dataset_transformations.transform_images
|
|
270
|
+
|
|
271
|
+
def shuffle(dataset: HafniaDataset, seed: int = 42) -> HafniaDataset:
|
|
272
|
+
table = dataset.samples.sample(n=len(dataset), with_replacement=False, seed=seed, shuffle=True)
|
|
273
|
+
return dataset.update_table(table)
|
|
274
|
+
|
|
275
|
+
def select_samples(
|
|
276
|
+
dataset: "HafniaDataset", n_samples: int, shuffle: bool = True, seed: int = 42, with_replacement: bool = False
|
|
277
|
+
) -> "HafniaDataset":
|
|
278
|
+
if not with_replacement:
|
|
279
|
+
n_samples = min(n_samples, len(dataset))
|
|
280
|
+
table = dataset.samples.sample(n=n_samples, with_replacement=with_replacement, seed=seed, shuffle=shuffle)
|
|
281
|
+
return dataset.update_table(table)
|
|
282
|
+
|
|
283
|
+
def splits_by_ratios(dataset: "HafniaDataset", split_ratios: Dict[str, float], seed: int = 42) -> "HafniaDataset":
|
|
284
|
+
"""
|
|
285
|
+
Divides the dataset into splits based on the provided ratios.
|
|
286
|
+
|
|
287
|
+
Example: Defining split ratios and applying the transformation
|
|
288
|
+
|
|
289
|
+
>>> dataset = HafniaDataset.read_from_path(Path("path/to/dataset"))
|
|
290
|
+
>>> split_ratios = {SplitName.TRAIN: 0.8, SplitName.VAL: 0.1, SplitName.TEST: 0.1}
|
|
291
|
+
>>> dataset_with_splits = splits_by_ratios(dataset, split_ratios, seed=42)
|
|
292
|
+
Or use the function as a
|
|
293
|
+
>>> dataset_with_splits = dataset.splits_by_ratios(split_ratios, seed=42)
|
|
294
|
+
"""
|
|
295
|
+
n_items = len(dataset)
|
|
296
|
+
split_name_column = dataset_helpers.create_split_name_list_from_ratios(
|
|
297
|
+
split_ratios=split_ratios, n_items=n_items, seed=seed
|
|
298
|
+
)
|
|
299
|
+
table = dataset.samples.with_columns(pl.Series(split_name_column).alias("split"))
|
|
300
|
+
return dataset.update_table(table)
|
|
301
|
+
|
|
302
|
+
def split_into_multiple_splits(
|
|
303
|
+
dataset: "HafniaDataset",
|
|
304
|
+
split_name: str,
|
|
305
|
+
split_ratios: Dict[str, float],
|
|
306
|
+
) -> "HafniaDataset":
|
|
307
|
+
"""
|
|
308
|
+
Divides a dataset split ('split_name') into multiple splits based on the provided split
|
|
309
|
+
ratios ('split_ratios'). This is especially useful for some open datasets where they have only provide
|
|
310
|
+
two splits or only provide annotations for two splits. This function allows you to create additional
|
|
311
|
+
splits based on the provided ratios.
|
|
312
|
+
|
|
313
|
+
Example: Defining split ratios and applying the transformation
|
|
314
|
+
>>> dataset = HafniaDataset.read_from_path(Path("path/to/dataset"))
|
|
315
|
+
>>> split_name = SplitName.TEST
|
|
316
|
+
>>> split_ratios = {SplitName.TEST: 0.8, SplitName.VAL: 0.2}
|
|
317
|
+
>>> dataset_with_splits = split_into_multiple_splits(dataset, split_name, split_ratios)
|
|
318
|
+
"""
|
|
319
|
+
dataset_split_to_be_divided = dataset.create_split_dataset(split_name=split_name)
|
|
320
|
+
if len(dataset_split_to_be_divided) == 0:
|
|
321
|
+
split_counts = dict(dataset.samples.select(pl.col(ColumnName.SPLIT).value_counts()).iter_rows())
|
|
322
|
+
raise ValueError(f"No samples in the '{split_name}' split to divide into multiple splits. {split_counts=}")
|
|
323
|
+
assert len(dataset_split_to_be_divided) > 0, f"No samples in the '{split_name}' split!"
|
|
324
|
+
dataset_split_to_be_divided = dataset_split_to_be_divided.splits_by_ratios(split_ratios=split_ratios, seed=42)
|
|
325
|
+
|
|
326
|
+
remaining_data = dataset.samples.filter(pl.col(ColumnName.SPLIT).is_in([split_name]).not_())
|
|
327
|
+
new_table = pl.concat([remaining_data, dataset_split_to_be_divided.samples], how="vertical")
|
|
328
|
+
dataset_new = dataset.update_table(new_table)
|
|
329
|
+
return dataset_new
|
|
330
|
+
|
|
331
|
+
def define_sample_set_by_size(dataset: "HafniaDataset", n_samples: int, seed: int = 42) -> "HafniaDataset":
|
|
332
|
+
is_sample_indices = Random(seed).sample(range(len(dataset)), n_samples)
|
|
333
|
+
is_sample_column = [False for _ in range(len(dataset))]
|
|
334
|
+
for idx in is_sample_indices:
|
|
335
|
+
is_sample_column[idx] = True
|
|
336
|
+
|
|
337
|
+
table = dataset.samples.with_columns(pl.Series(is_sample_column).alias("is_sample"))
|
|
338
|
+
return dataset.update_table(table)
|
|
339
|
+
|
|
340
|
+
def merge(dataset0: "HafniaDataset", dataset1: "HafniaDataset") -> "HafniaDataset":
|
|
341
|
+
"""
|
|
342
|
+
Merges two Hafnia datasets by concatenating their samples and updating the split names.
|
|
343
|
+
"""
|
|
344
|
+
## Currently, only a very naive merging is implemented.
|
|
345
|
+
# In the future we need to verify that the class and tasks are compatible.
|
|
346
|
+
# Do they have similar classes and tasks? What to do if they don't?
|
|
347
|
+
# For now, we just concatenate the samples and keep the split names as they are.
|
|
348
|
+
merged_samples = pl.concat([dataset0.samples, dataset1.samples], how="vertical")
|
|
349
|
+
return dataset0.update_table(merged_samples)
|
|
350
|
+
|
|
351
|
+
# Dataset stats
|
|
352
|
+
split_counts = dataset_stats.split_counts
|
|
353
|
+
|
|
354
|
+
def as_dict_dataset_splits(self) -> Dict[str, "HafniaDataset"]:
|
|
355
|
+
if ColumnName.SPLIT not in self.samples.columns:
|
|
356
|
+
raise ValueError(f"Dataset must contain a '{ColumnName.SPLIT}' column.")
|
|
357
|
+
|
|
358
|
+
splits = {}
|
|
359
|
+
for split_name in SplitName.valid_splits():
|
|
360
|
+
splits[split_name] = self.create_split_dataset(split_name)
|
|
361
|
+
|
|
362
|
+
return splits
|
|
363
|
+
|
|
364
|
+
def create_sample_dataset(self) -> "HafniaDataset":
|
|
365
|
+
if ColumnName.IS_SAMPLE not in self.samples.columns:
|
|
366
|
+
raise ValueError(f"Dataset must contain an '{ColumnName.IS_SAMPLE}' column.")
|
|
367
|
+
table = self.samples.filter(pl.col(ColumnName.IS_SAMPLE))
|
|
368
|
+
return self.update_table(table)
|
|
369
|
+
|
|
370
|
+
def create_split_dataset(self, split_name: Union[str | List[str]]) -> "HafniaDataset":
|
|
371
|
+
if isinstance(split_name, str):
|
|
372
|
+
split_names = [split_name]
|
|
373
|
+
elif isinstance(split_name, list):
|
|
374
|
+
split_names = split_name
|
|
375
|
+
|
|
376
|
+
for name in split_names:
|
|
377
|
+
if name not in SplitName.valid_splits():
|
|
378
|
+
raise ValueError(f"Invalid split name: {split_name}. Valid splits are: {SplitName.valid_splits()}")
|
|
379
|
+
|
|
380
|
+
filtered_dataset = self.samples.filter(pl.col(ColumnName.SPLIT).is_in(split_names))
|
|
381
|
+
return self.update_table(filtered_dataset)
|
|
382
|
+
|
|
383
|
+
def get_task_by_name(self, task_name: str) -> TaskInfo:
|
|
384
|
+
for task in self.info.tasks:
|
|
385
|
+
if task.name == task_name:
|
|
386
|
+
return task
|
|
387
|
+
raise ValueError(f"Task with name {task_name} not found in dataset info.")
|
|
388
|
+
|
|
389
|
+
def update_table(self, table: pl.DataFrame) -> "HafniaDataset":
|
|
390
|
+
return HafniaDataset(info=self.info.model_copy(), samples=table)
|
|
391
|
+
|
|
392
|
+
@staticmethod
|
|
393
|
+
def check_dataset_path(path_dataset: Path, raise_error: bool = True) -> bool:
|
|
394
|
+
"""
|
|
395
|
+
Checks if the dataset path exists and contains the required files.
|
|
396
|
+
Returns True if the dataset is valid, otherwise raises an error or returns False.
|
|
397
|
+
"""
|
|
398
|
+
if not path_dataset.exists():
|
|
399
|
+
if raise_error:
|
|
400
|
+
raise FileNotFoundError(f"Dataset path {path_dataset} does not exist.")
|
|
401
|
+
return False
|
|
402
|
+
|
|
403
|
+
required_files = [
|
|
404
|
+
FILENAME_DATASET_INFO,
|
|
405
|
+
FILENAME_ANNOTATIONS_JSONL,
|
|
406
|
+
FILENAME_ANNOTATIONS_PARQUET,
|
|
407
|
+
]
|
|
408
|
+
for filename in required_files:
|
|
409
|
+
if not (path_dataset / filename).exists():
|
|
410
|
+
if raise_error:
|
|
411
|
+
raise FileNotFoundError(f"Required file {filename} not found in {path_dataset}.")
|
|
412
|
+
return False
|
|
413
|
+
|
|
414
|
+
return True
|
|
415
|
+
|
|
416
|
+
def write(self, path_folder: Path, name_by_hash: bool = True, add_version: bool = False) -> None:
|
|
417
|
+
user_logger.info(f"Writing dataset to {path_folder}...")
|
|
418
|
+
if not path_folder.exists():
|
|
419
|
+
path_folder.mkdir(parents=True)
|
|
420
|
+
path_folder_images = path_folder / "data"
|
|
421
|
+
path_folder_images.mkdir(parents=True, exist_ok=True)
|
|
422
|
+
|
|
423
|
+
new_relative_paths = []
|
|
424
|
+
for org_path in tqdm(self.samples["file_name"].to_list(), desc="- Copy images"):
|
|
425
|
+
org_path = Path(org_path)
|
|
426
|
+
if not org_path.exists():
|
|
427
|
+
raise FileNotFoundError(f"File {org_path} does not exist in the dataset.")
|
|
428
|
+
if name_by_hash:
|
|
429
|
+
filename = dataset_helpers.filename_as_hash_from_path(org_path)
|
|
430
|
+
else:
|
|
431
|
+
filename = Path(org_path).name
|
|
432
|
+
new_path = path_folder_images / filename
|
|
433
|
+
if not new_path.exists():
|
|
434
|
+
shutil.copy2(org_path, new_path)
|
|
435
|
+
|
|
436
|
+
if not new_path.exists():
|
|
437
|
+
raise FileNotFoundError(f"File {new_path} does not exist in the dataset.")
|
|
438
|
+
new_relative_paths.append(str(new_path.relative_to(path_folder)))
|
|
439
|
+
|
|
440
|
+
table = self.samples.with_columns(pl.Series(new_relative_paths).alias("file_name"))
|
|
441
|
+
table.write_ndjson(path_folder / FILENAME_ANNOTATIONS_JSONL) # Json for readability
|
|
442
|
+
table.write_parquet(path_folder / FILENAME_ANNOTATIONS_PARQUET) # Parquet for speed
|
|
443
|
+
self.info.write_json(path_folder / FILENAME_DATASET_INFO)
|
|
444
|
+
|
|
445
|
+
if add_version:
|
|
446
|
+
path_version = path_folder / "versions" / f"{self.info.version}"
|
|
447
|
+
path_version.mkdir(parents=True, exist_ok=True)
|
|
448
|
+
for filename in DATASET_FILENAMES_REQUIRED:
|
|
449
|
+
shutil.copy2(path_folder / filename, path_version / filename)
|
|
450
|
+
|
|
451
|
+
def __eq__(self, value) -> bool:
|
|
452
|
+
if not isinstance(value, HafniaDataset):
|
|
453
|
+
return False
|
|
454
|
+
|
|
455
|
+
if self.info != value.info:
|
|
456
|
+
return False
|
|
457
|
+
|
|
458
|
+
if not isinstance(self.samples, pl.DataFrame) or not isinstance(value.samples, pl.DataFrame):
|
|
459
|
+
return False
|
|
460
|
+
|
|
461
|
+
if not self.samples.equals(value.samples):
|
|
462
|
+
return False
|
|
463
|
+
return True
|
|
464
|
+
|
|
465
|
+
def print_stats(self) -> None:
|
|
466
|
+
table_base = Table(title="Dataset Statistics", show_lines=True, box=rich.box.SIMPLE)
|
|
467
|
+
table_base.add_column("Property", style="cyan")
|
|
468
|
+
table_base.add_column("Value")
|
|
469
|
+
table_base.add_row("Dataset Name", self.info.dataset_name)
|
|
470
|
+
table_base.add_row("Version", self.info.version)
|
|
471
|
+
table_base.add_row("Number of samples", str(len(self.samples)))
|
|
472
|
+
rprint(table_base)
|
|
473
|
+
rprint(self.info.tasks)
|
|
474
|
+
|
|
475
|
+
splits_sets = {
|
|
476
|
+
"All": SplitName.valid_splits(),
|
|
477
|
+
"Train": [SplitName.TRAIN],
|
|
478
|
+
"Validation": [SplitName.VAL],
|
|
479
|
+
"Test": [SplitName.TEST],
|
|
480
|
+
}
|
|
481
|
+
rows = []
|
|
482
|
+
for split_name, splits in splits_sets.items():
|
|
483
|
+
dataset_split = self.create_split_dataset(splits)
|
|
484
|
+
table = dataset_split.samples
|
|
485
|
+
row = {}
|
|
486
|
+
row["Split"] = split_name
|
|
487
|
+
row["Sample "] = str(len(table))
|
|
488
|
+
for PrimitiveType in PRIMITIVE_TYPES:
|
|
489
|
+
column_name = PrimitiveType.column_name()
|
|
490
|
+
objects_df = create_primitive_table(table, PrimitiveType=PrimitiveType, keep_sample_data=False)
|
|
491
|
+
if objects_df is None:
|
|
492
|
+
continue
|
|
493
|
+
for (task_name,), object_group in objects_df.group_by(FieldName.TASK_NAME):
|
|
494
|
+
count = len(object_group[FieldName.CLASS_NAME])
|
|
495
|
+
row[f"{PrimitiveType.__name__}\n{task_name}"] = str(count)
|
|
496
|
+
rows.append(row)
|
|
497
|
+
|
|
498
|
+
rich_table = Table(title="Dataset Statistics", show_lines=True, box=rich.box.SIMPLE)
|
|
499
|
+
for i_row, row in enumerate(rows):
|
|
500
|
+
if i_row == 0:
|
|
501
|
+
for column_name in row.keys():
|
|
502
|
+
rich_table.add_column(column_name, justify="left", style="cyan")
|
|
503
|
+
rich_table.add_row(*[str(value) for value in row.values()])
|
|
504
|
+
rprint(rich_table)
|
|
505
|
+
|
|
506
|
+
|
|
507
|
+
def check_hafnia_dataset_from_path(path_dataset: Path) -> None:
|
|
508
|
+
dataset = HafniaDataset.from_path(path_dataset, check_for_images=True)
|
|
509
|
+
check_hafnia_dataset(dataset)
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
def get_or_create_dataset_path_from_recipe(
|
|
513
|
+
dataset_recipe: Any,
|
|
514
|
+
force_redownload: bool = False,
|
|
515
|
+
path_datasets: Optional[Union[Path, str]] = None,
|
|
516
|
+
) -> Path:
|
|
517
|
+
from hafnia.dataset.dataset_recipe.dataset_recipe import (
|
|
518
|
+
DatasetRecipe,
|
|
519
|
+
get_dataset_path_from_recipe,
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
recipe: DatasetRecipe = DatasetRecipe.from_implicit_form(dataset_recipe)
|
|
523
|
+
path_dataset = get_dataset_path_from_recipe(recipe, path_datasets=path_datasets)
|
|
524
|
+
|
|
525
|
+
if force_redownload:
|
|
526
|
+
shutil.rmtree(path_dataset, ignore_errors=True)
|
|
527
|
+
|
|
528
|
+
if HafniaDataset.check_dataset_path(path_dataset, raise_error=False):
|
|
529
|
+
return path_dataset
|
|
530
|
+
|
|
531
|
+
path_dataset.mkdir(parents=True, exist_ok=True)
|
|
532
|
+
path_recipe_json = path_dataset / FILENAME_RECIPE_JSON
|
|
533
|
+
path_recipe_json.write_text(recipe.model_dump_json(indent=4))
|
|
534
|
+
|
|
535
|
+
dataset: HafniaDataset = recipe.build()
|
|
536
|
+
dataset.write(path_dataset)
|
|
537
|
+
|
|
538
|
+
return path_dataset
|
|
539
|
+
|
|
540
|
+
|
|
541
|
+
def check_hafnia_dataset(dataset: HafniaDataset):
|
|
542
|
+
user_logger.info("Checking Hafnia dataset...")
|
|
543
|
+
assert isinstance(dataset.info.version, str) and len(dataset.info.version) > 0
|
|
544
|
+
assert isinstance(dataset.info.dataset_name, str) and len(dataset.info.dataset_name) > 0
|
|
545
|
+
|
|
546
|
+
is_sample_list = set(dataset.samples.select(pl.col(ColumnName.IS_SAMPLE)).unique().to_series().to_list())
|
|
547
|
+
if True not in is_sample_list:
|
|
548
|
+
raise ValueError(f"The dataset should contain '{ColumnName.IS_SAMPLE}=True' samples")
|
|
549
|
+
|
|
550
|
+
actual_splits = dataset.samples.select(pl.col(ColumnName.SPLIT)).unique().to_series().to_list()
|
|
551
|
+
expected_splits = SplitName.valid_splits()
|
|
552
|
+
if set(actual_splits) != set(expected_splits):
|
|
553
|
+
raise ValueError(f"Expected all splits '{expected_splits}' in dataset, but got '{actual_splits}'. ")
|
|
554
|
+
|
|
555
|
+
expected_tasks = dataset.info.tasks
|
|
556
|
+
for task in expected_tasks:
|
|
557
|
+
primitive = task.primitive.__name__
|
|
558
|
+
column_name = task.primitive.column_name()
|
|
559
|
+
primitive_column = dataset.samples[column_name]
|
|
560
|
+
# msg_something_wrong = f"Something is wrong with the '{primtive_name}' task '{task.name}' in dataset '{dataset.name}'. "
|
|
561
|
+
msg_something_wrong = (
|
|
562
|
+
f"Something is wrong with the defined tasks ('info.tasks') in dataset '{dataset.info.dataset_name}'. \n"
|
|
563
|
+
f"For '{primitive=}' and '{task.name=}' "
|
|
564
|
+
)
|
|
565
|
+
if primitive_column.dtype == pl.Null:
|
|
566
|
+
raise ValueError(msg_something_wrong + "the column is 'Null'. Please check the dataset.")
|
|
567
|
+
|
|
568
|
+
primitive_table = primitive_column.explode().struct.unnest().filter(pl.col(FieldName.TASK_NAME) == task.name)
|
|
569
|
+
if primitive_table.is_empty():
|
|
570
|
+
raise ValueError(
|
|
571
|
+
msg_something_wrong
|
|
572
|
+
+ f"the column '{column_name}' has no {task.name=} objects. Please check the dataset."
|
|
573
|
+
)
|
|
574
|
+
|
|
575
|
+
actual_classes = set(primitive_table[FieldName.CLASS_NAME].unique().to_list())
|
|
576
|
+
if task.class_names is None:
|
|
577
|
+
raise ValueError(
|
|
578
|
+
msg_something_wrong
|
|
579
|
+
+ f"the column '{column_name}' with {task.name=} has no defined classes. Please check the dataset."
|
|
580
|
+
)
|
|
581
|
+
defined_classes = set(task.class_names)
|
|
582
|
+
|
|
583
|
+
if not actual_classes.issubset(defined_classes):
|
|
584
|
+
raise ValueError(
|
|
585
|
+
msg_something_wrong
|
|
586
|
+
+ f"the column '{column_name}' with {task.name=} we expected the actual classes in the dataset to \n"
|
|
587
|
+
f"to be a subset of the defined classes\n\t{actual_classes=} \n\t{defined_classes=}."
|
|
588
|
+
)
|
|
589
|
+
# Check class_indices
|
|
590
|
+
mapped_indices = primitive_table[FieldName.CLASS_NAME].map_elements(
|
|
591
|
+
lambda x: task.class_names.index(x), return_dtype=pl.Int64
|
|
592
|
+
)
|
|
593
|
+
table_indices = primitive_table[FieldName.CLASS_IDX]
|
|
594
|
+
|
|
595
|
+
error_msg = msg_something_wrong + (
|
|
596
|
+
f"class indices in '{FieldName.CLASS_IDX}' column does not match classes ordering in 'task.class_names'"
|
|
597
|
+
)
|
|
598
|
+
assert mapped_indices.equals(table_indices), error_msg
|
|
599
|
+
|
|
600
|
+
distribution = dataset.info.distributions or []
|
|
601
|
+
distribution_names = [task.name for task in distribution]
|
|
602
|
+
# Check that tasks found in the 'dataset.table' matches the tasks defined in 'dataset.info.tasks'
|
|
603
|
+
for PrimitiveType in PRIMITIVE_TYPES:
|
|
604
|
+
column_name = PrimitiveType.column_name()
|
|
605
|
+
if column_name not in dataset.samples.columns:
|
|
606
|
+
continue
|
|
607
|
+
objects_df = create_primitive_table(dataset.samples, PrimitiveType=PrimitiveType, keep_sample_data=False)
|
|
608
|
+
if objects_df is None:
|
|
609
|
+
continue
|
|
610
|
+
for (task_name,), object_group in objects_df.group_by(FieldName.TASK_NAME):
|
|
611
|
+
has_task = any([t for t in expected_tasks if t.name == task_name and t.primitive == PrimitiveType])
|
|
612
|
+
if has_task:
|
|
613
|
+
continue
|
|
614
|
+
if task_name in distribution_names:
|
|
615
|
+
continue
|
|
616
|
+
class_names = object_group[FieldName.CLASS_NAME].unique().to_list()
|
|
617
|
+
raise ValueError(
|
|
618
|
+
f"Task name '{task_name}' for the '{PrimitiveType.__name__}' primitive is missing in "
|
|
619
|
+
f"'dataset.info.tasks' for dataset '{task_name}'. Missing task has the following "
|
|
620
|
+
f"classes: {class_names}. "
|
|
621
|
+
)
|
|
622
|
+
|
|
623
|
+
for sample_dict in tqdm(dataset, desc="Checking samples in dataset"):
|
|
624
|
+
sample = Sample(**sample_dict) # Checks format of all samples with pydantic validation # noqa: F841
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Dict
|
|
4
|
+
|
|
5
|
+
from hafnia.dataset.dataset_names import ColumnName
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from hafnia.dataset.hafnia_dataset import HafniaDataset
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def split_counts(dataset: HafniaDataset) -> Dict[str, int]:
|
|
12
|
+
"""
|
|
13
|
+
Returns a dictionary with the counts of samples in each split of the dataset.
|
|
14
|
+
"""
|
|
15
|
+
return dict(dataset.samples[ColumnName.SPLIT].value_counts().iter_rows())
|