hafnia 0.2.4__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/__main__.py +16 -3
- cli/config.py +45 -4
- cli/consts.py +1 -1
- cli/dataset_cmds.py +6 -14
- cli/dataset_recipe_cmds.py +78 -0
- cli/experiment_cmds.py +226 -43
- cli/keychain.py +88 -0
- cli/profile_cmds.py +10 -6
- cli/runc_cmds.py +5 -5
- cli/trainer_package_cmds.py +65 -0
- hafnia/__init__.py +2 -0
- hafnia/data/factory.py +1 -2
- hafnia/dataset/dataset_helpers.py +9 -14
- hafnia/dataset/dataset_names.py +10 -5
- hafnia/dataset/dataset_recipe/dataset_recipe.py +165 -67
- hafnia/dataset/dataset_recipe/recipe_transforms.py +48 -4
- hafnia/dataset/dataset_recipe/recipe_types.py +1 -1
- hafnia/dataset/dataset_upload_helper.py +265 -56
- hafnia/dataset/format_conversions/image_classification_from_directory.py +106 -0
- hafnia/dataset/format_conversions/torchvision_datasets.py +281 -0
- hafnia/dataset/hafnia_dataset.py +577 -213
- hafnia/dataset/license_types.py +63 -0
- hafnia/dataset/operations/dataset_stats.py +259 -3
- hafnia/dataset/operations/dataset_transformations.py +332 -7
- hafnia/dataset/operations/table_transformations.py +43 -5
- hafnia/dataset/primitives/__init__.py +8 -0
- hafnia/dataset/primitives/bbox.py +25 -12
- hafnia/dataset/primitives/bitmask.py +26 -14
- hafnia/dataset/primitives/classification.py +16 -8
- hafnia/dataset/primitives/point.py +7 -3
- hafnia/dataset/primitives/polygon.py +16 -9
- hafnia/dataset/primitives/segmentation.py +10 -7
- hafnia/experiment/hafnia_logger.py +111 -8
- hafnia/http.py +16 -2
- hafnia/platform/__init__.py +9 -3
- hafnia/platform/builder.py +12 -10
- hafnia/platform/dataset_recipe.py +104 -0
- hafnia/platform/datasets.py +47 -9
- hafnia/platform/download.py +25 -19
- hafnia/platform/experiment.py +51 -56
- hafnia/platform/trainer_package.py +57 -0
- hafnia/utils.py +81 -13
- hafnia/visualizations/image_visualizations.py +4 -4
- {hafnia-0.2.4.dist-info → hafnia-0.4.0.dist-info}/METADATA +40 -34
- hafnia-0.4.0.dist-info/RECORD +56 -0
- cli/recipe_cmds.py +0 -45
- hafnia-0.2.4.dist-info/RECORD +0 -49
- {hafnia-0.2.4.dist-info → hafnia-0.4.0.dist-info}/WHEEL +0 -0
- {hafnia-0.2.4.dist-info → hafnia-0.4.0.dist-info}/entry_points.txt +0 -0
- {hafnia-0.2.4.dist-info → hafnia-0.4.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
from typing import List, Optional
|
|
2
|
+
|
|
3
|
+
from hafnia.dataset.hafnia_dataset import License
|
|
4
|
+
|
|
5
|
+
LICENSE_TYPES: List[License] = [
|
|
6
|
+
License(
|
|
7
|
+
name="Creative Commons: Attribution-NonCommercial-ShareAlike 2.0 Generic",
|
|
8
|
+
name_short="CC BY-NC-SA 2.0",
|
|
9
|
+
url="https://creativecommons.org/licenses/by-nc-sa/2.0/",
|
|
10
|
+
),
|
|
11
|
+
License(
|
|
12
|
+
name="Creative Commons: Attribution-NonCommercial 2.0 Generic",
|
|
13
|
+
name_short="CC BY-NC 2.0",
|
|
14
|
+
url="https://creativecommons.org/licenses/by-nc/2.0/",
|
|
15
|
+
),
|
|
16
|
+
License(
|
|
17
|
+
name="Creative Commons: Attribution-NonCommercial-NoDerivs 2.0 Generic",
|
|
18
|
+
name_short="CC BY-NC-ND 2.0",
|
|
19
|
+
url="https://creativecommons.org/licenses/by-nc-nd/2.0/",
|
|
20
|
+
),
|
|
21
|
+
License(
|
|
22
|
+
name="Creative Commons: Attribution 2.0 Generic",
|
|
23
|
+
name_short="CC BY 2.0",
|
|
24
|
+
url="https://creativecommons.org/licenses/by/2.0/",
|
|
25
|
+
),
|
|
26
|
+
License(
|
|
27
|
+
name="Creative Commons: Attribution-ShareAlike 2.0 Generic",
|
|
28
|
+
name_short="CC BY-SA 2.0",
|
|
29
|
+
url="https://creativecommons.org/licenses/by-sa/2.0/",
|
|
30
|
+
),
|
|
31
|
+
License(
|
|
32
|
+
name="Creative Commons: Attribution-NoDerivs 2.0 Generic",
|
|
33
|
+
name_short="CC BY-ND 2.0",
|
|
34
|
+
url="https://creativecommons.org/licenses/by-nd/2.0/",
|
|
35
|
+
),
|
|
36
|
+
License(
|
|
37
|
+
name="Flickr: No known copyright restrictions",
|
|
38
|
+
name_short="Flickr",
|
|
39
|
+
url="https://flickr.com/commons/usage/",
|
|
40
|
+
),
|
|
41
|
+
License(
|
|
42
|
+
name="United States Government Work",
|
|
43
|
+
name_short="US Gov",
|
|
44
|
+
url="http://www.usa.gov/copyright.shtml",
|
|
45
|
+
),
|
|
46
|
+
]
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def get_license_by_url(url: str) -> Optional[License]:
|
|
50
|
+
for license in LICENSE_TYPES:
|
|
51
|
+
# To handle http urls
|
|
52
|
+
license_url = (license.url or "").replace("http://", "https://")
|
|
53
|
+
url_https = url.replace("http://", "https://")
|
|
54
|
+
if license_url == url_https:
|
|
55
|
+
return license
|
|
56
|
+
raise ValueError(f"License with URL '{url}' not found.")
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def get_license_by_short_name(short_name: str) -> Optional[License]:
|
|
60
|
+
for license in LICENSE_TYPES:
|
|
61
|
+
if license.name_short == short_name:
|
|
62
|
+
return license
|
|
63
|
+
raise ValueError(f"License with short name '{short_name}' not found.")
|
|
@@ -1,11 +1,21 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from typing import TYPE_CHECKING, Dict
|
|
3
|
+
from typing import TYPE_CHECKING, Dict, Optional, Type
|
|
4
4
|
|
|
5
|
-
|
|
5
|
+
import polars as pl
|
|
6
|
+
import rich
|
|
7
|
+
from rich import print as rprint
|
|
8
|
+
from rich.progress import track
|
|
9
|
+
from rich.table import Table
|
|
6
10
|
|
|
7
|
-
|
|
11
|
+
from hafnia.dataset.dataset_names import ColumnName, FieldName, SplitName
|
|
12
|
+
from hafnia.dataset.operations.table_transformations import create_primitive_table
|
|
13
|
+
from hafnia.dataset.primitives import PRIMITIVE_TYPES
|
|
14
|
+
from hafnia.log import user_logger
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING: # Using 'TYPE_CHECKING' to avoid circular imports during type checking
|
|
8
17
|
from hafnia.dataset.hafnia_dataset import HafniaDataset
|
|
18
|
+
from hafnia.dataset.primitives.primitive import Primitive
|
|
9
19
|
|
|
10
20
|
|
|
11
21
|
def split_counts(dataset: HafniaDataset) -> Dict[str, int]:
|
|
@@ -13,3 +23,249 @@ def split_counts(dataset: HafniaDataset) -> Dict[str, int]:
|
|
|
13
23
|
Returns a dictionary with the counts of samples in each split of the dataset.
|
|
14
24
|
"""
|
|
15
25
|
return dict(dataset.samples[ColumnName.SPLIT].value_counts().iter_rows())
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def class_counts_for_task(
|
|
29
|
+
dataset: HafniaDataset,
|
|
30
|
+
primitive: Optional[Type[Primitive]] = None,
|
|
31
|
+
task_name: Optional[str] = None,
|
|
32
|
+
) -> Dict[str, int]:
|
|
33
|
+
"""
|
|
34
|
+
Determines class name counts for a specific task in the dataset.
|
|
35
|
+
|
|
36
|
+
The counts are returned as a dictionary where keys are class names and values are their respective counts.
|
|
37
|
+
Note that class names with zero counts are included in the dictionary and that
|
|
38
|
+
the order of the dictionary matches the class idx.
|
|
39
|
+
|
|
40
|
+
>>> counts = dataset.class_counts_for_task(primitive=Bbox)
|
|
41
|
+
>>> print(counts)
|
|
42
|
+
{
|
|
43
|
+
'person': 0, # Note: Zero count classes are included to maintain order with class idx
|
|
44
|
+
'car': 500,
|
|
45
|
+
'bicycle': 0,
|
|
46
|
+
'bus': 100,
|
|
47
|
+
'truck': 150,
|
|
48
|
+
'motorcycle': 50,
|
|
49
|
+
}
|
|
50
|
+
"""
|
|
51
|
+
task = dataset.info.get_task_by_task_name_and_primitive(task_name=task_name, primitive=primitive)
|
|
52
|
+
class_counts_df = (
|
|
53
|
+
dataset.samples[task.primitive.column_name()]
|
|
54
|
+
.explode()
|
|
55
|
+
.struct.unnest()
|
|
56
|
+
.filter(pl.col(FieldName.TASK_NAME) == task.name)[FieldName.CLASS_NAME]
|
|
57
|
+
.value_counts()
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# Initialize counts with zero for all classes to ensure zero-count classes are included
|
|
61
|
+
# and to have class names in the order of class idx
|
|
62
|
+
class_counts = {name: 0 for name in task.class_names or []}
|
|
63
|
+
class_counts.update(dict(class_counts_df.iter_rows()))
|
|
64
|
+
|
|
65
|
+
return class_counts
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def class_counts_all(dataset: HafniaDataset) -> Dict[str, int]:
|
|
69
|
+
"""
|
|
70
|
+
Get class counts for all tasks in the dataset.
|
|
71
|
+
The counts are returned as a dictionary where keys are in the format
|
|
72
|
+
'{primitive_name}/{task_name}/{class_name}' and values are their respective counts.
|
|
73
|
+
|
|
74
|
+
Example:
|
|
75
|
+
>>> counts = dataset.class_counts_all()
|
|
76
|
+
>>> print(counts)
|
|
77
|
+
{
|
|
78
|
+
objects/bboxes/car: 500
|
|
79
|
+
objects/bboxes/person: 0
|
|
80
|
+
classifications/weather/sunny: 300
|
|
81
|
+
classifications/weather/rainy: 0
|
|
82
|
+
...
|
|
83
|
+
}
|
|
84
|
+
"""
|
|
85
|
+
class_counts = {}
|
|
86
|
+
for task in dataset.info.tasks:
|
|
87
|
+
if task.class_names is None:
|
|
88
|
+
raise ValueError(f"Task '{task.name}' does not have class names defined.")
|
|
89
|
+
class_counts_task = dataset.class_counts_for_task(primitive=task.primitive, task_name=task.name)
|
|
90
|
+
|
|
91
|
+
for class_idx, (class_name, count) in enumerate(class_counts_task.items()):
|
|
92
|
+
count_name = f"{task.primitive.__name__}/{task.name}/{class_name}"
|
|
93
|
+
class_counts[count_name] = count
|
|
94
|
+
|
|
95
|
+
return class_counts
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def print_stats(dataset: HafniaDataset) -> None:
|
|
99
|
+
"""
|
|
100
|
+
Prints verbose statistics about the dataset, including dataset name, version,
|
|
101
|
+
number of samples, and detailed counts of samples and tasks.
|
|
102
|
+
"""
|
|
103
|
+
table_base = Table(title="Dataset Statistics", show_lines=True, box=rich.box.SIMPLE)
|
|
104
|
+
table_base.add_column("Property", style="cyan")
|
|
105
|
+
table_base.add_column("Value")
|
|
106
|
+
table_base.add_row("Dataset Name", dataset.info.dataset_name)
|
|
107
|
+
table_base.add_row("Version", dataset.info.version)
|
|
108
|
+
table_base.add_row("Number of samples", str(len(dataset.samples)))
|
|
109
|
+
rprint(table_base)
|
|
110
|
+
print_sample_and_task_counts(dataset)
|
|
111
|
+
print_class_distribution(dataset)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def print_class_distribution(dataset: HafniaDataset) -> None:
|
|
115
|
+
"""
|
|
116
|
+
Prints the class distribution for each task in the dataset.
|
|
117
|
+
"""
|
|
118
|
+
for task in dataset.info.tasks:
|
|
119
|
+
if task.class_names is None:
|
|
120
|
+
raise ValueError(f"Task '{task.name}' does not have class names defined.")
|
|
121
|
+
class_counts = dataset.class_counts_for_task(primitive=task.primitive, task_name=task.name)
|
|
122
|
+
|
|
123
|
+
# Print class distribution
|
|
124
|
+
rich_table = Table(title=f"Class Count for '{task.primitive.__name__}/{task.name}'", show_lines=False)
|
|
125
|
+
rich_table.add_column("Class Name", style="cyan")
|
|
126
|
+
rich_table.add_column("Class Idx", style="cyan")
|
|
127
|
+
rich_table.add_column("Count", justify="right")
|
|
128
|
+
for class_name, count in class_counts.items():
|
|
129
|
+
class_idx = task.class_names.index(class_name) # Get class idx from task info
|
|
130
|
+
rich_table.add_row(class_name, str(class_idx), str(count))
|
|
131
|
+
rprint(rich_table)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def print_sample_and_task_counts(dataset: HafniaDataset) -> None:
|
|
135
|
+
"""
|
|
136
|
+
Prints a table with sample counts and task counts for each primitive type
|
|
137
|
+
in total and for each split (train, val, test).
|
|
138
|
+
"""
|
|
139
|
+
from hafnia.dataset.operations.table_transformations import create_primitive_table
|
|
140
|
+
from hafnia.dataset.primitives import PRIMITIVE_TYPES
|
|
141
|
+
|
|
142
|
+
splits_sets = {
|
|
143
|
+
"All": SplitName.valid_splits(),
|
|
144
|
+
"Train": [SplitName.TRAIN],
|
|
145
|
+
"Validation": [SplitName.VAL],
|
|
146
|
+
"Test": [SplitName.TEST],
|
|
147
|
+
}
|
|
148
|
+
rows = []
|
|
149
|
+
for split_name, splits in splits_sets.items():
|
|
150
|
+
dataset_split = dataset.create_split_dataset(splits)
|
|
151
|
+
table = dataset_split.samples
|
|
152
|
+
row = {}
|
|
153
|
+
row["Split"] = split_name
|
|
154
|
+
row["Sample "] = str(len(table))
|
|
155
|
+
for PrimitiveType in PRIMITIVE_TYPES:
|
|
156
|
+
column_name = PrimitiveType.column_name()
|
|
157
|
+
objects_df = create_primitive_table(table, PrimitiveType=PrimitiveType, keep_sample_data=False)
|
|
158
|
+
if objects_df is None:
|
|
159
|
+
continue
|
|
160
|
+
for (task_name,), object_group in objects_df.group_by(FieldName.TASK_NAME):
|
|
161
|
+
count = len(object_group[FieldName.CLASS_NAME])
|
|
162
|
+
row[f"{PrimitiveType.__name__}\n{task_name}"] = str(count)
|
|
163
|
+
rows.append(row)
|
|
164
|
+
|
|
165
|
+
rich_table = Table(title="Dataset Statistics", show_lines=True, box=rich.box.SIMPLE)
|
|
166
|
+
for i_row, row in enumerate(rows):
|
|
167
|
+
if i_row == 0:
|
|
168
|
+
for column_name in row.keys():
|
|
169
|
+
rich_table.add_column(column_name, justify="left", style="cyan")
|
|
170
|
+
rich_table.add_row(*[str(value) for value in row.values()])
|
|
171
|
+
rprint(rich_table)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def check_dataset(dataset: HafniaDataset):
|
|
175
|
+
"""
|
|
176
|
+
Performs various checks on the dataset to ensure its integrity and consistency.
|
|
177
|
+
Raises errors if any issues are found.
|
|
178
|
+
"""
|
|
179
|
+
from hafnia.dataset.hafnia_dataset import Sample
|
|
180
|
+
|
|
181
|
+
user_logger.info("Checking Hafnia dataset...")
|
|
182
|
+
assert isinstance(dataset.info.dataset_name, str) and len(dataset.info.dataset_name) > 0
|
|
183
|
+
|
|
184
|
+
sample_dataset = dataset.create_sample_dataset()
|
|
185
|
+
if len(sample_dataset) == 0:
|
|
186
|
+
raise ValueError("The dataset does not include a sample dataset")
|
|
187
|
+
|
|
188
|
+
actual_splits = dataset.samples.select(pl.col(ColumnName.SPLIT)).unique().to_series().to_list()
|
|
189
|
+
expected_splits = SplitName.valid_splits()
|
|
190
|
+
if set(actual_splits) != set(expected_splits):
|
|
191
|
+
raise ValueError(f"Expected all splits '{expected_splits}' in dataset, but got '{actual_splits}'. ")
|
|
192
|
+
|
|
193
|
+
dataset.check_dataset_tasks()
|
|
194
|
+
|
|
195
|
+
expected_tasks = dataset.info.tasks
|
|
196
|
+
distribution = dataset.info.distributions or []
|
|
197
|
+
distribution_names = [task.name for task in distribution]
|
|
198
|
+
# Check that tasks found in the 'dataset.table' matches the tasks defined in 'dataset.info.tasks'
|
|
199
|
+
for PrimitiveType in PRIMITIVE_TYPES:
|
|
200
|
+
column_name = PrimitiveType.column_name()
|
|
201
|
+
if column_name not in dataset.samples.columns:
|
|
202
|
+
continue
|
|
203
|
+
objects_df = create_primitive_table(dataset.samples, PrimitiveType=PrimitiveType, keep_sample_data=False)
|
|
204
|
+
if objects_df is None:
|
|
205
|
+
continue
|
|
206
|
+
for (task_name,), object_group in objects_df.group_by(FieldName.TASK_NAME):
|
|
207
|
+
has_task = any([t for t in expected_tasks if t.name == task_name and t.primitive == PrimitiveType])
|
|
208
|
+
if has_task or (task_name in distribution_names):
|
|
209
|
+
continue
|
|
210
|
+
class_names = object_group[FieldName.CLASS_NAME].unique().to_list()
|
|
211
|
+
raise ValueError(
|
|
212
|
+
f"Task name '{task_name}' for the '{PrimitiveType.__name__}' primitive is missing in "
|
|
213
|
+
f"'dataset.info.tasks' for dataset '{task_name}'. Missing task has the following "
|
|
214
|
+
f"classes: {class_names}. "
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
for sample_dict in track(dataset, description="Checking samples in dataset"):
|
|
218
|
+
sample = Sample(**sample_dict) # noqa: F841
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def check_dataset_tasks(dataset: HafniaDataset):
|
|
222
|
+
"""
|
|
223
|
+
Checks that the tasks defined in 'dataset.info.tasks' are consistent with the data in 'dataset.samples'.
|
|
224
|
+
"""
|
|
225
|
+
dataset.info.check_for_duplicate_task_names()
|
|
226
|
+
|
|
227
|
+
for task in dataset.info.tasks:
|
|
228
|
+
primitive = task.primitive.__name__
|
|
229
|
+
column_name = task.primitive.column_name()
|
|
230
|
+
primitive_column = dataset.samples[column_name]
|
|
231
|
+
msg_something_wrong = (
|
|
232
|
+
f"Something is wrong with the defined tasks ('info.tasks') in dataset '{dataset.info.dataset_name}'. \n"
|
|
233
|
+
f"For '{primitive=}' and '{task.name=}' "
|
|
234
|
+
)
|
|
235
|
+
if primitive_column.dtype == pl.Null:
|
|
236
|
+
raise ValueError(msg_something_wrong + "the column is 'Null'. Please check the dataset.")
|
|
237
|
+
|
|
238
|
+
if len(dataset) > 0: # Check only performed for non-empty datasets
|
|
239
|
+
primitive_table = (
|
|
240
|
+
primitive_column.explode().struct.unnest().filter(pl.col(FieldName.TASK_NAME) == task.name)
|
|
241
|
+
)
|
|
242
|
+
if primitive_table.is_empty():
|
|
243
|
+
raise ValueError(
|
|
244
|
+
msg_something_wrong
|
|
245
|
+
+ f"the column '{column_name}' has no {task.name=} objects. Please check the dataset."
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
actual_classes = set(primitive_table[FieldName.CLASS_NAME].unique().to_list())
|
|
249
|
+
if task.class_names is None:
|
|
250
|
+
raise ValueError(
|
|
251
|
+
msg_something_wrong
|
|
252
|
+
+ f"the column '{column_name}' with {task.name=} has no defined classes. Please check the dataset."
|
|
253
|
+
)
|
|
254
|
+
defined_classes = set(task.class_names)
|
|
255
|
+
|
|
256
|
+
if not actual_classes.issubset(defined_classes):
|
|
257
|
+
raise ValueError(
|
|
258
|
+
msg_something_wrong
|
|
259
|
+
+ f"the column '{column_name}' with {task.name=} we expected the actual classes in the dataset to \n"
|
|
260
|
+
f"to be a subset of the defined classes\n\t{actual_classes=} \n\t{defined_classes=}."
|
|
261
|
+
)
|
|
262
|
+
# Check class_indices
|
|
263
|
+
mapped_indices = primitive_table[FieldName.CLASS_NAME].map_elements(
|
|
264
|
+
lambda x: task.class_names.index(x), return_dtype=pl.Int64
|
|
265
|
+
)
|
|
266
|
+
table_indices = primitive_table[FieldName.CLASS_IDX]
|
|
267
|
+
|
|
268
|
+
error_msg = msg_something_wrong + (
|
|
269
|
+
f"class indices in '{FieldName.CLASS_IDX}' column does not match classes ordering in 'task.class_names'"
|
|
270
|
+
)
|
|
271
|
+
assert mapped_indices.equals(table_indices), error_msg
|