hafnia 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. cli/__main__.py +3 -1
  2. cli/config.py +43 -3
  3. cli/keychain.py +88 -0
  4. cli/profile_cmds.py +5 -2
  5. hafnia/__init__.py +1 -1
  6. hafnia/dataset/dataset_helpers.py +9 -2
  7. hafnia/dataset/dataset_names.py +2 -1
  8. hafnia/dataset/dataset_recipe/dataset_recipe.py +49 -37
  9. hafnia/dataset/dataset_recipe/recipe_transforms.py +18 -2
  10. hafnia/dataset/dataset_upload_helper.py +60 -4
  11. hafnia/dataset/format_conversions/image_classification_from_directory.py +106 -0
  12. hafnia/dataset/format_conversions/torchvision_datasets.py +281 -0
  13. hafnia/dataset/hafnia_dataset.py +176 -50
  14. hafnia/dataset/operations/dataset_stats.py +2 -3
  15. hafnia/dataset/operations/dataset_transformations.py +19 -15
  16. hafnia/dataset/operations/table_transformations.py +4 -3
  17. hafnia/dataset/primitives/bbox.py +25 -12
  18. hafnia/dataset/primitives/bitmask.py +26 -14
  19. hafnia/dataset/primitives/classification.py +16 -8
  20. hafnia/dataset/primitives/point.py +7 -3
  21. hafnia/dataset/primitives/polygon.py +16 -9
  22. hafnia/dataset/primitives/segmentation.py +10 -7
  23. hafnia/experiment/hafnia_logger.py +0 -9
  24. hafnia/platform/dataset_recipe.py +7 -2
  25. hafnia/platform/datasets.py +3 -3
  26. hafnia/platform/download.py +23 -18
  27. hafnia/utils.py +17 -0
  28. hafnia/visualizations/image_visualizations.py +1 -1
  29. {hafnia-0.3.0.dist-info → hafnia-0.4.0.dist-info}/METADATA +8 -6
  30. hafnia-0.4.0.dist-info/RECORD +56 -0
  31. hafnia-0.3.0.dist-info/RECORD +0 -53
  32. {hafnia-0.3.0.dist-info → hafnia-0.4.0.dist-info}/WHEEL +0 -0
  33. {hafnia-0.3.0.dist-info → hafnia-0.4.0.dist-info}/entry_points.txt +0 -0
  34. {hafnia-0.3.0.dist-info → hafnia-0.4.0.dist-info}/licenses/LICENSE +0 -0
cli/__main__.py CHANGED
@@ -37,7 +37,9 @@ def configure(cfg: Config) -> None:
37
37
 
38
38
  platform_url = click.prompt("Hafnia Platform URL", type=str, default=consts.DEFAULT_API_URL)
39
39
 
40
- cfg_profile = ConfigSchema(api_key=api_key, platform_url=platform_url)
40
+ use_keychain = click.confirm("Store API key in system keychain?", default=False)
41
+
42
+ cfg_profile = ConfigSchema(platform_url=platform_url, api_key=api_key, use_keychain=use_keychain)
41
43
  cfg.add_profile(profile_name, cfg_profile, set_active=True)
42
44
  cfg.save_config()
43
45
  profile_cmds.profile_show(cfg)
cli/config.py CHANGED
@@ -6,6 +6,7 @@ from typing import Dict, List, Optional
6
6
  from pydantic import BaseModel, field_validator
7
7
 
8
8
  import cli.consts as consts
9
+ import cli.keychain as keychain
9
10
  from hafnia.log import sys_logger, user_logger
10
11
 
11
12
  PLATFORM_API_MAPPING = {
@@ -19,9 +20,18 @@ PLATFORM_API_MAPPING = {
19
20
  }
20
21
 
21
22
 
23
+ class SecretStr(str):
24
+ def __repr__(self):
25
+ return "********"
26
+
27
+ def __str__(self):
28
+ return "********"
29
+
30
+
22
31
  class ConfigSchema(BaseModel):
23
32
  platform_url: str = ""
24
33
  api_key: Optional[str] = None
34
+ use_keychain: bool = False
25
35
 
26
36
  @field_validator("api_key")
27
37
  def validate_api_key(cls, value: Optional[str]) -> Optional[str]:
@@ -35,7 +45,7 @@ class ConfigSchema(BaseModel):
35
45
  sys_logger.warning("API key is missing the 'ApiKey ' prefix. Prefix is being added automatically.")
36
46
  value = f"ApiKey {value}"
37
47
 
38
- return value
48
+ return SecretStr(value) # Keeps the API key masked in logs and repr
39
49
 
40
50
 
41
51
  class ConfigFileSchema(BaseModel):
@@ -70,13 +80,32 @@ class Config:
70
80
 
71
81
  @property
72
82
  def api_key(self) -> str:
83
+ # Check keychain first if enabled
84
+ if self.config.use_keychain:
85
+ keychain_key = keychain.get_api_key(self.active_profile)
86
+ if keychain_key is not None:
87
+ return keychain_key
88
+
89
+ # Fall back to config file
73
90
  if self.config.api_key is not None:
74
91
  return self.config.api_key
92
+
75
93
  raise ValueError(consts.ERROR_API_KEY_NOT_SET)
76
94
 
77
95
  @api_key.setter
78
96
  def api_key(self, value: str) -> None:
79
- self.config.api_key = value
97
+ # Store in keychain if enabled
98
+ if self.config.use_keychain:
99
+ if keychain.store_api_key(self.active_profile, value):
100
+ # Successfully stored in keychain, don't store in config
101
+ self.config.api_key = None
102
+ else:
103
+ # Keychain storage failed, fall back to config file
104
+ sys_logger.warning("Failed to store in keychain, falling back to config file")
105
+ self.config.api_key = value
106
+ else:
107
+ # Not using keychain, store in config file
108
+ self.config.api_key = value
80
109
 
81
110
  @property
82
111
  def platform_url(self) -> str:
@@ -152,8 +181,19 @@ class Config:
152
181
  raise ValueError("Failed to parse configuration file")
153
182
 
154
183
  def save_config(self) -> None:
184
+ # Create a copy to avoid modifying the original data
185
+ config_to_save = self.config_data.model_dump()
186
+
187
+ # Store API key in keychain if enabled, and don't write to file
188
+ for profile_name, profile_data in config_to_save["profiles"].items():
189
+ if profile_data.get("use_keychain", False):
190
+ api_key = profile_data.get("api_key")
191
+ if api_key:
192
+ keychain.store_api_key(profile_name, api_key)
193
+ profile_data["api_key"] = None
194
+
155
195
  with open(self.config_path, "w") as f:
156
- json.dump(self.config_data.model_dump(), f, indent=4)
196
+ json.dump(config_to_save, f, indent=4)
157
197
 
158
198
  def remove_profile(self, profile_name: str) -> None:
159
199
  if profile_name not in self.config_data.profiles:
cli/keychain.py ADDED
@@ -0,0 +1,88 @@
1
+ """Keychain storage for API keys using the system keychain."""
2
+
3
+ from typing import Optional
4
+
5
+ from hafnia.log import sys_logger
6
+
7
+ # Keyring is optional - gracefully degrade if not available
8
+ try:
9
+ import keyring
10
+
11
+ KEYRING_AVAILABLE = True
12
+ except ImportError:
13
+ KEYRING_AVAILABLE = False
14
+ sys_logger.debug("keyring library not available, keychain storage disabled")
15
+
16
+ KEYRING_SERVICE_NAME = "hafnia-cli"
17
+
18
+
19
+ def store_api_key(profile_name: str, api_key: str) -> bool:
20
+ """
21
+ Store an API key in the system keychain.
22
+
23
+ Args:
24
+ profile_name: The profile name to associate with the key
25
+ api_key: The API key to store
26
+
27
+ Returns:
28
+ True if successfully stored, False otherwise
29
+ """
30
+ if not KEYRING_AVAILABLE:
31
+ sys_logger.warning("Keyring library not available, cannot store API key in keychain")
32
+ return False
33
+
34
+ try:
35
+ keyring.set_password(KEYRING_SERVICE_NAME, profile_name, api_key)
36
+ sys_logger.debug(f"Stored API key for profile '{profile_name}' in keychain")
37
+ return True
38
+ except Exception as e:
39
+ sys_logger.warning(f"Failed to store API key in keychain: {e}")
40
+ return False
41
+
42
+
43
+ def get_api_key(profile_name: str) -> Optional[str]:
44
+ """
45
+ Retrieve an API key from the system keychain.
46
+
47
+ Args:
48
+ profile_name: The profile name to retrieve the key for
49
+
50
+ Returns:
51
+ The API key if found, None otherwise
52
+ """
53
+ if not KEYRING_AVAILABLE:
54
+ return None
55
+
56
+ try:
57
+ api_key = keyring.get_password(KEYRING_SERVICE_NAME, profile_name)
58
+ if api_key:
59
+ sys_logger.debug(f"Retrieved API key for profile '{profile_name}' from keychain")
60
+ return api_key
61
+ except Exception as e:
62
+ sys_logger.warning(f"Failed to retrieve API key from keychain: {e}")
63
+ return None
64
+
65
+
66
+ def delete_api_key(profile_name: str) -> bool:
67
+ """
68
+ Delete an API key from the system keychain.
69
+
70
+ Args:
71
+ profile_name: The profile name to delete the key for
72
+
73
+ Returns:
74
+ True if successfully deleted or didn't exist, False on error
75
+ """
76
+ if not KEYRING_AVAILABLE:
77
+ return False
78
+
79
+ try:
80
+ keyring.delete_password(KEYRING_SERVICE_NAME, profile_name)
81
+ sys_logger.debug(f"Deleted API key for profile '{profile_name}' from keychain")
82
+ return True
83
+ except keyring.errors.PasswordDeleteError:
84
+ # Key didn't exist, which is fine
85
+ return True
86
+ except Exception as e:
87
+ sys_logger.warning(f"Failed to delete API key from keychain: {e}")
88
+ return False
cli/profile_cmds.py CHANGED
@@ -50,10 +50,13 @@ def cmd_profile_use(cfg: Config, profile_name: str) -> None:
50
50
  @click.option(
51
51
  "--activate/--no-activate", help="Activate the created profile after creation", default=True, show_default=True
52
52
  )
53
+ @click.option(
54
+ "--use-keychain", is_flag=True, help="Store API key in system keychain instead of config file", default=False
55
+ )
53
56
  @click.pass_obj
54
- def cmd_profile_create(cfg: Config, name: str, api_url: str, api_key: str, activate: bool) -> None:
57
+ def cmd_profile_create(cfg: Config, name: str, api_url: str, api_key: str, activate: bool, use_keychain: bool) -> None:
55
58
  """Create a new profile."""
56
- cfg_profile = ConfigSchema(platform_url=api_url, api_key=api_key)
59
+ cfg_profile = ConfigSchema(platform_url=api_url, api_key=api_key, use_keychain=use_keychain)
57
60
 
58
61
  cfg.add_profile(profile_name=name, profile=cfg_profile, set_active=activate)
59
62
  profile_show(cfg)
hafnia/__init__.py CHANGED
@@ -3,4 +3,4 @@ from importlib.metadata import version
3
3
  __package_name__ = "hafnia"
4
4
  __version__ = version(__package_name__)
5
5
 
6
- __dataset_format_version__ = "0.0.2" # Hafnia dataset format version
6
+ __dataset_format_version__ = "0.1.0" # Hafnia dataset format version
@@ -38,12 +38,19 @@ def hash_from_bytes(data: bytes) -> str:
38
38
 
39
39
  def save_image_with_hash_name(image: np.ndarray, path_folder: Path) -> Path:
40
40
  pil_image = Image.fromarray(image)
41
+ path_image = save_pil_image_with_hash_name(pil_image, path_folder)
42
+ return path_image
43
+
44
+
45
+ def save_pil_image_with_hash_name(image: Image.Image, path_folder: Path, allow_skip: bool = True) -> Path:
41
46
  buffer = io.BytesIO()
42
- pil_image.save(buffer, format="PNG")
47
+ image.save(buffer, format="PNG")
43
48
  hash_value = hash_from_bytes(buffer.getvalue())
44
49
  path_image = Path(path_folder) / relative_path_from_hash(hash=hash_value, suffix=".png")
50
+ if allow_skip and path_image.exists():
51
+ return path_image
45
52
  path_image.parent.mkdir(parents=True, exist_ok=True)
46
- pil_image.save(path_image)
53
+ image.save(path_image)
47
54
  return path_image
48
55
 
49
56
 
@@ -49,7 +49,7 @@ class FieldName:
49
49
 
50
50
  class ColumnName:
51
51
  SAMPLE_INDEX: str = "sample_index"
52
- FILE_NAME: str = "file_name"
52
+ FILE_PATH: str = "file_path"
53
53
  HEIGHT: str = "height"
54
54
  WIDTH: str = "width"
55
55
  SPLIT: str = "split"
@@ -57,6 +57,7 @@ class ColumnName:
57
57
  ATTRIBUTION: str = "attribution" # Attribution for the sample (image/video), e.g. creator, license, source, etc.
58
58
  TAGS: str = "tags"
59
59
  META: str = "meta"
60
+ DATASET_NAME: str = "dataset_name"
60
61
 
61
62
 
62
63
  class SplitName:
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import json
4
4
  import os
5
5
  from pathlib import Path
6
- from typing import Any, Callable, Dict, List, Optional, Type, Union
6
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
7
7
 
8
8
  from pydantic import (
9
9
  field_serializer,
@@ -12,7 +12,11 @@ from pydantic import (
12
12
 
13
13
  from hafnia import utils
14
14
  from hafnia.dataset.dataset_recipe import recipe_transforms
15
- from hafnia.dataset.dataset_recipe.recipe_types import RecipeCreation, RecipeTransform, Serializable
15
+ from hafnia.dataset.dataset_recipe.recipe_types import (
16
+ RecipeCreation,
17
+ RecipeTransform,
18
+ Serializable,
19
+ )
16
20
  from hafnia.dataset.hafnia_dataset import HafniaDataset
17
21
  from hafnia.dataset.primitives.primitive import Primitive
18
22
 
@@ -41,6 +45,17 @@ class DatasetRecipe(Serializable):
41
45
  creation = FromName(name=name, force_redownload=force_redownload, download_files=download_files)
42
46
  return DatasetRecipe(creation=creation)
43
47
 
48
+ @staticmethod
49
+ def from_name_public_dataset(
50
+ name: str, force_redownload: bool = False, n_samples: Optional[int] = None
51
+ ) -> DatasetRecipe:
52
+ creation = FromNamePublicDataset(
53
+ name=name,
54
+ force_redownload=force_redownload,
55
+ n_samples=n_samples,
56
+ )
57
+ return DatasetRecipe(creation=creation)
58
+
44
59
  @staticmethod
45
60
  def from_path(path_folder: Path, check_for_images: bool = True) -> DatasetRecipe:
46
61
  creation = FromPath(path_folder=path_folder, check_for_images=check_for_images)
@@ -222,7 +237,7 @@ class DatasetRecipe(Serializable):
222
237
  """Serialize the dataset recipe to a dictionary."""
223
238
  return self.model_dump(mode="json")
224
239
 
225
- def as_platform_recipe(self, recipe_name: Optional[str]) -> Dict:
240
+ def as_platform_recipe(self, recipe_name: Optional[str], overwrite: bool = False) -> Dict:
226
241
  """Uploads dataset recipe to the hafnia platform."""
227
242
  from cli.config import Config
228
243
  from hafnia.platform.dataset_recipe import get_or_create_dataset_recipe
@@ -235,6 +250,7 @@ class DatasetRecipe(Serializable):
235
250
  endpoint=endpoint_dataset,
236
251
  api_key=cfg.api_key,
237
252
  name=recipe_name,
253
+ overwrite=overwrite,
238
254
  )
239
255
 
240
256
  return recipe_dict
@@ -246,10 +262,17 @@ class DatasetRecipe(Serializable):
246
262
  return recipe
247
263
 
248
264
  def select_samples(
249
- recipe: DatasetRecipe, n_samples: int, shuffle: bool = True, seed: int = 42, with_replacement: bool = False
265
+ recipe: DatasetRecipe,
266
+ n_samples: int,
267
+ shuffle: bool = True,
268
+ seed: int = 42,
269
+ with_replacement: bool = False,
250
270
  ) -> DatasetRecipe:
251
271
  operation = recipe_transforms.SelectSamples(
252
- n_samples=n_samples, shuffle=shuffle, seed=seed, with_replacement=with_replacement
272
+ n_samples=n_samples,
273
+ shuffle=shuffle,
274
+ seed=seed,
275
+ with_replacement=with_replacement,
253
276
  )
254
277
  recipe.append_operation(operation)
255
278
  return recipe
@@ -273,7 +296,7 @@ class DatasetRecipe(Serializable):
273
296
 
274
297
  def class_mapper(
275
298
  recipe: DatasetRecipe,
276
- class_mapping: Dict[str, str],
299
+ class_mapping: Union[Dict[str, str], List[Tuple[str, str]]],
277
300
  method: str = "strict",
278
301
  primitive: Optional[Type[Primitive]] = None,
279
302
  task_name: Optional[str] = None,
@@ -400,6 +423,22 @@ class FromName(RecipeCreation):
400
423
  return [self.name]
401
424
 
402
425
 
426
+ class FromNamePublicDataset(RecipeCreation):
427
+ name: str
428
+ force_redownload: bool = False
429
+ n_samples: Optional[int] = None
430
+
431
+ @staticmethod
432
+ def get_function() -> Callable[..., "HafniaDataset"]:
433
+ return HafniaDataset.from_name_public_dataset
434
+
435
+ def as_short_name(self) -> str:
436
+ return f"Torchvision('{self.name}')"
437
+
438
+ def get_dataset_names(self) -> List[str]:
439
+ return []
440
+
441
+
403
442
  class FromMerge(RecipeCreation):
404
443
  recipe0: DatasetRecipe
405
444
  recipe1: DatasetRecipe
@@ -414,7 +453,10 @@ class FromMerge(RecipeCreation):
414
453
 
415
454
  def get_dataset_names(self) -> List[str]:
416
455
  """Get the dataset names from the merged recipes."""
417
- names = [*self.recipe0.creation.get_dataset_names(), *self.recipe1.creation.get_dataset_names()]
456
+ names = [
457
+ *self.recipe0.creation.get_dataset_names(),
458
+ *self.recipe1.creation.get_dataset_names(),
459
+ ]
418
460
  return names
419
461
 
420
462
 
@@ -439,33 +481,3 @@ class FromMerger(RecipeCreation):
439
481
  for recipe in self.recipes:
440
482
  names.extend(recipe.creation.get_dataset_names())
441
483
  return names
442
-
443
-
444
- def extract_dataset_names_from_json_dict(data: dict) -> list[str]:
445
- """
446
- Extract dataset names recursively from a JSON dictionary added with 'from_name'.
447
-
448
- Even if the same functionality is achieved with `DatasetRecipe.get_dataset_names()`,
449
- we want to keep this function in 'dipdatalib' to extract dataset names from json dictionaries
450
- directly.
451
- """
452
- creation_field = data.get("creation")
453
- if creation_field is None:
454
- return []
455
- if creation_field.get("__type__") == "FromName":
456
- return [creation_field["name"]]
457
- elif creation_field.get("__type__") == "FromMerge":
458
- recipe_names = ["recipe0", "recipe1"]
459
- dataset_name = []
460
- for recipe_name in recipe_names:
461
- recipe = creation_field.get(recipe_name)
462
- if recipe is None:
463
- continue
464
- dataset_name.extend(extract_dataset_names_from_json_dict(recipe))
465
- return dataset_name
466
- elif creation_field.get("__type__") == "FromMerger":
467
- dataset_name = []
468
- for recipe in creation_field.get("recipes", []):
469
- dataset_name.extend(extract_dataset_names_from_json_dict(recipe))
470
- return dataset_name
471
- return []
@@ -1,4 +1,6 @@
1
- from typing import Callable, Dict, List, Optional, Type, Union
1
+ from typing import Callable, Dict, List, Optional, Tuple, Type, Union
2
+
3
+ from pydantic import field_validator
2
4
 
3
5
  from hafnia.dataset.dataset_recipe.recipe_types import RecipeTransform
4
6
  from hafnia.dataset.hafnia_dataset import HafniaDataset
@@ -52,11 +54,25 @@ class DefineSampleSetBySize(RecipeTransform):
52
54
 
53
55
 
54
56
  class ClassMapper(RecipeTransform):
55
- class_mapping: Dict[str, str]
57
+ class_mapping: Union[Dict[str, str], List[Tuple[str, str]]]
56
58
  method: str = "strict"
57
59
  primitive: Optional[Type[Primitive]] = None
58
60
  task_name: Optional[str] = None
59
61
 
62
+ @field_validator("class_mapping", mode="after")
63
+ @classmethod
64
+ def serialize_class_mapping(cls, value: Union[Dict[str, str], List[Tuple[str, str]]]) -> List[Tuple[str, str]]:
65
+ # Converts the dictionary class mapping to a list of tuples
66
+ # e.g. {"old_class": "new_class", } --> [("old_class", "new_class")]
67
+ # The reason is that storing class mappings as a dictionary does not preserve order of json fields
68
+ # when stored in a database as a jsonb field (postgres).
69
+ # Preserving order of class mapping fields is important as it defines the indices of the classes.
70
+ # So to ensure that class indices are maintained, we preserve order of json fields, by converting the
71
+ # dictionary to a list of tuples.
72
+ if isinstance(value, dict):
73
+ value = list(value.items())
74
+ return value
75
+
60
76
  @staticmethod
61
77
  def get_function() -> Callable[..., "HafniaDataset"]:
62
78
  return HafniaDataset.class_mapper
@@ -4,7 +4,7 @@ import base64
4
4
  from datetime import datetime
5
5
  from enum import Enum
6
6
  from pathlib import Path
7
- from typing import Dict, List, Optional, Tuple, Type, Union
7
+ from typing import Any, Dict, List, Optional, Tuple, Type, Union
8
8
 
9
9
  import boto3
10
10
  import polars as pl
@@ -52,6 +52,7 @@ class DbDataset(BaseModel, validate_assignment=True): # type: ignore[call-arg]
52
52
  license_citation: Optional[str] = None
53
53
  version: Optional[str] = None
54
54
  s3_bucket_name: Optional[str] = None
55
+ dataset_format_version: Optional[str] = None
55
56
  annotation_date: Optional[datetime] = None
56
57
  annotation_project_id: Optional[str] = None
57
58
  annotation_dataset_id: Optional[str] = None
@@ -186,9 +187,58 @@ class EntityTypeChoices(str, Enum): # Should match `EntityTypeChoices` in `dipd
186
187
  EVENT = "EVENT"
187
188
 
188
189
 
190
+ class Annotations(BaseModel):
191
+ """
192
+ Used in 'DatasetImageMetadata' for visualizing image annotations
193
+ in gallery images on the dataset detail page.
194
+ """
195
+
196
+ objects: Optional[List[Bbox]] = None
197
+ classifications: Optional[List[Classification]] = None
198
+ polygons: Optional[List[Polygon]] = None
199
+ bitmasks: Optional[List[Bitmask]] = None
200
+
201
+
202
+ class DatasetImageMetadata(BaseModel):
203
+ """
204
+ Metadata for gallery images on the dataset detail page on portal.
205
+ """
206
+
207
+ annotations: Optional[Annotations] = None
208
+ meta: Optional[Dict[str, Any]] = None
209
+
210
+ @classmethod
211
+ def from_sample(cls, sample: Sample) -> "DatasetImageMetadata":
212
+ sample = sample.model_copy(deep=True)
213
+ sample.file_path = "/".join(Path(sample.file_path).parts[-3:])
214
+ metadata = {}
215
+ metadata_field_names = [
216
+ ColumnName.FILE_PATH,
217
+ ColumnName.HEIGHT,
218
+ ColumnName.WIDTH,
219
+ ColumnName.SPLIT,
220
+ ]
221
+ for field_name in metadata_field_names:
222
+ if hasattr(sample, field_name) and getattr(sample, field_name) is not None:
223
+ metadata[field_name] = getattr(sample, field_name)
224
+
225
+ obj = DatasetImageMetadata(
226
+ annotations=Annotations(
227
+ objects=sample.objects,
228
+ classifications=sample.classifications,
229
+ polygons=sample.polygons,
230
+ bitmasks=sample.bitmasks,
231
+ ),
232
+ meta=metadata,
233
+ )
234
+
235
+ return obj
236
+
237
+
189
238
  class DatasetImage(Attribution, validate_assignment=True): # type: ignore[call-arg]
190
239
  img: str # Base64-encoded image string
191
240
  order: Optional[int] = None
241
+ metadata: Optional[DatasetImageMetadata] = None
192
242
 
193
243
  @field_validator("img", mode="before")
194
244
  def validate_image_path(cls, v: Union[str, Path]) -> str:
@@ -254,7 +304,7 @@ def upload_dataset_details(cfg: Config, data: str, dataset_name: str) -> dict:
254
304
  import_endpoint = f"{dataset_endpoint}/{dataset_id}/import"
255
305
  headers = {"Authorization": cfg.api_key}
256
306
 
257
- user_logger.info("Importing dataset details. This may take up to 30 seconds...")
307
+ user_logger.info("Exporting dataset details to platform. This may take up to 30 seconds...")
258
308
  response = post(endpoint=import_endpoint, headers=headers, data=data) # type: ignore[assignment]
259
309
  return response # type: ignore[return-value]
260
310
 
@@ -569,7 +619,9 @@ def dataset_info_from_dataset(
569
619
  s3_bucket_name=bucket_sample,
570
620
  dataset_variants=dataset_variants,
571
621
  split_annotations_reports=dataset_reports,
572
- license_citation=dataset_meta_info.get("license_citation", None),
622
+ latest_update=dataset.info.updated_at,
623
+ dataset_format_version=dataset.info.format_version,
624
+ license_citation=dataset.info.reference_bibtex,
573
625
  data_captured_start=dataset_meta_info.get("data_captured_start", None),
574
626
  data_captured_end=dataset_meta_info.get("data_captured_end", None),
575
627
  data_received_start=dataset_meta_info.get("data_received_start", None),
@@ -594,7 +646,7 @@ def create_gallery_images(
594
646
  path_gallery_images.mkdir(parents=True, exist_ok=True)
595
647
  COL_IMAGE_NAME = "image_name"
596
648
  samples = dataset.samples.with_columns(
597
- dataset.samples[ColumnName.FILE_NAME].str.split("/").list.last().alias(COL_IMAGE_NAME)
649
+ dataset.samples[ColumnName.FILE_PATH].str.split("/").list.last().alias(COL_IMAGE_NAME)
598
650
  )
599
651
  gallery_samples = samples.filter(pl.col(COL_IMAGE_NAME).is_in(gallery_image_names))
600
652
 
@@ -604,6 +656,9 @@ def create_gallery_images(
604
656
  gallery_images = []
605
657
  for gallery_sample in gallery_samples.iter_rows(named=True):
606
658
  sample = Sample(**gallery_sample)
659
+
660
+ metadata = DatasetImageMetadata.from_sample(sample=sample)
661
+ sample.classifications = None # To not draw classifications in gallery images
607
662
  image = sample.draw_annotations()
608
663
 
609
664
  path_gallery_image = path_gallery_images / gallery_sample[COL_IMAGE_NAME]
@@ -611,6 +666,7 @@ def create_gallery_images(
611
666
 
612
667
  dataset_image_dict = {
613
668
  "img": path_gallery_image,
669
+ "metadata": metadata,
614
670
  }
615
671
  if sample.attribution is not None:
616
672
  sample.attribution.changes = "Annotations have been visualized"
@@ -0,0 +1,106 @@
1
+ import shutil
2
+ from pathlib import Path
3
+ from typing import List, Optional
4
+
5
+ import more_itertools
6
+ import polars as pl
7
+ from PIL import Image
8
+ from rich.progress import track
9
+
10
+ from hafnia.dataset.dataset_names import ColumnName, FieldName
11
+ from hafnia.dataset.hafnia_dataset import DatasetInfo, HafniaDataset, Sample, TaskInfo
12
+ from hafnia.dataset.primitives import Classification
13
+ from hafnia.utils import is_image_file
14
+
15
+
16
+ def import_image_classification_directory_tree(
17
+ path_folder: Path,
18
+ split: str,
19
+ n_samples: Optional[int] = None,
20
+ ) -> HafniaDataset:
21
+ class_folder_paths = [path for path in path_folder.iterdir() if path.is_dir()]
22
+ class_names = sorted([folder.name for folder in class_folder_paths]) # Sort for determinism
23
+
24
+ # Gather all image paths per class
25
+ path_images_per_class: List[List[Path]] = []
26
+ for path_class_folder in class_folder_paths:
27
+ per_class_images = []
28
+ for path_image in list(path_class_folder.rglob("*.*")):
29
+ if is_image_file(path_image):
30
+ per_class_images.append(path_image)
31
+ path_images_per_class.append(sorted(per_class_images))
32
+
33
+ # Interleave to ensure classes are balanced in the output dataset for n_samples < total
34
+ path_images = list(more_itertools.interleave_longest(*path_images_per_class))
35
+
36
+ if n_samples is not None:
37
+ path_images = path_images[:n_samples]
38
+
39
+ samples = []
40
+ for path_image_org in track(path_images, description="Convert 'image classification' dataset to Hafnia Dataset"):
41
+ class_name = path_image_org.parent.name
42
+
43
+ read_image = Image.open(path_image_org)
44
+ width, height = read_image.size
45
+
46
+ classifications = [Classification(class_name=class_name, class_idx=class_names.index(class_name))]
47
+ sample = Sample(
48
+ file_path=str(path_image_org.absolute()),
49
+ width=width,
50
+ height=height,
51
+ split=split,
52
+ classifications=classifications,
53
+ )
54
+ samples.append(sample)
55
+
56
+ dataset_info = DatasetInfo(
57
+ dataset_name="ImageClassificationFromDirectoryTree",
58
+ tasks=[TaskInfo(primitive=Classification, class_names=class_names)],
59
+ )
60
+
61
+ hafnia_dataset = HafniaDataset.from_samples_list(samples_list=samples, info=dataset_info)
62
+ return hafnia_dataset
63
+
64
+
65
+ def export_image_classification_directory_tree(
66
+ dataset: HafniaDataset,
67
+ path_output: Path,
68
+ task_name: Optional[str] = None,
69
+ clean_folder: bool = False,
70
+ ) -> Path:
71
+ task = dataset.info.get_task_by_task_name_and_primitive(task_name=task_name, primitive=Classification)
72
+
73
+ samples = dataset.samples.with_columns(
74
+ pl.col(task.primitive.column_name())
75
+ .list.filter(pl.element().struct.field(FieldName.TASK_NAME) == task.name)
76
+ .alias(task.primitive.column_name())
77
+ )
78
+
79
+ classification_counts = samples[task.primitive.column_name()].list.len()
80
+ has_no_classification_samples = (classification_counts == 0).sum()
81
+ if has_no_classification_samples > 0:
82
+ raise ValueError(f"Some samples do not have a classification for task '{task.name}'.")
83
+
84
+ has_multi_classification_samples = (classification_counts > 1).sum()
85
+ if has_multi_classification_samples > 0:
86
+ raise ValueError(f"Some samples have multiple classifications for task '{task.name}'.")
87
+
88
+ if clean_folder:
89
+ shutil.rmtree(path_output, ignore_errors=True)
90
+ path_output.mkdir(parents=True, exist_ok=True)
91
+
92
+ description = "Export Hafnia Dataset to directory tree"
93
+ for sample_dict in track(samples.iter_rows(named=True), total=len(samples), description=description):
94
+ classifications = sample_dict[task.primitive.column_name()]
95
+ if len(classifications) != 1:
96
+ raise ValueError("Each sample should have exactly one classification.")
97
+ classification = classifications[0]
98
+ class_name = classification[FieldName.CLASS_NAME].replace("/", "_") # Avoid issues with subfolders
99
+ path_class_folder = path_output / class_name
100
+ path_class_folder.mkdir(parents=True, exist_ok=True)
101
+
102
+ path_image_org = Path(sample_dict[ColumnName.FILE_PATH])
103
+ path_image_new = path_class_folder / path_image_org.name
104
+ shutil.copy2(path_image_org, path_image_new)
105
+
106
+ return path_output