hafnia 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/__main__.py +13 -2
- cli/config.py +2 -1
- cli/consts.py +1 -1
- cli/dataset_cmds.py +6 -14
- cli/dataset_recipe_cmds.py +78 -0
- cli/experiment_cmds.py +226 -43
- cli/profile_cmds.py +6 -5
- cli/runc_cmds.py +5 -5
- cli/trainer_package_cmds.py +65 -0
- hafnia/__init__.py +2 -0
- hafnia/data/factory.py +1 -2
- hafnia/dataset/dataset_helpers.py +0 -12
- hafnia/dataset/dataset_names.py +8 -4
- hafnia/dataset/dataset_recipe/dataset_recipe.py +119 -33
- hafnia/dataset/dataset_recipe/recipe_transforms.py +32 -4
- hafnia/dataset/dataset_recipe/recipe_types.py +1 -1
- hafnia/dataset/dataset_upload_helper.py +206 -53
- hafnia/dataset/hafnia_dataset.py +432 -194
- hafnia/dataset/license_types.py +63 -0
- hafnia/dataset/operations/dataset_stats.py +260 -3
- hafnia/dataset/operations/dataset_transformations.py +325 -4
- hafnia/dataset/operations/table_transformations.py +39 -2
- hafnia/dataset/primitives/__init__.py +8 -0
- hafnia/dataset/primitives/classification.py +1 -1
- hafnia/experiment/hafnia_logger.py +112 -0
- hafnia/http.py +16 -2
- hafnia/platform/__init__.py +9 -3
- hafnia/platform/builder.py +12 -10
- hafnia/platform/dataset_recipe.py +99 -0
- hafnia/platform/datasets.py +44 -6
- hafnia/platform/download.py +2 -1
- hafnia/platform/experiment.py +51 -56
- hafnia/platform/trainer_package.py +57 -0
- hafnia/utils.py +64 -13
- hafnia/visualizations/image_visualizations.py +3 -3
- {hafnia-0.2.4.dist-info → hafnia-0.3.0.dist-info}/METADATA +34 -30
- hafnia-0.3.0.dist-info/RECORD +53 -0
- cli/recipe_cmds.py +0 -45
- hafnia-0.2.4.dist-info/RECORD +0 -49
- {hafnia-0.2.4.dist-info → hafnia-0.3.0.dist-info}/WHEEL +0 -0
- {hafnia-0.2.4.dist-info → hafnia-0.3.0.dist-info}/entry_points.txt +0 -0
- {hafnia-0.2.4.dist-info → hafnia-0.3.0.dist-info}/licenses/LICENSE +0 -0
hafnia/dataset/hafnia_dataset.py
CHANGED
|
@@ -1,7 +1,11 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import collections
|
|
4
|
+
import copy
|
|
5
|
+
import json
|
|
3
6
|
import shutil
|
|
4
7
|
from dataclasses import dataclass
|
|
8
|
+
from datetime import datetime
|
|
5
9
|
from pathlib import Path
|
|
6
10
|
from random import Random
|
|
7
11
|
from typing import Any, Dict, List, Optional, Type, Union
|
|
@@ -9,13 +13,11 @@ from typing import Any, Dict, List, Optional, Type, Union
|
|
|
9
13
|
import more_itertools
|
|
10
14
|
import numpy as np
|
|
11
15
|
import polars as pl
|
|
12
|
-
import rich
|
|
13
16
|
from PIL import Image
|
|
14
|
-
from pydantic import BaseModel, field_serializer, field_validator
|
|
15
|
-
from rich import print as rprint
|
|
16
|
-
from rich.table import Table
|
|
17
|
+
from pydantic import BaseModel, Field, field_serializer, field_validator
|
|
17
18
|
from tqdm import tqdm
|
|
18
19
|
|
|
20
|
+
import hafnia
|
|
19
21
|
from hafnia.dataset import dataset_helpers
|
|
20
22
|
from hafnia.dataset.dataset_names import (
|
|
21
23
|
DATASET_FILENAMES_REQUIRED,
|
|
@@ -23,20 +25,16 @@ from hafnia.dataset.dataset_names import (
|
|
|
23
25
|
FILENAME_ANNOTATIONS_PARQUET,
|
|
24
26
|
FILENAME_DATASET_INFO,
|
|
25
27
|
FILENAME_RECIPE_JSON,
|
|
28
|
+
TAG_IS_SAMPLE,
|
|
26
29
|
ColumnName,
|
|
27
|
-
FieldName,
|
|
28
30
|
SplitName,
|
|
29
31
|
)
|
|
30
|
-
from hafnia.dataset.operations import dataset_stats, dataset_transformations
|
|
32
|
+
from hafnia.dataset.operations import dataset_stats, dataset_transformations, table_transformations
|
|
31
33
|
from hafnia.dataset.operations.table_transformations import (
|
|
32
34
|
check_image_paths,
|
|
33
|
-
create_primitive_table,
|
|
34
35
|
read_table_from_path,
|
|
35
36
|
)
|
|
36
|
-
from hafnia.dataset.primitives import
|
|
37
|
-
PRIMITIVE_NAME_TO_TYPE,
|
|
38
|
-
PRIMITIVE_TYPES,
|
|
39
|
-
)
|
|
37
|
+
from hafnia.dataset.primitives import PRIMITIVE_TYPES, get_primitive_type_from_string
|
|
40
38
|
from hafnia.dataset.primitives.bbox import Bbox
|
|
41
39
|
from hafnia.dataset.primitives.bitmask import Bitmask
|
|
42
40
|
from hafnia.dataset.primitives.classification import Classification
|
|
@@ -48,9 +46,7 @@ from hafnia.log import user_logger
|
|
|
48
46
|
class TaskInfo(BaseModel):
|
|
49
47
|
primitive: Type[Primitive] # Primitive class or string name of the primitive, e.g. "Bbox" or "bitmask"
|
|
50
48
|
class_names: Optional[List[str]] # Class names for the tasks. To get consistent class indices specify class_names.
|
|
51
|
-
name: Optional[str] =
|
|
52
|
-
None # None to use the default primitive task name Bbox ->"bboxes", Bitmask -> "bitmasks" etc.
|
|
53
|
-
)
|
|
49
|
+
name: Optional[str] = None # Use 'None' to use default name Bbox ->"bboxes", Bitmask -> "bitmasks" etc.
|
|
54
50
|
|
|
55
51
|
def model_post_init(self, __context: Any) -> None:
|
|
56
52
|
if self.name is None:
|
|
@@ -64,12 +60,7 @@ class TaskInfo(BaseModel):
|
|
|
64
60
|
@classmethod
|
|
65
61
|
def ensure_primitive(cls, primitive: Any) -> Any:
|
|
66
62
|
if isinstance(primitive, str):
|
|
67
|
-
|
|
68
|
-
raise ValueError(
|
|
69
|
-
f"Primitive '{primitive}' is not recognized. Available primitives: {list(PRIMITIVE_NAME_TO_TYPE.keys())}"
|
|
70
|
-
)
|
|
71
|
-
|
|
72
|
-
return PRIMITIVE_NAME_TO_TYPE[primitive]
|
|
63
|
+
return get_primitive_type_from_string(primitive)
|
|
73
64
|
|
|
74
65
|
if issubclass(primitive, Primitive):
|
|
75
66
|
return primitive
|
|
@@ -83,22 +74,187 @@ class TaskInfo(BaseModel):
|
|
|
83
74
|
raise ValueError(f"Primitive must be a subclass of Primitive, got {type(primitive)} instead.")
|
|
84
75
|
return primitive.__name__
|
|
85
76
|
|
|
77
|
+
@field_validator("class_names", mode="after")
|
|
78
|
+
@classmethod
|
|
79
|
+
def validate_unique_class_names(cls, class_names: Optional[List[str]]) -> Optional[List[str]]:
|
|
80
|
+
"""Validate that class names are unique"""
|
|
81
|
+
if class_names is None:
|
|
82
|
+
return None
|
|
83
|
+
duplicate_class_names = set([name for name in class_names if class_names.count(name) > 1])
|
|
84
|
+
if duplicate_class_names:
|
|
85
|
+
raise ValueError(
|
|
86
|
+
f"Class names must be unique. The following class names appear multiple times: {duplicate_class_names}."
|
|
87
|
+
)
|
|
88
|
+
return class_names
|
|
89
|
+
|
|
90
|
+
# To get unique hash value for TaskInfo objects
|
|
91
|
+
def __hash__(self) -> int:
|
|
92
|
+
class_names = self.class_names or []
|
|
93
|
+
return hash((self.name, self.primitive.__name__, tuple(class_names)))
|
|
94
|
+
|
|
95
|
+
def __eq__(self, other: Any) -> bool:
|
|
96
|
+
if not isinstance(other, TaskInfo):
|
|
97
|
+
return False
|
|
98
|
+
return self.name == other.name and self.primitive == other.primitive and self.class_names == other.class_names
|
|
99
|
+
|
|
86
100
|
|
|
87
101
|
class DatasetInfo(BaseModel):
|
|
88
102
|
dataset_name: str
|
|
89
|
-
version: str
|
|
90
|
-
tasks:
|
|
103
|
+
version: str # Dataset version. This is not the same as the Hafnia dataset format version.
|
|
104
|
+
tasks: List[TaskInfo]
|
|
91
105
|
distributions: Optional[List[TaskInfo]] = None # Distributions. TODO: FIX/REMOVE/CHANGE this
|
|
92
106
|
meta: Optional[Dict[str, Any]] = None # Metadata about the dataset, e.g. description, etc.
|
|
107
|
+
format_version: str = hafnia.__dataset_format_version__ # Version of the Hafnia dataset format
|
|
108
|
+
updated_at: datetime = datetime.now()
|
|
109
|
+
|
|
110
|
+
@field_validator("tasks", mode="after")
|
|
111
|
+
@classmethod
|
|
112
|
+
def _validate_check_for_duplicate_tasks(cls, tasks: List[TaskInfo]) -> List[TaskInfo]:
|
|
113
|
+
task_name_counts = collections.Counter(task.name for task in tasks)
|
|
114
|
+
duplicate_task_names = [name for name, count in task_name_counts.items() if count > 1]
|
|
115
|
+
if duplicate_task_names:
|
|
116
|
+
raise ValueError(
|
|
117
|
+
f"Tasks must be unique. The following tasks appear multiple times: {duplicate_task_names}."
|
|
118
|
+
)
|
|
119
|
+
return tasks
|
|
120
|
+
|
|
121
|
+
def check_for_duplicate_task_names(self) -> List[TaskInfo]:
|
|
122
|
+
return self._validate_check_for_duplicate_tasks(self.tasks)
|
|
93
123
|
|
|
94
124
|
def write_json(self, path: Path, indent: Optional[int] = 4) -> None:
|
|
95
125
|
json_str = self.model_dump_json(indent=indent)
|
|
96
126
|
path.write_text(json_str)
|
|
97
127
|
|
|
98
128
|
@staticmethod
|
|
99
|
-
def from_json_file(path: Path) ->
|
|
129
|
+
def from_json_file(path: Path) -> DatasetInfo:
|
|
100
130
|
json_str = path.read_text()
|
|
101
|
-
|
|
131
|
+
|
|
132
|
+
# TODO: Deprecated support for old dataset info without format_version
|
|
133
|
+
# Below 4 lines can be replaced by 'dataset_info = DatasetInfo.model_validate_json(json_str)'
|
|
134
|
+
# when all datasets include a 'format_version' field
|
|
135
|
+
json_dict = json.loads(json_str)
|
|
136
|
+
if "format_version" not in json_dict:
|
|
137
|
+
json_dict["format_version"] = "0.0.0"
|
|
138
|
+
|
|
139
|
+
if "updated_at" not in json_dict:
|
|
140
|
+
json_dict["updated_at"] = datetime.min.isoformat()
|
|
141
|
+
dataset_info = DatasetInfo.model_validate(json_dict)
|
|
142
|
+
|
|
143
|
+
return dataset_info
|
|
144
|
+
|
|
145
|
+
@staticmethod
|
|
146
|
+
def merge(info0: DatasetInfo, info1: DatasetInfo) -> DatasetInfo:
|
|
147
|
+
"""
|
|
148
|
+
Merges two DatasetInfo objects into one and validates if they are compatible.
|
|
149
|
+
"""
|
|
150
|
+
for task_ds0 in info0.tasks:
|
|
151
|
+
for task_ds1 in info1.tasks:
|
|
152
|
+
same_name = task_ds0.name == task_ds1.name
|
|
153
|
+
same_primitive = task_ds0.primitive == task_ds1.primitive
|
|
154
|
+
same_name_different_primitive = same_name and not same_primitive
|
|
155
|
+
if same_name_different_primitive:
|
|
156
|
+
raise ValueError(
|
|
157
|
+
f"Cannot merge datasets with different primitives for the same task name: "
|
|
158
|
+
f"'{task_ds0.name}' has primitive '{task_ds0.primitive}' in dataset0 and "
|
|
159
|
+
f"'{task_ds1.primitive}' in dataset1."
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
is_same_name_and_primitive = same_name and same_primitive
|
|
163
|
+
if is_same_name_and_primitive:
|
|
164
|
+
task_ds0_class_names = task_ds0.class_names or []
|
|
165
|
+
task_ds1_class_names = task_ds1.class_names or []
|
|
166
|
+
if task_ds0_class_names != task_ds1_class_names:
|
|
167
|
+
raise ValueError(
|
|
168
|
+
f"Cannot merge datasets with different class names for the same task name and primitive: "
|
|
169
|
+
f"'{task_ds0.name}' with primitive '{task_ds0.primitive}' has class names "
|
|
170
|
+
f"{task_ds0_class_names} in dataset0 and {task_ds1_class_names} in dataset1."
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
if info1.format_version != info0.format_version:
|
|
174
|
+
user_logger.warning(
|
|
175
|
+
"Dataset format version of the two datasets do not match. "
|
|
176
|
+
f"'{info1.format_version}' vs '{info0.format_version}'."
|
|
177
|
+
)
|
|
178
|
+
dataset_format_version = info0.format_version
|
|
179
|
+
if hafnia.__dataset_format_version__ != dataset_format_version:
|
|
180
|
+
user_logger.warning(
|
|
181
|
+
f"Dataset format version '{dataset_format_version}' does not match the current "
|
|
182
|
+
f"Hafnia format version '{hafnia.__dataset_format_version__}'."
|
|
183
|
+
)
|
|
184
|
+
unique_tasks = set(info0.tasks + info1.tasks)
|
|
185
|
+
distributions = set((info0.distributions or []) + (info1.distributions or []))
|
|
186
|
+
meta = (info0.meta or {}).copy()
|
|
187
|
+
meta.update(info1.meta or {})
|
|
188
|
+
return DatasetInfo(
|
|
189
|
+
dataset_name=info0.dataset_name + "+" + info1.dataset_name,
|
|
190
|
+
version="merged",
|
|
191
|
+
tasks=list(unique_tasks),
|
|
192
|
+
distributions=list(distributions),
|
|
193
|
+
meta=meta,
|
|
194
|
+
format_version=dataset_format_version,
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
def get_task_by_name(self, task_name: str) -> TaskInfo:
|
|
198
|
+
"""
|
|
199
|
+
Get task by its name. Raises an error if the task name is not found or if multiple tasks have the same name.
|
|
200
|
+
"""
|
|
201
|
+
tasks_with_name = [task for task in self.tasks if task.name == task_name]
|
|
202
|
+
if not tasks_with_name:
|
|
203
|
+
raise ValueError(f"Task with name '{task_name}' not found in dataset info.")
|
|
204
|
+
if len(tasks_with_name) > 1:
|
|
205
|
+
raise ValueError(f"Multiple tasks found with name '{task_name}'. This should not happen!")
|
|
206
|
+
return tasks_with_name[0]
|
|
207
|
+
|
|
208
|
+
def get_task_by_primitive(self, primitive: Union[Type[Primitive], str]) -> TaskInfo:
|
|
209
|
+
"""
|
|
210
|
+
Get task by its primitive type. Raises an error if the primitive type is not found or if multiple tasks
|
|
211
|
+
have the same primitive type.
|
|
212
|
+
"""
|
|
213
|
+
if isinstance(primitive, str):
|
|
214
|
+
primitive = get_primitive_type_from_string(primitive)
|
|
215
|
+
|
|
216
|
+
tasks_with_primitive = [task for task in self.tasks if task.primitive == primitive]
|
|
217
|
+
if not tasks_with_primitive:
|
|
218
|
+
raise ValueError(f"Task with primitive {primitive} not found in dataset info.")
|
|
219
|
+
if len(tasks_with_primitive) > 1:
|
|
220
|
+
raise ValueError(
|
|
221
|
+
f"Multiple tasks found with primitive {primitive}. Use '{self.get_task_by_name.__name__}' instead."
|
|
222
|
+
)
|
|
223
|
+
return tasks_with_primitive[0]
|
|
224
|
+
|
|
225
|
+
def get_task_by_task_name_and_primitive(
|
|
226
|
+
self,
|
|
227
|
+
task_name: Optional[str],
|
|
228
|
+
primitive: Optional[Union[Type[Primitive], str]],
|
|
229
|
+
) -> TaskInfo:
|
|
230
|
+
"""
|
|
231
|
+
Logic to get a unique task based on the provided 'task_name' and/or 'primitive'.
|
|
232
|
+
If both 'task_name' and 'primitive' are None, the dataset must have only one task.
|
|
233
|
+
"""
|
|
234
|
+
task = dataset_transformations.get_task_info_from_task_name_and_primitive(
|
|
235
|
+
tasks=self.tasks,
|
|
236
|
+
primitive=primitive,
|
|
237
|
+
task_name=task_name,
|
|
238
|
+
)
|
|
239
|
+
return task
|
|
240
|
+
|
|
241
|
+
def replace_task(self, old_task: TaskInfo, new_task: Optional[TaskInfo]) -> DatasetInfo:
|
|
242
|
+
dataset_info = self.model_copy(deep=True)
|
|
243
|
+
has_task = any(t for t in dataset_info.tasks if t.name == old_task.name and t.primitive == old_task.primitive)
|
|
244
|
+
if not has_task:
|
|
245
|
+
raise ValueError(f"Task '{old_task.__repr__()}' not found in dataset info.")
|
|
246
|
+
|
|
247
|
+
new_tasks = []
|
|
248
|
+
for task in dataset_info.tasks:
|
|
249
|
+
if task.name == old_task.name and task.primitive == old_task.primitive:
|
|
250
|
+
if new_task is None:
|
|
251
|
+
continue # Remove the task
|
|
252
|
+
new_tasks.append(new_task)
|
|
253
|
+
else:
|
|
254
|
+
new_tasks.append(task)
|
|
255
|
+
|
|
256
|
+
dataset_info.tasks = new_tasks
|
|
257
|
+
return dataset_info
|
|
102
258
|
|
|
103
259
|
|
|
104
260
|
class Sample(BaseModel):
|
|
@@ -106,7 +262,7 @@ class Sample(BaseModel):
|
|
|
106
262
|
height: int
|
|
107
263
|
width: int
|
|
108
264
|
split: str # Split name, e.g., "train", "val", "test"
|
|
109
|
-
|
|
265
|
+
tags: List[str] = [] # tags for a given sample. Used for creating subsets of the dataset.
|
|
110
266
|
collection_index: Optional[int] = None # Optional e.g. frame number for video datasets
|
|
111
267
|
collection_id: Optional[str] = None # Optional e.g. video name for video datasets
|
|
112
268
|
remote_path: Optional[str] = None # Optional remote path for the image, if applicable
|
|
@@ -116,6 +272,7 @@ class Sample(BaseModel):
|
|
|
116
272
|
bitmasks: Optional[List[Bitmask]] = None # List of bitmasks, if applicable
|
|
117
273
|
polygons: Optional[List[Polygon]] = None # List of polygons, if applicable
|
|
118
274
|
|
|
275
|
+
attribution: Optional[Attribution] = None # Attribution information for the image
|
|
119
276
|
meta: Optional[Dict] = None # Additional metadata, e.g., camera settings, GPS data, etc.
|
|
120
277
|
|
|
121
278
|
def get_annotations(self, primitive_types: Optional[List[Type[Primitive]]] = None) -> List[Primitive]:
|
|
@@ -158,11 +315,93 @@ class Sample(BaseModel):
|
|
|
158
315
|
return annotations_visualized
|
|
159
316
|
|
|
160
317
|
|
|
318
|
+
class License(BaseModel):
|
|
319
|
+
"""License information"""
|
|
320
|
+
|
|
321
|
+
name: Optional[str] = Field(
|
|
322
|
+
default=None,
|
|
323
|
+
description="License name. E.g. 'Creative Commons: Attribution 2.0 Generic'",
|
|
324
|
+
max_length=100,
|
|
325
|
+
)
|
|
326
|
+
name_short: Optional[str] = Field(
|
|
327
|
+
default=None,
|
|
328
|
+
description="License short name or abbreviation. E.g. 'CC BY 4.0'",
|
|
329
|
+
max_length=100,
|
|
330
|
+
)
|
|
331
|
+
url: Optional[str] = Field(
|
|
332
|
+
default=None,
|
|
333
|
+
description="License URL e.g. https://creativecommons.org/licenses/by/4.0/",
|
|
334
|
+
)
|
|
335
|
+
description: Optional[str] = Field(
|
|
336
|
+
default=None,
|
|
337
|
+
description=(
|
|
338
|
+
"License description e.g. 'You must give appropriate credit, provide a "
|
|
339
|
+
"link to the license, and indicate if changes were made.'"
|
|
340
|
+
),
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
valid_date: Optional[datetime] = Field(
|
|
344
|
+
default=None,
|
|
345
|
+
description="License valid date. E.g. '2023-01-01T00:00:00Z'",
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
permissions: Optional[List[str]] = Field(
|
|
349
|
+
default=None,
|
|
350
|
+
description="License permissions. Allowed to Access, Label, Distribute, Represent and Modify data.",
|
|
351
|
+
)
|
|
352
|
+
liability: Optional[str] = Field(
|
|
353
|
+
default=None,
|
|
354
|
+
description="License liability. Optional and not always applicable.",
|
|
355
|
+
)
|
|
356
|
+
location: Optional[str] = Field(
|
|
357
|
+
default=None,
|
|
358
|
+
description=(
|
|
359
|
+
"License Location. E.g. Iowa state. This is essential to understand the industry and "
|
|
360
|
+
"privacy location specific rules that applies to the data. Optional and not always applicable."
|
|
361
|
+
),
|
|
362
|
+
)
|
|
363
|
+
notes: Optional[str] = Field(
|
|
364
|
+
default=None,
|
|
365
|
+
description="Additional license notes. Optional and not always applicable.",
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
class Attribution(BaseModel):
|
|
370
|
+
"""Attribution information for the image: Giving source and credit to the original creator"""
|
|
371
|
+
|
|
372
|
+
title: Optional[str] = Field(default=None, description="Title of the image", max_length=255)
|
|
373
|
+
creator: Optional[str] = Field(default=None, description="Creator of the image", max_length=255)
|
|
374
|
+
creator_url: Optional[str] = Field(default=None, description="URL of the creator", max_length=255)
|
|
375
|
+
date_captured: Optional[datetime] = Field(default=None, description="Date when the image was captured")
|
|
376
|
+
copyright_notice: Optional[str] = Field(default=None, description="Copyright notice for the image", max_length=255)
|
|
377
|
+
licenses: Optional[List[License]] = Field(default=None, description="List of licenses for the image")
|
|
378
|
+
disclaimer: Optional[str] = Field(default=None, description="Disclaimer for the image", max_length=255)
|
|
379
|
+
changes: Optional[str] = Field(default=None, description="Changes made to the image", max_length=255)
|
|
380
|
+
source_url: Optional[str] = Field(default=None, description="Source URL for the image", max_length=255)
|
|
381
|
+
|
|
382
|
+
|
|
161
383
|
@dataclass
|
|
162
384
|
class HafniaDataset:
|
|
163
385
|
info: DatasetInfo
|
|
164
386
|
samples: pl.DataFrame
|
|
165
387
|
|
|
388
|
+
# Function mapping: Dataset stats
|
|
389
|
+
split_counts = dataset_stats.split_counts
|
|
390
|
+
class_counts_for_task = dataset_stats.class_counts_for_task
|
|
391
|
+
class_counts_all = dataset_stats.class_counts_all
|
|
392
|
+
|
|
393
|
+
# Function mapping: Print stats
|
|
394
|
+
print_stats = dataset_stats.print_stats
|
|
395
|
+
print_sample_and_task_counts = dataset_stats.print_sample_and_task_counts
|
|
396
|
+
print_class_distribution = dataset_stats.print_class_distribution
|
|
397
|
+
|
|
398
|
+
# Function mapping: Dataset checks
|
|
399
|
+
check_dataset = dataset_stats.check_dataset
|
|
400
|
+
check_dataset_tasks = dataset_stats.check_dataset_tasks
|
|
401
|
+
|
|
402
|
+
# Function mapping: Dataset transformations
|
|
403
|
+
transform_images = dataset_transformations.transform_images
|
|
404
|
+
|
|
166
405
|
def __getitem__(self, item: int) -> Dict[str, Any]:
|
|
167
406
|
return self.samples.row(index=item, named=True)
|
|
168
407
|
|
|
@@ -173,6 +412,18 @@ class HafniaDataset:
|
|
|
173
412
|
for row in self.samples.iter_rows(named=True):
|
|
174
413
|
yield row
|
|
175
414
|
|
|
415
|
+
def __post_init__(self):
|
|
416
|
+
samples = self.samples
|
|
417
|
+
if ColumnName.SAMPLE_INDEX not in samples.columns:
|
|
418
|
+
samples = samples.with_row_index(name=ColumnName.SAMPLE_INDEX)
|
|
419
|
+
|
|
420
|
+
# Backwards compatibility: If tags-column doesn't exist, create it with empty lists
|
|
421
|
+
if ColumnName.TAGS not in samples.columns:
|
|
422
|
+
tags_column: List[List[str]] = [[] for _ in range(len(self))] # type: ignore[annotation-unchecked]
|
|
423
|
+
samples = samples.with_columns(pl.Series(tags_column, dtype=pl.List(pl.String)).alias(ColumnName.TAGS))
|
|
424
|
+
|
|
425
|
+
self.samples = samples
|
|
426
|
+
|
|
176
427
|
@staticmethod
|
|
177
428
|
def from_path(path_folder: Path, check_for_images: bool = True) -> "HafniaDataset":
|
|
178
429
|
HafniaDataset.check_dataset_path(path_folder, raise_error=True)
|
|
@@ -192,11 +443,12 @@ class HafniaDataset:
|
|
|
192
443
|
"""
|
|
193
444
|
Load a dataset by its name. The dataset must be registered in the Hafnia platform.
|
|
194
445
|
"""
|
|
195
|
-
from hafnia.dataset.hafnia_dataset import HafniaDataset
|
|
196
446
|
from hafnia.platform.datasets import download_or_get_dataset_path
|
|
197
447
|
|
|
198
448
|
dataset_path = download_or_get_dataset_path(
|
|
199
|
-
dataset_name=name,
|
|
449
|
+
dataset_name=name,
|
|
450
|
+
force_redownload=force_redownload,
|
|
451
|
+
download_files=download_files,
|
|
200
452
|
)
|
|
201
453
|
return HafniaDataset.from_path(dataset_path, check_for_images=download_files)
|
|
202
454
|
|
|
@@ -210,9 +462,8 @@ class HafniaDataset:
|
|
|
210
462
|
else:
|
|
211
463
|
raise TypeError(f"Unsupported sample type: {type(sample)}. Expected Sample or dict.")
|
|
212
464
|
|
|
213
|
-
table = pl.from_records(json_samples)
|
|
214
|
-
table = table.with_row_index(name=ColumnName.SAMPLE_INDEX)
|
|
215
|
-
|
|
465
|
+
table = pl.from_records(json_samples)
|
|
466
|
+
table = table.drop(ColumnName.SAMPLE_INDEX).with_row_index(name=ColumnName.SAMPLE_INDEX)
|
|
216
467
|
return HafniaDataset(info=info, samples=table)
|
|
217
468
|
|
|
218
469
|
@staticmethod
|
|
@@ -241,7 +492,11 @@ class HafniaDataset:
|
|
|
241
492
|
If the dataset is already cached, it will be loaded from the cache.
|
|
242
493
|
"""
|
|
243
494
|
|
|
244
|
-
path_dataset = get_or_create_dataset_path_from_recipe(
|
|
495
|
+
path_dataset = get_or_create_dataset_path_from_recipe(
|
|
496
|
+
dataset_recipe,
|
|
497
|
+
path_datasets=path_datasets,
|
|
498
|
+
force_redownload=force_redownload,
|
|
499
|
+
)
|
|
245
500
|
return HafniaDataset.from_path(path_dataset, check_for_images=False)
|
|
246
501
|
|
|
247
502
|
@staticmethod
|
|
@@ -263,20 +518,24 @@ class HafniaDataset:
|
|
|
263
518
|
merged_dataset = HafniaDataset.merge(merged_dataset, dataset)
|
|
264
519
|
return merged_dataset
|
|
265
520
|
|
|
266
|
-
# Dataset transformations
|
|
267
|
-
transform_images = dataset_transformations.transform_images
|
|
268
|
-
|
|
269
521
|
def shuffle(dataset: HafniaDataset, seed: int = 42) -> HafniaDataset:
|
|
270
522
|
table = dataset.samples.sample(n=len(dataset), with_replacement=False, seed=seed, shuffle=True)
|
|
271
|
-
return dataset.
|
|
523
|
+
return dataset.update_samples(table)
|
|
272
524
|
|
|
273
525
|
def select_samples(
|
|
274
|
-
dataset: "HafniaDataset",
|
|
526
|
+
dataset: "HafniaDataset",
|
|
527
|
+
n_samples: int,
|
|
528
|
+
shuffle: bool = True,
|
|
529
|
+
seed: int = 42,
|
|
530
|
+
with_replacement: bool = False,
|
|
275
531
|
) -> "HafniaDataset":
|
|
532
|
+
"""
|
|
533
|
+
Create a new dataset with a subset of samples.
|
|
534
|
+
"""
|
|
276
535
|
if not with_replacement:
|
|
277
536
|
n_samples = min(n_samples, len(dataset))
|
|
278
537
|
table = dataset.samples.sample(n=n_samples, with_replacement=with_replacement, seed=seed, shuffle=shuffle)
|
|
279
|
-
return dataset.
|
|
538
|
+
return dataset.update_samples(table)
|
|
280
539
|
|
|
281
540
|
def splits_by_ratios(dataset: "HafniaDataset", split_ratios: Dict[str, float], seed: int = 42) -> "HafniaDataset":
|
|
282
541
|
"""
|
|
@@ -295,7 +554,7 @@ class HafniaDataset:
|
|
|
295
554
|
split_ratios=split_ratios, n_items=n_items, seed=seed
|
|
296
555
|
)
|
|
297
556
|
table = dataset.samples.with_columns(pl.Series(split_name_column).alias("split"))
|
|
298
|
-
return dataset.
|
|
557
|
+
return dataset.update_samples(table)
|
|
299
558
|
|
|
300
559
|
def split_into_multiple_splits(
|
|
301
560
|
dataset: "HafniaDataset",
|
|
@@ -323,33 +582,124 @@ class HafniaDataset:
|
|
|
323
582
|
|
|
324
583
|
remaining_data = dataset.samples.filter(pl.col(ColumnName.SPLIT).is_in([split_name]).not_())
|
|
325
584
|
new_table = pl.concat([remaining_data, dataset_split_to_be_divided.samples], how="vertical")
|
|
326
|
-
dataset_new = dataset.
|
|
585
|
+
dataset_new = dataset.update_samples(new_table)
|
|
327
586
|
return dataset_new
|
|
328
587
|
|
|
329
588
|
def define_sample_set_by_size(dataset: "HafniaDataset", n_samples: int, seed: int = 42) -> "HafniaDataset":
|
|
589
|
+
"""
|
|
590
|
+
Defines a sample set randomly by selecting 'n_samples' samples from the dataset.
|
|
591
|
+
"""
|
|
592
|
+
samples = dataset.samples
|
|
593
|
+
|
|
594
|
+
# Remove any pre-existing "sample"-tags
|
|
595
|
+
samples = samples.with_columns(
|
|
596
|
+
pl.col(ColumnName.TAGS).list.eval(pl.element().filter(pl.element() != TAG_IS_SAMPLE)).alias(ColumnName.TAGS)
|
|
597
|
+
)
|
|
598
|
+
|
|
599
|
+
# Add "sample" to tags column for the selected samples
|
|
330
600
|
is_sample_indices = Random(seed).sample(range(len(dataset)), n_samples)
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
601
|
+
samples = samples.with_columns(
|
|
602
|
+
pl.when(pl.int_range(len(samples)).is_in(is_sample_indices))
|
|
603
|
+
.then(pl.col(ColumnName.TAGS).list.concat(pl.lit([TAG_IS_SAMPLE])))
|
|
604
|
+
.otherwise(pl.col(ColumnName.TAGS))
|
|
605
|
+
)
|
|
606
|
+
return dataset.update_samples(samples)
|
|
607
|
+
|
|
608
|
+
def class_mapper(
|
|
609
|
+
dataset: "HafniaDataset",
|
|
610
|
+
class_mapping: Dict[str, str],
|
|
611
|
+
method: str = "strict",
|
|
612
|
+
primitive: Optional[Type[Primitive]] = None,
|
|
613
|
+
task_name: Optional[str] = None,
|
|
614
|
+
) -> "HafniaDataset":
|
|
615
|
+
"""
|
|
616
|
+
Map class names to new class names using a strict mapping.
|
|
617
|
+
A strict mapping means that all class names in the dataset must be mapped to a new class name.
|
|
618
|
+
If a class name is not mapped, an error is raised.
|
|
619
|
+
|
|
620
|
+
The class indices are determined by the order of appearance of the new class names in the mapping.
|
|
621
|
+
Duplicates in the new class names are removed, preserving the order of first appearance.
|
|
622
|
+
|
|
623
|
+
E.g.
|
|
624
|
+
|
|
625
|
+
mnist = HafniaDataset.from_name("mnist")
|
|
626
|
+
strict_class_mapping = {
|
|
627
|
+
"1 - one": "odd", # 'odd' appears first and becomes class index 0
|
|
628
|
+
"3 - three": "odd",
|
|
629
|
+
"5 - five": "odd",
|
|
630
|
+
"7 - seven": "odd",
|
|
631
|
+
"9 - nine": "odd",
|
|
632
|
+
"0 - zero": "even", # 'even' appears second and becomes class index 1
|
|
633
|
+
"2 - two": "even",
|
|
634
|
+
"4 - four": "even",
|
|
635
|
+
"6 - six": "even",
|
|
636
|
+
"8 - eight": "even",
|
|
637
|
+
}
|
|
638
|
+
|
|
639
|
+
dataset_new = class_mapper(dataset=mnist, class_mapping=strict_class_mapping)
|
|
334
640
|
|
|
335
|
-
|
|
336
|
-
return
|
|
641
|
+
"""
|
|
642
|
+
return dataset_transformations.class_mapper(
|
|
643
|
+
dataset=dataset,
|
|
644
|
+
class_mapping=class_mapping,
|
|
645
|
+
method=method,
|
|
646
|
+
primitive=primitive,
|
|
647
|
+
task_name=task_name,
|
|
648
|
+
)
|
|
649
|
+
|
|
650
|
+
def rename_task(
|
|
651
|
+
dataset: "HafniaDataset",
|
|
652
|
+
old_task_name: str,
|
|
653
|
+
new_task_name: str,
|
|
654
|
+
) -> "HafniaDataset":
|
|
655
|
+
"""
|
|
656
|
+
Rename a task in the dataset.
|
|
657
|
+
"""
|
|
658
|
+
return dataset_transformations.rename_task(
|
|
659
|
+
dataset=dataset, old_task_name=old_task_name, new_task_name=new_task_name
|
|
660
|
+
)
|
|
661
|
+
|
|
662
|
+
def select_samples_by_class_name(
|
|
663
|
+
dataset: HafniaDataset,
|
|
664
|
+
name: Union[List[str], str],
|
|
665
|
+
task_name: Optional[str] = None,
|
|
666
|
+
primitive: Optional[Type[Primitive]] = None,
|
|
667
|
+
) -> HafniaDataset:
|
|
668
|
+
"""
|
|
669
|
+
Select samples that contain at least one annotation with the specified class name(s).
|
|
670
|
+
If 'task_name' and 'primitive' are not provided, the function will attempt to infer the task.
|
|
671
|
+
"""
|
|
672
|
+
return dataset_transformations.select_samples_by_class_name(
|
|
673
|
+
dataset=dataset, name=name, task_name=task_name, primitive=primitive
|
|
674
|
+
)
|
|
337
675
|
|
|
338
676
|
def merge(dataset0: "HafniaDataset", dataset1: "HafniaDataset") -> "HafniaDataset":
|
|
339
677
|
"""
|
|
340
678
|
Merges two Hafnia datasets by concatenating their samples and updating the split names.
|
|
341
679
|
"""
|
|
342
|
-
## Currently, only a very naive merging is implemented.
|
|
343
|
-
# In the future we need to verify that the class and tasks are compatible.
|
|
344
|
-
# Do they have similar classes and tasks? What to do if they don't?
|
|
345
|
-
# For now, we just concatenate the samples and keep the split names as they are.
|
|
346
|
-
merged_samples = pl.concat([dataset0.samples, dataset1.samples], how="vertical")
|
|
347
|
-
return dataset0.update_table(merged_samples)
|
|
348
680
|
|
|
349
|
-
|
|
350
|
-
|
|
681
|
+
# Merges dataset info and checks for compatibility
|
|
682
|
+
merged_info = DatasetInfo.merge(dataset0.info, dataset1.info)
|
|
683
|
+
|
|
684
|
+
# Merges samples tables (removes incompatible columns)
|
|
685
|
+
merged_samples = table_transformations.merge_samples(samples0=dataset0.samples, samples1=dataset1.samples)
|
|
686
|
+
|
|
687
|
+
# Check if primitives have been removed during the merge_samples
|
|
688
|
+
for task in copy.deepcopy(merged_info.tasks):
|
|
689
|
+
if task.primitive.column_name() not in merged_samples.columns:
|
|
690
|
+
user_logger.warning(
|
|
691
|
+
f"Task '{task.name}' with primitive '{task.primitive.__name__}' has been removed during the merge. "
|
|
692
|
+
"This happens if the two datasets do not have the same primitives."
|
|
693
|
+
)
|
|
694
|
+
merged_info = merged_info.replace_task(old_task=task, new_task=None)
|
|
695
|
+
|
|
696
|
+
return HafniaDataset(info=merged_info, samples=merged_samples)
|
|
351
697
|
|
|
352
698
|
def as_dict_dataset_splits(self) -> Dict[str, "HafniaDataset"]:
|
|
699
|
+
"""
|
|
700
|
+
Splits the dataset into multiple datasets based on the 'split' column.
|
|
701
|
+
Returns a dictionary with split names as keys and HafniaDataset objects as values.
|
|
702
|
+
"""
|
|
353
703
|
if ColumnName.SPLIT not in self.samples.columns:
|
|
354
704
|
raise ValueError(f"Dataset must contain a '{ColumnName.SPLIT}' column.")
|
|
355
705
|
|
|
@@ -360,10 +710,22 @@ class HafniaDataset:
|
|
|
360
710
|
return splits
|
|
361
711
|
|
|
362
712
|
def create_sample_dataset(self) -> "HafniaDataset":
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
713
|
+
# Backwards compatibility. Remove in future versions when dataset have been updated
|
|
714
|
+
if "is_sample" in self.samples.columns:
|
|
715
|
+
user_logger.warning(
|
|
716
|
+
"'is_sample' column found in the dataset. This column is deprecated and will be removed in future versions. "
|
|
717
|
+
"Please use the 'tags' column with the tag 'sample' instead."
|
|
718
|
+
)
|
|
719
|
+
table = self.samples.filter(pl.col("is_sample") == True) # noqa: E712
|
|
720
|
+
return self.update_samples(table)
|
|
721
|
+
|
|
722
|
+
if ColumnName.TAGS not in self.samples.columns:
|
|
723
|
+
raise ValueError(f"Dataset must contain an '{ColumnName.TAGS}' column.")
|
|
724
|
+
|
|
725
|
+
table = self.samples.filter(
|
|
726
|
+
pl.col(ColumnName.TAGS).list.eval(pl.element().filter(pl.element() == TAG_IS_SAMPLE)).list.len() > 0
|
|
727
|
+
)
|
|
728
|
+
return self.update_samples(table)
|
|
367
729
|
|
|
368
730
|
def create_split_dataset(self, split_name: Union[str | List[str]]) -> "HafniaDataset":
|
|
369
731
|
if isinstance(split_name, str):
|
|
@@ -376,16 +738,12 @@ class HafniaDataset:
|
|
|
376
738
|
raise ValueError(f"Invalid split name: {split_name}. Valid splits are: {SplitName.valid_splits()}")
|
|
377
739
|
|
|
378
740
|
filtered_dataset = self.samples.filter(pl.col(ColumnName.SPLIT).is_in(split_names))
|
|
379
|
-
return self.
|
|
741
|
+
return self.update_samples(filtered_dataset)
|
|
380
742
|
|
|
381
|
-
def
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
raise ValueError(f"Task with name {task_name} not found in dataset info.")
|
|
386
|
-
|
|
387
|
-
def update_table(self, table: pl.DataFrame) -> "HafniaDataset":
|
|
388
|
-
return HafniaDataset(info=self.info.model_copy(), samples=table)
|
|
743
|
+
def update_samples(self, table: pl.DataFrame) -> "HafniaDataset":
|
|
744
|
+
dataset = HafniaDataset(info=self.info.model_copy(deep=True), samples=table)
|
|
745
|
+
dataset.check_dataset_tasks()
|
|
746
|
+
return dataset
|
|
389
747
|
|
|
390
748
|
@staticmethod
|
|
391
749
|
def check_dataset_path(path_dataset: Path, raise_error: bool = True) -> bool:
|
|
@@ -411,7 +769,10 @@ class HafniaDataset:
|
|
|
411
769
|
|
|
412
770
|
return True
|
|
413
771
|
|
|
414
|
-
def
|
|
772
|
+
def copy(self) -> "HafniaDataset":
|
|
773
|
+
return HafniaDataset(info=self.info.model_copy(deep=True), samples=self.samples.clone())
|
|
774
|
+
|
|
775
|
+
def write(self, path_folder: Path, add_version: bool = False, drop_null_cols: bool = True) -> None:
|
|
415
776
|
user_logger.info(f"Writing dataset to {path_folder}...")
|
|
416
777
|
if not path_folder.exists():
|
|
417
778
|
path_folder.mkdir(parents=True)
|
|
@@ -424,6 +785,10 @@ class HafniaDataset:
|
|
|
424
785
|
)
|
|
425
786
|
new_relative_paths.append(str(new_path.relative_to(path_folder)))
|
|
426
787
|
table = self.samples.with_columns(pl.Series(new_relative_paths).alias("file_name"))
|
|
788
|
+
|
|
789
|
+
if drop_null_cols: # Drops all unused/Null columns
|
|
790
|
+
table = table.drop(pl.selectors.by_dtype(pl.Null))
|
|
791
|
+
|
|
427
792
|
table.write_ndjson(path_folder / FILENAME_ANNOTATIONS_JSONL) # Json for readability
|
|
428
793
|
table.write_parquet(path_folder / FILENAME_ANNOTATIONS_PARQUET) # Parquet for speed
|
|
429
794
|
self.info.write_json(path_folder / FILENAME_DATASET_INFO)
|
|
@@ -448,51 +813,10 @@ class HafniaDataset:
|
|
|
448
813
|
return False
|
|
449
814
|
return True
|
|
450
815
|
|
|
451
|
-
def print_stats(self) -> None:
|
|
452
|
-
table_base = Table(title="Dataset Statistics", show_lines=True, box=rich.box.SIMPLE)
|
|
453
|
-
table_base.add_column("Property", style="cyan")
|
|
454
|
-
table_base.add_column("Value")
|
|
455
|
-
table_base.add_row("Dataset Name", self.info.dataset_name)
|
|
456
|
-
table_base.add_row("Version", self.info.version)
|
|
457
|
-
table_base.add_row("Number of samples", str(len(self.samples)))
|
|
458
|
-
rprint(table_base)
|
|
459
|
-
rprint(self.info.tasks)
|
|
460
|
-
|
|
461
|
-
splits_sets = {
|
|
462
|
-
"All": SplitName.valid_splits(),
|
|
463
|
-
"Train": [SplitName.TRAIN],
|
|
464
|
-
"Validation": [SplitName.VAL],
|
|
465
|
-
"Test": [SplitName.TEST],
|
|
466
|
-
}
|
|
467
|
-
rows = []
|
|
468
|
-
for split_name, splits in splits_sets.items():
|
|
469
|
-
dataset_split = self.create_split_dataset(splits)
|
|
470
|
-
table = dataset_split.samples
|
|
471
|
-
row = {}
|
|
472
|
-
row["Split"] = split_name
|
|
473
|
-
row["Sample "] = str(len(table))
|
|
474
|
-
for PrimitiveType in PRIMITIVE_TYPES:
|
|
475
|
-
column_name = PrimitiveType.column_name()
|
|
476
|
-
objects_df = create_primitive_table(table, PrimitiveType=PrimitiveType, keep_sample_data=False)
|
|
477
|
-
if objects_df is None:
|
|
478
|
-
continue
|
|
479
|
-
for (task_name,), object_group in objects_df.group_by(FieldName.TASK_NAME):
|
|
480
|
-
count = len(object_group[FieldName.CLASS_NAME])
|
|
481
|
-
row[f"{PrimitiveType.__name__}\n{task_name}"] = str(count)
|
|
482
|
-
rows.append(row)
|
|
483
|
-
|
|
484
|
-
rich_table = Table(title="Dataset Statistics", show_lines=True, box=rich.box.SIMPLE)
|
|
485
|
-
for i_row, row in enumerate(rows):
|
|
486
|
-
if i_row == 0:
|
|
487
|
-
for column_name in row.keys():
|
|
488
|
-
rich_table.add_column(column_name, justify="left", style="cyan")
|
|
489
|
-
rich_table.add_row(*[str(value) for value in row.values()])
|
|
490
|
-
rprint(rich_table)
|
|
491
|
-
|
|
492
816
|
|
|
493
817
|
def check_hafnia_dataset_from_path(path_dataset: Path) -> None:
|
|
494
818
|
dataset = HafniaDataset.from_path(path_dataset, check_for_images=True)
|
|
495
|
-
|
|
819
|
+
dataset.check_dataset()
|
|
496
820
|
|
|
497
821
|
|
|
498
822
|
def get_or_create_dataset_path_from_recipe(
|
|
@@ -522,89 +846,3 @@ def get_or_create_dataset_path_from_recipe(
|
|
|
522
846
|
dataset.write(path_dataset)
|
|
523
847
|
|
|
524
848
|
return path_dataset
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
def check_hafnia_dataset(dataset: HafniaDataset):
|
|
528
|
-
user_logger.info("Checking Hafnia dataset...")
|
|
529
|
-
assert isinstance(dataset.info.version, str) and len(dataset.info.version) > 0
|
|
530
|
-
assert isinstance(dataset.info.dataset_name, str) and len(dataset.info.dataset_name) > 0
|
|
531
|
-
|
|
532
|
-
is_sample_list = set(dataset.samples.select(pl.col(ColumnName.IS_SAMPLE)).unique().to_series().to_list())
|
|
533
|
-
if True not in is_sample_list:
|
|
534
|
-
raise ValueError(f"The dataset should contain '{ColumnName.IS_SAMPLE}=True' samples")
|
|
535
|
-
|
|
536
|
-
actual_splits = dataset.samples.select(pl.col(ColumnName.SPLIT)).unique().to_series().to_list()
|
|
537
|
-
expected_splits = SplitName.valid_splits()
|
|
538
|
-
if set(actual_splits) != set(expected_splits):
|
|
539
|
-
raise ValueError(f"Expected all splits '{expected_splits}' in dataset, but got '{actual_splits}'. ")
|
|
540
|
-
|
|
541
|
-
expected_tasks = dataset.info.tasks
|
|
542
|
-
for task in expected_tasks:
|
|
543
|
-
primitive = task.primitive.__name__
|
|
544
|
-
column_name = task.primitive.column_name()
|
|
545
|
-
primitive_column = dataset.samples[column_name]
|
|
546
|
-
# msg_something_wrong = f"Something is wrong with the '{primtive_name}' task '{task.name}' in dataset '{dataset.name}'. "
|
|
547
|
-
msg_something_wrong = (
|
|
548
|
-
f"Something is wrong with the defined tasks ('info.tasks') in dataset '{dataset.info.dataset_name}'. \n"
|
|
549
|
-
f"For '{primitive=}' and '{task.name=}' "
|
|
550
|
-
)
|
|
551
|
-
if primitive_column.dtype == pl.Null:
|
|
552
|
-
raise ValueError(msg_something_wrong + "the column is 'Null'. Please check the dataset.")
|
|
553
|
-
|
|
554
|
-
primitive_table = primitive_column.explode().struct.unnest().filter(pl.col(FieldName.TASK_NAME) == task.name)
|
|
555
|
-
if primitive_table.is_empty():
|
|
556
|
-
raise ValueError(
|
|
557
|
-
msg_something_wrong
|
|
558
|
-
+ f"the column '{column_name}' has no {task.name=} objects. Please check the dataset."
|
|
559
|
-
)
|
|
560
|
-
|
|
561
|
-
actual_classes = set(primitive_table[FieldName.CLASS_NAME].unique().to_list())
|
|
562
|
-
if task.class_names is None:
|
|
563
|
-
raise ValueError(
|
|
564
|
-
msg_something_wrong
|
|
565
|
-
+ f"the column '{column_name}' with {task.name=} has no defined classes. Please check the dataset."
|
|
566
|
-
)
|
|
567
|
-
defined_classes = set(task.class_names)
|
|
568
|
-
|
|
569
|
-
if not actual_classes.issubset(defined_classes):
|
|
570
|
-
raise ValueError(
|
|
571
|
-
msg_something_wrong
|
|
572
|
-
+ f"the column '{column_name}' with {task.name=} we expected the actual classes in the dataset to \n"
|
|
573
|
-
f"to be a subset of the defined classes\n\t{actual_classes=} \n\t{defined_classes=}."
|
|
574
|
-
)
|
|
575
|
-
# Check class_indices
|
|
576
|
-
mapped_indices = primitive_table[FieldName.CLASS_NAME].map_elements(
|
|
577
|
-
lambda x: task.class_names.index(x), return_dtype=pl.Int64
|
|
578
|
-
)
|
|
579
|
-
table_indices = primitive_table[FieldName.CLASS_IDX]
|
|
580
|
-
|
|
581
|
-
error_msg = msg_something_wrong + (
|
|
582
|
-
f"class indices in '{FieldName.CLASS_IDX}' column does not match classes ordering in 'task.class_names'"
|
|
583
|
-
)
|
|
584
|
-
assert mapped_indices.equals(table_indices), error_msg
|
|
585
|
-
|
|
586
|
-
distribution = dataset.info.distributions or []
|
|
587
|
-
distribution_names = [task.name for task in distribution]
|
|
588
|
-
# Check that tasks found in the 'dataset.table' matches the tasks defined in 'dataset.info.tasks'
|
|
589
|
-
for PrimitiveType in PRIMITIVE_TYPES:
|
|
590
|
-
column_name = PrimitiveType.column_name()
|
|
591
|
-
if column_name not in dataset.samples.columns:
|
|
592
|
-
continue
|
|
593
|
-
objects_df = create_primitive_table(dataset.samples, PrimitiveType=PrimitiveType, keep_sample_data=False)
|
|
594
|
-
if objects_df is None:
|
|
595
|
-
continue
|
|
596
|
-
for (task_name,), object_group in objects_df.group_by(FieldName.TASK_NAME):
|
|
597
|
-
has_task = any([t for t in expected_tasks if t.name == task_name and t.primitive == PrimitiveType])
|
|
598
|
-
if has_task:
|
|
599
|
-
continue
|
|
600
|
-
if task_name in distribution_names:
|
|
601
|
-
continue
|
|
602
|
-
class_names = object_group[FieldName.CLASS_NAME].unique().to_list()
|
|
603
|
-
raise ValueError(
|
|
604
|
-
f"Task name '{task_name}' for the '{PrimitiveType.__name__}' primitive is missing in "
|
|
605
|
-
f"'dataset.info.tasks' for dataset '{task_name}'. Missing task has the following "
|
|
606
|
-
f"classes: {class_names}. "
|
|
607
|
-
)
|
|
608
|
-
|
|
609
|
-
for sample_dict in tqdm(dataset, desc="Checking samples in dataset"):
|
|
610
|
-
sample = Sample(**sample_dict) # Checks format of all samples with pydantic validation # noqa: F841
|