hafnia 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/__main__.py +3 -1
- cli/config.py +43 -3
- cli/keychain.py +88 -0
- cli/profile_cmds.py +5 -2
- hafnia/__init__.py +1 -1
- hafnia/dataset/dataset_helpers.py +9 -2
- hafnia/dataset/dataset_names.py +2 -1
- hafnia/dataset/dataset_recipe/dataset_recipe.py +49 -37
- hafnia/dataset/dataset_recipe/recipe_transforms.py +18 -2
- hafnia/dataset/dataset_upload_helper.py +60 -4
- hafnia/dataset/format_conversions/image_classification_from_directory.py +106 -0
- hafnia/dataset/format_conversions/torchvision_datasets.py +281 -0
- hafnia/dataset/hafnia_dataset.py +176 -50
- hafnia/dataset/operations/dataset_stats.py +2 -3
- hafnia/dataset/operations/dataset_transformations.py +19 -15
- hafnia/dataset/operations/table_transformations.py +4 -3
- hafnia/dataset/primitives/bbox.py +25 -12
- hafnia/dataset/primitives/bitmask.py +26 -14
- hafnia/dataset/primitives/classification.py +16 -8
- hafnia/dataset/primitives/point.py +7 -3
- hafnia/dataset/primitives/polygon.py +16 -9
- hafnia/dataset/primitives/segmentation.py +10 -7
- hafnia/experiment/hafnia_logger.py +0 -9
- hafnia/platform/dataset_recipe.py +7 -2
- hafnia/platform/datasets.py +3 -3
- hafnia/platform/download.py +23 -18
- hafnia/utils.py +17 -0
- hafnia/visualizations/image_visualizations.py +1 -1
- {hafnia-0.3.0.dist-info → hafnia-0.4.0.dist-info}/METADATA +8 -6
- hafnia-0.4.0.dist-info/RECORD +56 -0
- hafnia-0.3.0.dist-info/RECORD +0 -53
- {hafnia-0.3.0.dist-info → hafnia-0.4.0.dist-info}/WHEEL +0 -0
- {hafnia-0.3.0.dist-info → hafnia-0.4.0.dist-info}/entry_points.txt +0 -0
- {hafnia-0.3.0.dist-info → hafnia-0.4.0.dist-info}/licenses/LICENSE +0 -0
cli/__main__.py
CHANGED
|
@@ -37,7 +37,9 @@ def configure(cfg: Config) -> None:
|
|
|
37
37
|
|
|
38
38
|
platform_url = click.prompt("Hafnia Platform URL", type=str, default=consts.DEFAULT_API_URL)
|
|
39
39
|
|
|
40
|
-
|
|
40
|
+
use_keychain = click.confirm("Store API key in system keychain?", default=False)
|
|
41
|
+
|
|
42
|
+
cfg_profile = ConfigSchema(platform_url=platform_url, api_key=api_key, use_keychain=use_keychain)
|
|
41
43
|
cfg.add_profile(profile_name, cfg_profile, set_active=True)
|
|
42
44
|
cfg.save_config()
|
|
43
45
|
profile_cmds.profile_show(cfg)
|
cli/config.py
CHANGED
|
@@ -6,6 +6,7 @@ from typing import Dict, List, Optional
|
|
|
6
6
|
from pydantic import BaseModel, field_validator
|
|
7
7
|
|
|
8
8
|
import cli.consts as consts
|
|
9
|
+
import cli.keychain as keychain
|
|
9
10
|
from hafnia.log import sys_logger, user_logger
|
|
10
11
|
|
|
11
12
|
PLATFORM_API_MAPPING = {
|
|
@@ -19,9 +20,18 @@ PLATFORM_API_MAPPING = {
|
|
|
19
20
|
}
|
|
20
21
|
|
|
21
22
|
|
|
23
|
+
class SecretStr(str):
|
|
24
|
+
def __repr__(self):
|
|
25
|
+
return "********"
|
|
26
|
+
|
|
27
|
+
def __str__(self):
|
|
28
|
+
return "********"
|
|
29
|
+
|
|
30
|
+
|
|
22
31
|
class ConfigSchema(BaseModel):
|
|
23
32
|
platform_url: str = ""
|
|
24
33
|
api_key: Optional[str] = None
|
|
34
|
+
use_keychain: bool = False
|
|
25
35
|
|
|
26
36
|
@field_validator("api_key")
|
|
27
37
|
def validate_api_key(cls, value: Optional[str]) -> Optional[str]:
|
|
@@ -35,7 +45,7 @@ class ConfigSchema(BaseModel):
|
|
|
35
45
|
sys_logger.warning("API key is missing the 'ApiKey ' prefix. Prefix is being added automatically.")
|
|
36
46
|
value = f"ApiKey {value}"
|
|
37
47
|
|
|
38
|
-
return value
|
|
48
|
+
return SecretStr(value) # Keeps the API key masked in logs and repr
|
|
39
49
|
|
|
40
50
|
|
|
41
51
|
class ConfigFileSchema(BaseModel):
|
|
@@ -70,13 +80,32 @@ class Config:
|
|
|
70
80
|
|
|
71
81
|
@property
|
|
72
82
|
def api_key(self) -> str:
|
|
83
|
+
# Check keychain first if enabled
|
|
84
|
+
if self.config.use_keychain:
|
|
85
|
+
keychain_key = keychain.get_api_key(self.active_profile)
|
|
86
|
+
if keychain_key is not None:
|
|
87
|
+
return keychain_key
|
|
88
|
+
|
|
89
|
+
# Fall back to config file
|
|
73
90
|
if self.config.api_key is not None:
|
|
74
91
|
return self.config.api_key
|
|
92
|
+
|
|
75
93
|
raise ValueError(consts.ERROR_API_KEY_NOT_SET)
|
|
76
94
|
|
|
77
95
|
@api_key.setter
|
|
78
96
|
def api_key(self, value: str) -> None:
|
|
79
|
-
|
|
97
|
+
# Store in keychain if enabled
|
|
98
|
+
if self.config.use_keychain:
|
|
99
|
+
if keychain.store_api_key(self.active_profile, value):
|
|
100
|
+
# Successfully stored in keychain, don't store in config
|
|
101
|
+
self.config.api_key = None
|
|
102
|
+
else:
|
|
103
|
+
# Keychain storage failed, fall back to config file
|
|
104
|
+
sys_logger.warning("Failed to store in keychain, falling back to config file")
|
|
105
|
+
self.config.api_key = value
|
|
106
|
+
else:
|
|
107
|
+
# Not using keychain, store in config file
|
|
108
|
+
self.config.api_key = value
|
|
80
109
|
|
|
81
110
|
@property
|
|
82
111
|
def platform_url(self) -> str:
|
|
@@ -152,8 +181,19 @@ class Config:
|
|
|
152
181
|
raise ValueError("Failed to parse configuration file")
|
|
153
182
|
|
|
154
183
|
def save_config(self) -> None:
|
|
184
|
+
# Create a copy to avoid modifying the original data
|
|
185
|
+
config_to_save = self.config_data.model_dump()
|
|
186
|
+
|
|
187
|
+
# Store API key in keychain if enabled, and don't write to file
|
|
188
|
+
for profile_name, profile_data in config_to_save["profiles"].items():
|
|
189
|
+
if profile_data.get("use_keychain", False):
|
|
190
|
+
api_key = profile_data.get("api_key")
|
|
191
|
+
if api_key:
|
|
192
|
+
keychain.store_api_key(profile_name, api_key)
|
|
193
|
+
profile_data["api_key"] = None
|
|
194
|
+
|
|
155
195
|
with open(self.config_path, "w") as f:
|
|
156
|
-
json.dump(
|
|
196
|
+
json.dump(config_to_save, f, indent=4)
|
|
157
197
|
|
|
158
198
|
def remove_profile(self, profile_name: str) -> None:
|
|
159
199
|
if profile_name not in self.config_data.profiles:
|
cli/keychain.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
"""Keychain storage for API keys using the system keychain."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
from hafnia.log import sys_logger
|
|
6
|
+
|
|
7
|
+
# Keyring is optional - gracefully degrade if not available
|
|
8
|
+
try:
|
|
9
|
+
import keyring
|
|
10
|
+
|
|
11
|
+
KEYRING_AVAILABLE = True
|
|
12
|
+
except ImportError:
|
|
13
|
+
KEYRING_AVAILABLE = False
|
|
14
|
+
sys_logger.debug("keyring library not available, keychain storage disabled")
|
|
15
|
+
|
|
16
|
+
KEYRING_SERVICE_NAME = "hafnia-cli"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def store_api_key(profile_name: str, api_key: str) -> bool:
|
|
20
|
+
"""
|
|
21
|
+
Store an API key in the system keychain.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
profile_name: The profile name to associate with the key
|
|
25
|
+
api_key: The API key to store
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
True if successfully stored, False otherwise
|
|
29
|
+
"""
|
|
30
|
+
if not KEYRING_AVAILABLE:
|
|
31
|
+
sys_logger.warning("Keyring library not available, cannot store API key in keychain")
|
|
32
|
+
return False
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
keyring.set_password(KEYRING_SERVICE_NAME, profile_name, api_key)
|
|
36
|
+
sys_logger.debug(f"Stored API key for profile '{profile_name}' in keychain")
|
|
37
|
+
return True
|
|
38
|
+
except Exception as e:
|
|
39
|
+
sys_logger.warning(f"Failed to store API key in keychain: {e}")
|
|
40
|
+
return False
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def get_api_key(profile_name: str) -> Optional[str]:
|
|
44
|
+
"""
|
|
45
|
+
Retrieve an API key from the system keychain.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
profile_name: The profile name to retrieve the key for
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
The API key if found, None otherwise
|
|
52
|
+
"""
|
|
53
|
+
if not KEYRING_AVAILABLE:
|
|
54
|
+
return None
|
|
55
|
+
|
|
56
|
+
try:
|
|
57
|
+
api_key = keyring.get_password(KEYRING_SERVICE_NAME, profile_name)
|
|
58
|
+
if api_key:
|
|
59
|
+
sys_logger.debug(f"Retrieved API key for profile '{profile_name}' from keychain")
|
|
60
|
+
return api_key
|
|
61
|
+
except Exception as e:
|
|
62
|
+
sys_logger.warning(f"Failed to retrieve API key from keychain: {e}")
|
|
63
|
+
return None
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def delete_api_key(profile_name: str) -> bool:
|
|
67
|
+
"""
|
|
68
|
+
Delete an API key from the system keychain.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
profile_name: The profile name to delete the key for
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
True if successfully deleted or didn't exist, False on error
|
|
75
|
+
"""
|
|
76
|
+
if not KEYRING_AVAILABLE:
|
|
77
|
+
return False
|
|
78
|
+
|
|
79
|
+
try:
|
|
80
|
+
keyring.delete_password(KEYRING_SERVICE_NAME, profile_name)
|
|
81
|
+
sys_logger.debug(f"Deleted API key for profile '{profile_name}' from keychain")
|
|
82
|
+
return True
|
|
83
|
+
except keyring.errors.PasswordDeleteError:
|
|
84
|
+
# Key didn't exist, which is fine
|
|
85
|
+
return True
|
|
86
|
+
except Exception as e:
|
|
87
|
+
sys_logger.warning(f"Failed to delete API key from keychain: {e}")
|
|
88
|
+
return False
|
cli/profile_cmds.py
CHANGED
|
@@ -50,10 +50,13 @@ def cmd_profile_use(cfg: Config, profile_name: str) -> None:
|
|
|
50
50
|
@click.option(
|
|
51
51
|
"--activate/--no-activate", help="Activate the created profile after creation", default=True, show_default=True
|
|
52
52
|
)
|
|
53
|
+
@click.option(
|
|
54
|
+
"--use-keychain", is_flag=True, help="Store API key in system keychain instead of config file", default=False
|
|
55
|
+
)
|
|
53
56
|
@click.pass_obj
|
|
54
|
-
def cmd_profile_create(cfg: Config, name: str, api_url: str, api_key: str, activate: bool) -> None:
|
|
57
|
+
def cmd_profile_create(cfg: Config, name: str, api_url: str, api_key: str, activate: bool, use_keychain: bool) -> None:
|
|
55
58
|
"""Create a new profile."""
|
|
56
|
-
cfg_profile = ConfigSchema(platform_url=api_url, api_key=api_key)
|
|
59
|
+
cfg_profile = ConfigSchema(platform_url=api_url, api_key=api_key, use_keychain=use_keychain)
|
|
57
60
|
|
|
58
61
|
cfg.add_profile(profile_name=name, profile=cfg_profile, set_active=activate)
|
|
59
62
|
profile_show(cfg)
|
hafnia/__init__.py
CHANGED
|
@@ -38,12 +38,19 @@ def hash_from_bytes(data: bytes) -> str:
|
|
|
38
38
|
|
|
39
39
|
def save_image_with_hash_name(image: np.ndarray, path_folder: Path) -> Path:
|
|
40
40
|
pil_image = Image.fromarray(image)
|
|
41
|
+
path_image = save_pil_image_with_hash_name(pil_image, path_folder)
|
|
42
|
+
return path_image
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def save_pil_image_with_hash_name(image: Image.Image, path_folder: Path, allow_skip: bool = True) -> Path:
|
|
41
46
|
buffer = io.BytesIO()
|
|
42
|
-
|
|
47
|
+
image.save(buffer, format="PNG")
|
|
43
48
|
hash_value = hash_from_bytes(buffer.getvalue())
|
|
44
49
|
path_image = Path(path_folder) / relative_path_from_hash(hash=hash_value, suffix=".png")
|
|
50
|
+
if allow_skip and path_image.exists():
|
|
51
|
+
return path_image
|
|
45
52
|
path_image.parent.mkdir(parents=True, exist_ok=True)
|
|
46
|
-
|
|
53
|
+
image.save(path_image)
|
|
47
54
|
return path_image
|
|
48
55
|
|
|
49
56
|
|
hafnia/dataset/dataset_names.py
CHANGED
|
@@ -49,7 +49,7 @@ class FieldName:
|
|
|
49
49
|
|
|
50
50
|
class ColumnName:
|
|
51
51
|
SAMPLE_INDEX: str = "sample_index"
|
|
52
|
-
|
|
52
|
+
FILE_PATH: str = "file_path"
|
|
53
53
|
HEIGHT: str = "height"
|
|
54
54
|
WIDTH: str = "width"
|
|
55
55
|
SPLIT: str = "split"
|
|
@@ -57,6 +57,7 @@ class ColumnName:
|
|
|
57
57
|
ATTRIBUTION: str = "attribution" # Attribution for the sample (image/video), e.g. creator, license, source, etc.
|
|
58
58
|
TAGS: str = "tags"
|
|
59
59
|
META: str = "meta"
|
|
60
|
+
DATASET_NAME: str = "dataset_name"
|
|
60
61
|
|
|
61
62
|
|
|
62
63
|
class SplitName:
|
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
import json
|
|
4
4
|
import os
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
from typing import Any, Callable, Dict, List, Optional, Type, Union
|
|
6
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
|
7
7
|
|
|
8
8
|
from pydantic import (
|
|
9
9
|
field_serializer,
|
|
@@ -12,7 +12,11 @@ from pydantic import (
|
|
|
12
12
|
|
|
13
13
|
from hafnia import utils
|
|
14
14
|
from hafnia.dataset.dataset_recipe import recipe_transforms
|
|
15
|
-
from hafnia.dataset.dataset_recipe.recipe_types import
|
|
15
|
+
from hafnia.dataset.dataset_recipe.recipe_types import (
|
|
16
|
+
RecipeCreation,
|
|
17
|
+
RecipeTransform,
|
|
18
|
+
Serializable,
|
|
19
|
+
)
|
|
16
20
|
from hafnia.dataset.hafnia_dataset import HafniaDataset
|
|
17
21
|
from hafnia.dataset.primitives.primitive import Primitive
|
|
18
22
|
|
|
@@ -41,6 +45,17 @@ class DatasetRecipe(Serializable):
|
|
|
41
45
|
creation = FromName(name=name, force_redownload=force_redownload, download_files=download_files)
|
|
42
46
|
return DatasetRecipe(creation=creation)
|
|
43
47
|
|
|
48
|
+
@staticmethod
|
|
49
|
+
def from_name_public_dataset(
|
|
50
|
+
name: str, force_redownload: bool = False, n_samples: Optional[int] = None
|
|
51
|
+
) -> DatasetRecipe:
|
|
52
|
+
creation = FromNamePublicDataset(
|
|
53
|
+
name=name,
|
|
54
|
+
force_redownload=force_redownload,
|
|
55
|
+
n_samples=n_samples,
|
|
56
|
+
)
|
|
57
|
+
return DatasetRecipe(creation=creation)
|
|
58
|
+
|
|
44
59
|
@staticmethod
|
|
45
60
|
def from_path(path_folder: Path, check_for_images: bool = True) -> DatasetRecipe:
|
|
46
61
|
creation = FromPath(path_folder=path_folder, check_for_images=check_for_images)
|
|
@@ -222,7 +237,7 @@ class DatasetRecipe(Serializable):
|
|
|
222
237
|
"""Serialize the dataset recipe to a dictionary."""
|
|
223
238
|
return self.model_dump(mode="json")
|
|
224
239
|
|
|
225
|
-
def as_platform_recipe(self, recipe_name: Optional[str]) -> Dict:
|
|
240
|
+
def as_platform_recipe(self, recipe_name: Optional[str], overwrite: bool = False) -> Dict:
|
|
226
241
|
"""Uploads dataset recipe to the hafnia platform."""
|
|
227
242
|
from cli.config import Config
|
|
228
243
|
from hafnia.platform.dataset_recipe import get_or_create_dataset_recipe
|
|
@@ -235,6 +250,7 @@ class DatasetRecipe(Serializable):
|
|
|
235
250
|
endpoint=endpoint_dataset,
|
|
236
251
|
api_key=cfg.api_key,
|
|
237
252
|
name=recipe_name,
|
|
253
|
+
overwrite=overwrite,
|
|
238
254
|
)
|
|
239
255
|
|
|
240
256
|
return recipe_dict
|
|
@@ -246,10 +262,17 @@ class DatasetRecipe(Serializable):
|
|
|
246
262
|
return recipe
|
|
247
263
|
|
|
248
264
|
def select_samples(
|
|
249
|
-
recipe: DatasetRecipe,
|
|
265
|
+
recipe: DatasetRecipe,
|
|
266
|
+
n_samples: int,
|
|
267
|
+
shuffle: bool = True,
|
|
268
|
+
seed: int = 42,
|
|
269
|
+
with_replacement: bool = False,
|
|
250
270
|
) -> DatasetRecipe:
|
|
251
271
|
operation = recipe_transforms.SelectSamples(
|
|
252
|
-
n_samples=n_samples,
|
|
272
|
+
n_samples=n_samples,
|
|
273
|
+
shuffle=shuffle,
|
|
274
|
+
seed=seed,
|
|
275
|
+
with_replacement=with_replacement,
|
|
253
276
|
)
|
|
254
277
|
recipe.append_operation(operation)
|
|
255
278
|
return recipe
|
|
@@ -273,7 +296,7 @@ class DatasetRecipe(Serializable):
|
|
|
273
296
|
|
|
274
297
|
def class_mapper(
|
|
275
298
|
recipe: DatasetRecipe,
|
|
276
|
-
class_mapping: Dict[str, str],
|
|
299
|
+
class_mapping: Union[Dict[str, str], List[Tuple[str, str]]],
|
|
277
300
|
method: str = "strict",
|
|
278
301
|
primitive: Optional[Type[Primitive]] = None,
|
|
279
302
|
task_name: Optional[str] = None,
|
|
@@ -400,6 +423,22 @@ class FromName(RecipeCreation):
|
|
|
400
423
|
return [self.name]
|
|
401
424
|
|
|
402
425
|
|
|
426
|
+
class FromNamePublicDataset(RecipeCreation):
|
|
427
|
+
name: str
|
|
428
|
+
force_redownload: bool = False
|
|
429
|
+
n_samples: Optional[int] = None
|
|
430
|
+
|
|
431
|
+
@staticmethod
|
|
432
|
+
def get_function() -> Callable[..., "HafniaDataset"]:
|
|
433
|
+
return HafniaDataset.from_name_public_dataset
|
|
434
|
+
|
|
435
|
+
def as_short_name(self) -> str:
|
|
436
|
+
return f"Torchvision('{self.name}')"
|
|
437
|
+
|
|
438
|
+
def get_dataset_names(self) -> List[str]:
|
|
439
|
+
return []
|
|
440
|
+
|
|
441
|
+
|
|
403
442
|
class FromMerge(RecipeCreation):
|
|
404
443
|
recipe0: DatasetRecipe
|
|
405
444
|
recipe1: DatasetRecipe
|
|
@@ -414,7 +453,10 @@ class FromMerge(RecipeCreation):
|
|
|
414
453
|
|
|
415
454
|
def get_dataset_names(self) -> List[str]:
|
|
416
455
|
"""Get the dataset names from the merged recipes."""
|
|
417
|
-
names = [
|
|
456
|
+
names = [
|
|
457
|
+
*self.recipe0.creation.get_dataset_names(),
|
|
458
|
+
*self.recipe1.creation.get_dataset_names(),
|
|
459
|
+
]
|
|
418
460
|
return names
|
|
419
461
|
|
|
420
462
|
|
|
@@ -439,33 +481,3 @@ class FromMerger(RecipeCreation):
|
|
|
439
481
|
for recipe in self.recipes:
|
|
440
482
|
names.extend(recipe.creation.get_dataset_names())
|
|
441
483
|
return names
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
def extract_dataset_names_from_json_dict(data: dict) -> list[str]:
|
|
445
|
-
"""
|
|
446
|
-
Extract dataset names recursively from a JSON dictionary added with 'from_name'.
|
|
447
|
-
|
|
448
|
-
Even if the same functionality is achieved with `DatasetRecipe.get_dataset_names()`,
|
|
449
|
-
we want to keep this function in 'dipdatalib' to extract dataset names from json dictionaries
|
|
450
|
-
directly.
|
|
451
|
-
"""
|
|
452
|
-
creation_field = data.get("creation")
|
|
453
|
-
if creation_field is None:
|
|
454
|
-
return []
|
|
455
|
-
if creation_field.get("__type__") == "FromName":
|
|
456
|
-
return [creation_field["name"]]
|
|
457
|
-
elif creation_field.get("__type__") == "FromMerge":
|
|
458
|
-
recipe_names = ["recipe0", "recipe1"]
|
|
459
|
-
dataset_name = []
|
|
460
|
-
for recipe_name in recipe_names:
|
|
461
|
-
recipe = creation_field.get(recipe_name)
|
|
462
|
-
if recipe is None:
|
|
463
|
-
continue
|
|
464
|
-
dataset_name.extend(extract_dataset_names_from_json_dict(recipe))
|
|
465
|
-
return dataset_name
|
|
466
|
-
elif creation_field.get("__type__") == "FromMerger":
|
|
467
|
-
dataset_name = []
|
|
468
|
-
for recipe in creation_field.get("recipes", []):
|
|
469
|
-
dataset_name.extend(extract_dataset_names_from_json_dict(recipe))
|
|
470
|
-
return dataset_name
|
|
471
|
-
return []
|
|
@@ -1,4 +1,6 @@
|
|
|
1
|
-
from typing import Callable, Dict, List, Optional, Type, Union
|
|
1
|
+
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
|
|
2
|
+
|
|
3
|
+
from pydantic import field_validator
|
|
2
4
|
|
|
3
5
|
from hafnia.dataset.dataset_recipe.recipe_types import RecipeTransform
|
|
4
6
|
from hafnia.dataset.hafnia_dataset import HafniaDataset
|
|
@@ -52,11 +54,25 @@ class DefineSampleSetBySize(RecipeTransform):
|
|
|
52
54
|
|
|
53
55
|
|
|
54
56
|
class ClassMapper(RecipeTransform):
|
|
55
|
-
class_mapping: Dict[str, str]
|
|
57
|
+
class_mapping: Union[Dict[str, str], List[Tuple[str, str]]]
|
|
56
58
|
method: str = "strict"
|
|
57
59
|
primitive: Optional[Type[Primitive]] = None
|
|
58
60
|
task_name: Optional[str] = None
|
|
59
61
|
|
|
62
|
+
@field_validator("class_mapping", mode="after")
|
|
63
|
+
@classmethod
|
|
64
|
+
def serialize_class_mapping(cls, value: Union[Dict[str, str], List[Tuple[str, str]]]) -> List[Tuple[str, str]]:
|
|
65
|
+
# Converts the dictionary class mapping to a list of tuples
|
|
66
|
+
# e.g. {"old_class": "new_class", } --> [("old_class", "new_class")]
|
|
67
|
+
# The reason is that storing class mappings as a dictionary does not preserve order of json fields
|
|
68
|
+
# when stored in a database as a jsonb field (postgres).
|
|
69
|
+
# Preserving order of class mapping fields is important as it defines the indices of the classes.
|
|
70
|
+
# So to ensure that class indices are maintained, we preserve order of json fields, by converting the
|
|
71
|
+
# dictionary to a list of tuples.
|
|
72
|
+
if isinstance(value, dict):
|
|
73
|
+
value = list(value.items())
|
|
74
|
+
return value
|
|
75
|
+
|
|
60
76
|
@staticmethod
|
|
61
77
|
def get_function() -> Callable[..., "HafniaDataset"]:
|
|
62
78
|
return HafniaDataset.class_mapper
|
|
@@ -4,7 +4,7 @@ import base64
|
|
|
4
4
|
from datetime import datetime
|
|
5
5
|
from enum import Enum
|
|
6
6
|
from pathlib import Path
|
|
7
|
-
from typing import Dict, List, Optional, Tuple, Type, Union
|
|
7
|
+
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
|
8
8
|
|
|
9
9
|
import boto3
|
|
10
10
|
import polars as pl
|
|
@@ -52,6 +52,7 @@ class DbDataset(BaseModel, validate_assignment=True): # type: ignore[call-arg]
|
|
|
52
52
|
license_citation: Optional[str] = None
|
|
53
53
|
version: Optional[str] = None
|
|
54
54
|
s3_bucket_name: Optional[str] = None
|
|
55
|
+
dataset_format_version: Optional[str] = None
|
|
55
56
|
annotation_date: Optional[datetime] = None
|
|
56
57
|
annotation_project_id: Optional[str] = None
|
|
57
58
|
annotation_dataset_id: Optional[str] = None
|
|
@@ -186,9 +187,58 @@ class EntityTypeChoices(str, Enum): # Should match `EntityTypeChoices` in `dipd
|
|
|
186
187
|
EVENT = "EVENT"
|
|
187
188
|
|
|
188
189
|
|
|
190
|
+
class Annotations(BaseModel):
|
|
191
|
+
"""
|
|
192
|
+
Used in 'DatasetImageMetadata' for visualizing image annotations
|
|
193
|
+
in gallery images on the dataset detail page.
|
|
194
|
+
"""
|
|
195
|
+
|
|
196
|
+
objects: Optional[List[Bbox]] = None
|
|
197
|
+
classifications: Optional[List[Classification]] = None
|
|
198
|
+
polygons: Optional[List[Polygon]] = None
|
|
199
|
+
bitmasks: Optional[List[Bitmask]] = None
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
class DatasetImageMetadata(BaseModel):
|
|
203
|
+
"""
|
|
204
|
+
Metadata for gallery images on the dataset detail page on portal.
|
|
205
|
+
"""
|
|
206
|
+
|
|
207
|
+
annotations: Optional[Annotations] = None
|
|
208
|
+
meta: Optional[Dict[str, Any]] = None
|
|
209
|
+
|
|
210
|
+
@classmethod
|
|
211
|
+
def from_sample(cls, sample: Sample) -> "DatasetImageMetadata":
|
|
212
|
+
sample = sample.model_copy(deep=True)
|
|
213
|
+
sample.file_path = "/".join(Path(sample.file_path).parts[-3:])
|
|
214
|
+
metadata = {}
|
|
215
|
+
metadata_field_names = [
|
|
216
|
+
ColumnName.FILE_PATH,
|
|
217
|
+
ColumnName.HEIGHT,
|
|
218
|
+
ColumnName.WIDTH,
|
|
219
|
+
ColumnName.SPLIT,
|
|
220
|
+
]
|
|
221
|
+
for field_name in metadata_field_names:
|
|
222
|
+
if hasattr(sample, field_name) and getattr(sample, field_name) is not None:
|
|
223
|
+
metadata[field_name] = getattr(sample, field_name)
|
|
224
|
+
|
|
225
|
+
obj = DatasetImageMetadata(
|
|
226
|
+
annotations=Annotations(
|
|
227
|
+
objects=sample.objects,
|
|
228
|
+
classifications=sample.classifications,
|
|
229
|
+
polygons=sample.polygons,
|
|
230
|
+
bitmasks=sample.bitmasks,
|
|
231
|
+
),
|
|
232
|
+
meta=metadata,
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
return obj
|
|
236
|
+
|
|
237
|
+
|
|
189
238
|
class DatasetImage(Attribution, validate_assignment=True): # type: ignore[call-arg]
|
|
190
239
|
img: str # Base64-encoded image string
|
|
191
240
|
order: Optional[int] = None
|
|
241
|
+
metadata: Optional[DatasetImageMetadata] = None
|
|
192
242
|
|
|
193
243
|
@field_validator("img", mode="before")
|
|
194
244
|
def validate_image_path(cls, v: Union[str, Path]) -> str:
|
|
@@ -254,7 +304,7 @@ def upload_dataset_details(cfg: Config, data: str, dataset_name: str) -> dict:
|
|
|
254
304
|
import_endpoint = f"{dataset_endpoint}/{dataset_id}/import"
|
|
255
305
|
headers = {"Authorization": cfg.api_key}
|
|
256
306
|
|
|
257
|
-
user_logger.info("
|
|
307
|
+
user_logger.info("Exporting dataset details to platform. This may take up to 30 seconds...")
|
|
258
308
|
response = post(endpoint=import_endpoint, headers=headers, data=data) # type: ignore[assignment]
|
|
259
309
|
return response # type: ignore[return-value]
|
|
260
310
|
|
|
@@ -569,7 +619,9 @@ def dataset_info_from_dataset(
|
|
|
569
619
|
s3_bucket_name=bucket_sample,
|
|
570
620
|
dataset_variants=dataset_variants,
|
|
571
621
|
split_annotations_reports=dataset_reports,
|
|
572
|
-
|
|
622
|
+
latest_update=dataset.info.updated_at,
|
|
623
|
+
dataset_format_version=dataset.info.format_version,
|
|
624
|
+
license_citation=dataset.info.reference_bibtex,
|
|
573
625
|
data_captured_start=dataset_meta_info.get("data_captured_start", None),
|
|
574
626
|
data_captured_end=dataset_meta_info.get("data_captured_end", None),
|
|
575
627
|
data_received_start=dataset_meta_info.get("data_received_start", None),
|
|
@@ -594,7 +646,7 @@ def create_gallery_images(
|
|
|
594
646
|
path_gallery_images.mkdir(parents=True, exist_ok=True)
|
|
595
647
|
COL_IMAGE_NAME = "image_name"
|
|
596
648
|
samples = dataset.samples.with_columns(
|
|
597
|
-
dataset.samples[ColumnName.
|
|
649
|
+
dataset.samples[ColumnName.FILE_PATH].str.split("/").list.last().alias(COL_IMAGE_NAME)
|
|
598
650
|
)
|
|
599
651
|
gallery_samples = samples.filter(pl.col(COL_IMAGE_NAME).is_in(gallery_image_names))
|
|
600
652
|
|
|
@@ -604,6 +656,9 @@ def create_gallery_images(
|
|
|
604
656
|
gallery_images = []
|
|
605
657
|
for gallery_sample in gallery_samples.iter_rows(named=True):
|
|
606
658
|
sample = Sample(**gallery_sample)
|
|
659
|
+
|
|
660
|
+
metadata = DatasetImageMetadata.from_sample(sample=sample)
|
|
661
|
+
sample.classifications = None # To not draw classifications in gallery images
|
|
607
662
|
image = sample.draw_annotations()
|
|
608
663
|
|
|
609
664
|
path_gallery_image = path_gallery_images / gallery_sample[COL_IMAGE_NAME]
|
|
@@ -611,6 +666,7 @@ def create_gallery_images(
|
|
|
611
666
|
|
|
612
667
|
dataset_image_dict = {
|
|
613
668
|
"img": path_gallery_image,
|
|
669
|
+
"metadata": metadata,
|
|
614
670
|
}
|
|
615
671
|
if sample.attribution is not None:
|
|
616
672
|
sample.attribution.changes = "Annotations have been visualized"
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
import shutil
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import List, Optional
|
|
4
|
+
|
|
5
|
+
import more_itertools
|
|
6
|
+
import polars as pl
|
|
7
|
+
from PIL import Image
|
|
8
|
+
from rich.progress import track
|
|
9
|
+
|
|
10
|
+
from hafnia.dataset.dataset_names import ColumnName, FieldName
|
|
11
|
+
from hafnia.dataset.hafnia_dataset import DatasetInfo, HafniaDataset, Sample, TaskInfo
|
|
12
|
+
from hafnia.dataset.primitives import Classification
|
|
13
|
+
from hafnia.utils import is_image_file
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def import_image_classification_directory_tree(
|
|
17
|
+
path_folder: Path,
|
|
18
|
+
split: str,
|
|
19
|
+
n_samples: Optional[int] = None,
|
|
20
|
+
) -> HafniaDataset:
|
|
21
|
+
class_folder_paths = [path for path in path_folder.iterdir() if path.is_dir()]
|
|
22
|
+
class_names = sorted([folder.name for folder in class_folder_paths]) # Sort for determinism
|
|
23
|
+
|
|
24
|
+
# Gather all image paths per class
|
|
25
|
+
path_images_per_class: List[List[Path]] = []
|
|
26
|
+
for path_class_folder in class_folder_paths:
|
|
27
|
+
per_class_images = []
|
|
28
|
+
for path_image in list(path_class_folder.rglob("*.*")):
|
|
29
|
+
if is_image_file(path_image):
|
|
30
|
+
per_class_images.append(path_image)
|
|
31
|
+
path_images_per_class.append(sorted(per_class_images))
|
|
32
|
+
|
|
33
|
+
# Interleave to ensure classes are balanced in the output dataset for n_samples < total
|
|
34
|
+
path_images = list(more_itertools.interleave_longest(*path_images_per_class))
|
|
35
|
+
|
|
36
|
+
if n_samples is not None:
|
|
37
|
+
path_images = path_images[:n_samples]
|
|
38
|
+
|
|
39
|
+
samples = []
|
|
40
|
+
for path_image_org in track(path_images, description="Convert 'image classification' dataset to Hafnia Dataset"):
|
|
41
|
+
class_name = path_image_org.parent.name
|
|
42
|
+
|
|
43
|
+
read_image = Image.open(path_image_org)
|
|
44
|
+
width, height = read_image.size
|
|
45
|
+
|
|
46
|
+
classifications = [Classification(class_name=class_name, class_idx=class_names.index(class_name))]
|
|
47
|
+
sample = Sample(
|
|
48
|
+
file_path=str(path_image_org.absolute()),
|
|
49
|
+
width=width,
|
|
50
|
+
height=height,
|
|
51
|
+
split=split,
|
|
52
|
+
classifications=classifications,
|
|
53
|
+
)
|
|
54
|
+
samples.append(sample)
|
|
55
|
+
|
|
56
|
+
dataset_info = DatasetInfo(
|
|
57
|
+
dataset_name="ImageClassificationFromDirectoryTree",
|
|
58
|
+
tasks=[TaskInfo(primitive=Classification, class_names=class_names)],
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
hafnia_dataset = HafniaDataset.from_samples_list(samples_list=samples, info=dataset_info)
|
|
62
|
+
return hafnia_dataset
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def export_image_classification_directory_tree(
|
|
66
|
+
dataset: HafniaDataset,
|
|
67
|
+
path_output: Path,
|
|
68
|
+
task_name: Optional[str] = None,
|
|
69
|
+
clean_folder: bool = False,
|
|
70
|
+
) -> Path:
|
|
71
|
+
task = dataset.info.get_task_by_task_name_and_primitive(task_name=task_name, primitive=Classification)
|
|
72
|
+
|
|
73
|
+
samples = dataset.samples.with_columns(
|
|
74
|
+
pl.col(task.primitive.column_name())
|
|
75
|
+
.list.filter(pl.element().struct.field(FieldName.TASK_NAME) == task.name)
|
|
76
|
+
.alias(task.primitive.column_name())
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
classification_counts = samples[task.primitive.column_name()].list.len()
|
|
80
|
+
has_no_classification_samples = (classification_counts == 0).sum()
|
|
81
|
+
if has_no_classification_samples > 0:
|
|
82
|
+
raise ValueError(f"Some samples do not have a classification for task '{task.name}'.")
|
|
83
|
+
|
|
84
|
+
has_multi_classification_samples = (classification_counts > 1).sum()
|
|
85
|
+
if has_multi_classification_samples > 0:
|
|
86
|
+
raise ValueError(f"Some samples have multiple classifications for task '{task.name}'.")
|
|
87
|
+
|
|
88
|
+
if clean_folder:
|
|
89
|
+
shutil.rmtree(path_output, ignore_errors=True)
|
|
90
|
+
path_output.mkdir(parents=True, exist_ok=True)
|
|
91
|
+
|
|
92
|
+
description = "Export Hafnia Dataset to directory tree"
|
|
93
|
+
for sample_dict in track(samples.iter_rows(named=True), total=len(samples), description=description):
|
|
94
|
+
classifications = sample_dict[task.primitive.column_name()]
|
|
95
|
+
if len(classifications) != 1:
|
|
96
|
+
raise ValueError("Each sample should have exactly one classification.")
|
|
97
|
+
classification = classifications[0]
|
|
98
|
+
class_name = classification[FieldName.CLASS_NAME].replace("/", "_") # Avoid issues with subfolders
|
|
99
|
+
path_class_folder = path_output / class_name
|
|
100
|
+
path_class_folder.mkdir(parents=True, exist_ok=True)
|
|
101
|
+
|
|
102
|
+
path_image_org = Path(sample_dict[ColumnName.FILE_PATH])
|
|
103
|
+
path_image_new = path_class_folder / path_image_org.name
|
|
104
|
+
shutil.copy2(path_image_org, path_image_new)
|
|
105
|
+
|
|
106
|
+
return path_output
|