hafnia 0.3.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/__main__.py +3 -1
- cli/config.py +43 -3
- cli/keychain.py +88 -0
- cli/profile_cmds.py +5 -2
- hafnia/__init__.py +1 -1
- hafnia/dataset/dataset_helpers.py +9 -2
- hafnia/dataset/dataset_names.py +130 -16
- hafnia/dataset/dataset_recipe/dataset_recipe.py +49 -37
- hafnia/dataset/dataset_recipe/recipe_transforms.py +18 -2
- hafnia/dataset/dataset_upload_helper.py +83 -22
- hafnia/dataset/format_conversions/format_image_classification_folder.py +110 -0
- hafnia/dataset/format_conversions/format_yolo.py +164 -0
- hafnia/dataset/format_conversions/torchvision_datasets.py +287 -0
- hafnia/dataset/hafnia_dataset.py +396 -96
- hafnia/dataset/operations/dataset_stats.py +84 -73
- hafnia/dataset/operations/dataset_transformations.py +116 -47
- hafnia/dataset/operations/table_transformations.py +135 -17
- hafnia/dataset/primitives/bbox.py +25 -14
- hafnia/dataset/primitives/bitmask.py +22 -15
- hafnia/dataset/primitives/classification.py +16 -8
- hafnia/dataset/primitives/point.py +7 -3
- hafnia/dataset/primitives/polygon.py +15 -10
- hafnia/dataset/primitives/primitive.py +1 -1
- hafnia/dataset/primitives/segmentation.py +12 -9
- hafnia/experiment/hafnia_logger.py +0 -9
- hafnia/platform/dataset_recipe.py +7 -2
- hafnia/platform/datasets.py +5 -9
- hafnia/platform/download.py +24 -90
- hafnia/torch_helpers.py +12 -12
- hafnia/utils.py +17 -0
- hafnia/visualizations/image_visualizations.py +3 -1
- {hafnia-0.3.0.dist-info → hafnia-0.4.1.dist-info}/METADATA +11 -9
- hafnia-0.4.1.dist-info/RECORD +57 -0
- hafnia-0.3.0.dist-info/RECORD +0 -53
- {hafnia-0.3.0.dist-info → hafnia-0.4.1.dist-info}/WHEEL +0 -0
- {hafnia-0.3.0.dist-info → hafnia-0.4.1.dist-info}/entry_points.txt +0 -0
- {hafnia-0.3.0.dist-info → hafnia-0.4.1.dist-info}/licenses/LICENSE +0 -0
cli/__main__.py
CHANGED
|
@@ -37,7 +37,9 @@ def configure(cfg: Config) -> None:
|
|
|
37
37
|
|
|
38
38
|
platform_url = click.prompt("Hafnia Platform URL", type=str, default=consts.DEFAULT_API_URL)
|
|
39
39
|
|
|
40
|
-
|
|
40
|
+
use_keychain = click.confirm("Store API key in system keychain?", default=False)
|
|
41
|
+
|
|
42
|
+
cfg_profile = ConfigSchema(platform_url=platform_url, api_key=api_key, use_keychain=use_keychain)
|
|
41
43
|
cfg.add_profile(profile_name, cfg_profile, set_active=True)
|
|
42
44
|
cfg.save_config()
|
|
43
45
|
profile_cmds.profile_show(cfg)
|
cli/config.py
CHANGED
|
@@ -6,6 +6,7 @@ from typing import Dict, List, Optional
|
|
|
6
6
|
from pydantic import BaseModel, field_validator
|
|
7
7
|
|
|
8
8
|
import cli.consts as consts
|
|
9
|
+
import cli.keychain as keychain
|
|
9
10
|
from hafnia.log import sys_logger, user_logger
|
|
10
11
|
|
|
11
12
|
PLATFORM_API_MAPPING = {
|
|
@@ -19,9 +20,18 @@ PLATFORM_API_MAPPING = {
|
|
|
19
20
|
}
|
|
20
21
|
|
|
21
22
|
|
|
23
|
+
class SecretStr(str):
|
|
24
|
+
def __repr__(self):
|
|
25
|
+
return "********"
|
|
26
|
+
|
|
27
|
+
def __str__(self):
|
|
28
|
+
return "********"
|
|
29
|
+
|
|
30
|
+
|
|
22
31
|
class ConfigSchema(BaseModel):
|
|
23
32
|
platform_url: str = ""
|
|
24
33
|
api_key: Optional[str] = None
|
|
34
|
+
use_keychain: bool = False
|
|
25
35
|
|
|
26
36
|
@field_validator("api_key")
|
|
27
37
|
def validate_api_key(cls, value: Optional[str]) -> Optional[str]:
|
|
@@ -35,7 +45,7 @@ class ConfigSchema(BaseModel):
|
|
|
35
45
|
sys_logger.warning("API key is missing the 'ApiKey ' prefix. Prefix is being added automatically.")
|
|
36
46
|
value = f"ApiKey {value}"
|
|
37
47
|
|
|
38
|
-
return value
|
|
48
|
+
return SecretStr(value) # Keeps the API key masked in logs and repr
|
|
39
49
|
|
|
40
50
|
|
|
41
51
|
class ConfigFileSchema(BaseModel):
|
|
@@ -70,13 +80,32 @@ class Config:
|
|
|
70
80
|
|
|
71
81
|
@property
|
|
72
82
|
def api_key(self) -> str:
|
|
83
|
+
# Check keychain first if enabled
|
|
84
|
+
if self.config.use_keychain:
|
|
85
|
+
keychain_key = keychain.get_api_key(self.active_profile)
|
|
86
|
+
if keychain_key is not None:
|
|
87
|
+
return keychain_key
|
|
88
|
+
|
|
89
|
+
# Fall back to config file
|
|
73
90
|
if self.config.api_key is not None:
|
|
74
91
|
return self.config.api_key
|
|
92
|
+
|
|
75
93
|
raise ValueError(consts.ERROR_API_KEY_NOT_SET)
|
|
76
94
|
|
|
77
95
|
@api_key.setter
|
|
78
96
|
def api_key(self, value: str) -> None:
|
|
79
|
-
|
|
97
|
+
# Store in keychain if enabled
|
|
98
|
+
if self.config.use_keychain:
|
|
99
|
+
if keychain.store_api_key(self.active_profile, value):
|
|
100
|
+
# Successfully stored in keychain, don't store in config
|
|
101
|
+
self.config.api_key = None
|
|
102
|
+
else:
|
|
103
|
+
# Keychain storage failed, fall back to config file
|
|
104
|
+
sys_logger.warning("Failed to store in keychain, falling back to config file")
|
|
105
|
+
self.config.api_key = value
|
|
106
|
+
else:
|
|
107
|
+
# Not using keychain, store in config file
|
|
108
|
+
self.config.api_key = value
|
|
80
109
|
|
|
81
110
|
@property
|
|
82
111
|
def platform_url(self) -> str:
|
|
@@ -152,8 +181,19 @@ class Config:
|
|
|
152
181
|
raise ValueError("Failed to parse configuration file")
|
|
153
182
|
|
|
154
183
|
def save_config(self) -> None:
|
|
184
|
+
# Create a copy to avoid modifying the original data
|
|
185
|
+
config_to_save = self.config_data.model_dump()
|
|
186
|
+
|
|
187
|
+
# Store API key in keychain if enabled, and don't write to file
|
|
188
|
+
for profile_name, profile_data in config_to_save["profiles"].items():
|
|
189
|
+
if profile_data.get("use_keychain", False):
|
|
190
|
+
api_key = profile_data.get("api_key")
|
|
191
|
+
if api_key:
|
|
192
|
+
keychain.store_api_key(profile_name, api_key)
|
|
193
|
+
profile_data["api_key"] = None
|
|
194
|
+
|
|
155
195
|
with open(self.config_path, "w") as f:
|
|
156
|
-
json.dump(
|
|
196
|
+
json.dump(config_to_save, f, indent=4)
|
|
157
197
|
|
|
158
198
|
def remove_profile(self, profile_name: str) -> None:
|
|
159
199
|
if profile_name not in self.config_data.profiles:
|
cli/keychain.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
"""Keychain storage for API keys using the system keychain."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
from hafnia.log import sys_logger
|
|
6
|
+
|
|
7
|
+
# Keyring is optional - gracefully degrade if not available
|
|
8
|
+
try:
|
|
9
|
+
import keyring
|
|
10
|
+
|
|
11
|
+
KEYRING_AVAILABLE = True
|
|
12
|
+
except ImportError:
|
|
13
|
+
KEYRING_AVAILABLE = False
|
|
14
|
+
sys_logger.debug("keyring library not available, keychain storage disabled")
|
|
15
|
+
|
|
16
|
+
KEYRING_SERVICE_NAME = "hafnia-cli"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def store_api_key(profile_name: str, api_key: str) -> bool:
|
|
20
|
+
"""
|
|
21
|
+
Store an API key in the system keychain.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
profile_name: The profile name to associate with the key
|
|
25
|
+
api_key: The API key to store
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
True if successfully stored, False otherwise
|
|
29
|
+
"""
|
|
30
|
+
if not KEYRING_AVAILABLE:
|
|
31
|
+
sys_logger.warning("Keyring library not available, cannot store API key in keychain")
|
|
32
|
+
return False
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
keyring.set_password(KEYRING_SERVICE_NAME, profile_name, api_key)
|
|
36
|
+
sys_logger.debug(f"Stored API key for profile '{profile_name}' in keychain")
|
|
37
|
+
return True
|
|
38
|
+
except Exception as e:
|
|
39
|
+
sys_logger.warning(f"Failed to store API key in keychain: {e}")
|
|
40
|
+
return False
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def get_api_key(profile_name: str) -> Optional[str]:
|
|
44
|
+
"""
|
|
45
|
+
Retrieve an API key from the system keychain.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
profile_name: The profile name to retrieve the key for
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
The API key if found, None otherwise
|
|
52
|
+
"""
|
|
53
|
+
if not KEYRING_AVAILABLE:
|
|
54
|
+
return None
|
|
55
|
+
|
|
56
|
+
try:
|
|
57
|
+
api_key = keyring.get_password(KEYRING_SERVICE_NAME, profile_name)
|
|
58
|
+
if api_key:
|
|
59
|
+
sys_logger.debug(f"Retrieved API key for profile '{profile_name}' from keychain")
|
|
60
|
+
return api_key
|
|
61
|
+
except Exception as e:
|
|
62
|
+
sys_logger.warning(f"Failed to retrieve API key from keychain: {e}")
|
|
63
|
+
return None
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def delete_api_key(profile_name: str) -> bool:
|
|
67
|
+
"""
|
|
68
|
+
Delete an API key from the system keychain.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
profile_name: The profile name to delete the key for
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
True if successfully deleted or didn't exist, False on error
|
|
75
|
+
"""
|
|
76
|
+
if not KEYRING_AVAILABLE:
|
|
77
|
+
return False
|
|
78
|
+
|
|
79
|
+
try:
|
|
80
|
+
keyring.delete_password(KEYRING_SERVICE_NAME, profile_name)
|
|
81
|
+
sys_logger.debug(f"Deleted API key for profile '{profile_name}' from keychain")
|
|
82
|
+
return True
|
|
83
|
+
except keyring.errors.PasswordDeleteError:
|
|
84
|
+
# Key didn't exist, which is fine
|
|
85
|
+
return True
|
|
86
|
+
except Exception as e:
|
|
87
|
+
sys_logger.warning(f"Failed to delete API key from keychain: {e}")
|
|
88
|
+
return False
|
cli/profile_cmds.py
CHANGED
|
@@ -50,10 +50,13 @@ def cmd_profile_use(cfg: Config, profile_name: str) -> None:
|
|
|
50
50
|
@click.option(
|
|
51
51
|
"--activate/--no-activate", help="Activate the created profile after creation", default=True, show_default=True
|
|
52
52
|
)
|
|
53
|
+
@click.option(
|
|
54
|
+
"--use-keychain", is_flag=True, help="Store API key in system keychain instead of config file", default=False
|
|
55
|
+
)
|
|
53
56
|
@click.pass_obj
|
|
54
|
-
def cmd_profile_create(cfg: Config, name: str, api_url: str, api_key: str, activate: bool) -> None:
|
|
57
|
+
def cmd_profile_create(cfg: Config, name: str, api_url: str, api_key: str, activate: bool, use_keychain: bool) -> None:
|
|
55
58
|
"""Create a new profile."""
|
|
56
|
-
cfg_profile = ConfigSchema(platform_url=api_url, api_key=api_key)
|
|
59
|
+
cfg_profile = ConfigSchema(platform_url=api_url, api_key=api_key, use_keychain=use_keychain)
|
|
57
60
|
|
|
58
61
|
cfg.add_profile(profile_name=name, profile=cfg_profile, set_active=activate)
|
|
59
62
|
profile_show(cfg)
|
hafnia/__init__.py
CHANGED
|
@@ -38,12 +38,19 @@ def hash_from_bytes(data: bytes) -> str:
|
|
|
38
38
|
|
|
39
39
|
def save_image_with_hash_name(image: np.ndarray, path_folder: Path) -> Path:
|
|
40
40
|
pil_image = Image.fromarray(image)
|
|
41
|
+
path_image = save_pil_image_with_hash_name(pil_image, path_folder)
|
|
42
|
+
return path_image
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def save_pil_image_with_hash_name(image: Image.Image, path_folder: Path, allow_skip: bool = True) -> Path:
|
|
41
46
|
buffer = io.BytesIO()
|
|
42
|
-
|
|
47
|
+
image.save(buffer, format="PNG")
|
|
43
48
|
hash_value = hash_from_bytes(buffer.getvalue())
|
|
44
49
|
path_image = Path(path_folder) / relative_path_from_hash(hash=hash_value, suffix=".png")
|
|
50
|
+
if allow_skip and path_image.exists():
|
|
51
|
+
return path_image
|
|
45
52
|
path_image.parent.mkdir(parents=True, exist_ok=True)
|
|
46
|
-
|
|
53
|
+
image.save(path_image)
|
|
47
54
|
return path_image
|
|
48
55
|
|
|
49
56
|
|
hafnia/dataset/dataset_names.py
CHANGED
|
@@ -1,5 +1,8 @@
|
|
|
1
1
|
from enum import Enum
|
|
2
|
-
from typing import List
|
|
2
|
+
from typing import Dict, List, Optional
|
|
3
|
+
|
|
4
|
+
import boto3
|
|
5
|
+
from pydantic import BaseModel, field_validator
|
|
3
6
|
|
|
4
7
|
FILENAME_RECIPE_JSON = "recipe.json"
|
|
5
8
|
FILENAME_DATASET_INFO = "dataset_info.json"
|
|
@@ -23,7 +26,7 @@ TAG_IS_SAMPLE = "sample"
|
|
|
23
26
|
OPS_REMOVE_CLASS = "__REMOVE__"
|
|
24
27
|
|
|
25
28
|
|
|
26
|
-
class
|
|
29
|
+
class PrimitiveField:
|
|
27
30
|
CLASS_NAME: str = "class_name" # Name of the class this primitive is associated with, e.g. "car" for Bbox
|
|
28
31
|
CLASS_IDX: str = "class_idx" # Index of the class this primitive is associated with, e.g. 0 for "car" if it is the first class # noqa: E501
|
|
29
32
|
OBJECT_ID: str = "object_id" # Unique identifier for the object, e.g. "12345123"
|
|
@@ -38,39 +41,150 @@ class FieldName:
|
|
|
38
41
|
Returns a list of expected field names for primitives.
|
|
39
42
|
"""
|
|
40
43
|
return [
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
44
|
+
PrimitiveField.CLASS_NAME,
|
|
45
|
+
PrimitiveField.CLASS_IDX,
|
|
46
|
+
PrimitiveField.OBJECT_ID,
|
|
47
|
+
PrimitiveField.CONFIDENCE,
|
|
48
|
+
PrimitiveField.META,
|
|
49
|
+
PrimitiveField.TASK_NAME,
|
|
47
50
|
]
|
|
48
51
|
|
|
49
52
|
|
|
50
|
-
class
|
|
51
|
-
|
|
52
|
-
FILE_NAME: str = "file_name"
|
|
53
|
+
class SampleField:
|
|
54
|
+
FILE_PATH: str = "file_path"
|
|
53
55
|
HEIGHT: str = "height"
|
|
54
56
|
WIDTH: str = "width"
|
|
55
57
|
SPLIT: str = "split"
|
|
58
|
+
TAGS: str = "tags"
|
|
59
|
+
|
|
60
|
+
CLASSIFICATIONS: str = "classifications"
|
|
61
|
+
BBOXES: str = "bboxes"
|
|
62
|
+
BITMASKS: str = "bitmasks"
|
|
63
|
+
POLYGONS: str = "polygons"
|
|
64
|
+
|
|
65
|
+
STORAGE_FORMAT: str = "storage_format" # E.g. "image", "video", "zip"
|
|
66
|
+
COLLECTION_INDEX: str = "collection_index"
|
|
67
|
+
COLLECTION_ID: str = "collection_id"
|
|
56
68
|
REMOTE_PATH: str = "remote_path" # Path to the file in remote storage, e.g. S3
|
|
69
|
+
SAMPLE_INDEX: str = "sample_index"
|
|
70
|
+
|
|
57
71
|
ATTRIBUTION: str = "attribution" # Attribution for the sample (image/video), e.g. creator, license, source, etc.
|
|
58
|
-
TAGS: str = "tags"
|
|
59
72
|
META: str = "meta"
|
|
73
|
+
DATASET_NAME: str = "dataset_name"
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class StorageFormat:
|
|
77
|
+
IMAGE: str = "image"
|
|
78
|
+
VIDEO: str = "video"
|
|
79
|
+
ZIP: str = "zip"
|
|
60
80
|
|
|
61
81
|
|
|
62
82
|
class SplitName:
|
|
63
|
-
TRAIN = "train"
|
|
64
|
-
VAL = "validation"
|
|
65
|
-
TEST = "test"
|
|
66
|
-
UNDEFINED = "UNDEFINED"
|
|
83
|
+
TRAIN: str = "train"
|
|
84
|
+
VAL: str = "validation"
|
|
85
|
+
TEST: str = "test"
|
|
86
|
+
UNDEFINED: str = "UNDEFINED"
|
|
67
87
|
|
|
68
88
|
@staticmethod
|
|
69
89
|
def valid_splits() -> List[str]:
|
|
70
90
|
return [SplitName.TRAIN, SplitName.VAL, SplitName.TEST]
|
|
71
91
|
|
|
92
|
+
@staticmethod
|
|
93
|
+
def all_split_names() -> List[str]:
|
|
94
|
+
return [*SplitName.valid_splits(), SplitName.UNDEFINED]
|
|
95
|
+
|
|
72
96
|
|
|
73
97
|
class DatasetVariant(Enum):
|
|
74
98
|
DUMP = "dump"
|
|
75
99
|
SAMPLE = "sample"
|
|
76
100
|
HIDDEN = "hidden"
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class AwsCredentials(BaseModel):
|
|
104
|
+
access_key: str
|
|
105
|
+
secret_key: str
|
|
106
|
+
session_token: str
|
|
107
|
+
region: Optional[str]
|
|
108
|
+
|
|
109
|
+
def aws_credentials(self) -> Dict[str, str]:
|
|
110
|
+
"""
|
|
111
|
+
Returns the AWS credentials as a dictionary.
|
|
112
|
+
"""
|
|
113
|
+
environment_vars = {
|
|
114
|
+
"AWS_ACCESS_KEY_ID": self.access_key,
|
|
115
|
+
"AWS_SECRET_ACCESS_KEY": self.secret_key,
|
|
116
|
+
"AWS_SESSION_TOKEN": self.session_token,
|
|
117
|
+
}
|
|
118
|
+
if self.region:
|
|
119
|
+
environment_vars["AWS_REGION"] = self.region
|
|
120
|
+
|
|
121
|
+
return environment_vars
|
|
122
|
+
|
|
123
|
+
@staticmethod
|
|
124
|
+
def from_session(session: boto3.Session) -> "AwsCredentials":
|
|
125
|
+
"""
|
|
126
|
+
Creates AwsCredentials from a Boto3 session.
|
|
127
|
+
"""
|
|
128
|
+
frozen_credentials = session.get_credentials().get_frozen_credentials()
|
|
129
|
+
return AwsCredentials(
|
|
130
|
+
access_key=frozen_credentials.access_key,
|
|
131
|
+
secret_key=frozen_credentials.secret_key,
|
|
132
|
+
session_token=frozen_credentials.token,
|
|
133
|
+
region=session.region_name,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
ARN_PREFIX = "arn:aws:s3:::"
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class ResourceCredentials(AwsCredentials):
|
|
141
|
+
s3_arn: str
|
|
142
|
+
|
|
143
|
+
@staticmethod
|
|
144
|
+
def fix_naming(payload: Dict[str, str]) -> "ResourceCredentials":
|
|
145
|
+
"""
|
|
146
|
+
The endpoint returns a payload with a key called 's3_path', but it
|
|
147
|
+
is actually an ARN path (starts with arn:aws:s3::). This method renames it to 's3_arn' for consistency.
|
|
148
|
+
"""
|
|
149
|
+
if "s3_path" in payload and payload["s3_path"].startswith(ARN_PREFIX):
|
|
150
|
+
payload["s3_arn"] = payload.pop("s3_path")
|
|
151
|
+
|
|
152
|
+
if "region" not in payload:
|
|
153
|
+
payload["region"] = "eu-west-1"
|
|
154
|
+
return ResourceCredentials(**payload)
|
|
155
|
+
|
|
156
|
+
@field_validator("s3_arn")
|
|
157
|
+
@classmethod
|
|
158
|
+
def validate_s3_arn(cls, value: str) -> str:
|
|
159
|
+
"""Validate s3_arn to ensure it starts with 'arn:aws:s3:::'"""
|
|
160
|
+
if not value.startswith("arn:aws:s3:::"):
|
|
161
|
+
raise ValueError(f"Invalid S3 ARN: {value}. It should start with 'arn:aws:s3:::'")
|
|
162
|
+
return value
|
|
163
|
+
|
|
164
|
+
def s3_path(self) -> str:
|
|
165
|
+
"""
|
|
166
|
+
Extracts the S3 path from the ARN.
|
|
167
|
+
Example: arn:aws:s3:::my-bucket/my-prefix -> my-bucket/my-prefix
|
|
168
|
+
"""
|
|
169
|
+
return self.s3_arn[len(ARN_PREFIX) :]
|
|
170
|
+
|
|
171
|
+
def s3_uri(self) -> str:
|
|
172
|
+
"""
|
|
173
|
+
Converts the S3 ARN to a URI format.
|
|
174
|
+
Example: arn:aws:s3:::my-bucket/my-prefix -> s3://my-bucket/my-prefix
|
|
175
|
+
"""
|
|
176
|
+
return f"s3://{self.s3_path()}"
|
|
177
|
+
|
|
178
|
+
def bucket_name(self) -> str:
|
|
179
|
+
"""
|
|
180
|
+
Extracts the bucket name from the S3 ARN.
|
|
181
|
+
Example: arn:aws:s3:::my-bucket/my-prefix -> my-bucket
|
|
182
|
+
"""
|
|
183
|
+
return self.s3_path().split("/")[0]
|
|
184
|
+
|
|
185
|
+
def object_key(self) -> str:
|
|
186
|
+
"""
|
|
187
|
+
Extracts the object key from the S3 ARN.
|
|
188
|
+
Example: arn:aws:s3:::my-bucket/my-prefix -> my-prefix
|
|
189
|
+
"""
|
|
190
|
+
return "/".join(self.s3_path().split("/")[1:])
|
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
import json
|
|
4
4
|
import os
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
from typing import Any, Callable, Dict, List, Optional, Type, Union
|
|
6
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
|
7
7
|
|
|
8
8
|
from pydantic import (
|
|
9
9
|
field_serializer,
|
|
@@ -12,7 +12,11 @@ from pydantic import (
|
|
|
12
12
|
|
|
13
13
|
from hafnia import utils
|
|
14
14
|
from hafnia.dataset.dataset_recipe import recipe_transforms
|
|
15
|
-
from hafnia.dataset.dataset_recipe.recipe_types import
|
|
15
|
+
from hafnia.dataset.dataset_recipe.recipe_types import (
|
|
16
|
+
RecipeCreation,
|
|
17
|
+
RecipeTransform,
|
|
18
|
+
Serializable,
|
|
19
|
+
)
|
|
16
20
|
from hafnia.dataset.hafnia_dataset import HafniaDataset
|
|
17
21
|
from hafnia.dataset.primitives.primitive import Primitive
|
|
18
22
|
|
|
@@ -41,6 +45,17 @@ class DatasetRecipe(Serializable):
|
|
|
41
45
|
creation = FromName(name=name, force_redownload=force_redownload, download_files=download_files)
|
|
42
46
|
return DatasetRecipe(creation=creation)
|
|
43
47
|
|
|
48
|
+
@staticmethod
|
|
49
|
+
def from_name_public_dataset(
|
|
50
|
+
name: str, force_redownload: bool = False, n_samples: Optional[int] = None
|
|
51
|
+
) -> DatasetRecipe:
|
|
52
|
+
creation = FromNamePublicDataset(
|
|
53
|
+
name=name,
|
|
54
|
+
force_redownload=force_redownload,
|
|
55
|
+
n_samples=n_samples,
|
|
56
|
+
)
|
|
57
|
+
return DatasetRecipe(creation=creation)
|
|
58
|
+
|
|
44
59
|
@staticmethod
|
|
45
60
|
def from_path(path_folder: Path, check_for_images: bool = True) -> DatasetRecipe:
|
|
46
61
|
creation = FromPath(path_folder=path_folder, check_for_images=check_for_images)
|
|
@@ -222,7 +237,7 @@ class DatasetRecipe(Serializable):
|
|
|
222
237
|
"""Serialize the dataset recipe to a dictionary."""
|
|
223
238
|
return self.model_dump(mode="json")
|
|
224
239
|
|
|
225
|
-
def as_platform_recipe(self, recipe_name: Optional[str]) -> Dict:
|
|
240
|
+
def as_platform_recipe(self, recipe_name: Optional[str], overwrite: bool = False) -> Dict:
|
|
226
241
|
"""Uploads dataset recipe to the hafnia platform."""
|
|
227
242
|
from cli.config import Config
|
|
228
243
|
from hafnia.platform.dataset_recipe import get_or_create_dataset_recipe
|
|
@@ -235,6 +250,7 @@ class DatasetRecipe(Serializable):
|
|
|
235
250
|
endpoint=endpoint_dataset,
|
|
236
251
|
api_key=cfg.api_key,
|
|
237
252
|
name=recipe_name,
|
|
253
|
+
overwrite=overwrite,
|
|
238
254
|
)
|
|
239
255
|
|
|
240
256
|
return recipe_dict
|
|
@@ -246,10 +262,17 @@ class DatasetRecipe(Serializable):
|
|
|
246
262
|
return recipe
|
|
247
263
|
|
|
248
264
|
def select_samples(
|
|
249
|
-
recipe: DatasetRecipe,
|
|
265
|
+
recipe: DatasetRecipe,
|
|
266
|
+
n_samples: int,
|
|
267
|
+
shuffle: bool = True,
|
|
268
|
+
seed: int = 42,
|
|
269
|
+
with_replacement: bool = False,
|
|
250
270
|
) -> DatasetRecipe:
|
|
251
271
|
operation = recipe_transforms.SelectSamples(
|
|
252
|
-
n_samples=n_samples,
|
|
272
|
+
n_samples=n_samples,
|
|
273
|
+
shuffle=shuffle,
|
|
274
|
+
seed=seed,
|
|
275
|
+
with_replacement=with_replacement,
|
|
253
276
|
)
|
|
254
277
|
recipe.append_operation(operation)
|
|
255
278
|
return recipe
|
|
@@ -273,7 +296,7 @@ class DatasetRecipe(Serializable):
|
|
|
273
296
|
|
|
274
297
|
def class_mapper(
|
|
275
298
|
recipe: DatasetRecipe,
|
|
276
|
-
class_mapping: Dict[str, str],
|
|
299
|
+
class_mapping: Union[Dict[str, str], List[Tuple[str, str]]],
|
|
277
300
|
method: str = "strict",
|
|
278
301
|
primitive: Optional[Type[Primitive]] = None,
|
|
279
302
|
task_name: Optional[str] = None,
|
|
@@ -400,6 +423,22 @@ class FromName(RecipeCreation):
|
|
|
400
423
|
return [self.name]
|
|
401
424
|
|
|
402
425
|
|
|
426
|
+
class FromNamePublicDataset(RecipeCreation):
|
|
427
|
+
name: str
|
|
428
|
+
force_redownload: bool = False
|
|
429
|
+
n_samples: Optional[int] = None
|
|
430
|
+
|
|
431
|
+
@staticmethod
|
|
432
|
+
def get_function() -> Callable[..., "HafniaDataset"]:
|
|
433
|
+
return HafniaDataset.from_name_public_dataset
|
|
434
|
+
|
|
435
|
+
def as_short_name(self) -> str:
|
|
436
|
+
return f"Torchvision('{self.name}')"
|
|
437
|
+
|
|
438
|
+
def get_dataset_names(self) -> List[str]:
|
|
439
|
+
return []
|
|
440
|
+
|
|
441
|
+
|
|
403
442
|
class FromMerge(RecipeCreation):
|
|
404
443
|
recipe0: DatasetRecipe
|
|
405
444
|
recipe1: DatasetRecipe
|
|
@@ -414,7 +453,10 @@ class FromMerge(RecipeCreation):
|
|
|
414
453
|
|
|
415
454
|
def get_dataset_names(self) -> List[str]:
|
|
416
455
|
"""Get the dataset names from the merged recipes."""
|
|
417
|
-
names = [
|
|
456
|
+
names = [
|
|
457
|
+
*self.recipe0.creation.get_dataset_names(),
|
|
458
|
+
*self.recipe1.creation.get_dataset_names(),
|
|
459
|
+
]
|
|
418
460
|
return names
|
|
419
461
|
|
|
420
462
|
|
|
@@ -439,33 +481,3 @@ class FromMerger(RecipeCreation):
|
|
|
439
481
|
for recipe in self.recipes:
|
|
440
482
|
names.extend(recipe.creation.get_dataset_names())
|
|
441
483
|
return names
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
def extract_dataset_names_from_json_dict(data: dict) -> list[str]:
|
|
445
|
-
"""
|
|
446
|
-
Extract dataset names recursively from a JSON dictionary added with 'from_name'.
|
|
447
|
-
|
|
448
|
-
Even if the same functionality is achieved with `DatasetRecipe.get_dataset_names()`,
|
|
449
|
-
we want to keep this function in 'dipdatalib' to extract dataset names from json dictionaries
|
|
450
|
-
directly.
|
|
451
|
-
"""
|
|
452
|
-
creation_field = data.get("creation")
|
|
453
|
-
if creation_field is None:
|
|
454
|
-
return []
|
|
455
|
-
if creation_field.get("__type__") == "FromName":
|
|
456
|
-
return [creation_field["name"]]
|
|
457
|
-
elif creation_field.get("__type__") == "FromMerge":
|
|
458
|
-
recipe_names = ["recipe0", "recipe1"]
|
|
459
|
-
dataset_name = []
|
|
460
|
-
for recipe_name in recipe_names:
|
|
461
|
-
recipe = creation_field.get(recipe_name)
|
|
462
|
-
if recipe is None:
|
|
463
|
-
continue
|
|
464
|
-
dataset_name.extend(extract_dataset_names_from_json_dict(recipe))
|
|
465
|
-
return dataset_name
|
|
466
|
-
elif creation_field.get("__type__") == "FromMerger":
|
|
467
|
-
dataset_name = []
|
|
468
|
-
for recipe in creation_field.get("recipes", []):
|
|
469
|
-
dataset_name.extend(extract_dataset_names_from_json_dict(recipe))
|
|
470
|
-
return dataset_name
|
|
471
|
-
return []
|
|
@@ -1,4 +1,6 @@
|
|
|
1
|
-
from typing import Callable, Dict, List, Optional, Type, Union
|
|
1
|
+
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
|
|
2
|
+
|
|
3
|
+
from pydantic import field_validator
|
|
2
4
|
|
|
3
5
|
from hafnia.dataset.dataset_recipe.recipe_types import RecipeTransform
|
|
4
6
|
from hafnia.dataset.hafnia_dataset import HafniaDataset
|
|
@@ -52,11 +54,25 @@ class DefineSampleSetBySize(RecipeTransform):
|
|
|
52
54
|
|
|
53
55
|
|
|
54
56
|
class ClassMapper(RecipeTransform):
|
|
55
|
-
class_mapping: Dict[str, str]
|
|
57
|
+
class_mapping: Union[Dict[str, str], List[Tuple[str, str]]]
|
|
56
58
|
method: str = "strict"
|
|
57
59
|
primitive: Optional[Type[Primitive]] = None
|
|
58
60
|
task_name: Optional[str] = None
|
|
59
61
|
|
|
62
|
+
@field_validator("class_mapping", mode="after")
|
|
63
|
+
@classmethod
|
|
64
|
+
def serialize_class_mapping(cls, value: Union[Dict[str, str], List[Tuple[str, str]]]) -> List[Tuple[str, str]]:
|
|
65
|
+
# Converts the dictionary class mapping to a list of tuples
|
|
66
|
+
# e.g. {"old_class": "new_class", } --> [("old_class", "new_class")]
|
|
67
|
+
# The reason is that storing class mappings as a dictionary does not preserve order of json fields
|
|
68
|
+
# when stored in a database as a jsonb field (postgres).
|
|
69
|
+
# Preserving order of class mapping fields is important as it defines the indices of the classes.
|
|
70
|
+
# So to ensure that class indices are maintained, we preserve order of json fields, by converting the
|
|
71
|
+
# dictionary to a list of tuples.
|
|
72
|
+
if isinstance(value, dict):
|
|
73
|
+
value = list(value.items())
|
|
74
|
+
return value
|
|
75
|
+
|
|
60
76
|
@staticmethod
|
|
61
77
|
def get_function() -> Callable[..., "HafniaDataset"]:
|
|
62
78
|
return HafniaDataset.class_mapper
|