hafnia 0.1.27__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/__main__.py +2 -2
- cli/config.py +17 -4
- cli/dataset_cmds.py +60 -0
- cli/runc_cmds.py +1 -1
- hafnia/data/__init__.py +2 -2
- hafnia/data/factory.py +12 -56
- hafnia/dataset/dataset_helpers.py +91 -0
- hafnia/dataset/dataset_names.py +72 -0
- hafnia/dataset/dataset_recipe/dataset_recipe.py +327 -0
- hafnia/dataset/dataset_recipe/recipe_transforms.py +53 -0
- hafnia/dataset/dataset_recipe/recipe_types.py +140 -0
- hafnia/dataset/dataset_upload_helper.py +468 -0
- hafnia/dataset/hafnia_dataset.py +624 -0
- hafnia/dataset/operations/dataset_stats.py +15 -0
- hafnia/dataset/operations/dataset_transformations.py +82 -0
- hafnia/dataset/operations/table_transformations.py +183 -0
- hafnia/dataset/primitives/__init__.py +16 -0
- hafnia/dataset/primitives/bbox.py +137 -0
- hafnia/dataset/primitives/bitmask.py +182 -0
- hafnia/dataset/primitives/classification.py +56 -0
- hafnia/dataset/primitives/point.py +25 -0
- hafnia/dataset/primitives/polygon.py +100 -0
- hafnia/dataset/primitives/primitive.py +44 -0
- hafnia/dataset/primitives/segmentation.py +51 -0
- hafnia/dataset/primitives/utils.py +51 -0
- hafnia/experiment/hafnia_logger.py +7 -7
- hafnia/helper_testing.py +108 -0
- hafnia/http.py +5 -3
- hafnia/platform/__init__.py +2 -2
- hafnia/platform/datasets.py +197 -0
- hafnia/platform/download.py +85 -23
- hafnia/torch_helpers.py +180 -95
- hafnia/utils.py +21 -2
- hafnia/visualizations/colors.py +267 -0
- hafnia/visualizations/image_visualizations.py +202 -0
- {hafnia-0.1.27.dist-info → hafnia-0.2.1.dist-info}/METADATA +209 -99
- hafnia-0.2.1.dist-info/RECORD +50 -0
- cli/data_cmds.py +0 -53
- hafnia-0.1.27.dist-info/RECORD +0 -27
- {hafnia-0.1.27.dist-info → hafnia-0.2.1.dist-info}/WHEEL +0 -0
- {hafnia-0.1.27.dist-info → hafnia-0.2.1.dist-info}/entry_points.txt +0 -0
- {hafnia-0.1.27.dist-info → hafnia-0.2.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
|
2
|
+
|
|
3
|
+
import cv2
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from hafnia.dataset.primitives.bitmask import Bitmask
|
|
7
|
+
from hafnia.dataset.primitives.point import Point
|
|
8
|
+
from hafnia.dataset.primitives.primitive import Primitive
|
|
9
|
+
from hafnia.dataset.primitives.utils import class_color_by_name, get_class_name
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Polygon(Primitive):
|
|
13
|
+
# Names should match names in FieldName
|
|
14
|
+
points: List[Point]
|
|
15
|
+
class_name: Optional[str] = None # This should match the string in 'FieldName.CLASS_NAME'
|
|
16
|
+
class_idx: Optional[int] = None # This should match the string in 'FieldName.CLASS_IDX'
|
|
17
|
+
object_id: Optional[str] = None # This should match the string in 'FieldName.OBJECT_ID'
|
|
18
|
+
confidence: Optional[float] = None # Confidence score (0-1.0) for the primitive, e.g. 0.95 for Bbox
|
|
19
|
+
ground_truth: bool = True # Whether this is ground truth or a prediction
|
|
20
|
+
|
|
21
|
+
task_name: str = "" # Task name to support multiple Polygon tasks in the same dataset. "" defaults to "polygon"
|
|
22
|
+
meta: Optional[Dict[str, Any]] = None # This can be used to store additional information about the bitmask
|
|
23
|
+
|
|
24
|
+
@staticmethod
|
|
25
|
+
def from_list_of_points(
|
|
26
|
+
points: Sequence[Sequence[float]],
|
|
27
|
+
class_name: Optional[str] = None,
|
|
28
|
+
class_idx: Optional[int] = None,
|
|
29
|
+
object_id: Optional[str] = None,
|
|
30
|
+
) -> "Polygon":
|
|
31
|
+
list_points = [Point(x=point[0], y=point[1]) for point in points]
|
|
32
|
+
return Polygon(points=list_points, class_name=class_name, class_idx=class_idx, object_id=object_id)
|
|
33
|
+
|
|
34
|
+
@staticmethod
|
|
35
|
+
def default_task_name() -> str:
|
|
36
|
+
return "polygon"
|
|
37
|
+
|
|
38
|
+
@staticmethod
|
|
39
|
+
def column_name() -> str:
|
|
40
|
+
return "polygons"
|
|
41
|
+
|
|
42
|
+
def calculate_area(self) -> float:
|
|
43
|
+
raise NotImplementedError()
|
|
44
|
+
|
|
45
|
+
def to_pixel_coordinates(
|
|
46
|
+
self, image_shape: Tuple[int, int], as_int: bool = True, clip_values: bool = True
|
|
47
|
+
) -> List[Tuple]:
|
|
48
|
+
points = [
|
|
49
|
+
point.to_pixel_coordinates(image_shape=image_shape, as_int=as_int, clip_values=clip_values)
|
|
50
|
+
for point in self.points
|
|
51
|
+
]
|
|
52
|
+
return points
|
|
53
|
+
|
|
54
|
+
def draw(self, image: np.ndarray, inplace: bool = False) -> np.ndarray:
|
|
55
|
+
if not inplace:
|
|
56
|
+
image = image.copy()
|
|
57
|
+
points = np.array(self.to_pixel_coordinates(image_shape=image.shape[:2]))
|
|
58
|
+
|
|
59
|
+
bottom_left_idx = np.lexsort((-points[:, 1], points[:, 0]))[0]
|
|
60
|
+
bottom_left_np = points[bottom_left_idx, :]
|
|
61
|
+
margin = 5
|
|
62
|
+
bottom_left = (bottom_left_np[0] + margin, bottom_left_np[1] - margin)
|
|
63
|
+
|
|
64
|
+
class_name = self.get_class_name()
|
|
65
|
+
color = class_color_by_name(class_name)
|
|
66
|
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
|
67
|
+
cv2.polylines(image, [points], isClosed=True, color=(0, 255, 0), thickness=2)
|
|
68
|
+
cv2.putText(
|
|
69
|
+
img=image, text=class_name, org=bottom_left, fontFace=font, fontScale=0.75, color=color, thickness=2
|
|
70
|
+
)
|
|
71
|
+
return image
|
|
72
|
+
|
|
73
|
+
def anonymize_by_blurring(self, image: np.ndarray, inplace: bool = False, max_resolution: int = 20) -> np.ndarray:
|
|
74
|
+
if not inplace:
|
|
75
|
+
image = image.copy()
|
|
76
|
+
points = np.array(self.to_pixel_coordinates(image_shape=image.shape[:2]))
|
|
77
|
+
mask = np.zeros(image.shape[:2], dtype=np.uint8)
|
|
78
|
+
mask = cv2.fillPoly(mask, [points], color=255).astype(bool)
|
|
79
|
+
bitmask = Bitmask.from_mask(mask=mask, top=0, left=0).squeeze_mask()
|
|
80
|
+
image = bitmask.anonymize_by_blurring(image=image, inplace=inplace, max_resolution=max_resolution)
|
|
81
|
+
|
|
82
|
+
return image
|
|
83
|
+
|
|
84
|
+
def mask(
|
|
85
|
+
self, image: np.ndarray, inplace: bool = False, color: Optional[Tuple[np.uint8, np.uint8, np.uint8]] = None
|
|
86
|
+
) -> np.ndarray:
|
|
87
|
+
if not inplace:
|
|
88
|
+
image = image.copy()
|
|
89
|
+
points = self.to_pixel_coordinates(image_shape=image.shape[:2])
|
|
90
|
+
|
|
91
|
+
if color is None:
|
|
92
|
+
mask = np.zeros_like(image[:, :, 0])
|
|
93
|
+
bitmask = cv2.fillPoly(mask, pts=[np.array(points)], color=255).astype(bool) # type: ignore[assignment]
|
|
94
|
+
color = tuple(int(value) for value in np.mean(image[bitmask], axis=0)) # type: ignore[assignment]
|
|
95
|
+
|
|
96
|
+
cv2.fillPoly(image, [np.array(points)], color=color)
|
|
97
|
+
return image
|
|
98
|
+
|
|
99
|
+
def get_class_name(self) -> str:
|
|
100
|
+
return get_class_name(self.class_name, self.class_idx)
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABCMeta, abstractmethod
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Primitive(BaseModel, metaclass=ABCMeta):
|
|
10
|
+
def model_post_init(self, context) -> None:
|
|
11
|
+
if self.task_name == "": # type: ignore[has-type] # Hack because 'task_name' doesn't exist in base-class yet.
|
|
12
|
+
self.task_name = self.default_task_name()
|
|
13
|
+
|
|
14
|
+
@staticmethod
|
|
15
|
+
@abstractmethod
|
|
16
|
+
def default_task_name() -> str:
|
|
17
|
+
# E.g. "return bboxes" for Bbox
|
|
18
|
+
raise NotImplementedError
|
|
19
|
+
|
|
20
|
+
@staticmethod
|
|
21
|
+
@abstractmethod
|
|
22
|
+
def column_name() -> str:
|
|
23
|
+
"""
|
|
24
|
+
Name of field used in hugging face datasets for storing annotations
|
|
25
|
+
E.g. "objects" for Bbox.
|
|
26
|
+
"""
|
|
27
|
+
pass
|
|
28
|
+
|
|
29
|
+
@abstractmethod
|
|
30
|
+
def calculate_area(self) -> float:
|
|
31
|
+
# Calculate the area of the primitive
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
@abstractmethod
|
|
35
|
+
def draw(self, image: np.ndarray, inplace: bool = False) -> np.ndarray:
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
@abstractmethod
|
|
39
|
+
def mask(self, image: np.ndarray, inplace: bool = False) -> np.ndarray:
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
@abstractmethod
|
|
43
|
+
def anonymize_by_blurring(self, image: np.ndarray, inplace: bool = False, max_resolution: int = 20) -> np.ndarray:
|
|
44
|
+
pass
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
2
|
+
|
|
3
|
+
import cv2
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from hafnia.dataset.primitives.primitive import Primitive
|
|
7
|
+
from hafnia.dataset.primitives.utils import get_class_name
|
|
8
|
+
from hafnia.visualizations.colors import get_n_colors
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Segmentation(Primitive):
|
|
12
|
+
# mask: np.ndarray
|
|
13
|
+
class_names: Optional[List[str]] = None # This should match the string in 'FieldName.CLASS_NAME'
|
|
14
|
+
ground_truth: bool = True # Whether this is ground truth or a prediction
|
|
15
|
+
|
|
16
|
+
# confidence: Optional[float] = None # Confidence score (0-1.0) for the primitive, e.g. 0.95 for Classification
|
|
17
|
+
task_name: str = (
|
|
18
|
+
"" # Task name to support multiple Segmentation tasks in the same dataset. "" defaults to "segmentation"
|
|
19
|
+
)
|
|
20
|
+
meta: Optional[Dict[str, Any]] = None # This can be used to store additional information about the bitmask
|
|
21
|
+
|
|
22
|
+
@staticmethod
|
|
23
|
+
def default_task_name() -> str:
|
|
24
|
+
return "segmentation"
|
|
25
|
+
|
|
26
|
+
@staticmethod
|
|
27
|
+
def column_name() -> str:
|
|
28
|
+
return "segmentation"
|
|
29
|
+
|
|
30
|
+
def calculate_area(self) -> float:
|
|
31
|
+
raise NotImplementedError()
|
|
32
|
+
|
|
33
|
+
def draw(self, image: np.ndarray, inplace: bool = False) -> np.ndarray:
|
|
34
|
+
if not inplace:
|
|
35
|
+
image = image.copy()
|
|
36
|
+
|
|
37
|
+
color_mapping = np.asarray(get_n_colors(len(self.class_names)), dtype=np.uint8) # type: ignore[arg-type]
|
|
38
|
+
label_image = color_mapping[self.mask]
|
|
39
|
+
blended = cv2.addWeighted(image, 0.5, label_image, 0.5, 0)
|
|
40
|
+
return blended
|
|
41
|
+
|
|
42
|
+
def mask(
|
|
43
|
+
self, image: np.ndarray, inplace: bool = False, color: Optional[Tuple[np.uint8, np.uint8, np.uint8]] = None
|
|
44
|
+
) -> np.ndarray:
|
|
45
|
+
return image
|
|
46
|
+
|
|
47
|
+
def anonymize_by_blurring(self, image: np.ndarray, inplace: bool = False, max_resolution: int = 20) -> np.ndarray:
|
|
48
|
+
return image
|
|
49
|
+
|
|
50
|
+
def get_class_name(self) -> str:
|
|
51
|
+
return get_class_name(self.class_name, self.class_idx)
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
from typing import Optional, Tuple, Union
|
|
3
|
+
|
|
4
|
+
import cv2
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def text_org_from_left_bottom_to_centered(xy_org: tuple, text: str, font, font_scale: float, thickness: int) -> tuple:
|
|
9
|
+
xy_text_size = cv2.getTextSize(text, fontFace=font, fontScale=font_scale, thickness=thickness)[0]
|
|
10
|
+
xy_text_size_half = np.array(xy_text_size) / 2
|
|
11
|
+
xy_centered_np = xy_org + xy_text_size_half * np.array([-1, 1])
|
|
12
|
+
xy_centered = tuple(int(value) for value in xy_centered_np)
|
|
13
|
+
return xy_centered
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def round_int_clip_value(value: Union[int, float], max_value: int) -> int:
|
|
17
|
+
return clip(value=int(round(value)), v_min=0, v_max=max_value) # noqa: RUF046
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def class_color_by_name(name: str) -> Tuple[int, int, int]:
|
|
21
|
+
# Create a hash of the class name
|
|
22
|
+
hash_object = hashlib.md5(name.encode())
|
|
23
|
+
# Use the hash to generate a color
|
|
24
|
+
hash_digest = hash_object.hexdigest()
|
|
25
|
+
color = (int(hash_digest[0:2], 16), int(hash_digest[2:4], 16), int(hash_digest[4:6], 16))
|
|
26
|
+
return color
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# Define an abstract base class
|
|
30
|
+
def clip(value, v_min, v_max):
|
|
31
|
+
return min(max(v_min, value), v_max)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def get_class_name(class_name: Optional[str], class_idx: Optional[int]) -> str:
|
|
35
|
+
if class_name is not None:
|
|
36
|
+
return class_name
|
|
37
|
+
if class_idx is not None:
|
|
38
|
+
return f"IDX:{class_idx}"
|
|
39
|
+
return "NoName"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def anonymize_by_resizing(blur_region: np.ndarray, max_resolution: int = 20) -> np.ndarray:
|
|
43
|
+
"""
|
|
44
|
+
Removes high-frequency details from a region of an image by resizing it down and then back up.
|
|
45
|
+
"""
|
|
46
|
+
original_shape = blur_region.shape[:2]
|
|
47
|
+
resize_factor = max(original_shape) / max_resolution
|
|
48
|
+
new_size = (int(original_shape[0] / resize_factor), int(original_shape[1] / resize_factor))
|
|
49
|
+
blur_region_downsized = cv2.resize(blur_region, new_size[::-1], interpolation=cv2.INTER_LINEAR)
|
|
50
|
+
blur_region_upsized = cv2.resize(blur_region_downsized, original_shape[::-1], interpolation=cv2.INTER_LINEAR)
|
|
51
|
+
return blur_region_upsized
|
|
@@ -9,12 +9,12 @@ from typing import Dict, Optional, Union
|
|
|
9
9
|
|
|
10
10
|
import pyarrow as pa
|
|
11
11
|
import pyarrow.parquet as pq
|
|
12
|
-
from datasets import DatasetDict
|
|
13
12
|
from pydantic import BaseModel, field_validator
|
|
14
13
|
|
|
15
14
|
from hafnia.data.factory import load_dataset
|
|
15
|
+
from hafnia.dataset.hafnia_dataset import HafniaDataset
|
|
16
16
|
from hafnia.log import sys_logger, user_logger
|
|
17
|
-
from hafnia.utils import
|
|
17
|
+
from hafnia.utils import is_hafnia_cloud_job, now_as_str
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
class EntityType(Enum):
|
|
@@ -92,7 +92,7 @@ class HafniaLogger:
|
|
|
92
92
|
self.schema = Entity.create_schema()
|
|
93
93
|
self.log_environment()
|
|
94
94
|
|
|
95
|
-
def load_dataset(self, dataset_name: str) ->
|
|
95
|
+
def load_dataset(self, dataset_name: str) -> HafniaDataset:
|
|
96
96
|
"""
|
|
97
97
|
Load a dataset from the specified path.
|
|
98
98
|
"""
|
|
@@ -101,7 +101,7 @@ class HafniaLogger:
|
|
|
101
101
|
|
|
102
102
|
def path_local_experiment(self) -> Path:
|
|
103
103
|
"""Get the path for local experiment."""
|
|
104
|
-
if
|
|
104
|
+
if is_hafnia_cloud_job():
|
|
105
105
|
raise RuntimeError("Cannot access local experiment path in remote job.")
|
|
106
106
|
return self._local_experiment_path
|
|
107
107
|
|
|
@@ -110,7 +110,7 @@ class HafniaLogger:
|
|
|
110
110
|
if "MDI_CHECKPOINT_DIR" in os.environ:
|
|
111
111
|
return Path(os.environ["MDI_CHECKPOINT_DIR"])
|
|
112
112
|
|
|
113
|
-
if
|
|
113
|
+
if is_hafnia_cloud_job():
|
|
114
114
|
return Path("/opt/ml/checkpoints")
|
|
115
115
|
return self.path_local_experiment() / "checkpoints"
|
|
116
116
|
|
|
@@ -119,7 +119,7 @@ class HafniaLogger:
|
|
|
119
119
|
if "MDI_ARTIFACT_DIR" in os.environ:
|
|
120
120
|
return Path(os.environ["MDI_ARTIFACT_DIR"])
|
|
121
121
|
|
|
122
|
-
if
|
|
122
|
+
if is_hafnia_cloud_job():
|
|
123
123
|
return Path("/opt/ml/output/data")
|
|
124
124
|
|
|
125
125
|
return self.path_local_experiment() / "data"
|
|
@@ -129,7 +129,7 @@ class HafniaLogger:
|
|
|
129
129
|
if "MDI_MODEL_DIR" in os.environ:
|
|
130
130
|
return Path(os.environ["MDI_MODEL_DIR"])
|
|
131
131
|
|
|
132
|
-
if
|
|
132
|
+
if is_hafnia_cloud_job():
|
|
133
133
|
return Path("/opt/ml/model")
|
|
134
134
|
|
|
135
135
|
return self.path_local_experiment() / "model"
|
hafnia/helper_testing.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
from inspect import getmembers, isfunction, signature
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from types import FunctionType
|
|
4
|
+
from typing import Any, Callable, Dict, Union, get_origin
|
|
5
|
+
|
|
6
|
+
from hafnia import utils
|
|
7
|
+
from hafnia.dataset.dataset_names import FILENAME_ANNOTATIONS_JSONL, DatasetVariant
|
|
8
|
+
from hafnia.dataset.hafnia_dataset import HafniaDataset, Sample
|
|
9
|
+
|
|
10
|
+
MICRO_DATASETS = {
|
|
11
|
+
"tiny-dataset": utils.PATH_DATASETS / "tiny-dataset",
|
|
12
|
+
"coco-2017": utils.PATH_DATASETS / "coco-2017",
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def get_path_workspace() -> Path:
|
|
17
|
+
return Path(__file__).parents[2]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def get_path_expected_images() -> Path:
|
|
21
|
+
return get_path_workspace() / "tests" / "data" / "expected_images"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def get_path_test_data() -> Path:
|
|
25
|
+
return get_path_workspace() / "tests" / "data"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_path_micro_hafnia_dataset_no_check() -> Path:
|
|
29
|
+
return get_path_test_data() / "micro_test_datasets"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def get_path_micro_hafnia_dataset(dataset_name: str, force_update=False) -> Path:
|
|
33
|
+
import pytest
|
|
34
|
+
|
|
35
|
+
if dataset_name not in MICRO_DATASETS:
|
|
36
|
+
raise ValueError(f"Dataset name '{dataset_name}' is not recognized. Available options: {list(MICRO_DATASETS)}")
|
|
37
|
+
path_dataset = MICRO_DATASETS[dataset_name]
|
|
38
|
+
|
|
39
|
+
path_test_dataset = get_path_micro_hafnia_dataset_no_check() / dataset_name
|
|
40
|
+
path_test_dataset_annotations = path_test_dataset / FILENAME_ANNOTATIONS_JSONL
|
|
41
|
+
if path_test_dataset_annotations.exists() and not force_update:
|
|
42
|
+
return path_test_dataset
|
|
43
|
+
|
|
44
|
+
hafnia_dataset = HafniaDataset.from_path(path_dataset / DatasetVariant.SAMPLE.value)
|
|
45
|
+
hafnia_dataset = hafnia_dataset.select_samples(n_samples=3, seed=42)
|
|
46
|
+
hafnia_dataset.write(path_test_dataset)
|
|
47
|
+
|
|
48
|
+
if force_update:
|
|
49
|
+
pytest.fail(
|
|
50
|
+
"Sample image and metadata have been updated using 'force_update=True'. Set 'force_update=False' and rerun the test."
|
|
51
|
+
)
|
|
52
|
+
pytest.fail("Missing test sample image. Please rerun the test.")
|
|
53
|
+
return path_test_dataset
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def get_sample_micro_hafnia_dataset(dataset_name: str, force_update=False) -> Sample:
|
|
57
|
+
micro_dataset = get_micro_hafnia_dataset(dataset_name=dataset_name, force_update=force_update)
|
|
58
|
+
sample_dict = micro_dataset[0]
|
|
59
|
+
sample = Sample(**sample_dict)
|
|
60
|
+
return sample
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def get_micro_hafnia_dataset(dataset_name: str, force_update: bool = False) -> HafniaDataset:
|
|
64
|
+
path_dataset = get_path_micro_hafnia_dataset(dataset_name=dataset_name, force_update=force_update)
|
|
65
|
+
hafnia_dataset = HafniaDataset.from_path(path_dataset)
|
|
66
|
+
return hafnia_dataset
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def is_hafnia_configured() -> bool:
|
|
70
|
+
"""
|
|
71
|
+
Check if Hafnia is configured by verifying if the API key is set.
|
|
72
|
+
"""
|
|
73
|
+
from cli.config import Config
|
|
74
|
+
|
|
75
|
+
return Config().is_configured()
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def is_typing_type(annotation: Any) -> bool:
|
|
79
|
+
return get_origin(annotation) is not None
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def annotation_as_string(annotation: Union[type, str]) -> str:
|
|
83
|
+
"""Convert type annotation to string."""
|
|
84
|
+
if isinstance(annotation, str):
|
|
85
|
+
return annotation.replace("'", "")
|
|
86
|
+
if is_typing_type(annotation): # Is using typing types like List, Dict, etc.
|
|
87
|
+
return str(annotation).replace("typing.", "")
|
|
88
|
+
if hasattr(annotation, "__name__"):
|
|
89
|
+
return annotation.__name__
|
|
90
|
+
return str(annotation)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def get_hafnia_functions_from_module(python_module) -> Dict[str, FunctionType]:
|
|
94
|
+
def dataset_is_first_arg(func: Callable) -> bool:
|
|
95
|
+
"""
|
|
96
|
+
Check if the function has 'HafniaDataset' as the first parameter.
|
|
97
|
+
"""
|
|
98
|
+
func_signature = signature(func)
|
|
99
|
+
params = func_signature.parameters
|
|
100
|
+
if len(params) == 0:
|
|
101
|
+
return False
|
|
102
|
+
first_argument_type = list(params.values())[0]
|
|
103
|
+
|
|
104
|
+
annotation_as_str = annotation_as_string(first_argument_type.annotation)
|
|
105
|
+
return annotation_as_str == "HafniaDataset"
|
|
106
|
+
|
|
107
|
+
functions = {func[0]: func[1] for func in getmembers(python_module, isfunction) if dataset_is_first_arg(func[1])}
|
|
108
|
+
return functions
|
hafnia/http.py
CHANGED
|
@@ -31,7 +31,7 @@ def fetch(endpoint: str, headers: Dict, params: Optional[Dict] = None) -> Dict:
|
|
|
31
31
|
http.clear()
|
|
32
32
|
|
|
33
33
|
|
|
34
|
-
def post(endpoint: str, headers: Dict, data: Union[Path, Dict, bytes], multipart: bool = False) -> Dict:
|
|
34
|
+
def post(endpoint: str, headers: Dict, data: Union[Path, Dict, bytes, str], multipart: bool = False) -> Dict:
|
|
35
35
|
"""Posts data to backend endpoint.
|
|
36
36
|
|
|
37
37
|
Args:
|
|
@@ -64,9 +64,11 @@ def post(endpoint: str, headers: Dict, data: Union[Path, Dict, bytes], multipart
|
|
|
64
64
|
with open(data, "rb") as f:
|
|
65
65
|
body = f.read()
|
|
66
66
|
response = http.request("POST", endpoint, body=body, headers=headers)
|
|
67
|
-
elif isinstance(data, dict):
|
|
67
|
+
elif isinstance(data, (str, dict)):
|
|
68
|
+
if isinstance(data, dict):
|
|
69
|
+
data = json.dumps(data)
|
|
68
70
|
headers["Content-Type"] = "application/json"
|
|
69
|
-
response = http.request("POST", endpoint, body=
|
|
71
|
+
response = http.request("POST", endpoint, body=data, headers=headers)
|
|
70
72
|
elif isinstance(data, bytes):
|
|
71
73
|
response = http.request("POST", endpoint, body=data, headers=headers)
|
|
72
74
|
else:
|
hafnia/platform/__init__.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from hafnia.platform.download import (
|
|
2
2
|
download_resource,
|
|
3
3
|
download_single_object,
|
|
4
|
-
|
|
4
|
+
get_resource_credentials,
|
|
5
5
|
)
|
|
6
6
|
from hafnia.platform.experiment import (
|
|
7
7
|
create_experiment,
|
|
@@ -17,5 +17,5 @@ __all__ = [
|
|
|
17
17
|
"create_experiment",
|
|
18
18
|
"download_resource",
|
|
19
19
|
"download_single_object",
|
|
20
|
-
"
|
|
20
|
+
"get_resource_credentials",
|
|
21
21
|
]
|
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import shutil
|
|
3
|
+
import subprocess
|
|
4
|
+
import tempfile
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Dict, List, Optional
|
|
7
|
+
|
|
8
|
+
import rich
|
|
9
|
+
from tqdm import tqdm
|
|
10
|
+
|
|
11
|
+
from cli.config import Config
|
|
12
|
+
from hafnia import utils
|
|
13
|
+
from hafnia.dataset.dataset_names import DATASET_FILENAMES_REQUIRED, ColumnName
|
|
14
|
+
from hafnia.dataset.dataset_recipe.dataset_recipe import (
|
|
15
|
+
DatasetRecipe,
|
|
16
|
+
get_dataset_path_from_recipe,
|
|
17
|
+
)
|
|
18
|
+
from hafnia.dataset.hafnia_dataset import HafniaDataset
|
|
19
|
+
from hafnia.http import fetch
|
|
20
|
+
from hafnia.log import sys_logger, user_logger
|
|
21
|
+
from hafnia.platform import get_dataset_id
|
|
22
|
+
from hafnia.platform.download import get_resource_credentials
|
|
23
|
+
from hafnia.utils import timed
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@timed("Fetching dataset list.")
|
|
27
|
+
def dataset_list(cfg: Optional[Config] = None) -> List[Dict[str, str]]:
|
|
28
|
+
"""List available datasets on the Hafnia platform."""
|
|
29
|
+
cfg = cfg or Config()
|
|
30
|
+
endpoint_dataset = cfg.get_platform_endpoint("datasets")
|
|
31
|
+
header = {"Authorization": cfg.api_key}
|
|
32
|
+
datasets: List[Dict[str, str]] = fetch(endpoint_dataset, headers=header) # type: ignore
|
|
33
|
+
if not datasets:
|
|
34
|
+
raise ValueError("No datasets found on the Hafnia platform.")
|
|
35
|
+
|
|
36
|
+
return datasets
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def download_or_get_dataset_path(
|
|
40
|
+
dataset_name: str,
|
|
41
|
+
cfg: Optional[Config] = None,
|
|
42
|
+
path_datasets_folder: Optional[str] = None,
|
|
43
|
+
force_redownload: bool = False,
|
|
44
|
+
download_files: bool = True,
|
|
45
|
+
) -> Path:
|
|
46
|
+
"""Download or get the path of the dataset."""
|
|
47
|
+
recipe_explicit = DatasetRecipe.from_implicit_form(dataset_name)
|
|
48
|
+
path_dataset = get_dataset_path_from_recipe(recipe_explicit, path_datasets=path_datasets_folder)
|
|
49
|
+
|
|
50
|
+
is_dataset_valid = HafniaDataset.check_dataset_path(path_dataset, raise_error=False)
|
|
51
|
+
if is_dataset_valid and not force_redownload:
|
|
52
|
+
user_logger.info("Dataset found locally. Set 'force=True' or add `--force` flag with cli to re-download")
|
|
53
|
+
return path_dataset
|
|
54
|
+
|
|
55
|
+
cfg = cfg or Config()
|
|
56
|
+
api_key = cfg.api_key
|
|
57
|
+
|
|
58
|
+
shutil.rmtree(path_dataset, ignore_errors=True)
|
|
59
|
+
|
|
60
|
+
endpoint_dataset = cfg.get_platform_endpoint("datasets")
|
|
61
|
+
dataset_id = get_dataset_id(dataset_name=dataset_name, endpoint=endpoint_dataset, api_key=api_key)
|
|
62
|
+
if dataset_id is None:
|
|
63
|
+
sys_logger.error(f"Dataset '{dataset_name}' not found on the Hafnia platform.")
|
|
64
|
+
access_dataset_endpoint = f"{endpoint_dataset}/{dataset_id}/temporary-credentials"
|
|
65
|
+
|
|
66
|
+
download_dataset_from_access_endpoint(
|
|
67
|
+
endpoint=access_dataset_endpoint,
|
|
68
|
+
api_key=api_key,
|
|
69
|
+
path_dataset=path_dataset,
|
|
70
|
+
download_files=download_files,
|
|
71
|
+
)
|
|
72
|
+
return path_dataset
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def download_dataset_from_access_endpoint(
|
|
76
|
+
endpoint: str,
|
|
77
|
+
api_key: str,
|
|
78
|
+
path_dataset: Path,
|
|
79
|
+
download_files: bool = True,
|
|
80
|
+
) -> None:
|
|
81
|
+
resource_credentials = get_resource_credentials(endpoint, api_key)
|
|
82
|
+
|
|
83
|
+
local_dataset_paths = [str(path_dataset / filename) for filename in DATASET_FILENAMES_REQUIRED]
|
|
84
|
+
s3_uri = resource_credentials.s3_uri()
|
|
85
|
+
s3_dataset_files = [f"{s3_uri}/{filename}" for filename in DATASET_FILENAMES_REQUIRED]
|
|
86
|
+
|
|
87
|
+
envs = resource_credentials.aws_credentials()
|
|
88
|
+
fast_copy_files_s3(
|
|
89
|
+
src_paths=s3_dataset_files,
|
|
90
|
+
dst_paths=local_dataset_paths,
|
|
91
|
+
append_envs=envs,
|
|
92
|
+
description="Downloading annotations",
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
if not download_files:
|
|
96
|
+
return
|
|
97
|
+
|
|
98
|
+
dataset = HafniaDataset.from_path(path_dataset, check_for_images=False)
|
|
99
|
+
fast_copy_files_s3(
|
|
100
|
+
src_paths=dataset.samples[ColumnName.REMOTE_PATH].to_list(),
|
|
101
|
+
dst_paths=dataset.samples[ColumnName.FILE_NAME].to_list(),
|
|
102
|
+
append_envs=envs,
|
|
103
|
+
description="Downloading images",
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def fast_copy_files_s3(
|
|
108
|
+
src_paths: List[str],
|
|
109
|
+
dst_paths: List[str],
|
|
110
|
+
append_envs: Optional[Dict[str, str]] = None,
|
|
111
|
+
description: str = "Copying files",
|
|
112
|
+
) -> List[str]:
|
|
113
|
+
if len(src_paths) != len(dst_paths):
|
|
114
|
+
raise ValueError("Source and destination paths must have the same length.")
|
|
115
|
+
|
|
116
|
+
cmds = [f"cp {src} {dst}" for src, dst in zip(src_paths, dst_paths)]
|
|
117
|
+
lines = execute_s5cmd_commands(cmds, append_envs=append_envs, description=description)
|
|
118
|
+
return lines
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def execute_s5cmd_commands(
|
|
122
|
+
commands: List[str],
|
|
123
|
+
append_envs: Optional[Dict[str, str]] = None,
|
|
124
|
+
description: str = "Executing s5cmd commands",
|
|
125
|
+
) -> List[str]:
|
|
126
|
+
append_envs = append_envs or {}
|
|
127
|
+
with tempfile.NamedTemporaryFile(suffix=".txt") as tmp_file:
|
|
128
|
+
tmp_file_path = Path(tmp_file.name)
|
|
129
|
+
tmp_file_path.write_text("\n".join(commands))
|
|
130
|
+
run_cmds = [
|
|
131
|
+
"s5cmd",
|
|
132
|
+
"run",
|
|
133
|
+
str(tmp_file_path),
|
|
134
|
+
]
|
|
135
|
+
envs = os.environ.copy()
|
|
136
|
+
envs.update(append_envs)
|
|
137
|
+
|
|
138
|
+
process = subprocess.Popen(
|
|
139
|
+
run_cmds,
|
|
140
|
+
stdout=subprocess.PIPE,
|
|
141
|
+
stderr=subprocess.STDOUT,
|
|
142
|
+
universal_newlines=True,
|
|
143
|
+
env=envs,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
error_lines = []
|
|
147
|
+
lines = []
|
|
148
|
+
for line in tqdm(process.stdout, total=len(commands), desc=description):
|
|
149
|
+
if "ERROR" in line or "error" in line:
|
|
150
|
+
error_lines.append(line.strip())
|
|
151
|
+
lines.append(line.strip())
|
|
152
|
+
|
|
153
|
+
if len(error_lines) > 0:
|
|
154
|
+
show_n_lines = min(5, len(error_lines))
|
|
155
|
+
str_error_lines = "\n".join(error_lines[:show_n_lines])
|
|
156
|
+
user_logger.error(
|
|
157
|
+
f"Detected {len(error_lines)} errors occurred while executing a total of {len(commands)} "
|
|
158
|
+
f" commands with s5cmd. The first {show_n_lines} is printed below:\n{str_error_lines}"
|
|
159
|
+
)
|
|
160
|
+
raise RuntimeError("Errors occurred during s5cmd execution.")
|
|
161
|
+
return lines
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
TABLE_FIELDS = {
|
|
165
|
+
"ID": "id",
|
|
166
|
+
"Hidden\nSamples": "hidden.samples",
|
|
167
|
+
"Hidden\nSize": "hidden.size",
|
|
168
|
+
"Sample\nSamples": "sample.samples",
|
|
169
|
+
"Sample\nSize": "sample.size",
|
|
170
|
+
"Name": "name",
|
|
171
|
+
"Title": "title",
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def create_rich_table_from_dataset(datasets: List[Dict[str, str]]) -> rich.table.Table:
|
|
176
|
+
datasets = extend_dataset_details(datasets)
|
|
177
|
+
datasets = sorted(datasets, key=lambda x: x["name"].lower())
|
|
178
|
+
|
|
179
|
+
table = rich.table.Table(title="Available Datasets")
|
|
180
|
+
for i_dataset, dataset in enumerate(datasets):
|
|
181
|
+
if i_dataset == 0:
|
|
182
|
+
for column_name, _ in TABLE_FIELDS.items():
|
|
183
|
+
table.add_column(column_name, justify="left", style="cyan", no_wrap=True)
|
|
184
|
+
row = [str(dataset.get(field, "")) for field in TABLE_FIELDS.values()]
|
|
185
|
+
table.add_row(*row)
|
|
186
|
+
|
|
187
|
+
return table
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def extend_dataset_details(datasets: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
191
|
+
"""Extends dataset details with number of samples and size"""
|
|
192
|
+
for dataset in datasets:
|
|
193
|
+
for variant in dataset["dataset_variants"]:
|
|
194
|
+
variant_type = variant["variant_type"]
|
|
195
|
+
dataset[f"{variant_type}.samples"] = variant["number_of_data_items"]
|
|
196
|
+
dataset[f"{variant_type}.size"] = utils.size_human_readable(variant["size_bytes"])
|
|
197
|
+
return datasets
|