hafnia 0.2.4__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/__main__.py +16 -3
- cli/config.py +45 -4
- cli/consts.py +1 -1
- cli/dataset_cmds.py +6 -14
- cli/dataset_recipe_cmds.py +78 -0
- cli/experiment_cmds.py +226 -43
- cli/keychain.py +88 -0
- cli/profile_cmds.py +10 -6
- cli/runc_cmds.py +5 -5
- cli/trainer_package_cmds.py +65 -0
- hafnia/__init__.py +2 -0
- hafnia/data/factory.py +1 -2
- hafnia/dataset/dataset_helpers.py +9 -14
- hafnia/dataset/dataset_names.py +10 -5
- hafnia/dataset/dataset_recipe/dataset_recipe.py +165 -67
- hafnia/dataset/dataset_recipe/recipe_transforms.py +48 -4
- hafnia/dataset/dataset_recipe/recipe_types.py +1 -1
- hafnia/dataset/dataset_upload_helper.py +265 -56
- hafnia/dataset/format_conversions/image_classification_from_directory.py +106 -0
- hafnia/dataset/format_conversions/torchvision_datasets.py +281 -0
- hafnia/dataset/hafnia_dataset.py +577 -213
- hafnia/dataset/license_types.py +63 -0
- hafnia/dataset/operations/dataset_stats.py +259 -3
- hafnia/dataset/operations/dataset_transformations.py +332 -7
- hafnia/dataset/operations/table_transformations.py +43 -5
- hafnia/dataset/primitives/__init__.py +8 -0
- hafnia/dataset/primitives/bbox.py +25 -12
- hafnia/dataset/primitives/bitmask.py +26 -14
- hafnia/dataset/primitives/classification.py +16 -8
- hafnia/dataset/primitives/point.py +7 -3
- hafnia/dataset/primitives/polygon.py +16 -9
- hafnia/dataset/primitives/segmentation.py +10 -7
- hafnia/experiment/hafnia_logger.py +111 -8
- hafnia/http.py +16 -2
- hafnia/platform/__init__.py +9 -3
- hafnia/platform/builder.py +12 -10
- hafnia/platform/dataset_recipe.py +104 -0
- hafnia/platform/datasets.py +47 -9
- hafnia/platform/download.py +25 -19
- hafnia/platform/experiment.py +51 -56
- hafnia/platform/trainer_package.py +57 -0
- hafnia/utils.py +81 -13
- hafnia/visualizations/image_visualizations.py +4 -4
- {hafnia-0.2.4.dist-info → hafnia-0.4.0.dist-info}/METADATA +40 -34
- hafnia-0.4.0.dist-info/RECORD +56 -0
- cli/recipe_cmds.py +0 -45
- hafnia-0.2.4.dist-info/RECORD +0 -49
- {hafnia-0.2.4.dist-info → hafnia-0.4.0.dist-info}/WHEEL +0 -0
- {hafnia-0.2.4.dist-info → hafnia-0.4.0.dist-info}/entry_points.txt +0 -0
- {hafnia-0.2.4.dist-info → hafnia-0.4.0.dist-info}/licenses/LICENSE +0 -0
hafnia/dataset/hafnia_dataset.py
CHANGED
|
@@ -1,21 +1,24 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import collections
|
|
4
|
+
import copy
|
|
5
|
+
import json
|
|
3
6
|
import shutil
|
|
4
7
|
from dataclasses import dataclass
|
|
8
|
+
from datetime import datetime
|
|
5
9
|
from pathlib import Path
|
|
6
10
|
from random import Random
|
|
7
|
-
from typing import Any, Dict, List, Optional, Type, Union
|
|
11
|
+
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
|
8
12
|
|
|
9
13
|
import more_itertools
|
|
10
14
|
import numpy as np
|
|
11
15
|
import polars as pl
|
|
12
|
-
import
|
|
16
|
+
from packaging.version import Version
|
|
13
17
|
from PIL import Image
|
|
14
|
-
from pydantic import BaseModel, field_serializer, field_validator
|
|
15
|
-
from rich import
|
|
16
|
-
from rich.table import Table
|
|
17
|
-
from tqdm import tqdm
|
|
18
|
+
from pydantic import BaseModel, Field, field_serializer, field_validator
|
|
19
|
+
from rich.progress import track
|
|
18
20
|
|
|
21
|
+
import hafnia
|
|
19
22
|
from hafnia.dataset import dataset_helpers
|
|
20
23
|
from hafnia.dataset.dataset_names import (
|
|
21
24
|
DATASET_FILENAMES_REQUIRED,
|
|
@@ -23,20 +26,20 @@ from hafnia.dataset.dataset_names import (
|
|
|
23
26
|
FILENAME_ANNOTATIONS_PARQUET,
|
|
24
27
|
FILENAME_DATASET_INFO,
|
|
25
28
|
FILENAME_RECIPE_JSON,
|
|
29
|
+
TAG_IS_SAMPLE,
|
|
26
30
|
ColumnName,
|
|
27
|
-
FieldName,
|
|
28
31
|
SplitName,
|
|
29
32
|
)
|
|
30
|
-
from hafnia.dataset.operations import
|
|
33
|
+
from hafnia.dataset.operations import (
|
|
34
|
+
dataset_stats,
|
|
35
|
+
dataset_transformations,
|
|
36
|
+
table_transformations,
|
|
37
|
+
)
|
|
31
38
|
from hafnia.dataset.operations.table_transformations import (
|
|
32
39
|
check_image_paths,
|
|
33
|
-
|
|
34
|
-
read_table_from_path,
|
|
35
|
-
)
|
|
36
|
-
from hafnia.dataset.primitives import (
|
|
37
|
-
PRIMITIVE_NAME_TO_TYPE,
|
|
38
|
-
PRIMITIVE_TYPES,
|
|
40
|
+
read_samples_from_path,
|
|
39
41
|
)
|
|
42
|
+
from hafnia.dataset.primitives import PRIMITIVE_TYPES, get_primitive_type_from_string
|
|
40
43
|
from hafnia.dataset.primitives.bbox import Bbox
|
|
41
44
|
from hafnia.dataset.primitives.bitmask import Bitmask
|
|
42
45
|
from hafnia.dataset.primitives.classification import Classification
|
|
@@ -46,10 +49,16 @@ from hafnia.log import user_logger
|
|
|
46
49
|
|
|
47
50
|
|
|
48
51
|
class TaskInfo(BaseModel):
|
|
49
|
-
primitive: Type[Primitive]
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
52
|
+
primitive: Type[Primitive] = Field(
|
|
53
|
+
description="Primitive class or string name of the primitive, e.g. 'Bbox' or 'bitmask'"
|
|
54
|
+
)
|
|
55
|
+
class_names: Optional[List[str]] = Field(default=None, description="Optional list of class names for the primitive")
|
|
56
|
+
name: Optional[str] = Field(
|
|
57
|
+
default=None,
|
|
58
|
+
description=(
|
|
59
|
+
"Optional name for the task. 'None' will use default name of the provided primitive. "
|
|
60
|
+
"e.g. Bbox ->'bboxes', Bitmask -> 'bitmasks' etc."
|
|
61
|
+
),
|
|
53
62
|
)
|
|
54
63
|
|
|
55
64
|
def model_post_init(self, __context: Any) -> None:
|
|
@@ -64,12 +73,7 @@ class TaskInfo(BaseModel):
|
|
|
64
73
|
@classmethod
|
|
65
74
|
def ensure_primitive(cls, primitive: Any) -> Any:
|
|
66
75
|
if isinstance(primitive, str):
|
|
67
|
-
|
|
68
|
-
raise ValueError(
|
|
69
|
-
f"Primitive '{primitive}' is not recognized. Available primitives: {list(PRIMITIVE_NAME_TO_TYPE.keys())}"
|
|
70
|
-
)
|
|
71
|
-
|
|
72
|
-
return PRIMITIVE_NAME_TO_TYPE[primitive]
|
|
76
|
+
return get_primitive_type_from_string(primitive)
|
|
73
77
|
|
|
74
78
|
if issubclass(primitive, Primitive):
|
|
75
79
|
return primitive
|
|
@@ -83,40 +87,273 @@ class TaskInfo(BaseModel):
|
|
|
83
87
|
raise ValueError(f"Primitive must be a subclass of Primitive, got {type(primitive)} instead.")
|
|
84
88
|
return primitive.__name__
|
|
85
89
|
|
|
90
|
+
@field_validator("class_names", mode="after")
|
|
91
|
+
@classmethod
|
|
92
|
+
def validate_unique_class_names(cls, class_names: Optional[List[str]]) -> Optional[List[str]]:
|
|
93
|
+
"""Validate that class names are unique"""
|
|
94
|
+
if class_names is None:
|
|
95
|
+
return None
|
|
96
|
+
duplicate_class_names = set([name for name in class_names if class_names.count(name) > 1])
|
|
97
|
+
if duplicate_class_names:
|
|
98
|
+
raise ValueError(
|
|
99
|
+
f"Class names must be unique. The following class names appear multiple times: {duplicate_class_names}."
|
|
100
|
+
)
|
|
101
|
+
return class_names
|
|
102
|
+
|
|
103
|
+
# To get unique hash value for TaskInfo objects
|
|
104
|
+
def __hash__(self) -> int:
|
|
105
|
+
class_names = self.class_names or []
|
|
106
|
+
return hash((self.name, self.primitive.__name__, tuple(class_names)))
|
|
107
|
+
|
|
108
|
+
def __eq__(self, other: Any) -> bool:
|
|
109
|
+
if not isinstance(other, TaskInfo):
|
|
110
|
+
return False
|
|
111
|
+
return self.name == other.name and self.primitive == other.primitive and self.class_names == other.class_names
|
|
112
|
+
|
|
86
113
|
|
|
87
114
|
class DatasetInfo(BaseModel):
|
|
88
|
-
dataset_name: str
|
|
89
|
-
version: str
|
|
90
|
-
tasks:
|
|
91
|
-
distributions: Optional[List[TaskInfo]] = None
|
|
92
|
-
|
|
115
|
+
dataset_name: str = Field(description="Name of the dataset, e.g. 'coco'")
|
|
116
|
+
version: Optional[str] = Field(default=None, description="Version of the dataset")
|
|
117
|
+
tasks: List[TaskInfo] = Field(default=None, description="List of tasks in the dataset")
|
|
118
|
+
distributions: Optional[List[TaskInfo]] = Field(default=None, description="Optional list of task distributions")
|
|
119
|
+
reference_bibtex: Optional[str] = Field(
|
|
120
|
+
default=None,
|
|
121
|
+
description="Optional, BibTeX reference to dataset publication",
|
|
122
|
+
)
|
|
123
|
+
reference_paper_url: Optional[str] = Field(
|
|
124
|
+
default=None,
|
|
125
|
+
description="Optional, URL to dataset publication",
|
|
126
|
+
)
|
|
127
|
+
reference_dataset_page: Optional[str] = Field(
|
|
128
|
+
default=None,
|
|
129
|
+
description="Optional, URL to the dataset page",
|
|
130
|
+
)
|
|
131
|
+
meta: Optional[Dict[str, Any]] = Field(default=None, description="Optional metadata about the dataset")
|
|
132
|
+
format_version: str = Field(
|
|
133
|
+
default=hafnia.__dataset_format_version__,
|
|
134
|
+
description="Version of the Hafnia dataset format. You should not set this manually.",
|
|
135
|
+
)
|
|
136
|
+
updated_at: datetime = Field(
|
|
137
|
+
default_factory=datetime.now,
|
|
138
|
+
description="Timestamp of the last update to the dataset info. You should not set this manually.",
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
@field_validator("tasks", mode="after")
|
|
142
|
+
@classmethod
|
|
143
|
+
def _validate_check_for_duplicate_tasks(cls, tasks: Optional[List[TaskInfo]]) -> List[TaskInfo]:
|
|
144
|
+
if tasks is None:
|
|
145
|
+
return []
|
|
146
|
+
task_name_counts = collections.Counter(task.name for task in tasks)
|
|
147
|
+
duplicate_task_names = [name for name, count in task_name_counts.items() if count > 1]
|
|
148
|
+
if duplicate_task_names:
|
|
149
|
+
raise ValueError(
|
|
150
|
+
f"Tasks must be unique. The following tasks appear multiple times: {duplicate_task_names}."
|
|
151
|
+
)
|
|
152
|
+
return tasks
|
|
153
|
+
|
|
154
|
+
@field_validator("format_version")
|
|
155
|
+
@classmethod
|
|
156
|
+
def _validate_format_version(cls, format_version: str) -> str:
|
|
157
|
+
try:
|
|
158
|
+
Version(format_version)
|
|
159
|
+
except Exception as e:
|
|
160
|
+
raise ValueError(f"Invalid format_version '{format_version}'. Must be a valid version string.") from e
|
|
161
|
+
|
|
162
|
+
if Version(format_version) > Version(hafnia.__dataset_format_version__):
|
|
163
|
+
user_logger.warning(
|
|
164
|
+
f"The loaded dataset format version '{format_version}' is newer than the format version "
|
|
165
|
+
f"'{hafnia.__dataset_format_version__}' used in your version of Hafnia. Please consider "
|
|
166
|
+
f"updating Hafnia package."
|
|
167
|
+
)
|
|
168
|
+
return format_version
|
|
169
|
+
|
|
170
|
+
@field_validator("version")
|
|
171
|
+
@classmethod
|
|
172
|
+
def _validate_version(cls, dataset_version: Optional[str]) -> Optional[str]:
|
|
173
|
+
if dataset_version is None:
|
|
174
|
+
return None
|
|
175
|
+
|
|
176
|
+
try:
|
|
177
|
+
Version(dataset_version)
|
|
178
|
+
except Exception as e:
|
|
179
|
+
raise ValueError(f"Invalid dataset_version '{dataset_version}'. Must be a valid version string.") from e
|
|
180
|
+
|
|
181
|
+
return dataset_version
|
|
182
|
+
|
|
183
|
+
def check_for_duplicate_task_names(self) -> List[TaskInfo]:
|
|
184
|
+
return self._validate_check_for_duplicate_tasks(self.tasks)
|
|
93
185
|
|
|
94
186
|
def write_json(self, path: Path, indent: Optional[int] = 4) -> None:
|
|
95
187
|
json_str = self.model_dump_json(indent=indent)
|
|
96
188
|
path.write_text(json_str)
|
|
97
189
|
|
|
98
190
|
@staticmethod
|
|
99
|
-
def from_json_file(path: Path) ->
|
|
191
|
+
def from_json_file(path: Path) -> DatasetInfo:
|
|
100
192
|
json_str = path.read_text()
|
|
101
|
-
|
|
193
|
+
|
|
194
|
+
# TODO: Deprecated support for old dataset info without format_version
|
|
195
|
+
# Below 4 lines can be replaced by 'dataset_info = DatasetInfo.model_validate_json(json_str)'
|
|
196
|
+
# when all datasets include a 'format_version' field
|
|
197
|
+
json_dict = json.loads(json_str)
|
|
198
|
+
if "format_version" not in json_dict:
|
|
199
|
+
json_dict["format_version"] = "0.0.0"
|
|
200
|
+
|
|
201
|
+
if "updated_at" not in json_dict:
|
|
202
|
+
json_dict["updated_at"] = datetime.min.isoformat()
|
|
203
|
+
dataset_info = DatasetInfo.model_validate(json_dict)
|
|
204
|
+
|
|
205
|
+
return dataset_info
|
|
206
|
+
|
|
207
|
+
@staticmethod
|
|
208
|
+
def merge(info0: DatasetInfo, info1: DatasetInfo) -> DatasetInfo:
|
|
209
|
+
"""
|
|
210
|
+
Merges two DatasetInfo objects into one and validates if they are compatible.
|
|
211
|
+
"""
|
|
212
|
+
for task_ds0 in info0.tasks:
|
|
213
|
+
for task_ds1 in info1.tasks:
|
|
214
|
+
same_name = task_ds0.name == task_ds1.name
|
|
215
|
+
same_primitive = task_ds0.primitive == task_ds1.primitive
|
|
216
|
+
same_name_different_primitive = same_name and not same_primitive
|
|
217
|
+
if same_name_different_primitive:
|
|
218
|
+
raise ValueError(
|
|
219
|
+
f"Cannot merge datasets with different primitives for the same task name: "
|
|
220
|
+
f"'{task_ds0.name}' has primitive '{task_ds0.primitive}' in dataset0 and "
|
|
221
|
+
f"'{task_ds1.primitive}' in dataset1."
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
is_same_name_and_primitive = same_name and same_primitive
|
|
225
|
+
if is_same_name_and_primitive:
|
|
226
|
+
task_ds0_class_names = task_ds0.class_names or []
|
|
227
|
+
task_ds1_class_names = task_ds1.class_names or []
|
|
228
|
+
if task_ds0_class_names != task_ds1_class_names:
|
|
229
|
+
raise ValueError(
|
|
230
|
+
f"Cannot merge datasets with different class names for the same task name and primitive: "
|
|
231
|
+
f"'{task_ds0.name}' with primitive '{task_ds0.primitive}' has class names "
|
|
232
|
+
f"{task_ds0_class_names} in dataset0 and {task_ds1_class_names} in dataset1."
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
if info1.format_version != info0.format_version:
|
|
236
|
+
user_logger.warning(
|
|
237
|
+
"Dataset format version of the two datasets do not match. "
|
|
238
|
+
f"'{info1.format_version}' vs '{info0.format_version}'."
|
|
239
|
+
)
|
|
240
|
+
dataset_format_version = info0.format_version
|
|
241
|
+
if hafnia.__dataset_format_version__ != dataset_format_version:
|
|
242
|
+
user_logger.warning(
|
|
243
|
+
f"Dataset format version '{dataset_format_version}' does not match the current "
|
|
244
|
+
f"Hafnia format version '{hafnia.__dataset_format_version__}'."
|
|
245
|
+
)
|
|
246
|
+
unique_tasks = set(info0.tasks + info1.tasks)
|
|
247
|
+
distributions = set((info0.distributions or []) + (info1.distributions or []))
|
|
248
|
+
meta = (info0.meta or {}).copy()
|
|
249
|
+
meta.update(info1.meta or {})
|
|
250
|
+
return DatasetInfo(
|
|
251
|
+
dataset_name=info0.dataset_name + "+" + info1.dataset_name,
|
|
252
|
+
version=None,
|
|
253
|
+
tasks=list(unique_tasks),
|
|
254
|
+
distributions=list(distributions),
|
|
255
|
+
meta=meta,
|
|
256
|
+
format_version=dataset_format_version,
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
def get_task_by_name(self, task_name: str) -> TaskInfo:
|
|
260
|
+
"""
|
|
261
|
+
Get task by its name. Raises an error if the task name is not found or if multiple tasks have the same name.
|
|
262
|
+
"""
|
|
263
|
+
tasks_with_name = [task for task in self.tasks if task.name == task_name]
|
|
264
|
+
if not tasks_with_name:
|
|
265
|
+
raise ValueError(f"Task with name '{task_name}' not found in dataset info.")
|
|
266
|
+
if len(tasks_with_name) > 1:
|
|
267
|
+
raise ValueError(f"Multiple tasks found with name '{task_name}'. This should not happen!")
|
|
268
|
+
return tasks_with_name[0]
|
|
269
|
+
|
|
270
|
+
def get_task_by_primitive(self, primitive: Union[Type[Primitive], str]) -> TaskInfo:
|
|
271
|
+
"""
|
|
272
|
+
Get task by its primitive type. Raises an error if the primitive type is not found or if multiple tasks
|
|
273
|
+
have the same primitive type.
|
|
274
|
+
"""
|
|
275
|
+
if isinstance(primitive, str):
|
|
276
|
+
primitive = get_primitive_type_from_string(primitive)
|
|
277
|
+
|
|
278
|
+
tasks_with_primitive = [task for task in self.tasks if task.primitive == primitive]
|
|
279
|
+
if not tasks_with_primitive:
|
|
280
|
+
raise ValueError(f"Task with primitive {primitive} not found in dataset info.")
|
|
281
|
+
if len(tasks_with_primitive) > 1:
|
|
282
|
+
raise ValueError(
|
|
283
|
+
f"Multiple tasks found with primitive {primitive}. Use '{self.get_task_by_name.__name__}' instead."
|
|
284
|
+
)
|
|
285
|
+
return tasks_with_primitive[0]
|
|
286
|
+
|
|
287
|
+
def get_task_by_task_name_and_primitive(
|
|
288
|
+
self,
|
|
289
|
+
task_name: Optional[str],
|
|
290
|
+
primitive: Optional[Union[Type[Primitive], str]],
|
|
291
|
+
) -> TaskInfo:
|
|
292
|
+
"""
|
|
293
|
+
Logic to get a unique task based on the provided 'task_name' and/or 'primitive'.
|
|
294
|
+
If both 'task_name' and 'primitive' are None, the dataset must have only one task.
|
|
295
|
+
"""
|
|
296
|
+
task = dataset_transformations.get_task_info_from_task_name_and_primitive(
|
|
297
|
+
tasks=self.tasks,
|
|
298
|
+
primitive=primitive,
|
|
299
|
+
task_name=task_name,
|
|
300
|
+
)
|
|
301
|
+
return task
|
|
302
|
+
|
|
303
|
+
def replace_task(self, old_task: TaskInfo, new_task: Optional[TaskInfo]) -> DatasetInfo:
|
|
304
|
+
dataset_info = self.model_copy(deep=True)
|
|
305
|
+
has_task = any(t for t in dataset_info.tasks if t.name == old_task.name and t.primitive == old_task.primitive)
|
|
306
|
+
if not has_task:
|
|
307
|
+
raise ValueError(f"Task '{old_task.__repr__()}' not found in dataset info.")
|
|
308
|
+
|
|
309
|
+
new_tasks = []
|
|
310
|
+
for task in dataset_info.tasks:
|
|
311
|
+
if task.name == old_task.name and task.primitive == old_task.primitive:
|
|
312
|
+
if new_task is None:
|
|
313
|
+
continue # Remove the task
|
|
314
|
+
new_tasks.append(new_task)
|
|
315
|
+
else:
|
|
316
|
+
new_tasks.append(task)
|
|
317
|
+
|
|
318
|
+
dataset_info.tasks = new_tasks
|
|
319
|
+
return dataset_info
|
|
102
320
|
|
|
103
321
|
|
|
104
322
|
class Sample(BaseModel):
|
|
105
|
-
|
|
106
|
-
height: int
|
|
107
|
-
width: int
|
|
108
|
-
split: str
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
323
|
+
file_path: str = Field(description="Path to the image file")
|
|
324
|
+
height: int = Field(description="Height of the image")
|
|
325
|
+
width: int = Field(description="Width of the image")
|
|
326
|
+
split: str = Field(description="Split name, e.g., 'train', 'val', 'test'")
|
|
327
|
+
tags: List[str] = Field(
|
|
328
|
+
default_factory=list,
|
|
329
|
+
description="Tags for a given sample. Used for creating subsets of the dataset.",
|
|
330
|
+
)
|
|
331
|
+
collection_index: Optional[int] = Field(default=None, description="Optional e.g. frame number for video datasets")
|
|
332
|
+
collection_id: Optional[str] = Field(default=None, description="Optional e.g. video name for video datasets")
|
|
333
|
+
remote_path: Optional[str] = Field(default=None, description="Optional remote path for the image, if applicable")
|
|
334
|
+
sample_index: Optional[int] = Field(
|
|
335
|
+
default=None,
|
|
336
|
+
description="Don't manually set this, it is used for indexing samples in the dataset.",
|
|
337
|
+
)
|
|
338
|
+
classifications: Optional[List[Classification]] = Field(
|
|
339
|
+
default=None, description="Optional list of classifications"
|
|
340
|
+
)
|
|
341
|
+
objects: Optional[List[Bbox]] = Field(default=None, description="Optional list of objects (bounding boxes)")
|
|
342
|
+
bitmasks: Optional[List[Bitmask]] = Field(default=None, description="Optional list of bitmasks")
|
|
343
|
+
polygons: Optional[List[Polygon]] = Field(default=None, description="Optional list of polygons")
|
|
344
|
+
|
|
345
|
+
attribution: Optional[Attribution] = Field(default=None, description="Attribution information for the image")
|
|
346
|
+
dataset_name: Optional[str] = Field(
|
|
347
|
+
default=None,
|
|
348
|
+
description=(
|
|
349
|
+
"Don't manually set this, it will be automatically defined during initialization. "
|
|
350
|
+
"Name of the dataset the sample belongs to. E.g. 'coco-2017' or 'midwest-vehicle-detection'."
|
|
351
|
+
),
|
|
352
|
+
)
|
|
353
|
+
meta: Optional[Dict] = Field(
|
|
354
|
+
default=None,
|
|
355
|
+
description="Additional metadata, e.g., camera settings, GPS data, etc.",
|
|
356
|
+
)
|
|
120
357
|
|
|
121
358
|
def get_annotations(self, primitive_types: Optional[List[Type[Primitive]]] = None) -> List[Primitive]:
|
|
122
359
|
"""
|
|
@@ -137,7 +374,7 @@ class Sample(BaseModel):
|
|
|
137
374
|
Reads the image from the file path and returns it as a PIL Image.
|
|
138
375
|
Raises FileNotFoundError if the image file does not exist.
|
|
139
376
|
"""
|
|
140
|
-
path_image = Path(self.
|
|
377
|
+
path_image = Path(self.file_path)
|
|
141
378
|
if not path_image.exists():
|
|
142
379
|
raise FileNotFoundError(f"Image file {path_image} does not exist. Please check the file path.")
|
|
143
380
|
|
|
@@ -158,11 +395,93 @@ class Sample(BaseModel):
|
|
|
158
395
|
return annotations_visualized
|
|
159
396
|
|
|
160
397
|
|
|
398
|
+
class License(BaseModel):
|
|
399
|
+
"""License information"""
|
|
400
|
+
|
|
401
|
+
name: Optional[str] = Field(
|
|
402
|
+
default=None,
|
|
403
|
+
description="License name. E.g. 'Creative Commons: Attribution 2.0 Generic'",
|
|
404
|
+
max_length=100,
|
|
405
|
+
)
|
|
406
|
+
name_short: Optional[str] = Field(
|
|
407
|
+
default=None,
|
|
408
|
+
description="License short name or abbreviation. E.g. 'CC BY 4.0'",
|
|
409
|
+
max_length=100,
|
|
410
|
+
)
|
|
411
|
+
url: Optional[str] = Field(
|
|
412
|
+
default=None,
|
|
413
|
+
description="License URL e.g. https://creativecommons.org/licenses/by/4.0/",
|
|
414
|
+
)
|
|
415
|
+
description: Optional[str] = Field(
|
|
416
|
+
default=None,
|
|
417
|
+
description=(
|
|
418
|
+
"License description e.g. 'You must give appropriate credit, provide a "
|
|
419
|
+
"link to the license, and indicate if changes were made.'"
|
|
420
|
+
),
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
valid_date: Optional[datetime] = Field(
|
|
424
|
+
default=None,
|
|
425
|
+
description="License valid date. E.g. '2023-01-01T00:00:00Z'",
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
permissions: Optional[List[str]] = Field(
|
|
429
|
+
default=None,
|
|
430
|
+
description="License permissions. Allowed to Access, Label, Distribute, Represent and Modify data.",
|
|
431
|
+
)
|
|
432
|
+
liability: Optional[str] = Field(
|
|
433
|
+
default=None,
|
|
434
|
+
description="License liability. Optional and not always applicable.",
|
|
435
|
+
)
|
|
436
|
+
location: Optional[str] = Field(
|
|
437
|
+
default=None,
|
|
438
|
+
description=(
|
|
439
|
+
"License Location. E.g. Iowa state. This is essential to understand the industry and "
|
|
440
|
+
"privacy location specific rules that applies to the data. Optional and not always applicable."
|
|
441
|
+
),
|
|
442
|
+
)
|
|
443
|
+
notes: Optional[str] = Field(
|
|
444
|
+
default=None,
|
|
445
|
+
description="Additional license notes. Optional and not always applicable.",
|
|
446
|
+
)
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
class Attribution(BaseModel):
|
|
450
|
+
"""Attribution information for the image: Giving source and credit to the original creator"""
|
|
451
|
+
|
|
452
|
+
title: Optional[str] = Field(default=None, description="Title of the image", max_length=255)
|
|
453
|
+
creator: Optional[str] = Field(default=None, description="Creator of the image", max_length=255)
|
|
454
|
+
creator_url: Optional[str] = Field(default=None, description="URL of the creator", max_length=255)
|
|
455
|
+
date_captured: Optional[datetime] = Field(default=None, description="Date when the image was captured")
|
|
456
|
+
copyright_notice: Optional[str] = Field(default=None, description="Copyright notice for the image", max_length=255)
|
|
457
|
+
licenses: Optional[List[License]] = Field(default=None, description="List of licenses for the image")
|
|
458
|
+
disclaimer: Optional[str] = Field(default=None, description="Disclaimer for the image", max_length=255)
|
|
459
|
+
changes: Optional[str] = Field(default=None, description="Changes made to the image", max_length=255)
|
|
460
|
+
source_url: Optional[str] = Field(default=None, description="Source URL for the image", max_length=255)
|
|
461
|
+
|
|
462
|
+
|
|
161
463
|
@dataclass
|
|
162
464
|
class HafniaDataset:
|
|
163
465
|
info: DatasetInfo
|
|
164
466
|
samples: pl.DataFrame
|
|
165
467
|
|
|
468
|
+
# Function mapping: Dataset stats
|
|
469
|
+
split_counts = dataset_stats.split_counts
|
|
470
|
+
class_counts_for_task = dataset_stats.class_counts_for_task
|
|
471
|
+
class_counts_all = dataset_stats.class_counts_all
|
|
472
|
+
|
|
473
|
+
# Function mapping: Print stats
|
|
474
|
+
print_stats = dataset_stats.print_stats
|
|
475
|
+
print_sample_and_task_counts = dataset_stats.print_sample_and_task_counts
|
|
476
|
+
print_class_distribution = dataset_stats.print_class_distribution
|
|
477
|
+
|
|
478
|
+
# Function mapping: Dataset checks
|
|
479
|
+
check_dataset = dataset_stats.check_dataset
|
|
480
|
+
check_dataset_tasks = dataset_stats.check_dataset_tasks
|
|
481
|
+
|
|
482
|
+
# Function mapping: Dataset transformations
|
|
483
|
+
transform_images = dataset_transformations.transform_images
|
|
484
|
+
|
|
166
485
|
def __getitem__(self, item: int) -> Dict[str, Any]:
|
|
167
486
|
return self.samples.row(index=item, named=True)
|
|
168
487
|
|
|
@@ -173,30 +492,36 @@ class HafniaDataset:
|
|
|
173
492
|
for row in self.samples.iter_rows(named=True):
|
|
174
493
|
yield row
|
|
175
494
|
|
|
495
|
+
def __post_init__(self):
|
|
496
|
+
self.samples, self.info = _dataset_corrections(self.samples, self.info)
|
|
497
|
+
|
|
176
498
|
@staticmethod
|
|
177
499
|
def from_path(path_folder: Path, check_for_images: bool = True) -> "HafniaDataset":
|
|
500
|
+
path_folder = Path(path_folder)
|
|
178
501
|
HafniaDataset.check_dataset_path(path_folder, raise_error=True)
|
|
179
502
|
|
|
180
503
|
dataset_info = DatasetInfo.from_json_file(path_folder / FILENAME_DATASET_INFO)
|
|
181
|
-
|
|
504
|
+
samples = read_samples_from_path(path_folder)
|
|
505
|
+
samples, dataset_info = _dataset_corrections(samples, dataset_info)
|
|
182
506
|
|
|
183
507
|
# Convert from relative paths to absolute paths
|
|
184
508
|
dataset_root = path_folder.absolute().as_posix() + "/"
|
|
185
|
-
|
|
509
|
+
samples = samples.with_columns((dataset_root + pl.col(ColumnName.FILE_PATH)).alias(ColumnName.FILE_PATH))
|
|
186
510
|
if check_for_images:
|
|
187
|
-
check_image_paths(
|
|
188
|
-
return HafniaDataset(samples=
|
|
511
|
+
check_image_paths(samples)
|
|
512
|
+
return HafniaDataset(samples=samples, info=dataset_info)
|
|
189
513
|
|
|
190
514
|
@staticmethod
|
|
191
515
|
def from_name(name: str, force_redownload: bool = False, download_files: bool = True) -> "HafniaDataset":
|
|
192
516
|
"""
|
|
193
517
|
Load a dataset by its name. The dataset must be registered in the Hafnia platform.
|
|
194
518
|
"""
|
|
195
|
-
from hafnia.dataset.hafnia_dataset import HafniaDataset
|
|
196
519
|
from hafnia.platform.datasets import download_or_get_dataset_path
|
|
197
520
|
|
|
198
521
|
dataset_path = download_or_get_dataset_path(
|
|
199
|
-
dataset_name=name,
|
|
522
|
+
dataset_name=name,
|
|
523
|
+
force_redownload=force_redownload,
|
|
524
|
+
download_files=download_files,
|
|
200
525
|
)
|
|
201
526
|
return HafniaDataset.from_path(dataset_path, check_for_images=download_files)
|
|
202
527
|
|
|
@@ -210,9 +535,16 @@ class HafniaDataset:
|
|
|
210
535
|
else:
|
|
211
536
|
raise TypeError(f"Unsupported sample type: {type(sample)}. Expected Sample or dict.")
|
|
212
537
|
|
|
213
|
-
table = pl.from_records(json_samples)
|
|
214
|
-
table = table.with_row_index(name=ColumnName.SAMPLE_INDEX)
|
|
538
|
+
table = pl.from_records(json_samples)
|
|
539
|
+
table = table.drop(ColumnName.SAMPLE_INDEX).with_row_index(name=ColumnName.SAMPLE_INDEX)
|
|
215
540
|
|
|
541
|
+
# Add 'dataset_name' to samples
|
|
542
|
+
table = table.with_columns(
|
|
543
|
+
pl.when(pl.col(ColumnName.DATASET_NAME).is_null())
|
|
544
|
+
.then(pl.lit(info.dataset_name))
|
|
545
|
+
.otherwise(pl.col(ColumnName.DATASET_NAME))
|
|
546
|
+
.alias(ColumnName.DATASET_NAME)
|
|
547
|
+
)
|
|
216
548
|
return HafniaDataset(info=info, samples=table)
|
|
217
549
|
|
|
218
550
|
@staticmethod
|
|
@@ -241,7 +573,11 @@ class HafniaDataset:
|
|
|
241
573
|
If the dataset is already cached, it will be loaded from the cache.
|
|
242
574
|
"""
|
|
243
575
|
|
|
244
|
-
path_dataset = get_or_create_dataset_path_from_recipe(
|
|
576
|
+
path_dataset = get_or_create_dataset_path_from_recipe(
|
|
577
|
+
dataset_recipe,
|
|
578
|
+
path_datasets=path_datasets,
|
|
579
|
+
force_redownload=force_redownload,
|
|
580
|
+
)
|
|
245
581
|
return HafniaDataset.from_path(path_dataset, check_for_images=False)
|
|
246
582
|
|
|
247
583
|
@staticmethod
|
|
@@ -263,20 +599,46 @@ class HafniaDataset:
|
|
|
263
599
|
merged_dataset = HafniaDataset.merge(merged_dataset, dataset)
|
|
264
600
|
return merged_dataset
|
|
265
601
|
|
|
266
|
-
|
|
267
|
-
|
|
602
|
+
@staticmethod
|
|
603
|
+
def from_name_public_dataset(
|
|
604
|
+
name: str,
|
|
605
|
+
force_redownload: bool = False,
|
|
606
|
+
n_samples: Optional[int] = None,
|
|
607
|
+
) -> HafniaDataset:
|
|
608
|
+
from hafnia.dataset.format_conversions.torchvision_datasets import (
|
|
609
|
+
torchvision_to_hafnia_converters,
|
|
610
|
+
)
|
|
611
|
+
|
|
612
|
+
name_to_torchvision_function = torchvision_to_hafnia_converters()
|
|
613
|
+
|
|
614
|
+
if name not in name_to_torchvision_function:
|
|
615
|
+
raise ValueError(
|
|
616
|
+
f"Unknown torchvision dataset name: {name}. Supported: {list(name_to_torchvision_function.keys())}"
|
|
617
|
+
)
|
|
618
|
+
vision_dataset = name_to_torchvision_function[name]
|
|
619
|
+
return vision_dataset(
|
|
620
|
+
force_redownload=force_redownload,
|
|
621
|
+
n_samples=n_samples,
|
|
622
|
+
)
|
|
268
623
|
|
|
269
624
|
def shuffle(dataset: HafniaDataset, seed: int = 42) -> HafniaDataset:
|
|
270
625
|
table = dataset.samples.sample(n=len(dataset), with_replacement=False, seed=seed, shuffle=True)
|
|
271
|
-
return dataset.
|
|
626
|
+
return dataset.update_samples(table)
|
|
272
627
|
|
|
273
628
|
def select_samples(
|
|
274
|
-
dataset: "HafniaDataset",
|
|
629
|
+
dataset: "HafniaDataset",
|
|
630
|
+
n_samples: int,
|
|
631
|
+
shuffle: bool = True,
|
|
632
|
+
seed: int = 42,
|
|
633
|
+
with_replacement: bool = False,
|
|
275
634
|
) -> "HafniaDataset":
|
|
635
|
+
"""
|
|
636
|
+
Create a new dataset with a subset of samples.
|
|
637
|
+
"""
|
|
276
638
|
if not with_replacement:
|
|
277
639
|
n_samples = min(n_samples, len(dataset))
|
|
278
640
|
table = dataset.samples.sample(n=n_samples, with_replacement=with_replacement, seed=seed, shuffle=shuffle)
|
|
279
|
-
return dataset.
|
|
641
|
+
return dataset.update_samples(table)
|
|
280
642
|
|
|
281
643
|
def splits_by_ratios(dataset: "HafniaDataset", split_ratios: Dict[str, float], seed: int = 42) -> "HafniaDataset":
|
|
282
644
|
"""
|
|
@@ -295,7 +657,7 @@ class HafniaDataset:
|
|
|
295
657
|
split_ratios=split_ratios, n_items=n_items, seed=seed
|
|
296
658
|
)
|
|
297
659
|
table = dataset.samples.with_columns(pl.Series(split_name_column).alias("split"))
|
|
298
|
-
return dataset.
|
|
660
|
+
return dataset.update_samples(table)
|
|
299
661
|
|
|
300
662
|
def split_into_multiple_splits(
|
|
301
663
|
dataset: "HafniaDataset",
|
|
@@ -323,33 +685,124 @@ class HafniaDataset:
|
|
|
323
685
|
|
|
324
686
|
remaining_data = dataset.samples.filter(pl.col(ColumnName.SPLIT).is_in([split_name]).not_())
|
|
325
687
|
new_table = pl.concat([remaining_data, dataset_split_to_be_divided.samples], how="vertical")
|
|
326
|
-
dataset_new = dataset.
|
|
688
|
+
dataset_new = dataset.update_samples(new_table)
|
|
327
689
|
return dataset_new
|
|
328
690
|
|
|
329
691
|
def define_sample_set_by_size(dataset: "HafniaDataset", n_samples: int, seed: int = 42) -> "HafniaDataset":
|
|
692
|
+
"""
|
|
693
|
+
Defines a sample set randomly by selecting 'n_samples' samples from the dataset.
|
|
694
|
+
"""
|
|
695
|
+
samples = dataset.samples
|
|
696
|
+
|
|
697
|
+
# Remove any pre-existing "sample"-tags
|
|
698
|
+
samples = samples.with_columns(
|
|
699
|
+
pl.col(ColumnName.TAGS).list.eval(pl.element().filter(pl.element() != TAG_IS_SAMPLE)).alias(ColumnName.TAGS)
|
|
700
|
+
)
|
|
701
|
+
|
|
702
|
+
# Add "sample" to tags column for the selected samples
|
|
330
703
|
is_sample_indices = Random(seed).sample(range(len(dataset)), n_samples)
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
704
|
+
samples = samples.with_columns(
|
|
705
|
+
pl.when(pl.int_range(len(samples)).is_in(is_sample_indices))
|
|
706
|
+
.then(pl.col(ColumnName.TAGS).list.concat(pl.lit([TAG_IS_SAMPLE])))
|
|
707
|
+
.otherwise(pl.col(ColumnName.TAGS))
|
|
708
|
+
)
|
|
709
|
+
return dataset.update_samples(samples)
|
|
710
|
+
|
|
711
|
+
def class_mapper(
|
|
712
|
+
dataset: "HafniaDataset",
|
|
713
|
+
class_mapping: Union[Dict[str, str], List[Tuple[str, str]]],
|
|
714
|
+
method: str = "strict",
|
|
715
|
+
primitive: Optional[Type[Primitive]] = None,
|
|
716
|
+
task_name: Optional[str] = None,
|
|
717
|
+
) -> "HafniaDataset":
|
|
718
|
+
"""
|
|
719
|
+
Map class names to new class names using a strict mapping.
|
|
720
|
+
A strict mapping means that all class names in the dataset must be mapped to a new class name.
|
|
721
|
+
If a class name is not mapped, an error is raised.
|
|
722
|
+
|
|
723
|
+
The class indices are determined by the order of appearance of the new class names in the mapping.
|
|
724
|
+
Duplicates in the new class names are removed, preserving the order of first appearance.
|
|
725
|
+
|
|
726
|
+
E.g.
|
|
727
|
+
|
|
728
|
+
mnist = HafniaDataset.from_name("mnist")
|
|
729
|
+
strict_class_mapping = {
|
|
730
|
+
"1 - one": "odd", # 'odd' appears first and becomes class index 0
|
|
731
|
+
"3 - three": "odd",
|
|
732
|
+
"5 - five": "odd",
|
|
733
|
+
"7 - seven": "odd",
|
|
734
|
+
"9 - nine": "odd",
|
|
735
|
+
"0 - zero": "even", # 'even' appears second and becomes class index 1
|
|
736
|
+
"2 - two": "even",
|
|
737
|
+
"4 - four": "even",
|
|
738
|
+
"6 - six": "even",
|
|
739
|
+
"8 - eight": "even",
|
|
740
|
+
}
|
|
741
|
+
|
|
742
|
+
dataset_new = class_mapper(dataset=mnist, class_mapping=strict_class_mapping)
|
|
743
|
+
|
|
744
|
+
"""
|
|
745
|
+
return dataset_transformations.class_mapper(
|
|
746
|
+
dataset=dataset,
|
|
747
|
+
class_mapping=class_mapping,
|
|
748
|
+
method=method,
|
|
749
|
+
primitive=primitive,
|
|
750
|
+
task_name=task_name,
|
|
751
|
+
)
|
|
334
752
|
|
|
335
|
-
|
|
336
|
-
|
|
753
|
+
def rename_task(
|
|
754
|
+
dataset: "HafniaDataset",
|
|
755
|
+
old_task_name: str,
|
|
756
|
+
new_task_name: str,
|
|
757
|
+
) -> "HafniaDataset":
|
|
758
|
+
"""
|
|
759
|
+
Rename a task in the dataset.
|
|
760
|
+
"""
|
|
761
|
+
return dataset_transformations.rename_task(
|
|
762
|
+
dataset=dataset, old_task_name=old_task_name, new_task_name=new_task_name
|
|
763
|
+
)
|
|
764
|
+
|
|
765
|
+
def select_samples_by_class_name(
|
|
766
|
+
dataset: HafniaDataset,
|
|
767
|
+
name: Union[List[str], str],
|
|
768
|
+
task_name: Optional[str] = None,
|
|
769
|
+
primitive: Optional[Type[Primitive]] = None,
|
|
770
|
+
) -> HafniaDataset:
|
|
771
|
+
"""
|
|
772
|
+
Select samples that contain at least one annotation with the specified class name(s).
|
|
773
|
+
If 'task_name' and 'primitive' are not provided, the function will attempt to infer the task.
|
|
774
|
+
"""
|
|
775
|
+
return dataset_transformations.select_samples_by_class_name(
|
|
776
|
+
dataset=dataset, name=name, task_name=task_name, primitive=primitive
|
|
777
|
+
)
|
|
337
778
|
|
|
338
779
|
def merge(dataset0: "HafniaDataset", dataset1: "HafniaDataset") -> "HafniaDataset":
|
|
339
780
|
"""
|
|
340
781
|
Merges two Hafnia datasets by concatenating their samples and updating the split names.
|
|
341
782
|
"""
|
|
342
|
-
## Currently, only a very naive merging is implemented.
|
|
343
|
-
# In the future we need to verify that the class and tasks are compatible.
|
|
344
|
-
# Do they have similar classes and tasks? What to do if they don't?
|
|
345
|
-
# For now, we just concatenate the samples and keep the split names as they are.
|
|
346
|
-
merged_samples = pl.concat([dataset0.samples, dataset1.samples], how="vertical")
|
|
347
|
-
return dataset0.update_table(merged_samples)
|
|
348
783
|
|
|
349
|
-
|
|
350
|
-
|
|
784
|
+
# Merges dataset info and checks for compatibility
|
|
785
|
+
merged_info = DatasetInfo.merge(dataset0.info, dataset1.info)
|
|
786
|
+
|
|
787
|
+
# Merges samples tables (removes incompatible columns)
|
|
788
|
+
merged_samples = table_transformations.merge_samples(samples0=dataset0.samples, samples1=dataset1.samples)
|
|
789
|
+
|
|
790
|
+
# Check if primitives have been removed during the merge_samples
|
|
791
|
+
for task in copy.deepcopy(merged_info.tasks):
|
|
792
|
+
if task.primitive.column_name() not in merged_samples.columns:
|
|
793
|
+
user_logger.warning(
|
|
794
|
+
f"Task '{task.name}' with primitive '{task.primitive.__name__}' has been removed during the merge. "
|
|
795
|
+
"This happens if the two datasets do not have the same primitives."
|
|
796
|
+
)
|
|
797
|
+
merged_info = merged_info.replace_task(old_task=task, new_task=None)
|
|
798
|
+
|
|
799
|
+
return HafniaDataset(info=merged_info, samples=merged_samples)
|
|
351
800
|
|
|
352
801
|
def as_dict_dataset_splits(self) -> Dict[str, "HafniaDataset"]:
|
|
802
|
+
"""
|
|
803
|
+
Splits the dataset into multiple datasets based on the 'split' column.
|
|
804
|
+
Returns a dictionary with split names as keys and HafniaDataset objects as values.
|
|
805
|
+
"""
|
|
353
806
|
if ColumnName.SPLIT not in self.samples.columns:
|
|
354
807
|
raise ValueError(f"Dataset must contain a '{ColumnName.SPLIT}' column.")
|
|
355
808
|
|
|
@@ -360,10 +813,22 @@ class HafniaDataset:
|
|
|
360
813
|
return splits
|
|
361
814
|
|
|
362
815
|
def create_sample_dataset(self) -> "HafniaDataset":
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
816
|
+
# Backwards compatibility. Remove in future versions when dataset have been updated
|
|
817
|
+
if "is_sample" in self.samples.columns:
|
|
818
|
+
user_logger.warning(
|
|
819
|
+
"'is_sample' column found in the dataset. This column is deprecated and will be removed in future versions. "
|
|
820
|
+
"Please use the 'tags' column with the tag 'sample' instead."
|
|
821
|
+
)
|
|
822
|
+
table = self.samples.filter(pl.col("is_sample") == True) # noqa: E712
|
|
823
|
+
return self.update_samples(table)
|
|
824
|
+
|
|
825
|
+
if ColumnName.TAGS not in self.samples.columns:
|
|
826
|
+
raise ValueError(f"Dataset must contain an '{ColumnName.TAGS}' column.")
|
|
827
|
+
|
|
828
|
+
table = self.samples.filter(
|
|
829
|
+
pl.col(ColumnName.TAGS).list.eval(pl.element().filter(pl.element() == TAG_IS_SAMPLE)).list.len() > 0
|
|
830
|
+
)
|
|
831
|
+
return self.update_samples(table)
|
|
367
832
|
|
|
368
833
|
def create_split_dataset(self, split_name: Union[str | List[str]]) -> "HafniaDataset":
|
|
369
834
|
if isinstance(split_name, str):
|
|
@@ -376,16 +841,12 @@ class HafniaDataset:
|
|
|
376
841
|
raise ValueError(f"Invalid split name: {split_name}. Valid splits are: {SplitName.valid_splits()}")
|
|
377
842
|
|
|
378
843
|
filtered_dataset = self.samples.filter(pl.col(ColumnName.SPLIT).is_in(split_names))
|
|
379
|
-
return self.
|
|
380
|
-
|
|
381
|
-
def get_task_by_name(self, task_name: str) -> TaskInfo:
|
|
382
|
-
for task in self.info.tasks:
|
|
383
|
-
if task.name == task_name:
|
|
384
|
-
return task
|
|
385
|
-
raise ValueError(f"Task with name {task_name} not found in dataset info.")
|
|
844
|
+
return self.update_samples(filtered_dataset)
|
|
386
845
|
|
|
387
|
-
def
|
|
388
|
-
|
|
846
|
+
def update_samples(self, table: pl.DataFrame) -> "HafniaDataset":
|
|
847
|
+
dataset = HafniaDataset(info=self.info.model_copy(deep=True), samples=table)
|
|
848
|
+
dataset.check_dataset_tasks()
|
|
849
|
+
return dataset
|
|
389
850
|
|
|
390
851
|
@staticmethod
|
|
391
852
|
def check_dataset_path(path_dataset: Path, raise_error: bool = True) -> bool:
|
|
@@ -411,19 +872,27 @@ class HafniaDataset:
|
|
|
411
872
|
|
|
412
873
|
return True
|
|
413
874
|
|
|
414
|
-
def
|
|
875
|
+
def copy(self) -> "HafniaDataset":
|
|
876
|
+
return HafniaDataset(info=self.info.model_copy(deep=True), samples=self.samples.clone())
|
|
877
|
+
|
|
878
|
+
def write(self, path_folder: Path, add_version: bool = False, drop_null_cols: bool = True) -> None:
|
|
415
879
|
user_logger.info(f"Writing dataset to {path_folder}...")
|
|
416
880
|
if not path_folder.exists():
|
|
417
881
|
path_folder.mkdir(parents=True)
|
|
418
882
|
|
|
419
883
|
new_relative_paths = []
|
|
420
|
-
|
|
884
|
+
org_paths = self.samples[ColumnName.FILE_PATH].to_list()
|
|
885
|
+
for org_path in track(org_paths, description="- Copy images"):
|
|
421
886
|
new_path = dataset_helpers.copy_and_rename_file_to_hash_value(
|
|
422
887
|
path_source=Path(org_path),
|
|
423
888
|
path_dataset_root=path_folder,
|
|
424
889
|
)
|
|
425
890
|
new_relative_paths.append(str(new_path.relative_to(path_folder)))
|
|
426
|
-
table = self.samples.with_columns(pl.Series(new_relative_paths).alias(
|
|
891
|
+
table = self.samples.with_columns(pl.Series(new_relative_paths).alias(ColumnName.FILE_PATH))
|
|
892
|
+
|
|
893
|
+
if drop_null_cols: # Drops all unused/Null columns
|
|
894
|
+
table = table.drop(pl.selectors.by_dtype(pl.Null))
|
|
895
|
+
|
|
427
896
|
table.write_ndjson(path_folder / FILENAME_ANNOTATIONS_JSONL) # Json for readability
|
|
428
897
|
table.write_parquet(path_folder / FILENAME_ANNOTATIONS_PARQUET) # Parquet for speed
|
|
429
898
|
self.info.write_json(path_folder / FILENAME_DATASET_INFO)
|
|
@@ -448,51 +917,10 @@ class HafniaDataset:
|
|
|
448
917
|
return False
|
|
449
918
|
return True
|
|
450
919
|
|
|
451
|
-
def print_stats(self) -> None:
|
|
452
|
-
table_base = Table(title="Dataset Statistics", show_lines=True, box=rich.box.SIMPLE)
|
|
453
|
-
table_base.add_column("Property", style="cyan")
|
|
454
|
-
table_base.add_column("Value")
|
|
455
|
-
table_base.add_row("Dataset Name", self.info.dataset_name)
|
|
456
|
-
table_base.add_row("Version", self.info.version)
|
|
457
|
-
table_base.add_row("Number of samples", str(len(self.samples)))
|
|
458
|
-
rprint(table_base)
|
|
459
|
-
rprint(self.info.tasks)
|
|
460
|
-
|
|
461
|
-
splits_sets = {
|
|
462
|
-
"All": SplitName.valid_splits(),
|
|
463
|
-
"Train": [SplitName.TRAIN],
|
|
464
|
-
"Validation": [SplitName.VAL],
|
|
465
|
-
"Test": [SplitName.TEST],
|
|
466
|
-
}
|
|
467
|
-
rows = []
|
|
468
|
-
for split_name, splits in splits_sets.items():
|
|
469
|
-
dataset_split = self.create_split_dataset(splits)
|
|
470
|
-
table = dataset_split.samples
|
|
471
|
-
row = {}
|
|
472
|
-
row["Split"] = split_name
|
|
473
|
-
row["Sample "] = str(len(table))
|
|
474
|
-
for PrimitiveType in PRIMITIVE_TYPES:
|
|
475
|
-
column_name = PrimitiveType.column_name()
|
|
476
|
-
objects_df = create_primitive_table(table, PrimitiveType=PrimitiveType, keep_sample_data=False)
|
|
477
|
-
if objects_df is None:
|
|
478
|
-
continue
|
|
479
|
-
for (task_name,), object_group in objects_df.group_by(FieldName.TASK_NAME):
|
|
480
|
-
count = len(object_group[FieldName.CLASS_NAME])
|
|
481
|
-
row[f"{PrimitiveType.__name__}\n{task_name}"] = str(count)
|
|
482
|
-
rows.append(row)
|
|
483
|
-
|
|
484
|
-
rich_table = Table(title="Dataset Statistics", show_lines=True, box=rich.box.SIMPLE)
|
|
485
|
-
for i_row, row in enumerate(rows):
|
|
486
|
-
if i_row == 0:
|
|
487
|
-
for column_name in row.keys():
|
|
488
|
-
rich_table.add_column(column_name, justify="left", style="cyan")
|
|
489
|
-
rich_table.add_row(*[str(value) for value in row.values()])
|
|
490
|
-
rprint(rich_table)
|
|
491
|
-
|
|
492
920
|
|
|
493
921
|
def check_hafnia_dataset_from_path(path_dataset: Path) -> None:
|
|
494
922
|
dataset = HafniaDataset.from_path(path_dataset, check_for_images=True)
|
|
495
|
-
|
|
923
|
+
dataset.check_dataset()
|
|
496
924
|
|
|
497
925
|
|
|
498
926
|
def get_or_create_dataset_path_from_recipe(
|
|
@@ -524,87 +952,23 @@ def get_or_create_dataset_path_from_recipe(
|
|
|
524
952
|
return path_dataset
|
|
525
953
|
|
|
526
954
|
|
|
527
|
-
def
|
|
528
|
-
|
|
529
|
-
assert isinstance(dataset.info.version, str) and len(dataset.info.version) > 0
|
|
530
|
-
assert isinstance(dataset.info.dataset_name, str) and len(dataset.info.dataset_name) > 0
|
|
531
|
-
|
|
532
|
-
is_sample_list = set(dataset.samples.select(pl.col(ColumnName.IS_SAMPLE)).unique().to_series().to_list())
|
|
533
|
-
if True not in is_sample_list:
|
|
534
|
-
raise ValueError(f"The dataset should contain '{ColumnName.IS_SAMPLE}=True' samples")
|
|
535
|
-
|
|
536
|
-
actual_splits = dataset.samples.select(pl.col(ColumnName.SPLIT)).unique().to_series().to_list()
|
|
537
|
-
expected_splits = SplitName.valid_splits()
|
|
538
|
-
if set(actual_splits) != set(expected_splits):
|
|
539
|
-
raise ValueError(f"Expected all splits '{expected_splits}' in dataset, but got '{actual_splits}'. ")
|
|
540
|
-
|
|
541
|
-
expected_tasks = dataset.info.tasks
|
|
542
|
-
for task in expected_tasks:
|
|
543
|
-
primitive = task.primitive.__name__
|
|
544
|
-
column_name = task.primitive.column_name()
|
|
545
|
-
primitive_column = dataset.samples[column_name]
|
|
546
|
-
# msg_something_wrong = f"Something is wrong with the '{primtive_name}' task '{task.name}' in dataset '{dataset.name}'. "
|
|
547
|
-
msg_something_wrong = (
|
|
548
|
-
f"Something is wrong with the defined tasks ('info.tasks') in dataset '{dataset.info.dataset_name}'. \n"
|
|
549
|
-
f"For '{primitive=}' and '{task.name=}' "
|
|
550
|
-
)
|
|
551
|
-
if primitive_column.dtype == pl.Null:
|
|
552
|
-
raise ValueError(msg_something_wrong + "the column is 'Null'. Please check the dataset.")
|
|
955
|
+
def _dataset_corrections(samples: pl.DataFrame, dataset_info: DatasetInfo) -> Tuple[pl.DataFrame, DatasetInfo]:
|
|
956
|
+
format_version_of_dataset = Version(dataset_info.format_version)
|
|
553
957
|
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
+ f"the column '{column_name}' has no {task.name=} objects. Please check the dataset."
|
|
559
|
-
)
|
|
958
|
+
## Backwards compatibility fixes for older dataset versions
|
|
959
|
+
if format_version_of_dataset <= Version("0.3.0"):
|
|
960
|
+
if ColumnName.DATASET_NAME not in samples.columns:
|
|
961
|
+
samples = samples.with_columns(pl.lit(dataset_info.dataset_name).alias(ColumnName.DATASET_NAME))
|
|
560
962
|
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
raise ValueError(
|
|
564
|
-
msg_something_wrong
|
|
565
|
-
+ f"the column '{column_name}' with {task.name=} has no defined classes. Please check the dataset."
|
|
566
|
-
)
|
|
567
|
-
defined_classes = set(task.class_names)
|
|
963
|
+
if "file_name" in samples.columns:
|
|
964
|
+
samples = samples.rename({"file_name": ColumnName.FILE_PATH})
|
|
568
965
|
|
|
569
|
-
if not
|
|
570
|
-
|
|
571
|
-
msg_something_wrong
|
|
572
|
-
+ f"the column '{column_name}' with {task.name=} we expected the actual classes in the dataset to \n"
|
|
573
|
-
f"to be a subset of the defined classes\n\t{actual_classes=} \n\t{defined_classes=}."
|
|
574
|
-
)
|
|
575
|
-
# Check class_indices
|
|
576
|
-
mapped_indices = primitive_table[FieldName.CLASS_NAME].map_elements(
|
|
577
|
-
lambda x: task.class_names.index(x), return_dtype=pl.Int64
|
|
578
|
-
)
|
|
579
|
-
table_indices = primitive_table[FieldName.CLASS_IDX]
|
|
966
|
+
if ColumnName.SAMPLE_INDEX not in samples.columns:
|
|
967
|
+
samples = samples.with_row_index(name=ColumnName.SAMPLE_INDEX)
|
|
580
968
|
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
distribution = dataset.info.distributions or []
|
|
587
|
-
distribution_names = [task.name for task in distribution]
|
|
588
|
-
# Check that tasks found in the 'dataset.table' matches the tasks defined in 'dataset.info.tasks'
|
|
589
|
-
for PrimitiveType in PRIMITIVE_TYPES:
|
|
590
|
-
column_name = PrimitiveType.column_name()
|
|
591
|
-
if column_name not in dataset.samples.columns:
|
|
592
|
-
continue
|
|
593
|
-
objects_df = create_primitive_table(dataset.samples, PrimitiveType=PrimitiveType, keep_sample_data=False)
|
|
594
|
-
if objects_df is None:
|
|
595
|
-
continue
|
|
596
|
-
for (task_name,), object_group in objects_df.group_by(FieldName.TASK_NAME):
|
|
597
|
-
has_task = any([t for t in expected_tasks if t.name == task_name and t.primitive == PrimitiveType])
|
|
598
|
-
if has_task:
|
|
599
|
-
continue
|
|
600
|
-
if task_name in distribution_names:
|
|
601
|
-
continue
|
|
602
|
-
class_names = object_group[FieldName.CLASS_NAME].unique().to_list()
|
|
603
|
-
raise ValueError(
|
|
604
|
-
f"Task name '{task_name}' for the '{PrimitiveType.__name__}' primitive is missing in "
|
|
605
|
-
f"'dataset.info.tasks' for dataset '{task_name}'. Missing task has the following "
|
|
606
|
-
f"classes: {class_names}. "
|
|
607
|
-
)
|
|
969
|
+
# Backwards compatibility: If tags-column doesn't exist, create it with empty lists
|
|
970
|
+
if ColumnName.TAGS not in samples.columns:
|
|
971
|
+
tags_column: List[List[str]] = [[] for _ in range(len(samples))] # type: ignore[annotation-unchecked]
|
|
972
|
+
samples = samples.with_columns(pl.Series(tags_column, dtype=pl.List(pl.String)).alias(ColumnName.TAGS))
|
|
608
973
|
|
|
609
|
-
|
|
610
|
-
sample = Sample(**sample_dict) # Checks format of all samples with pydantic validation # noqa: F841
|
|
974
|
+
return samples, dataset_info
|