hafnia 0.2.4__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. cli/__main__.py +16 -3
  2. cli/config.py +45 -4
  3. cli/consts.py +1 -1
  4. cli/dataset_cmds.py +6 -14
  5. cli/dataset_recipe_cmds.py +78 -0
  6. cli/experiment_cmds.py +226 -43
  7. cli/keychain.py +88 -0
  8. cli/profile_cmds.py +10 -6
  9. cli/runc_cmds.py +5 -5
  10. cli/trainer_package_cmds.py +65 -0
  11. hafnia/__init__.py +2 -0
  12. hafnia/data/factory.py +1 -2
  13. hafnia/dataset/dataset_helpers.py +9 -14
  14. hafnia/dataset/dataset_names.py +10 -5
  15. hafnia/dataset/dataset_recipe/dataset_recipe.py +165 -67
  16. hafnia/dataset/dataset_recipe/recipe_transforms.py +48 -4
  17. hafnia/dataset/dataset_recipe/recipe_types.py +1 -1
  18. hafnia/dataset/dataset_upload_helper.py +265 -56
  19. hafnia/dataset/format_conversions/image_classification_from_directory.py +106 -0
  20. hafnia/dataset/format_conversions/torchvision_datasets.py +281 -0
  21. hafnia/dataset/hafnia_dataset.py +577 -213
  22. hafnia/dataset/license_types.py +63 -0
  23. hafnia/dataset/operations/dataset_stats.py +259 -3
  24. hafnia/dataset/operations/dataset_transformations.py +332 -7
  25. hafnia/dataset/operations/table_transformations.py +43 -5
  26. hafnia/dataset/primitives/__init__.py +8 -0
  27. hafnia/dataset/primitives/bbox.py +25 -12
  28. hafnia/dataset/primitives/bitmask.py +26 -14
  29. hafnia/dataset/primitives/classification.py +16 -8
  30. hafnia/dataset/primitives/point.py +7 -3
  31. hafnia/dataset/primitives/polygon.py +16 -9
  32. hafnia/dataset/primitives/segmentation.py +10 -7
  33. hafnia/experiment/hafnia_logger.py +111 -8
  34. hafnia/http.py +16 -2
  35. hafnia/platform/__init__.py +9 -3
  36. hafnia/platform/builder.py +12 -10
  37. hafnia/platform/dataset_recipe.py +104 -0
  38. hafnia/platform/datasets.py +47 -9
  39. hafnia/platform/download.py +25 -19
  40. hafnia/platform/experiment.py +51 -56
  41. hafnia/platform/trainer_package.py +57 -0
  42. hafnia/utils.py +81 -13
  43. hafnia/visualizations/image_visualizations.py +4 -4
  44. {hafnia-0.2.4.dist-info → hafnia-0.4.0.dist-info}/METADATA +40 -34
  45. hafnia-0.4.0.dist-info/RECORD +56 -0
  46. cli/recipe_cmds.py +0 -45
  47. hafnia-0.2.4.dist-info/RECORD +0 -49
  48. {hafnia-0.2.4.dist-info → hafnia-0.4.0.dist-info}/WHEEL +0 -0
  49. {hafnia-0.2.4.dist-info → hafnia-0.4.0.dist-info}/entry_points.txt +0 -0
  50. {hafnia-0.2.4.dist-info → hafnia-0.4.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,63 @@
1
+ from typing import List, Optional
2
+
3
+ from hafnia.dataset.hafnia_dataset import License
4
+
5
+ LICENSE_TYPES: List[License] = [
6
+ License(
7
+ name="Creative Commons: Attribution-NonCommercial-ShareAlike 2.0 Generic",
8
+ name_short="CC BY-NC-SA 2.0",
9
+ url="https://creativecommons.org/licenses/by-nc-sa/2.0/",
10
+ ),
11
+ License(
12
+ name="Creative Commons: Attribution-NonCommercial 2.0 Generic",
13
+ name_short="CC BY-NC 2.0",
14
+ url="https://creativecommons.org/licenses/by-nc/2.0/",
15
+ ),
16
+ License(
17
+ name="Creative Commons: Attribution-NonCommercial-NoDerivs 2.0 Generic",
18
+ name_short="CC BY-NC-ND 2.0",
19
+ url="https://creativecommons.org/licenses/by-nc-nd/2.0/",
20
+ ),
21
+ License(
22
+ name="Creative Commons: Attribution 2.0 Generic",
23
+ name_short="CC BY 2.0",
24
+ url="https://creativecommons.org/licenses/by/2.0/",
25
+ ),
26
+ License(
27
+ name="Creative Commons: Attribution-ShareAlike 2.0 Generic",
28
+ name_short="CC BY-SA 2.0",
29
+ url="https://creativecommons.org/licenses/by-sa/2.0/",
30
+ ),
31
+ License(
32
+ name="Creative Commons: Attribution-NoDerivs 2.0 Generic",
33
+ name_short="CC BY-ND 2.0",
34
+ url="https://creativecommons.org/licenses/by-nd/2.0/",
35
+ ),
36
+ License(
37
+ name="Flickr: No known copyright restrictions",
38
+ name_short="Flickr",
39
+ url="https://flickr.com/commons/usage/",
40
+ ),
41
+ License(
42
+ name="United States Government Work",
43
+ name_short="US Gov",
44
+ url="http://www.usa.gov/copyright.shtml",
45
+ ),
46
+ ]
47
+
48
+
49
+ def get_license_by_url(url: str) -> Optional[License]:
50
+ for license in LICENSE_TYPES:
51
+ # To handle http urls
52
+ license_url = (license.url or "").replace("http://", "https://")
53
+ url_https = url.replace("http://", "https://")
54
+ if license_url == url_https:
55
+ return license
56
+ raise ValueError(f"License with URL '{url}' not found.")
57
+
58
+
59
+ def get_license_by_short_name(short_name: str) -> Optional[License]:
60
+ for license in LICENSE_TYPES:
61
+ if license.name_short == short_name:
62
+ return license
63
+ raise ValueError(f"License with short name '{short_name}' not found.")
@@ -1,11 +1,21 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import TYPE_CHECKING, Dict
3
+ from typing import TYPE_CHECKING, Dict, Optional, Type
4
4
 
5
- from hafnia.dataset.dataset_names import ColumnName
5
+ import polars as pl
6
+ import rich
7
+ from rich import print as rprint
8
+ from rich.progress import track
9
+ from rich.table import Table
6
10
 
7
- if TYPE_CHECKING:
11
+ from hafnia.dataset.dataset_names import ColumnName, FieldName, SplitName
12
+ from hafnia.dataset.operations.table_transformations import create_primitive_table
13
+ from hafnia.dataset.primitives import PRIMITIVE_TYPES
14
+ from hafnia.log import user_logger
15
+
16
+ if TYPE_CHECKING: # Using 'TYPE_CHECKING' to avoid circular imports during type checking
8
17
  from hafnia.dataset.hafnia_dataset import HafniaDataset
18
+ from hafnia.dataset.primitives.primitive import Primitive
9
19
 
10
20
 
11
21
  def split_counts(dataset: HafniaDataset) -> Dict[str, int]:
@@ -13,3 +23,249 @@ def split_counts(dataset: HafniaDataset) -> Dict[str, int]:
13
23
  Returns a dictionary with the counts of samples in each split of the dataset.
14
24
  """
15
25
  return dict(dataset.samples[ColumnName.SPLIT].value_counts().iter_rows())
26
+
27
+
28
+ def class_counts_for_task(
29
+ dataset: HafniaDataset,
30
+ primitive: Optional[Type[Primitive]] = None,
31
+ task_name: Optional[str] = None,
32
+ ) -> Dict[str, int]:
33
+ """
34
+ Determines class name counts for a specific task in the dataset.
35
+
36
+ The counts are returned as a dictionary where keys are class names and values are their respective counts.
37
+ Note that class names with zero counts are included in the dictionary and that
38
+ the order of the dictionary matches the class idx.
39
+
40
+ >>> counts = dataset.class_counts_for_task(primitive=Bbox)
41
+ >>> print(counts)
42
+ {
43
+ 'person': 0, # Note: Zero count classes are included to maintain order with class idx
44
+ 'car': 500,
45
+ 'bicycle': 0,
46
+ 'bus': 100,
47
+ 'truck': 150,
48
+ 'motorcycle': 50,
49
+ }
50
+ """
51
+ task = dataset.info.get_task_by_task_name_and_primitive(task_name=task_name, primitive=primitive)
52
+ class_counts_df = (
53
+ dataset.samples[task.primitive.column_name()]
54
+ .explode()
55
+ .struct.unnest()
56
+ .filter(pl.col(FieldName.TASK_NAME) == task.name)[FieldName.CLASS_NAME]
57
+ .value_counts()
58
+ )
59
+
60
+ # Initialize counts with zero for all classes to ensure zero-count classes are included
61
+ # and to have class names in the order of class idx
62
+ class_counts = {name: 0 for name in task.class_names or []}
63
+ class_counts.update(dict(class_counts_df.iter_rows()))
64
+
65
+ return class_counts
66
+
67
+
68
+ def class_counts_all(dataset: HafniaDataset) -> Dict[str, int]:
69
+ """
70
+ Get class counts for all tasks in the dataset.
71
+ The counts are returned as a dictionary where keys are in the format
72
+ '{primitive_name}/{task_name}/{class_name}' and values are their respective counts.
73
+
74
+ Example:
75
+ >>> counts = dataset.class_counts_all()
76
+ >>> print(counts)
77
+ {
78
+ objects/bboxes/car: 500
79
+ objects/bboxes/person: 0
80
+ classifications/weather/sunny: 300
81
+ classifications/weather/rainy: 0
82
+ ...
83
+ }
84
+ """
85
+ class_counts = {}
86
+ for task in dataset.info.tasks:
87
+ if task.class_names is None:
88
+ raise ValueError(f"Task '{task.name}' does not have class names defined.")
89
+ class_counts_task = dataset.class_counts_for_task(primitive=task.primitive, task_name=task.name)
90
+
91
+ for class_idx, (class_name, count) in enumerate(class_counts_task.items()):
92
+ count_name = f"{task.primitive.__name__}/{task.name}/{class_name}"
93
+ class_counts[count_name] = count
94
+
95
+ return class_counts
96
+
97
+
98
+ def print_stats(dataset: HafniaDataset) -> None:
99
+ """
100
+ Prints verbose statistics about the dataset, including dataset name, version,
101
+ number of samples, and detailed counts of samples and tasks.
102
+ """
103
+ table_base = Table(title="Dataset Statistics", show_lines=True, box=rich.box.SIMPLE)
104
+ table_base.add_column("Property", style="cyan")
105
+ table_base.add_column("Value")
106
+ table_base.add_row("Dataset Name", dataset.info.dataset_name)
107
+ table_base.add_row("Version", dataset.info.version)
108
+ table_base.add_row("Number of samples", str(len(dataset.samples)))
109
+ rprint(table_base)
110
+ print_sample_and_task_counts(dataset)
111
+ print_class_distribution(dataset)
112
+
113
+
114
+ def print_class_distribution(dataset: HafniaDataset) -> None:
115
+ """
116
+ Prints the class distribution for each task in the dataset.
117
+ """
118
+ for task in dataset.info.tasks:
119
+ if task.class_names is None:
120
+ raise ValueError(f"Task '{task.name}' does not have class names defined.")
121
+ class_counts = dataset.class_counts_for_task(primitive=task.primitive, task_name=task.name)
122
+
123
+ # Print class distribution
124
+ rich_table = Table(title=f"Class Count for '{task.primitive.__name__}/{task.name}'", show_lines=False)
125
+ rich_table.add_column("Class Name", style="cyan")
126
+ rich_table.add_column("Class Idx", style="cyan")
127
+ rich_table.add_column("Count", justify="right")
128
+ for class_name, count in class_counts.items():
129
+ class_idx = task.class_names.index(class_name) # Get class idx from task info
130
+ rich_table.add_row(class_name, str(class_idx), str(count))
131
+ rprint(rich_table)
132
+
133
+
134
+ def print_sample_and_task_counts(dataset: HafniaDataset) -> None:
135
+ """
136
+ Prints a table with sample counts and task counts for each primitive type
137
+ in total and for each split (train, val, test).
138
+ """
139
+ from hafnia.dataset.operations.table_transformations import create_primitive_table
140
+ from hafnia.dataset.primitives import PRIMITIVE_TYPES
141
+
142
+ splits_sets = {
143
+ "All": SplitName.valid_splits(),
144
+ "Train": [SplitName.TRAIN],
145
+ "Validation": [SplitName.VAL],
146
+ "Test": [SplitName.TEST],
147
+ }
148
+ rows = []
149
+ for split_name, splits in splits_sets.items():
150
+ dataset_split = dataset.create_split_dataset(splits)
151
+ table = dataset_split.samples
152
+ row = {}
153
+ row["Split"] = split_name
154
+ row["Sample "] = str(len(table))
155
+ for PrimitiveType in PRIMITIVE_TYPES:
156
+ column_name = PrimitiveType.column_name()
157
+ objects_df = create_primitive_table(table, PrimitiveType=PrimitiveType, keep_sample_data=False)
158
+ if objects_df is None:
159
+ continue
160
+ for (task_name,), object_group in objects_df.group_by(FieldName.TASK_NAME):
161
+ count = len(object_group[FieldName.CLASS_NAME])
162
+ row[f"{PrimitiveType.__name__}\n{task_name}"] = str(count)
163
+ rows.append(row)
164
+
165
+ rich_table = Table(title="Dataset Statistics", show_lines=True, box=rich.box.SIMPLE)
166
+ for i_row, row in enumerate(rows):
167
+ if i_row == 0:
168
+ for column_name in row.keys():
169
+ rich_table.add_column(column_name, justify="left", style="cyan")
170
+ rich_table.add_row(*[str(value) for value in row.values()])
171
+ rprint(rich_table)
172
+
173
+
174
+ def check_dataset(dataset: HafniaDataset):
175
+ """
176
+ Performs various checks on the dataset to ensure its integrity and consistency.
177
+ Raises errors if any issues are found.
178
+ """
179
+ from hafnia.dataset.hafnia_dataset import Sample
180
+
181
+ user_logger.info("Checking Hafnia dataset...")
182
+ assert isinstance(dataset.info.dataset_name, str) and len(dataset.info.dataset_name) > 0
183
+
184
+ sample_dataset = dataset.create_sample_dataset()
185
+ if len(sample_dataset) == 0:
186
+ raise ValueError("The dataset does not include a sample dataset")
187
+
188
+ actual_splits = dataset.samples.select(pl.col(ColumnName.SPLIT)).unique().to_series().to_list()
189
+ expected_splits = SplitName.valid_splits()
190
+ if set(actual_splits) != set(expected_splits):
191
+ raise ValueError(f"Expected all splits '{expected_splits}' in dataset, but got '{actual_splits}'. ")
192
+
193
+ dataset.check_dataset_tasks()
194
+
195
+ expected_tasks = dataset.info.tasks
196
+ distribution = dataset.info.distributions or []
197
+ distribution_names = [task.name for task in distribution]
198
+ # Check that tasks found in the 'dataset.table' matches the tasks defined in 'dataset.info.tasks'
199
+ for PrimitiveType in PRIMITIVE_TYPES:
200
+ column_name = PrimitiveType.column_name()
201
+ if column_name not in dataset.samples.columns:
202
+ continue
203
+ objects_df = create_primitive_table(dataset.samples, PrimitiveType=PrimitiveType, keep_sample_data=False)
204
+ if objects_df is None:
205
+ continue
206
+ for (task_name,), object_group in objects_df.group_by(FieldName.TASK_NAME):
207
+ has_task = any([t for t in expected_tasks if t.name == task_name and t.primitive == PrimitiveType])
208
+ if has_task or (task_name in distribution_names):
209
+ continue
210
+ class_names = object_group[FieldName.CLASS_NAME].unique().to_list()
211
+ raise ValueError(
212
+ f"Task name '{task_name}' for the '{PrimitiveType.__name__}' primitive is missing in "
213
+ f"'dataset.info.tasks' for dataset '{task_name}'. Missing task has the following "
214
+ f"classes: {class_names}. "
215
+ )
216
+
217
+ for sample_dict in track(dataset, description="Checking samples in dataset"):
218
+ sample = Sample(**sample_dict) # noqa: F841
219
+
220
+
221
+ def check_dataset_tasks(dataset: HafniaDataset):
222
+ """
223
+ Checks that the tasks defined in 'dataset.info.tasks' are consistent with the data in 'dataset.samples'.
224
+ """
225
+ dataset.info.check_for_duplicate_task_names()
226
+
227
+ for task in dataset.info.tasks:
228
+ primitive = task.primitive.__name__
229
+ column_name = task.primitive.column_name()
230
+ primitive_column = dataset.samples[column_name]
231
+ msg_something_wrong = (
232
+ f"Something is wrong with the defined tasks ('info.tasks') in dataset '{dataset.info.dataset_name}'. \n"
233
+ f"For '{primitive=}' and '{task.name=}' "
234
+ )
235
+ if primitive_column.dtype == pl.Null:
236
+ raise ValueError(msg_something_wrong + "the column is 'Null'. Please check the dataset.")
237
+
238
+ if len(dataset) > 0: # Check only performed for non-empty datasets
239
+ primitive_table = (
240
+ primitive_column.explode().struct.unnest().filter(pl.col(FieldName.TASK_NAME) == task.name)
241
+ )
242
+ if primitive_table.is_empty():
243
+ raise ValueError(
244
+ msg_something_wrong
245
+ + f"the column '{column_name}' has no {task.name=} objects. Please check the dataset."
246
+ )
247
+
248
+ actual_classes = set(primitive_table[FieldName.CLASS_NAME].unique().to_list())
249
+ if task.class_names is None:
250
+ raise ValueError(
251
+ msg_something_wrong
252
+ + f"the column '{column_name}' with {task.name=} has no defined classes. Please check the dataset."
253
+ )
254
+ defined_classes = set(task.class_names)
255
+
256
+ if not actual_classes.issubset(defined_classes):
257
+ raise ValueError(
258
+ msg_something_wrong
259
+ + f"the column '{column_name}' with {task.name=} we expected the actual classes in the dataset to \n"
260
+ f"to be a subset of the defined classes\n\t{actual_classes=} \n\t{defined_classes=}."
261
+ )
262
+ # Check class_indices
263
+ mapped_indices = primitive_table[FieldName.CLASS_NAME].map_elements(
264
+ lambda x: task.class_names.index(x), return_dtype=pl.Int64
265
+ )
266
+ table_indices = primitive_table[FieldName.CLASS_IDX]
267
+
268
+ error_msg = msg_something_wrong + (
269
+ f"class indices in '{FieldName.CLASS_IDX}' column does not match classes ordering in 'task.class_names'"
270
+ )
271
+ assert mapped_indices.equals(table_indices), error_msg