hafnia 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. cli/__main__.py +3 -1
  2. cli/config.py +43 -3
  3. cli/keychain.py +88 -0
  4. cli/profile_cmds.py +5 -2
  5. hafnia/__init__.py +1 -1
  6. hafnia/dataset/dataset_helpers.py +9 -2
  7. hafnia/dataset/dataset_names.py +2 -1
  8. hafnia/dataset/dataset_recipe/dataset_recipe.py +49 -37
  9. hafnia/dataset/dataset_recipe/recipe_transforms.py +18 -2
  10. hafnia/dataset/dataset_upload_helper.py +60 -4
  11. hafnia/dataset/format_conversions/image_classification_from_directory.py +106 -0
  12. hafnia/dataset/format_conversions/torchvision_datasets.py +281 -0
  13. hafnia/dataset/hafnia_dataset.py +176 -50
  14. hafnia/dataset/operations/dataset_stats.py +2 -3
  15. hafnia/dataset/operations/dataset_transformations.py +19 -15
  16. hafnia/dataset/operations/table_transformations.py +4 -3
  17. hafnia/dataset/primitives/bbox.py +25 -12
  18. hafnia/dataset/primitives/bitmask.py +26 -14
  19. hafnia/dataset/primitives/classification.py +16 -8
  20. hafnia/dataset/primitives/point.py +7 -3
  21. hafnia/dataset/primitives/polygon.py +16 -9
  22. hafnia/dataset/primitives/segmentation.py +10 -7
  23. hafnia/experiment/hafnia_logger.py +0 -9
  24. hafnia/platform/dataset_recipe.py +7 -2
  25. hafnia/platform/datasets.py +3 -3
  26. hafnia/platform/download.py +23 -18
  27. hafnia/utils.py +17 -0
  28. hafnia/visualizations/image_visualizations.py +1 -1
  29. {hafnia-0.3.0.dist-info → hafnia-0.4.0.dist-info}/METADATA +8 -6
  30. hafnia-0.4.0.dist-info/RECORD +56 -0
  31. hafnia-0.3.0.dist-info/RECORD +0 -53
  32. {hafnia-0.3.0.dist-info → hafnia-0.4.0.dist-info}/WHEEL +0 -0
  33. {hafnia-0.3.0.dist-info → hafnia-0.4.0.dist-info}/entry_points.txt +0 -0
  34. {hafnia-0.3.0.dist-info → hafnia-0.4.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,281 @@
1
+ import inspect
2
+ import os
3
+ import shutil
4
+ import tempfile
5
+ import textwrap
6
+ from pathlib import Path
7
+ from typing import Callable, Dict, List, Optional, Tuple
8
+
9
+ from rich.progress import track
10
+ from torchvision import datasets as tv_datasets
11
+ from torchvision.datasets import VisionDataset
12
+ from torchvision.datasets.utils import download_and_extract_archive, extract_archive
13
+
14
+ from hafnia import utils
15
+ from hafnia.dataset.dataset_helpers import save_pil_image_with_hash_name
16
+ from hafnia.dataset.dataset_names import SplitName
17
+ from hafnia.dataset.format_conversions.image_classification_from_directory import (
18
+ import_image_classification_directory_tree,
19
+ )
20
+ from hafnia.dataset.hafnia_dataset import DatasetInfo, HafniaDataset, Sample, TaskInfo
21
+ from hafnia.dataset.primitives import Classification
22
+
23
+
24
+ def torchvision_to_hafnia_converters() -> Dict[str, Callable]:
25
+ return {
26
+ "mnist": mnist_as_hafnia_dataset,
27
+ "cifar10": cifar10_as_hafnia_dataset,
28
+ "cifar100": cifar100_as_hafnia_dataset,
29
+ "caltech-101": caltech_101_as_hafnia_dataset,
30
+ "caltech-256": caltech_256_as_hafnia_dataset,
31
+ }
32
+
33
+
34
+ def mnist_as_hafnia_dataset(force_redownload=False, n_samples: Optional[int] = None) -> HafniaDataset:
35
+ samples, tasks = torchvision_basic_image_classification_dataset_as_hafnia_dataset(
36
+ dataset_loader=tv_datasets.MNIST,
37
+ force_redownload=force_redownload,
38
+ n_samples=n_samples,
39
+ )
40
+
41
+ dataset_info = DatasetInfo(
42
+ dataset_name="mnist",
43
+ version="1.1.0",
44
+ tasks=tasks,
45
+ reference_bibtex=textwrap.dedent("""\
46
+ @article{lecun2010mnist,
47
+ title={MNIST handwritten digit database},
48
+ author={LeCun, Yann and Cortes, Corinna and Burges, CJ},
49
+ journal={ATT Labs [Online]. Available: http://yann.lecun.com/exdb/mnist},
50
+ volume={2},
51
+ year={2010}
52
+ }"""),
53
+ reference_paper_url=None,
54
+ reference_dataset_page="http://yann.lecun.com/exdb/mnist",
55
+ )
56
+ return HafniaDataset.from_samples_list(samples_list=samples, info=dataset_info)
57
+
58
+
59
+ def cifar10_as_hafnia_dataset(force_redownload: bool = False, n_samples: Optional[int] = None) -> HafniaDataset:
60
+ return cifar_as_hafnia_dataset(dataset_name="cifar10", force_redownload=force_redownload, n_samples=n_samples)
61
+
62
+
63
+ def cifar100_as_hafnia_dataset(force_redownload: bool = False, n_samples: Optional[int] = None) -> HafniaDataset:
64
+ return cifar_as_hafnia_dataset(dataset_name="cifar100", force_redownload=force_redownload, n_samples=n_samples)
65
+
66
+
67
+ def caltech_101_as_hafnia_dataset(
68
+ force_redownload: bool = False,
69
+ n_samples: Optional[int] = None,
70
+ ) -> HafniaDataset:
71
+ dataset_name = "caltech-101"
72
+ path_image_classification_folder = _download_and_extract_caltech_dataset(
73
+ dataset_name, force_redownload=force_redownload
74
+ )
75
+ hafnia_dataset = import_image_classification_directory_tree(
76
+ path_image_classification_folder,
77
+ split=SplitName.TRAIN,
78
+ n_samples=n_samples,
79
+ )
80
+ hafnia_dataset.info.dataset_name = dataset_name
81
+ hafnia_dataset.info.version = "1.1.0"
82
+ hafnia_dataset.info.reference_bibtex = textwrap.dedent("""\
83
+ @article{FeiFei2004LearningGV,
84
+ title={Learning Generative Visual Models from Few Training Examples: An Incremental Bayesian
85
+ Approach Tested on 101 Object Categories},
86
+ author={Li Fei-Fei and Rob Fergus and Pietro Perona},
87
+ journal={Computer Vision and Pattern Recognition Workshop},
88
+ year={2004},
89
+ }
90
+ """)
91
+ hafnia_dataset.info.reference_dataset_page = "https://data.caltech.edu/records/mzrjq-6wc02"
92
+
93
+ return hafnia_dataset
94
+
95
+
96
+ def caltech_256_as_hafnia_dataset(
97
+ force_redownload: bool = False,
98
+ n_samples: Optional[int] = None,
99
+ ) -> HafniaDataset:
100
+ dataset_name = "caltech-256"
101
+
102
+ path_image_classification_folder = _download_and_extract_caltech_dataset(
103
+ dataset_name, force_redownload=force_redownload
104
+ )
105
+ hafnia_dataset = import_image_classification_directory_tree(
106
+ path_image_classification_folder,
107
+ split=SplitName.TRAIN,
108
+ n_samples=n_samples,
109
+ )
110
+ hafnia_dataset.info.dataset_name = dataset_name
111
+ hafnia_dataset.info.version = "1.1.0"
112
+ hafnia_dataset.info.reference_bibtex = textwrap.dedent("""\
113
+ @misc{griffin_2023_5sv1j-ytw97,
114
+ author = {Griffin, Gregory and
115
+ Holub, Alex and
116
+ Perona, Pietro},
117
+ title = {Caltech-256 Object Category Dataset},
118
+ month = aug,
119
+ year = 2023,
120
+ publisher = {California Institute of Technology},
121
+ version = {public},
122
+ }""")
123
+ hafnia_dataset.info.reference_dataset_page = "https://data.caltech.edu/records/nyy15-4j048"
124
+
125
+ return hafnia_dataset
126
+
127
+
128
+ def cifar_as_hafnia_dataset(
129
+ dataset_name: str,
130
+ force_redownload: bool = False,
131
+ n_samples: Optional[int] = None,
132
+ ) -> HafniaDataset:
133
+ if dataset_name == "cifar10":
134
+ dataset_loader = tv_datasets.CIFAR10
135
+ elif dataset_name == "cifar100":
136
+ dataset_loader = tv_datasets.CIFAR100
137
+ else:
138
+ raise ValueError(f"Unknown dataset name: {dataset_name}. Supported: cifar10, cifar100")
139
+ samples, tasks = torchvision_basic_image_classification_dataset_as_hafnia_dataset(
140
+ dataset_loader=dataset_loader,
141
+ force_redownload=force_redownload,
142
+ n_samples=n_samples,
143
+ )
144
+
145
+ dataset_info = DatasetInfo(
146
+ dataset_name=dataset_name,
147
+ version="1.1.0",
148
+ tasks=tasks,
149
+ reference_bibtex=textwrap.dedent("""\
150
+ @@TECHREPORT{Krizhevsky09learningmultiple,
151
+ author = {Alex Krizhevsky},
152
+ title = {Learning multiple layers of features from tiny images},
153
+ institution = {},
154
+ year = {2009}
155
+ }"""),
156
+ reference_paper_url="https://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf",
157
+ reference_dataset_page="https://www.cs.toronto.edu/~kriz/cifar.html",
158
+ )
159
+
160
+ return HafniaDataset.from_samples_list(samples_list=samples, info=dataset_info)
161
+
162
+
163
+ def torchvision_basic_image_classification_dataset_as_hafnia_dataset(
164
+ dataset_loader: VisionDataset,
165
+ force_redownload: bool = False,
166
+ n_samples: Optional[int] = None,
167
+ ) -> Tuple[List[Sample], List[TaskInfo]]:
168
+ """
169
+ Converts a certain group of torchvision-based image classification datasets to a Hafnia Dataset.
170
+
171
+ This conversion only works for certain group of image classification VisionDataset by torchvision.
172
+ Common for these datasets is:
173
+ 1) They provide a 'class_to_idx' mapping,
174
+ 2) A "train" boolean parameter in the init function to separate training and test data - thus no validation split
175
+ is available for these datasets,
176
+ 3) Datasets are in-memory and not on disk
177
+ 4) Samples consist of a PIL image and a class index.
178
+
179
+ """
180
+ torchvision_dataset_name = dataset_loader.__name__
181
+
182
+ # Check if loader has train-parameter using inspect module
183
+ params = inspect.signature(dataset_loader).parameters
184
+
185
+ has_train_param = ("train" in params) and (params["train"].annotation is bool)
186
+ if not has_train_param:
187
+ raise ValueError(
188
+ f"The dataset loader '{dataset_loader.__name__}' does not have a 'train: bool' parameter in the init "
189
+ "function. This is a sign that the wrong dataset loader is being used. This conversion function only "
190
+ "works for certain image classification datasets provided by torchvision that are similar to e.g. "
191
+ "MNIST, CIFAR-10, CIFAR-100"
192
+ )
193
+
194
+ path_torchvision_dataset = utils.get_path_torchvision_downloads() / torchvision_dataset_name
195
+ path_hafnia_conversions = utils.get_path_hafnia_conversions() / torchvision_dataset_name
196
+
197
+ if force_redownload:
198
+ shutil.rmtree(path_torchvision_dataset, ignore_errors=True)
199
+ shutil.rmtree(path_hafnia_conversions, ignore_errors=True)
200
+
201
+ splits = {
202
+ SplitName.TRAIN: dataset_loader(root=path_torchvision_dataset, train=True, download=True),
203
+ SplitName.TEST: dataset_loader(root=path_torchvision_dataset, train=False, download=True),
204
+ }
205
+
206
+ samples = []
207
+ n_samples_per_split = n_samples // len(splits) if n_samples is not None else None
208
+ for split_name, torchvision_dataset in splits.items():
209
+ class_name_to_index = torchvision_dataset.class_to_idx
210
+ class_index_to_name = {v: k for k, v in class_name_to_index.items()}
211
+ description = f"Convert '{torchvision_dataset_name}' ({split_name} split) to Hafnia Dataset "
212
+ samples_in_split = []
213
+ for image, class_idx in track(torchvision_dataset, total=n_samples_per_split, description=description):
214
+ (width, height) = image.size
215
+ path_image = save_pil_image_with_hash_name(image, path_hafnia_conversions)
216
+ sample = Sample(
217
+ file_path=str(path_image),
218
+ height=height,
219
+ width=width,
220
+ split=split_name,
221
+ classifications=[
222
+ Classification(
223
+ class_name=class_index_to_name[class_idx],
224
+ class_idx=class_idx,
225
+ )
226
+ ],
227
+ )
228
+ samples_in_split.append(sample)
229
+
230
+ if n_samples_per_split is not None and len(samples_in_split) >= n_samples_per_split:
231
+ break
232
+
233
+ samples.extend(samples_in_split)
234
+ class_names = list(class_name_to_index.keys())
235
+ tasks = [TaskInfo(primitive=Classification, class_names=class_names)]
236
+
237
+ return samples, tasks
238
+
239
+
240
+ def _download_and_extract_caltech_dataset(dataset_name: str, force_redownload: bool) -> Path:
241
+ path_torchvision_dataset = utils.get_path_torchvision_downloads() / dataset_name
242
+
243
+ if force_redownload:
244
+ shutil.rmtree(path_torchvision_dataset, ignore_errors=True)
245
+
246
+ if path_torchvision_dataset.exists():
247
+ return path_torchvision_dataset
248
+
249
+ with tempfile.TemporaryDirectory() as tmpdirname:
250
+ path_tmp_output = Path(tmpdirname)
251
+ path_tmp_output.mkdir(parents=True, exist_ok=True)
252
+
253
+ if dataset_name == "caltech-101":
254
+ download_and_extract_archive(
255
+ "https://data.caltech.edu/records/mzrjq-6wc02/files/caltech-101.zip",
256
+ download_root=path_tmp_output,
257
+ filename="caltech-101.zip",
258
+ md5="3138e1922a9193bfa496528edbbc45d0",
259
+ )
260
+ path_output_extracted = path_tmp_output / "caltech-101"
261
+ for gzip_file in os.listdir(path_output_extracted):
262
+ if gzip_file.endswith(".gz"):
263
+ extract_archive(os.path.join(path_output_extracted, gzip_file), path_output_extracted)
264
+ path_org = path_output_extracted / "101_ObjectCategories"
265
+
266
+ elif dataset_name == "caltech-256":
267
+ org_dataset_name = "256_ObjectCategories"
268
+ path_org = path_tmp_output / org_dataset_name
269
+ download_and_extract_archive(
270
+ url=f"https://data.caltech.edu/records/nyy15-4j048/files/{org_dataset_name}.tar",
271
+ download_root=path_tmp_output,
272
+ md5="67b4f42ca05d46448c6bb8ecd2220f6d",
273
+ remove_finished=True,
274
+ )
275
+
276
+ else:
277
+ raise ValueError(f"Unknown dataset name: {dataset_name}. Supported: caltech-101, caltech-256")
278
+
279
+ shutil.rmtree(path_torchvision_dataset, ignore_errors=True)
280
+ shutil.move(path_org, path_torchvision_dataset)
281
+ return path_torchvision_dataset
@@ -8,14 +8,15 @@ from dataclasses import dataclass
8
8
  from datetime import datetime
9
9
  from pathlib import Path
10
10
  from random import Random
11
- from typing import Any, Dict, List, Optional, Type, Union
11
+ from typing import Any, Dict, List, Optional, Tuple, Type, Union
12
12
 
13
13
  import more_itertools
14
14
  import numpy as np
15
15
  import polars as pl
16
+ from packaging.version import Version
16
17
  from PIL import Image
17
18
  from pydantic import BaseModel, Field, field_serializer, field_validator
18
- from tqdm import tqdm
19
+ from rich.progress import track
19
20
 
20
21
  import hafnia
21
22
  from hafnia.dataset import dataset_helpers
@@ -29,10 +30,14 @@ from hafnia.dataset.dataset_names import (
29
30
  ColumnName,
30
31
  SplitName,
31
32
  )
32
- from hafnia.dataset.operations import dataset_stats, dataset_transformations, table_transformations
33
+ from hafnia.dataset.operations import (
34
+ dataset_stats,
35
+ dataset_transformations,
36
+ table_transformations,
37
+ )
33
38
  from hafnia.dataset.operations.table_transformations import (
34
39
  check_image_paths,
35
- read_table_from_path,
40
+ read_samples_from_path,
36
41
  )
37
42
  from hafnia.dataset.primitives import PRIMITIVE_TYPES, get_primitive_type_from_string
38
43
  from hafnia.dataset.primitives.bbox import Bbox
@@ -44,9 +49,17 @@ from hafnia.log import user_logger
44
49
 
45
50
 
46
51
  class TaskInfo(BaseModel):
47
- primitive: Type[Primitive] # Primitive class or string name of the primitive, e.g. "Bbox" or "bitmask"
48
- class_names: Optional[List[str]] # Class names for the tasks. To get consistent class indices specify class_names.
49
- name: Optional[str] = None # Use 'None' to use default name Bbox ->"bboxes", Bitmask -> "bitmasks" etc.
52
+ primitive: Type[Primitive] = Field(
53
+ description="Primitive class or string name of the primitive, e.g. 'Bbox' or 'bitmask'"
54
+ )
55
+ class_names: Optional[List[str]] = Field(default=None, description="Optional list of class names for the primitive")
56
+ name: Optional[str] = Field(
57
+ default=None,
58
+ description=(
59
+ "Optional name for the task. 'None' will use default name of the provided primitive. "
60
+ "e.g. Bbox ->'bboxes', Bitmask -> 'bitmasks' etc."
61
+ ),
62
+ )
50
63
 
51
64
  def model_post_init(self, __context: Any) -> None:
52
65
  if self.name is None:
@@ -99,17 +112,37 @@ class TaskInfo(BaseModel):
99
112
 
100
113
 
101
114
  class DatasetInfo(BaseModel):
102
- dataset_name: str
103
- version: str # Dataset version. This is not the same as the Hafnia dataset format version.
104
- tasks: List[TaskInfo]
105
- distributions: Optional[List[TaskInfo]] = None # Distributions. TODO: FIX/REMOVE/CHANGE this
106
- meta: Optional[Dict[str, Any]] = None # Metadata about the dataset, e.g. description, etc.
107
- format_version: str = hafnia.__dataset_format_version__ # Version of the Hafnia dataset format
108
- updated_at: datetime = datetime.now()
115
+ dataset_name: str = Field(description="Name of the dataset, e.g. 'coco'")
116
+ version: Optional[str] = Field(default=None, description="Version of the dataset")
117
+ tasks: List[TaskInfo] = Field(default=None, description="List of tasks in the dataset")
118
+ distributions: Optional[List[TaskInfo]] = Field(default=None, description="Optional list of task distributions")
119
+ reference_bibtex: Optional[str] = Field(
120
+ default=None,
121
+ description="Optional, BibTeX reference to dataset publication",
122
+ )
123
+ reference_paper_url: Optional[str] = Field(
124
+ default=None,
125
+ description="Optional, URL to dataset publication",
126
+ )
127
+ reference_dataset_page: Optional[str] = Field(
128
+ default=None,
129
+ description="Optional, URL to the dataset page",
130
+ )
131
+ meta: Optional[Dict[str, Any]] = Field(default=None, description="Optional metadata about the dataset")
132
+ format_version: str = Field(
133
+ default=hafnia.__dataset_format_version__,
134
+ description="Version of the Hafnia dataset format. You should not set this manually.",
135
+ )
136
+ updated_at: datetime = Field(
137
+ default_factory=datetime.now,
138
+ description="Timestamp of the last update to the dataset info. You should not set this manually.",
139
+ )
109
140
 
110
141
  @field_validator("tasks", mode="after")
111
142
  @classmethod
112
- def _validate_check_for_duplicate_tasks(cls, tasks: List[TaskInfo]) -> List[TaskInfo]:
143
+ def _validate_check_for_duplicate_tasks(cls, tasks: Optional[List[TaskInfo]]) -> List[TaskInfo]:
144
+ if tasks is None:
145
+ return []
113
146
  task_name_counts = collections.Counter(task.name for task in tasks)
114
147
  duplicate_task_names = [name for name, count in task_name_counts.items() if count > 1]
115
148
  if duplicate_task_names:
@@ -118,6 +151,35 @@ class DatasetInfo(BaseModel):
118
151
  )
119
152
  return tasks
120
153
 
154
+ @field_validator("format_version")
155
+ @classmethod
156
+ def _validate_format_version(cls, format_version: str) -> str:
157
+ try:
158
+ Version(format_version)
159
+ except Exception as e:
160
+ raise ValueError(f"Invalid format_version '{format_version}'. Must be a valid version string.") from e
161
+
162
+ if Version(format_version) > Version(hafnia.__dataset_format_version__):
163
+ user_logger.warning(
164
+ f"The loaded dataset format version '{format_version}' is newer than the format version "
165
+ f"'{hafnia.__dataset_format_version__}' used in your version of Hafnia. Please consider "
166
+ f"updating Hafnia package."
167
+ )
168
+ return format_version
169
+
170
+ @field_validator("version")
171
+ @classmethod
172
+ def _validate_version(cls, dataset_version: Optional[str]) -> Optional[str]:
173
+ if dataset_version is None:
174
+ return None
175
+
176
+ try:
177
+ Version(dataset_version)
178
+ except Exception as e:
179
+ raise ValueError(f"Invalid dataset_version '{dataset_version}'. Must be a valid version string.") from e
180
+
181
+ return dataset_version
182
+
121
183
  def check_for_duplicate_task_names(self) -> List[TaskInfo]:
122
184
  return self._validate_check_for_duplicate_tasks(self.tasks)
123
185
 
@@ -187,7 +249,7 @@ class DatasetInfo(BaseModel):
187
249
  meta.update(info1.meta or {})
188
250
  return DatasetInfo(
189
251
  dataset_name=info0.dataset_name + "+" + info1.dataset_name,
190
- version="merged",
252
+ version=None,
191
253
  tasks=list(unique_tasks),
192
254
  distributions=list(distributions),
193
255
  meta=meta,
@@ -258,22 +320,40 @@ class DatasetInfo(BaseModel):
258
320
 
259
321
 
260
322
  class Sample(BaseModel):
261
- file_name: str
262
- height: int
263
- width: int
264
- split: str # Split name, e.g., "train", "val", "test"
265
- tags: List[str] = [] # tags for a given sample. Used for creating subsets of the dataset.
266
- collection_index: Optional[int] = None # Optional e.g. frame number for video datasets
267
- collection_id: Optional[str] = None # Optional e.g. video name for video datasets
268
- remote_path: Optional[str] = None # Optional remote path for the image, if applicable
269
- sample_index: Optional[int] = None # Don't manually set this, it is used for indexing samples in the dataset.
270
- classifications: Optional[List[Classification]] = None # Optional classification primitive
271
- objects: Optional[List[Bbox]] = None # List of coordinate primitives, e.g., Bbox, Bitmask, etc.
272
- bitmasks: Optional[List[Bitmask]] = None # List of bitmasks, if applicable
273
- polygons: Optional[List[Polygon]] = None # List of polygons, if applicable
274
-
275
- attribution: Optional[Attribution] = None # Attribution information for the image
276
- meta: Optional[Dict] = None # Additional metadata, e.g., camera settings, GPS data, etc.
323
+ file_path: str = Field(description="Path to the image file")
324
+ height: int = Field(description="Height of the image")
325
+ width: int = Field(description="Width of the image")
326
+ split: str = Field(description="Split name, e.g., 'train', 'val', 'test'")
327
+ tags: List[str] = Field(
328
+ default_factory=list,
329
+ description="Tags for a given sample. Used for creating subsets of the dataset.",
330
+ )
331
+ collection_index: Optional[int] = Field(default=None, description="Optional e.g. frame number for video datasets")
332
+ collection_id: Optional[str] = Field(default=None, description="Optional e.g. video name for video datasets")
333
+ remote_path: Optional[str] = Field(default=None, description="Optional remote path for the image, if applicable")
334
+ sample_index: Optional[int] = Field(
335
+ default=None,
336
+ description="Don't manually set this, it is used for indexing samples in the dataset.",
337
+ )
338
+ classifications: Optional[List[Classification]] = Field(
339
+ default=None, description="Optional list of classifications"
340
+ )
341
+ objects: Optional[List[Bbox]] = Field(default=None, description="Optional list of objects (bounding boxes)")
342
+ bitmasks: Optional[List[Bitmask]] = Field(default=None, description="Optional list of bitmasks")
343
+ polygons: Optional[List[Polygon]] = Field(default=None, description="Optional list of polygons")
344
+
345
+ attribution: Optional[Attribution] = Field(default=None, description="Attribution information for the image")
346
+ dataset_name: Optional[str] = Field(
347
+ default=None,
348
+ description=(
349
+ "Don't manually set this, it will be automatically defined during initialization. "
350
+ "Name of the dataset the sample belongs to. E.g. 'coco-2017' or 'midwest-vehicle-detection'."
351
+ ),
352
+ )
353
+ meta: Optional[Dict] = Field(
354
+ default=None,
355
+ description="Additional metadata, e.g., camera settings, GPS data, etc.",
356
+ )
277
357
 
278
358
  def get_annotations(self, primitive_types: Optional[List[Type[Primitive]]] = None) -> List[Primitive]:
279
359
  """
@@ -294,7 +374,7 @@ class Sample(BaseModel):
294
374
  Reads the image from the file path and returns it as a PIL Image.
295
375
  Raises FileNotFoundError if the image file does not exist.
296
376
  """
297
- path_image = Path(self.file_name)
377
+ path_image = Path(self.file_path)
298
378
  if not path_image.exists():
299
379
  raise FileNotFoundError(f"Image file {path_image} does not exist. Please check the file path.")
300
380
 
@@ -413,30 +493,23 @@ class HafniaDataset:
413
493
  yield row
414
494
 
415
495
  def __post_init__(self):
416
- samples = self.samples
417
- if ColumnName.SAMPLE_INDEX not in samples.columns:
418
- samples = samples.with_row_index(name=ColumnName.SAMPLE_INDEX)
419
-
420
- # Backwards compatibility: If tags-column doesn't exist, create it with empty lists
421
- if ColumnName.TAGS not in samples.columns:
422
- tags_column: List[List[str]] = [[] for _ in range(len(self))] # type: ignore[annotation-unchecked]
423
- samples = samples.with_columns(pl.Series(tags_column, dtype=pl.List(pl.String)).alias(ColumnName.TAGS))
424
-
425
- self.samples = samples
496
+ self.samples, self.info = _dataset_corrections(self.samples, self.info)
426
497
 
427
498
  @staticmethod
428
499
  def from_path(path_folder: Path, check_for_images: bool = True) -> "HafniaDataset":
500
+ path_folder = Path(path_folder)
429
501
  HafniaDataset.check_dataset_path(path_folder, raise_error=True)
430
502
 
431
503
  dataset_info = DatasetInfo.from_json_file(path_folder / FILENAME_DATASET_INFO)
432
- table = read_table_from_path(path_folder)
504
+ samples = read_samples_from_path(path_folder)
505
+ samples, dataset_info = _dataset_corrections(samples, dataset_info)
433
506
 
434
507
  # Convert from relative paths to absolute paths
435
508
  dataset_root = path_folder.absolute().as_posix() + "/"
436
- table = table.with_columns((dataset_root + pl.col("file_name")).alias("file_name"))
509
+ samples = samples.with_columns((dataset_root + pl.col(ColumnName.FILE_PATH)).alias(ColumnName.FILE_PATH))
437
510
  if check_for_images:
438
- check_image_paths(table)
439
- return HafniaDataset(samples=table, info=dataset_info)
511
+ check_image_paths(samples)
512
+ return HafniaDataset(samples=samples, info=dataset_info)
440
513
 
441
514
  @staticmethod
442
515
  def from_name(name: str, force_redownload: bool = False, download_files: bool = True) -> "HafniaDataset":
@@ -464,6 +537,14 @@ class HafniaDataset:
464
537
 
465
538
  table = pl.from_records(json_samples)
466
539
  table = table.drop(ColumnName.SAMPLE_INDEX).with_row_index(name=ColumnName.SAMPLE_INDEX)
540
+
541
+ # Add 'dataset_name' to samples
542
+ table = table.with_columns(
543
+ pl.when(pl.col(ColumnName.DATASET_NAME).is_null())
544
+ .then(pl.lit(info.dataset_name))
545
+ .otherwise(pl.col(ColumnName.DATASET_NAME))
546
+ .alias(ColumnName.DATASET_NAME)
547
+ )
467
548
  return HafniaDataset(info=info, samples=table)
468
549
 
469
550
  @staticmethod
@@ -518,6 +599,28 @@ class HafniaDataset:
518
599
  merged_dataset = HafniaDataset.merge(merged_dataset, dataset)
519
600
  return merged_dataset
520
601
 
602
+ @staticmethod
603
+ def from_name_public_dataset(
604
+ name: str,
605
+ force_redownload: bool = False,
606
+ n_samples: Optional[int] = None,
607
+ ) -> HafniaDataset:
608
+ from hafnia.dataset.format_conversions.torchvision_datasets import (
609
+ torchvision_to_hafnia_converters,
610
+ )
611
+
612
+ name_to_torchvision_function = torchvision_to_hafnia_converters()
613
+
614
+ if name not in name_to_torchvision_function:
615
+ raise ValueError(
616
+ f"Unknown torchvision dataset name: {name}. Supported: {list(name_to_torchvision_function.keys())}"
617
+ )
618
+ vision_dataset = name_to_torchvision_function[name]
619
+ return vision_dataset(
620
+ force_redownload=force_redownload,
621
+ n_samples=n_samples,
622
+ )
623
+
521
624
  def shuffle(dataset: HafniaDataset, seed: int = 42) -> HafniaDataset:
522
625
  table = dataset.samples.sample(n=len(dataset), with_replacement=False, seed=seed, shuffle=True)
523
626
  return dataset.update_samples(table)
@@ -607,7 +710,7 @@ class HafniaDataset:
607
710
 
608
711
  def class_mapper(
609
712
  dataset: "HafniaDataset",
610
- class_mapping: Dict[str, str],
713
+ class_mapping: Union[Dict[str, str], List[Tuple[str, str]]],
611
714
  method: str = "strict",
612
715
  primitive: Optional[Type[Primitive]] = None,
613
716
  task_name: Optional[str] = None,
@@ -778,13 +881,14 @@ class HafniaDataset:
778
881
  path_folder.mkdir(parents=True)
779
882
 
780
883
  new_relative_paths = []
781
- for org_path in tqdm(self.samples["file_name"].to_list(), desc="- Copy images"):
884
+ org_paths = self.samples[ColumnName.FILE_PATH].to_list()
885
+ for org_path in track(org_paths, description="- Copy images"):
782
886
  new_path = dataset_helpers.copy_and_rename_file_to_hash_value(
783
887
  path_source=Path(org_path),
784
888
  path_dataset_root=path_folder,
785
889
  )
786
890
  new_relative_paths.append(str(new_path.relative_to(path_folder)))
787
- table = self.samples.with_columns(pl.Series(new_relative_paths).alias("file_name"))
891
+ table = self.samples.with_columns(pl.Series(new_relative_paths).alias(ColumnName.FILE_PATH))
788
892
 
789
893
  if drop_null_cols: # Drops all unused/Null columns
790
894
  table = table.drop(pl.selectors.by_dtype(pl.Null))
@@ -846,3 +950,25 @@ def get_or_create_dataset_path_from_recipe(
846
950
  dataset.write(path_dataset)
847
951
 
848
952
  return path_dataset
953
+
954
+
955
+ def _dataset_corrections(samples: pl.DataFrame, dataset_info: DatasetInfo) -> Tuple[pl.DataFrame, DatasetInfo]:
956
+ format_version_of_dataset = Version(dataset_info.format_version)
957
+
958
+ ## Backwards compatibility fixes for older dataset versions
959
+ if format_version_of_dataset <= Version("0.3.0"):
960
+ if ColumnName.DATASET_NAME not in samples.columns:
961
+ samples = samples.with_columns(pl.lit(dataset_info.dataset_name).alias(ColumnName.DATASET_NAME))
962
+
963
+ if "file_name" in samples.columns:
964
+ samples = samples.rename({"file_name": ColumnName.FILE_PATH})
965
+
966
+ if ColumnName.SAMPLE_INDEX not in samples.columns:
967
+ samples = samples.with_row_index(name=ColumnName.SAMPLE_INDEX)
968
+
969
+ # Backwards compatibility: If tags-column doesn't exist, create it with empty lists
970
+ if ColumnName.TAGS not in samples.columns:
971
+ tags_column: List[List[str]] = [[] for _ in range(len(samples))] # type: ignore[annotation-unchecked]
972
+ samples = samples.with_columns(pl.Series(tags_column, dtype=pl.List(pl.String)).alias(ColumnName.TAGS))
973
+
974
+ return samples, dataset_info
@@ -5,8 +5,8 @@ from typing import TYPE_CHECKING, Dict, Optional, Type
5
5
  import polars as pl
6
6
  import rich
7
7
  from rich import print as rprint
8
+ from rich.progress import track
8
9
  from rich.table import Table
9
- from tqdm import tqdm
10
10
 
11
11
  from hafnia.dataset.dataset_names import ColumnName, FieldName, SplitName
12
12
  from hafnia.dataset.operations.table_transformations import create_primitive_table
@@ -179,7 +179,6 @@ def check_dataset(dataset: HafniaDataset):
179
179
  from hafnia.dataset.hafnia_dataset import Sample
180
180
 
181
181
  user_logger.info("Checking Hafnia dataset...")
182
- assert isinstance(dataset.info.version, str) and len(dataset.info.version) > 0
183
182
  assert isinstance(dataset.info.dataset_name, str) and len(dataset.info.dataset_name) > 0
184
183
 
185
184
  sample_dataset = dataset.create_sample_dataset()
@@ -215,7 +214,7 @@ def check_dataset(dataset: HafniaDataset):
215
214
  f"classes: {class_names}. "
216
215
  )
217
216
 
218
- for sample_dict in tqdm(dataset, desc="Checking samples in dataset"):
217
+ for sample_dict in track(dataset, description="Checking samples in dataset"):
219
218
  sample = Sample(**sample_dict) # noqa: F841
220
219
 
221
220