hafnia 0.3.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. cli/__main__.py +3 -1
  2. cli/config.py +43 -3
  3. cli/keychain.py +88 -0
  4. cli/profile_cmds.py +5 -2
  5. hafnia/__init__.py +1 -1
  6. hafnia/dataset/dataset_helpers.py +9 -2
  7. hafnia/dataset/dataset_names.py +130 -16
  8. hafnia/dataset/dataset_recipe/dataset_recipe.py +49 -37
  9. hafnia/dataset/dataset_recipe/recipe_transforms.py +18 -2
  10. hafnia/dataset/dataset_upload_helper.py +83 -22
  11. hafnia/dataset/format_conversions/format_image_classification_folder.py +110 -0
  12. hafnia/dataset/format_conversions/format_yolo.py +164 -0
  13. hafnia/dataset/format_conversions/torchvision_datasets.py +287 -0
  14. hafnia/dataset/hafnia_dataset.py +396 -96
  15. hafnia/dataset/operations/dataset_stats.py +84 -73
  16. hafnia/dataset/operations/dataset_transformations.py +116 -47
  17. hafnia/dataset/operations/table_transformations.py +135 -17
  18. hafnia/dataset/primitives/bbox.py +25 -14
  19. hafnia/dataset/primitives/bitmask.py +22 -15
  20. hafnia/dataset/primitives/classification.py +16 -8
  21. hafnia/dataset/primitives/point.py +7 -3
  22. hafnia/dataset/primitives/polygon.py +15 -10
  23. hafnia/dataset/primitives/primitive.py +1 -1
  24. hafnia/dataset/primitives/segmentation.py +12 -9
  25. hafnia/experiment/hafnia_logger.py +0 -9
  26. hafnia/platform/dataset_recipe.py +7 -2
  27. hafnia/platform/datasets.py +5 -9
  28. hafnia/platform/download.py +24 -90
  29. hafnia/torch_helpers.py +12 -12
  30. hafnia/utils.py +17 -0
  31. hafnia/visualizations/image_visualizations.py +3 -1
  32. {hafnia-0.3.0.dist-info → hafnia-0.4.1.dist-info}/METADATA +11 -9
  33. hafnia-0.4.1.dist-info/RECORD +57 -0
  34. hafnia-0.3.0.dist-info/RECORD +0 -53
  35. {hafnia-0.3.0.dist-info → hafnia-0.4.1.dist-info}/WHEEL +0 -0
  36. {hafnia-0.3.0.dist-info → hafnia-0.4.1.dist-info}/entry_points.txt +0 -0
  37. {hafnia-0.3.0.dist-info → hafnia-0.4.1.dist-info}/licenses/LICENSE +0 -0
cli/__main__.py CHANGED
@@ -37,7 +37,9 @@ def configure(cfg: Config) -> None:
37
37
 
38
38
  platform_url = click.prompt("Hafnia Platform URL", type=str, default=consts.DEFAULT_API_URL)
39
39
 
40
- cfg_profile = ConfigSchema(api_key=api_key, platform_url=platform_url)
40
+ use_keychain = click.confirm("Store API key in system keychain?", default=False)
41
+
42
+ cfg_profile = ConfigSchema(platform_url=platform_url, api_key=api_key, use_keychain=use_keychain)
41
43
  cfg.add_profile(profile_name, cfg_profile, set_active=True)
42
44
  cfg.save_config()
43
45
  profile_cmds.profile_show(cfg)
cli/config.py CHANGED
@@ -6,6 +6,7 @@ from typing import Dict, List, Optional
6
6
  from pydantic import BaseModel, field_validator
7
7
 
8
8
  import cli.consts as consts
9
+ import cli.keychain as keychain
9
10
  from hafnia.log import sys_logger, user_logger
10
11
 
11
12
  PLATFORM_API_MAPPING = {
@@ -19,9 +20,18 @@ PLATFORM_API_MAPPING = {
19
20
  }
20
21
 
21
22
 
23
+ class SecretStr(str):
24
+ def __repr__(self):
25
+ return "********"
26
+
27
+ def __str__(self):
28
+ return "********"
29
+
30
+
22
31
  class ConfigSchema(BaseModel):
23
32
  platform_url: str = ""
24
33
  api_key: Optional[str] = None
34
+ use_keychain: bool = False
25
35
 
26
36
  @field_validator("api_key")
27
37
  def validate_api_key(cls, value: Optional[str]) -> Optional[str]:
@@ -35,7 +45,7 @@ class ConfigSchema(BaseModel):
35
45
  sys_logger.warning("API key is missing the 'ApiKey ' prefix. Prefix is being added automatically.")
36
46
  value = f"ApiKey {value}"
37
47
 
38
- return value
48
+ return SecretStr(value) # Keeps the API key masked in logs and repr
39
49
 
40
50
 
41
51
  class ConfigFileSchema(BaseModel):
@@ -70,13 +80,32 @@ class Config:
70
80
 
71
81
  @property
72
82
  def api_key(self) -> str:
83
+ # Check keychain first if enabled
84
+ if self.config.use_keychain:
85
+ keychain_key = keychain.get_api_key(self.active_profile)
86
+ if keychain_key is not None:
87
+ return keychain_key
88
+
89
+ # Fall back to config file
73
90
  if self.config.api_key is not None:
74
91
  return self.config.api_key
92
+
75
93
  raise ValueError(consts.ERROR_API_KEY_NOT_SET)
76
94
 
77
95
  @api_key.setter
78
96
  def api_key(self, value: str) -> None:
79
- self.config.api_key = value
97
+ # Store in keychain if enabled
98
+ if self.config.use_keychain:
99
+ if keychain.store_api_key(self.active_profile, value):
100
+ # Successfully stored in keychain, don't store in config
101
+ self.config.api_key = None
102
+ else:
103
+ # Keychain storage failed, fall back to config file
104
+ sys_logger.warning("Failed to store in keychain, falling back to config file")
105
+ self.config.api_key = value
106
+ else:
107
+ # Not using keychain, store in config file
108
+ self.config.api_key = value
80
109
 
81
110
  @property
82
111
  def platform_url(self) -> str:
@@ -152,8 +181,19 @@ class Config:
152
181
  raise ValueError("Failed to parse configuration file")
153
182
 
154
183
  def save_config(self) -> None:
184
+ # Create a copy to avoid modifying the original data
185
+ config_to_save = self.config_data.model_dump()
186
+
187
+ # Store API key in keychain if enabled, and don't write to file
188
+ for profile_name, profile_data in config_to_save["profiles"].items():
189
+ if profile_data.get("use_keychain", False):
190
+ api_key = profile_data.get("api_key")
191
+ if api_key:
192
+ keychain.store_api_key(profile_name, api_key)
193
+ profile_data["api_key"] = None
194
+
155
195
  with open(self.config_path, "w") as f:
156
- json.dump(self.config_data.model_dump(), f, indent=4)
196
+ json.dump(config_to_save, f, indent=4)
157
197
 
158
198
  def remove_profile(self, profile_name: str) -> None:
159
199
  if profile_name not in self.config_data.profiles:
cli/keychain.py ADDED
@@ -0,0 +1,88 @@
1
+ """Keychain storage for API keys using the system keychain."""
2
+
3
+ from typing import Optional
4
+
5
+ from hafnia.log import sys_logger
6
+
7
+ # Keyring is optional - gracefully degrade if not available
8
+ try:
9
+ import keyring
10
+
11
+ KEYRING_AVAILABLE = True
12
+ except ImportError:
13
+ KEYRING_AVAILABLE = False
14
+ sys_logger.debug("keyring library not available, keychain storage disabled")
15
+
16
+ KEYRING_SERVICE_NAME = "hafnia-cli"
17
+
18
+
19
+ def store_api_key(profile_name: str, api_key: str) -> bool:
20
+ """
21
+ Store an API key in the system keychain.
22
+
23
+ Args:
24
+ profile_name: The profile name to associate with the key
25
+ api_key: The API key to store
26
+
27
+ Returns:
28
+ True if successfully stored, False otherwise
29
+ """
30
+ if not KEYRING_AVAILABLE:
31
+ sys_logger.warning("Keyring library not available, cannot store API key in keychain")
32
+ return False
33
+
34
+ try:
35
+ keyring.set_password(KEYRING_SERVICE_NAME, profile_name, api_key)
36
+ sys_logger.debug(f"Stored API key for profile '{profile_name}' in keychain")
37
+ return True
38
+ except Exception as e:
39
+ sys_logger.warning(f"Failed to store API key in keychain: {e}")
40
+ return False
41
+
42
+
43
+ def get_api_key(profile_name: str) -> Optional[str]:
44
+ """
45
+ Retrieve an API key from the system keychain.
46
+
47
+ Args:
48
+ profile_name: The profile name to retrieve the key for
49
+
50
+ Returns:
51
+ The API key if found, None otherwise
52
+ """
53
+ if not KEYRING_AVAILABLE:
54
+ return None
55
+
56
+ try:
57
+ api_key = keyring.get_password(KEYRING_SERVICE_NAME, profile_name)
58
+ if api_key:
59
+ sys_logger.debug(f"Retrieved API key for profile '{profile_name}' from keychain")
60
+ return api_key
61
+ except Exception as e:
62
+ sys_logger.warning(f"Failed to retrieve API key from keychain: {e}")
63
+ return None
64
+
65
+
66
+ def delete_api_key(profile_name: str) -> bool:
67
+ """
68
+ Delete an API key from the system keychain.
69
+
70
+ Args:
71
+ profile_name: The profile name to delete the key for
72
+
73
+ Returns:
74
+ True if successfully deleted or didn't exist, False on error
75
+ """
76
+ if not KEYRING_AVAILABLE:
77
+ return False
78
+
79
+ try:
80
+ keyring.delete_password(KEYRING_SERVICE_NAME, profile_name)
81
+ sys_logger.debug(f"Deleted API key for profile '{profile_name}' from keychain")
82
+ return True
83
+ except keyring.errors.PasswordDeleteError:
84
+ # Key didn't exist, which is fine
85
+ return True
86
+ except Exception as e:
87
+ sys_logger.warning(f"Failed to delete API key from keychain: {e}")
88
+ return False
cli/profile_cmds.py CHANGED
@@ -50,10 +50,13 @@ def cmd_profile_use(cfg: Config, profile_name: str) -> None:
50
50
  @click.option(
51
51
  "--activate/--no-activate", help="Activate the created profile after creation", default=True, show_default=True
52
52
  )
53
+ @click.option(
54
+ "--use-keychain", is_flag=True, help="Store API key in system keychain instead of config file", default=False
55
+ )
53
56
  @click.pass_obj
54
- def cmd_profile_create(cfg: Config, name: str, api_url: str, api_key: str, activate: bool) -> None:
57
+ def cmd_profile_create(cfg: Config, name: str, api_url: str, api_key: str, activate: bool, use_keychain: bool) -> None:
55
58
  """Create a new profile."""
56
- cfg_profile = ConfigSchema(platform_url=api_url, api_key=api_key)
59
+ cfg_profile = ConfigSchema(platform_url=api_url, api_key=api_key, use_keychain=use_keychain)
57
60
 
58
61
  cfg.add_profile(profile_name=name, profile=cfg_profile, set_active=activate)
59
62
  profile_show(cfg)
hafnia/__init__.py CHANGED
@@ -3,4 +3,4 @@ from importlib.metadata import version
3
3
  __package_name__ = "hafnia"
4
4
  __version__ = version(__package_name__)
5
5
 
6
- __dataset_format_version__ = "0.0.2" # Hafnia dataset format version
6
+ __dataset_format_version__ = "0.2.0" # Hafnia dataset format version
@@ -38,12 +38,19 @@ def hash_from_bytes(data: bytes) -> str:
38
38
 
39
39
  def save_image_with_hash_name(image: np.ndarray, path_folder: Path) -> Path:
40
40
  pil_image = Image.fromarray(image)
41
+ path_image = save_pil_image_with_hash_name(pil_image, path_folder)
42
+ return path_image
43
+
44
+
45
+ def save_pil_image_with_hash_name(image: Image.Image, path_folder: Path, allow_skip: bool = True) -> Path:
41
46
  buffer = io.BytesIO()
42
- pil_image.save(buffer, format="PNG")
47
+ image.save(buffer, format="PNG")
43
48
  hash_value = hash_from_bytes(buffer.getvalue())
44
49
  path_image = Path(path_folder) / relative_path_from_hash(hash=hash_value, suffix=".png")
50
+ if allow_skip and path_image.exists():
51
+ return path_image
45
52
  path_image.parent.mkdir(parents=True, exist_ok=True)
46
- pil_image.save(path_image)
53
+ image.save(path_image)
47
54
  return path_image
48
55
 
49
56
 
@@ -1,5 +1,8 @@
1
1
  from enum import Enum
2
- from typing import List
2
+ from typing import Dict, List, Optional
3
+
4
+ import boto3
5
+ from pydantic import BaseModel, field_validator
3
6
 
4
7
  FILENAME_RECIPE_JSON = "recipe.json"
5
8
  FILENAME_DATASET_INFO = "dataset_info.json"
@@ -23,7 +26,7 @@ TAG_IS_SAMPLE = "sample"
23
26
  OPS_REMOVE_CLASS = "__REMOVE__"
24
27
 
25
28
 
26
- class FieldName:
29
+ class PrimitiveField:
27
30
  CLASS_NAME: str = "class_name" # Name of the class this primitive is associated with, e.g. "car" for Bbox
28
31
  CLASS_IDX: str = "class_idx" # Index of the class this primitive is associated with, e.g. 0 for "car" if it is the first class # noqa: E501
29
32
  OBJECT_ID: str = "object_id" # Unique identifier for the object, e.g. "12345123"
@@ -38,39 +41,150 @@ class FieldName:
38
41
  Returns a list of expected field names for primitives.
39
42
  """
40
43
  return [
41
- FieldName.CLASS_NAME,
42
- FieldName.CLASS_IDX,
43
- FieldName.OBJECT_ID,
44
- FieldName.CONFIDENCE,
45
- FieldName.META,
46
- FieldName.TASK_NAME,
44
+ PrimitiveField.CLASS_NAME,
45
+ PrimitiveField.CLASS_IDX,
46
+ PrimitiveField.OBJECT_ID,
47
+ PrimitiveField.CONFIDENCE,
48
+ PrimitiveField.META,
49
+ PrimitiveField.TASK_NAME,
47
50
  ]
48
51
 
49
52
 
50
- class ColumnName:
51
- SAMPLE_INDEX: str = "sample_index"
52
- FILE_NAME: str = "file_name"
53
+ class SampleField:
54
+ FILE_PATH: str = "file_path"
53
55
  HEIGHT: str = "height"
54
56
  WIDTH: str = "width"
55
57
  SPLIT: str = "split"
58
+ TAGS: str = "tags"
59
+
60
+ CLASSIFICATIONS: str = "classifications"
61
+ BBOXES: str = "bboxes"
62
+ BITMASKS: str = "bitmasks"
63
+ POLYGONS: str = "polygons"
64
+
65
+ STORAGE_FORMAT: str = "storage_format" # E.g. "image", "video", "zip"
66
+ COLLECTION_INDEX: str = "collection_index"
67
+ COLLECTION_ID: str = "collection_id"
56
68
  REMOTE_PATH: str = "remote_path" # Path to the file in remote storage, e.g. S3
69
+ SAMPLE_INDEX: str = "sample_index"
70
+
57
71
  ATTRIBUTION: str = "attribution" # Attribution for the sample (image/video), e.g. creator, license, source, etc.
58
- TAGS: str = "tags"
59
72
  META: str = "meta"
73
+ DATASET_NAME: str = "dataset_name"
74
+
75
+
76
+ class StorageFormat:
77
+ IMAGE: str = "image"
78
+ VIDEO: str = "video"
79
+ ZIP: str = "zip"
60
80
 
61
81
 
62
82
  class SplitName:
63
- TRAIN = "train"
64
- VAL = "validation"
65
- TEST = "test"
66
- UNDEFINED = "UNDEFINED"
83
+ TRAIN: str = "train"
84
+ VAL: str = "validation"
85
+ TEST: str = "test"
86
+ UNDEFINED: str = "UNDEFINED"
67
87
 
68
88
  @staticmethod
69
89
  def valid_splits() -> List[str]:
70
90
  return [SplitName.TRAIN, SplitName.VAL, SplitName.TEST]
71
91
 
92
+ @staticmethod
93
+ def all_split_names() -> List[str]:
94
+ return [*SplitName.valid_splits(), SplitName.UNDEFINED]
95
+
72
96
 
73
97
  class DatasetVariant(Enum):
74
98
  DUMP = "dump"
75
99
  SAMPLE = "sample"
76
100
  HIDDEN = "hidden"
101
+
102
+
103
+ class AwsCredentials(BaseModel):
104
+ access_key: str
105
+ secret_key: str
106
+ session_token: str
107
+ region: Optional[str]
108
+
109
+ def aws_credentials(self) -> Dict[str, str]:
110
+ """
111
+ Returns the AWS credentials as a dictionary.
112
+ """
113
+ environment_vars = {
114
+ "AWS_ACCESS_KEY_ID": self.access_key,
115
+ "AWS_SECRET_ACCESS_KEY": self.secret_key,
116
+ "AWS_SESSION_TOKEN": self.session_token,
117
+ }
118
+ if self.region:
119
+ environment_vars["AWS_REGION"] = self.region
120
+
121
+ return environment_vars
122
+
123
+ @staticmethod
124
+ def from_session(session: boto3.Session) -> "AwsCredentials":
125
+ """
126
+ Creates AwsCredentials from a Boto3 session.
127
+ """
128
+ frozen_credentials = session.get_credentials().get_frozen_credentials()
129
+ return AwsCredentials(
130
+ access_key=frozen_credentials.access_key,
131
+ secret_key=frozen_credentials.secret_key,
132
+ session_token=frozen_credentials.token,
133
+ region=session.region_name,
134
+ )
135
+
136
+
137
+ ARN_PREFIX = "arn:aws:s3:::"
138
+
139
+
140
+ class ResourceCredentials(AwsCredentials):
141
+ s3_arn: str
142
+
143
+ @staticmethod
144
+ def fix_naming(payload: Dict[str, str]) -> "ResourceCredentials":
145
+ """
146
+ The endpoint returns a payload with a key called 's3_path', but it
147
+ is actually an ARN path (starts with arn:aws:s3::). This method renames it to 's3_arn' for consistency.
148
+ """
149
+ if "s3_path" in payload and payload["s3_path"].startswith(ARN_PREFIX):
150
+ payload["s3_arn"] = payload.pop("s3_path")
151
+
152
+ if "region" not in payload:
153
+ payload["region"] = "eu-west-1"
154
+ return ResourceCredentials(**payload)
155
+
156
+ @field_validator("s3_arn")
157
+ @classmethod
158
+ def validate_s3_arn(cls, value: str) -> str:
159
+ """Validate s3_arn to ensure it starts with 'arn:aws:s3:::'"""
160
+ if not value.startswith("arn:aws:s3:::"):
161
+ raise ValueError(f"Invalid S3 ARN: {value}. It should start with 'arn:aws:s3:::'")
162
+ return value
163
+
164
+ def s3_path(self) -> str:
165
+ """
166
+ Extracts the S3 path from the ARN.
167
+ Example: arn:aws:s3:::my-bucket/my-prefix -> my-bucket/my-prefix
168
+ """
169
+ return self.s3_arn[len(ARN_PREFIX) :]
170
+
171
+ def s3_uri(self) -> str:
172
+ """
173
+ Converts the S3 ARN to a URI format.
174
+ Example: arn:aws:s3:::my-bucket/my-prefix -> s3://my-bucket/my-prefix
175
+ """
176
+ return f"s3://{self.s3_path()}"
177
+
178
+ def bucket_name(self) -> str:
179
+ """
180
+ Extracts the bucket name from the S3 ARN.
181
+ Example: arn:aws:s3:::my-bucket/my-prefix -> my-bucket
182
+ """
183
+ return self.s3_path().split("/")[0]
184
+
185
+ def object_key(self) -> str:
186
+ """
187
+ Extracts the object key from the S3 ARN.
188
+ Example: arn:aws:s3:::my-bucket/my-prefix -> my-prefix
189
+ """
190
+ return "/".join(self.s3_path().split("/")[1:])
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import json
4
4
  import os
5
5
  from pathlib import Path
6
- from typing import Any, Callable, Dict, List, Optional, Type, Union
6
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
7
7
 
8
8
  from pydantic import (
9
9
  field_serializer,
@@ -12,7 +12,11 @@ from pydantic import (
12
12
 
13
13
  from hafnia import utils
14
14
  from hafnia.dataset.dataset_recipe import recipe_transforms
15
- from hafnia.dataset.dataset_recipe.recipe_types import RecipeCreation, RecipeTransform, Serializable
15
+ from hafnia.dataset.dataset_recipe.recipe_types import (
16
+ RecipeCreation,
17
+ RecipeTransform,
18
+ Serializable,
19
+ )
16
20
  from hafnia.dataset.hafnia_dataset import HafniaDataset
17
21
  from hafnia.dataset.primitives.primitive import Primitive
18
22
 
@@ -41,6 +45,17 @@ class DatasetRecipe(Serializable):
41
45
  creation = FromName(name=name, force_redownload=force_redownload, download_files=download_files)
42
46
  return DatasetRecipe(creation=creation)
43
47
 
48
+ @staticmethod
49
+ def from_name_public_dataset(
50
+ name: str, force_redownload: bool = False, n_samples: Optional[int] = None
51
+ ) -> DatasetRecipe:
52
+ creation = FromNamePublicDataset(
53
+ name=name,
54
+ force_redownload=force_redownload,
55
+ n_samples=n_samples,
56
+ )
57
+ return DatasetRecipe(creation=creation)
58
+
44
59
  @staticmethod
45
60
  def from_path(path_folder: Path, check_for_images: bool = True) -> DatasetRecipe:
46
61
  creation = FromPath(path_folder=path_folder, check_for_images=check_for_images)
@@ -222,7 +237,7 @@ class DatasetRecipe(Serializable):
222
237
  """Serialize the dataset recipe to a dictionary."""
223
238
  return self.model_dump(mode="json")
224
239
 
225
- def as_platform_recipe(self, recipe_name: Optional[str]) -> Dict:
240
+ def as_platform_recipe(self, recipe_name: Optional[str], overwrite: bool = False) -> Dict:
226
241
  """Uploads dataset recipe to the hafnia platform."""
227
242
  from cli.config import Config
228
243
  from hafnia.platform.dataset_recipe import get_or_create_dataset_recipe
@@ -235,6 +250,7 @@ class DatasetRecipe(Serializable):
235
250
  endpoint=endpoint_dataset,
236
251
  api_key=cfg.api_key,
237
252
  name=recipe_name,
253
+ overwrite=overwrite,
238
254
  )
239
255
 
240
256
  return recipe_dict
@@ -246,10 +262,17 @@ class DatasetRecipe(Serializable):
246
262
  return recipe
247
263
 
248
264
  def select_samples(
249
- recipe: DatasetRecipe, n_samples: int, shuffle: bool = True, seed: int = 42, with_replacement: bool = False
265
+ recipe: DatasetRecipe,
266
+ n_samples: int,
267
+ shuffle: bool = True,
268
+ seed: int = 42,
269
+ with_replacement: bool = False,
250
270
  ) -> DatasetRecipe:
251
271
  operation = recipe_transforms.SelectSamples(
252
- n_samples=n_samples, shuffle=shuffle, seed=seed, with_replacement=with_replacement
272
+ n_samples=n_samples,
273
+ shuffle=shuffle,
274
+ seed=seed,
275
+ with_replacement=with_replacement,
253
276
  )
254
277
  recipe.append_operation(operation)
255
278
  return recipe
@@ -273,7 +296,7 @@ class DatasetRecipe(Serializable):
273
296
 
274
297
  def class_mapper(
275
298
  recipe: DatasetRecipe,
276
- class_mapping: Dict[str, str],
299
+ class_mapping: Union[Dict[str, str], List[Tuple[str, str]]],
277
300
  method: str = "strict",
278
301
  primitive: Optional[Type[Primitive]] = None,
279
302
  task_name: Optional[str] = None,
@@ -400,6 +423,22 @@ class FromName(RecipeCreation):
400
423
  return [self.name]
401
424
 
402
425
 
426
+ class FromNamePublicDataset(RecipeCreation):
427
+ name: str
428
+ force_redownload: bool = False
429
+ n_samples: Optional[int] = None
430
+
431
+ @staticmethod
432
+ def get_function() -> Callable[..., "HafniaDataset"]:
433
+ return HafniaDataset.from_name_public_dataset
434
+
435
+ def as_short_name(self) -> str:
436
+ return f"Torchvision('{self.name}')"
437
+
438
+ def get_dataset_names(self) -> List[str]:
439
+ return []
440
+
441
+
403
442
  class FromMerge(RecipeCreation):
404
443
  recipe0: DatasetRecipe
405
444
  recipe1: DatasetRecipe
@@ -414,7 +453,10 @@ class FromMerge(RecipeCreation):
414
453
 
415
454
  def get_dataset_names(self) -> List[str]:
416
455
  """Get the dataset names from the merged recipes."""
417
- names = [*self.recipe0.creation.get_dataset_names(), *self.recipe1.creation.get_dataset_names()]
456
+ names = [
457
+ *self.recipe0.creation.get_dataset_names(),
458
+ *self.recipe1.creation.get_dataset_names(),
459
+ ]
418
460
  return names
419
461
 
420
462
 
@@ -439,33 +481,3 @@ class FromMerger(RecipeCreation):
439
481
  for recipe in self.recipes:
440
482
  names.extend(recipe.creation.get_dataset_names())
441
483
  return names
442
-
443
-
444
- def extract_dataset_names_from_json_dict(data: dict) -> list[str]:
445
- """
446
- Extract dataset names recursively from a JSON dictionary added with 'from_name'.
447
-
448
- Even if the same functionality is achieved with `DatasetRecipe.get_dataset_names()`,
449
- we want to keep this function in 'dipdatalib' to extract dataset names from json dictionaries
450
- directly.
451
- """
452
- creation_field = data.get("creation")
453
- if creation_field is None:
454
- return []
455
- if creation_field.get("__type__") == "FromName":
456
- return [creation_field["name"]]
457
- elif creation_field.get("__type__") == "FromMerge":
458
- recipe_names = ["recipe0", "recipe1"]
459
- dataset_name = []
460
- for recipe_name in recipe_names:
461
- recipe = creation_field.get(recipe_name)
462
- if recipe is None:
463
- continue
464
- dataset_name.extend(extract_dataset_names_from_json_dict(recipe))
465
- return dataset_name
466
- elif creation_field.get("__type__") == "FromMerger":
467
- dataset_name = []
468
- for recipe in creation_field.get("recipes", []):
469
- dataset_name.extend(extract_dataset_names_from_json_dict(recipe))
470
- return dataset_name
471
- return []
@@ -1,4 +1,6 @@
1
- from typing import Callable, Dict, List, Optional, Type, Union
1
+ from typing import Callable, Dict, List, Optional, Tuple, Type, Union
2
+
3
+ from pydantic import field_validator
2
4
 
3
5
  from hafnia.dataset.dataset_recipe.recipe_types import RecipeTransform
4
6
  from hafnia.dataset.hafnia_dataset import HafniaDataset
@@ -52,11 +54,25 @@ class DefineSampleSetBySize(RecipeTransform):
52
54
 
53
55
 
54
56
  class ClassMapper(RecipeTransform):
55
- class_mapping: Dict[str, str]
57
+ class_mapping: Union[Dict[str, str], List[Tuple[str, str]]]
56
58
  method: str = "strict"
57
59
  primitive: Optional[Type[Primitive]] = None
58
60
  task_name: Optional[str] = None
59
61
 
62
+ @field_validator("class_mapping", mode="after")
63
+ @classmethod
64
+ def serialize_class_mapping(cls, value: Union[Dict[str, str], List[Tuple[str, str]]]) -> List[Tuple[str, str]]:
65
+ # Converts the dictionary class mapping to a list of tuples
66
+ # e.g. {"old_class": "new_class", } --> [("old_class", "new_class")]
67
+ # The reason is that storing class mappings as a dictionary does not preserve order of json fields
68
+ # when stored in a database as a jsonb field (postgres).
69
+ # Preserving order of class mapping fields is important as it defines the indices of the classes.
70
+ # So to ensure that class indices are maintained, we preserve order of json fields, by converting the
71
+ # dictionary to a list of tuples.
72
+ if isinstance(value, dict):
73
+ value = list(value.items())
74
+ return value
75
+
60
76
  @staticmethod
61
77
  def get_function() -> Callable[..., "HafniaDataset"]:
62
78
  return HafniaDataset.class_mapper