hafnia 0.1.27__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. cli/__main__.py +2 -2
  2. cli/config.py +17 -4
  3. cli/dataset_cmds.py +60 -0
  4. cli/runc_cmds.py +1 -1
  5. hafnia/data/__init__.py +2 -2
  6. hafnia/data/factory.py +12 -56
  7. hafnia/dataset/dataset_helpers.py +91 -0
  8. hafnia/dataset/dataset_names.py +72 -0
  9. hafnia/dataset/dataset_recipe/dataset_recipe.py +327 -0
  10. hafnia/dataset/dataset_recipe/recipe_transforms.py +53 -0
  11. hafnia/dataset/dataset_recipe/recipe_types.py +140 -0
  12. hafnia/dataset/dataset_upload_helper.py +468 -0
  13. hafnia/dataset/hafnia_dataset.py +624 -0
  14. hafnia/dataset/operations/dataset_stats.py +15 -0
  15. hafnia/dataset/operations/dataset_transformations.py +82 -0
  16. hafnia/dataset/operations/table_transformations.py +183 -0
  17. hafnia/dataset/primitives/__init__.py +16 -0
  18. hafnia/dataset/primitives/bbox.py +137 -0
  19. hafnia/dataset/primitives/bitmask.py +182 -0
  20. hafnia/dataset/primitives/classification.py +56 -0
  21. hafnia/dataset/primitives/point.py +25 -0
  22. hafnia/dataset/primitives/polygon.py +100 -0
  23. hafnia/dataset/primitives/primitive.py +44 -0
  24. hafnia/dataset/primitives/segmentation.py +51 -0
  25. hafnia/dataset/primitives/utils.py +51 -0
  26. hafnia/experiment/hafnia_logger.py +7 -7
  27. hafnia/helper_testing.py +108 -0
  28. hafnia/http.py +5 -3
  29. hafnia/platform/__init__.py +2 -2
  30. hafnia/platform/datasets.py +197 -0
  31. hafnia/platform/download.py +85 -23
  32. hafnia/torch_helpers.py +180 -95
  33. hafnia/utils.py +21 -2
  34. hafnia/visualizations/colors.py +267 -0
  35. hafnia/visualizations/image_visualizations.py +202 -0
  36. {hafnia-0.1.27.dist-info → hafnia-0.2.1.dist-info}/METADATA +209 -99
  37. hafnia-0.2.1.dist-info/RECORD +50 -0
  38. cli/data_cmds.py +0 -53
  39. hafnia-0.1.27.dist-info/RECORD +0 -27
  40. {hafnia-0.1.27.dist-info → hafnia-0.2.1.dist-info}/WHEEL +0 -0
  41. {hafnia-0.1.27.dist-info → hafnia-0.2.1.dist-info}/entry_points.txt +0 -0
  42. {hafnia-0.1.27.dist-info → hafnia-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,100 @@
1
+ from typing import Any, Dict, List, Optional, Sequence, Tuple
2
+
3
+ import cv2
4
+ import numpy as np
5
+
6
+ from hafnia.dataset.primitives.bitmask import Bitmask
7
+ from hafnia.dataset.primitives.point import Point
8
+ from hafnia.dataset.primitives.primitive import Primitive
9
+ from hafnia.dataset.primitives.utils import class_color_by_name, get_class_name
10
+
11
+
12
+ class Polygon(Primitive):
13
+ # Names should match names in FieldName
14
+ points: List[Point]
15
+ class_name: Optional[str] = None # This should match the string in 'FieldName.CLASS_NAME'
16
+ class_idx: Optional[int] = None # This should match the string in 'FieldName.CLASS_IDX'
17
+ object_id: Optional[str] = None # This should match the string in 'FieldName.OBJECT_ID'
18
+ confidence: Optional[float] = None # Confidence score (0-1.0) for the primitive, e.g. 0.95 for Bbox
19
+ ground_truth: bool = True # Whether this is ground truth or a prediction
20
+
21
+ task_name: str = "" # Task name to support multiple Polygon tasks in the same dataset. "" defaults to "polygon"
22
+ meta: Optional[Dict[str, Any]] = None # This can be used to store additional information about the bitmask
23
+
24
+ @staticmethod
25
+ def from_list_of_points(
26
+ points: Sequence[Sequence[float]],
27
+ class_name: Optional[str] = None,
28
+ class_idx: Optional[int] = None,
29
+ object_id: Optional[str] = None,
30
+ ) -> "Polygon":
31
+ list_points = [Point(x=point[0], y=point[1]) for point in points]
32
+ return Polygon(points=list_points, class_name=class_name, class_idx=class_idx, object_id=object_id)
33
+
34
+ @staticmethod
35
+ def default_task_name() -> str:
36
+ return "polygon"
37
+
38
+ @staticmethod
39
+ def column_name() -> str:
40
+ return "polygons"
41
+
42
+ def calculate_area(self) -> float:
43
+ raise NotImplementedError()
44
+
45
+ def to_pixel_coordinates(
46
+ self, image_shape: Tuple[int, int], as_int: bool = True, clip_values: bool = True
47
+ ) -> List[Tuple]:
48
+ points = [
49
+ point.to_pixel_coordinates(image_shape=image_shape, as_int=as_int, clip_values=clip_values)
50
+ for point in self.points
51
+ ]
52
+ return points
53
+
54
+ def draw(self, image: np.ndarray, inplace: bool = False) -> np.ndarray:
55
+ if not inplace:
56
+ image = image.copy()
57
+ points = np.array(self.to_pixel_coordinates(image_shape=image.shape[:2]))
58
+
59
+ bottom_left_idx = np.lexsort((-points[:, 1], points[:, 0]))[0]
60
+ bottom_left_np = points[bottom_left_idx, :]
61
+ margin = 5
62
+ bottom_left = (bottom_left_np[0] + margin, bottom_left_np[1] - margin)
63
+
64
+ class_name = self.get_class_name()
65
+ color = class_color_by_name(class_name)
66
+ font = cv2.FONT_HERSHEY_SIMPLEX
67
+ cv2.polylines(image, [points], isClosed=True, color=(0, 255, 0), thickness=2)
68
+ cv2.putText(
69
+ img=image, text=class_name, org=bottom_left, fontFace=font, fontScale=0.75, color=color, thickness=2
70
+ )
71
+ return image
72
+
73
+ def anonymize_by_blurring(self, image: np.ndarray, inplace: bool = False, max_resolution: int = 20) -> np.ndarray:
74
+ if not inplace:
75
+ image = image.copy()
76
+ points = np.array(self.to_pixel_coordinates(image_shape=image.shape[:2]))
77
+ mask = np.zeros(image.shape[:2], dtype=np.uint8)
78
+ mask = cv2.fillPoly(mask, [points], color=255).astype(bool)
79
+ bitmask = Bitmask.from_mask(mask=mask, top=0, left=0).squeeze_mask()
80
+ image = bitmask.anonymize_by_blurring(image=image, inplace=inplace, max_resolution=max_resolution)
81
+
82
+ return image
83
+
84
+ def mask(
85
+ self, image: np.ndarray, inplace: bool = False, color: Optional[Tuple[np.uint8, np.uint8, np.uint8]] = None
86
+ ) -> np.ndarray:
87
+ if not inplace:
88
+ image = image.copy()
89
+ points = self.to_pixel_coordinates(image_shape=image.shape[:2])
90
+
91
+ if color is None:
92
+ mask = np.zeros_like(image[:, :, 0])
93
+ bitmask = cv2.fillPoly(mask, pts=[np.array(points)], color=255).astype(bool) # type: ignore[assignment]
94
+ color = tuple(int(value) for value in np.mean(image[bitmask], axis=0)) # type: ignore[assignment]
95
+
96
+ cv2.fillPoly(image, [np.array(points)], color=color)
97
+ return image
98
+
99
+ def get_class_name(self) -> str:
100
+ return get_class_name(self.class_name, self.class_idx)
@@ -0,0 +1,44 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABCMeta, abstractmethod
4
+
5
+ import numpy as np
6
+ from pydantic import BaseModel
7
+
8
+
9
+ class Primitive(BaseModel, metaclass=ABCMeta):
10
+ def model_post_init(self, context) -> None:
11
+ if self.task_name == "": # type: ignore[has-type] # Hack because 'task_name' doesn't exist in base-class yet.
12
+ self.task_name = self.default_task_name()
13
+
14
+ @staticmethod
15
+ @abstractmethod
16
+ def default_task_name() -> str:
17
+ # E.g. "return bboxes" for Bbox
18
+ raise NotImplementedError
19
+
20
+ @staticmethod
21
+ @abstractmethod
22
+ def column_name() -> str:
23
+ """
24
+ Name of field used in hugging face datasets for storing annotations
25
+ E.g. "objects" for Bbox.
26
+ """
27
+ pass
28
+
29
+ @abstractmethod
30
+ def calculate_area(self) -> float:
31
+ # Calculate the area of the primitive
32
+ pass
33
+
34
+ @abstractmethod
35
+ def draw(self, image: np.ndarray, inplace: bool = False) -> np.ndarray:
36
+ pass
37
+
38
+ @abstractmethod
39
+ def mask(self, image: np.ndarray, inplace: bool = False) -> np.ndarray:
40
+ pass
41
+
42
+ @abstractmethod
43
+ def anonymize_by_blurring(self, image: np.ndarray, inplace: bool = False, max_resolution: int = 20) -> np.ndarray:
44
+ pass
@@ -0,0 +1,51 @@
1
+ from typing import Any, Dict, List, Optional, Tuple
2
+
3
+ import cv2
4
+ import numpy as np
5
+
6
+ from hafnia.dataset.primitives.primitive import Primitive
7
+ from hafnia.dataset.primitives.utils import get_class_name
8
+ from hafnia.visualizations.colors import get_n_colors
9
+
10
+
11
+ class Segmentation(Primitive):
12
+ # mask: np.ndarray
13
+ class_names: Optional[List[str]] = None # This should match the string in 'FieldName.CLASS_NAME'
14
+ ground_truth: bool = True # Whether this is ground truth or a prediction
15
+
16
+ # confidence: Optional[float] = None # Confidence score (0-1.0) for the primitive, e.g. 0.95 for Classification
17
+ task_name: str = (
18
+ "" # Task name to support multiple Segmentation tasks in the same dataset. "" defaults to "segmentation"
19
+ )
20
+ meta: Optional[Dict[str, Any]] = None # This can be used to store additional information about the bitmask
21
+
22
+ @staticmethod
23
+ def default_task_name() -> str:
24
+ return "segmentation"
25
+
26
+ @staticmethod
27
+ def column_name() -> str:
28
+ return "segmentation"
29
+
30
+ def calculate_area(self) -> float:
31
+ raise NotImplementedError()
32
+
33
+ def draw(self, image: np.ndarray, inplace: bool = False) -> np.ndarray:
34
+ if not inplace:
35
+ image = image.copy()
36
+
37
+ color_mapping = np.asarray(get_n_colors(len(self.class_names)), dtype=np.uint8) # type: ignore[arg-type]
38
+ label_image = color_mapping[self.mask]
39
+ blended = cv2.addWeighted(image, 0.5, label_image, 0.5, 0)
40
+ return blended
41
+
42
+ def mask(
43
+ self, image: np.ndarray, inplace: bool = False, color: Optional[Tuple[np.uint8, np.uint8, np.uint8]] = None
44
+ ) -> np.ndarray:
45
+ return image
46
+
47
+ def anonymize_by_blurring(self, image: np.ndarray, inplace: bool = False, max_resolution: int = 20) -> np.ndarray:
48
+ return image
49
+
50
+ def get_class_name(self) -> str:
51
+ return get_class_name(self.class_name, self.class_idx)
@@ -0,0 +1,51 @@
1
+ import hashlib
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import cv2
5
+ import numpy as np
6
+
7
+
8
+ def text_org_from_left_bottom_to_centered(xy_org: tuple, text: str, font, font_scale: float, thickness: int) -> tuple:
9
+ xy_text_size = cv2.getTextSize(text, fontFace=font, fontScale=font_scale, thickness=thickness)[0]
10
+ xy_text_size_half = np.array(xy_text_size) / 2
11
+ xy_centered_np = xy_org + xy_text_size_half * np.array([-1, 1])
12
+ xy_centered = tuple(int(value) for value in xy_centered_np)
13
+ return xy_centered
14
+
15
+
16
+ def round_int_clip_value(value: Union[int, float], max_value: int) -> int:
17
+ return clip(value=int(round(value)), v_min=0, v_max=max_value) # noqa: RUF046
18
+
19
+
20
+ def class_color_by_name(name: str) -> Tuple[int, int, int]:
21
+ # Create a hash of the class name
22
+ hash_object = hashlib.md5(name.encode())
23
+ # Use the hash to generate a color
24
+ hash_digest = hash_object.hexdigest()
25
+ color = (int(hash_digest[0:2], 16), int(hash_digest[2:4], 16), int(hash_digest[4:6], 16))
26
+ return color
27
+
28
+
29
+ # Define an abstract base class
30
+ def clip(value, v_min, v_max):
31
+ return min(max(v_min, value), v_max)
32
+
33
+
34
+ def get_class_name(class_name: Optional[str], class_idx: Optional[int]) -> str:
35
+ if class_name is not None:
36
+ return class_name
37
+ if class_idx is not None:
38
+ return f"IDX:{class_idx}"
39
+ return "NoName"
40
+
41
+
42
+ def anonymize_by_resizing(blur_region: np.ndarray, max_resolution: int = 20) -> np.ndarray:
43
+ """
44
+ Removes high-frequency details from a region of an image by resizing it down and then back up.
45
+ """
46
+ original_shape = blur_region.shape[:2]
47
+ resize_factor = max(original_shape) / max_resolution
48
+ new_size = (int(original_shape[0] / resize_factor), int(original_shape[1] / resize_factor))
49
+ blur_region_downsized = cv2.resize(blur_region, new_size[::-1], interpolation=cv2.INTER_LINEAR)
50
+ blur_region_upsized = cv2.resize(blur_region_downsized, original_shape[::-1], interpolation=cv2.INTER_LINEAR)
51
+ return blur_region_upsized
@@ -9,12 +9,12 @@ from typing import Dict, Optional, Union
9
9
 
10
10
  import pyarrow as pa
11
11
  import pyarrow.parquet as pq
12
- from datasets import DatasetDict
13
12
  from pydantic import BaseModel, field_validator
14
13
 
15
14
  from hafnia.data.factory import load_dataset
15
+ from hafnia.dataset.hafnia_dataset import HafniaDataset
16
16
  from hafnia.log import sys_logger, user_logger
17
- from hafnia.utils import is_remote_job, now_as_str
17
+ from hafnia.utils import is_hafnia_cloud_job, now_as_str
18
18
 
19
19
 
20
20
  class EntityType(Enum):
@@ -92,7 +92,7 @@ class HafniaLogger:
92
92
  self.schema = Entity.create_schema()
93
93
  self.log_environment()
94
94
 
95
- def load_dataset(self, dataset_name: str) -> DatasetDict:
95
+ def load_dataset(self, dataset_name: str) -> HafniaDataset:
96
96
  """
97
97
  Load a dataset from the specified path.
98
98
  """
@@ -101,7 +101,7 @@ class HafniaLogger:
101
101
 
102
102
  def path_local_experiment(self) -> Path:
103
103
  """Get the path for local experiment."""
104
- if is_remote_job():
104
+ if is_hafnia_cloud_job():
105
105
  raise RuntimeError("Cannot access local experiment path in remote job.")
106
106
  return self._local_experiment_path
107
107
 
@@ -110,7 +110,7 @@ class HafniaLogger:
110
110
  if "MDI_CHECKPOINT_DIR" in os.environ:
111
111
  return Path(os.environ["MDI_CHECKPOINT_DIR"])
112
112
 
113
- if is_remote_job():
113
+ if is_hafnia_cloud_job():
114
114
  return Path("/opt/ml/checkpoints")
115
115
  return self.path_local_experiment() / "checkpoints"
116
116
 
@@ -119,7 +119,7 @@ class HafniaLogger:
119
119
  if "MDI_ARTIFACT_DIR" in os.environ:
120
120
  return Path(os.environ["MDI_ARTIFACT_DIR"])
121
121
 
122
- if is_remote_job():
122
+ if is_hafnia_cloud_job():
123
123
  return Path("/opt/ml/output/data")
124
124
 
125
125
  return self.path_local_experiment() / "data"
@@ -129,7 +129,7 @@ class HafniaLogger:
129
129
  if "MDI_MODEL_DIR" in os.environ:
130
130
  return Path(os.environ["MDI_MODEL_DIR"])
131
131
 
132
- if is_remote_job():
132
+ if is_hafnia_cloud_job():
133
133
  return Path("/opt/ml/model")
134
134
 
135
135
  return self.path_local_experiment() / "model"
@@ -0,0 +1,108 @@
1
+ from inspect import getmembers, isfunction, signature
2
+ from pathlib import Path
3
+ from types import FunctionType
4
+ from typing import Any, Callable, Dict, Union, get_origin
5
+
6
+ from hafnia import utils
7
+ from hafnia.dataset.dataset_names import FILENAME_ANNOTATIONS_JSONL, DatasetVariant
8
+ from hafnia.dataset.hafnia_dataset import HafniaDataset, Sample
9
+
10
+ MICRO_DATASETS = {
11
+ "tiny-dataset": utils.PATH_DATASETS / "tiny-dataset",
12
+ "coco-2017": utils.PATH_DATASETS / "coco-2017",
13
+ }
14
+
15
+
16
+ def get_path_workspace() -> Path:
17
+ return Path(__file__).parents[2]
18
+
19
+
20
+ def get_path_expected_images() -> Path:
21
+ return get_path_workspace() / "tests" / "data" / "expected_images"
22
+
23
+
24
+ def get_path_test_data() -> Path:
25
+ return get_path_workspace() / "tests" / "data"
26
+
27
+
28
+ def get_path_micro_hafnia_dataset_no_check() -> Path:
29
+ return get_path_test_data() / "micro_test_datasets"
30
+
31
+
32
+ def get_path_micro_hafnia_dataset(dataset_name: str, force_update=False) -> Path:
33
+ import pytest
34
+
35
+ if dataset_name not in MICRO_DATASETS:
36
+ raise ValueError(f"Dataset name '{dataset_name}' is not recognized. Available options: {list(MICRO_DATASETS)}")
37
+ path_dataset = MICRO_DATASETS[dataset_name]
38
+
39
+ path_test_dataset = get_path_micro_hafnia_dataset_no_check() / dataset_name
40
+ path_test_dataset_annotations = path_test_dataset / FILENAME_ANNOTATIONS_JSONL
41
+ if path_test_dataset_annotations.exists() and not force_update:
42
+ return path_test_dataset
43
+
44
+ hafnia_dataset = HafniaDataset.from_path(path_dataset / DatasetVariant.SAMPLE.value)
45
+ hafnia_dataset = hafnia_dataset.select_samples(n_samples=3, seed=42)
46
+ hafnia_dataset.write(path_test_dataset)
47
+
48
+ if force_update:
49
+ pytest.fail(
50
+ "Sample image and metadata have been updated using 'force_update=True'. Set 'force_update=False' and rerun the test."
51
+ )
52
+ pytest.fail("Missing test sample image. Please rerun the test.")
53
+ return path_test_dataset
54
+
55
+
56
+ def get_sample_micro_hafnia_dataset(dataset_name: str, force_update=False) -> Sample:
57
+ micro_dataset = get_micro_hafnia_dataset(dataset_name=dataset_name, force_update=force_update)
58
+ sample_dict = micro_dataset[0]
59
+ sample = Sample(**sample_dict)
60
+ return sample
61
+
62
+
63
+ def get_micro_hafnia_dataset(dataset_name: str, force_update: bool = False) -> HafniaDataset:
64
+ path_dataset = get_path_micro_hafnia_dataset(dataset_name=dataset_name, force_update=force_update)
65
+ hafnia_dataset = HafniaDataset.from_path(path_dataset)
66
+ return hafnia_dataset
67
+
68
+
69
+ def is_hafnia_configured() -> bool:
70
+ """
71
+ Check if Hafnia is configured by verifying if the API key is set.
72
+ """
73
+ from cli.config import Config
74
+
75
+ return Config().is_configured()
76
+
77
+
78
+ def is_typing_type(annotation: Any) -> bool:
79
+ return get_origin(annotation) is not None
80
+
81
+
82
+ def annotation_as_string(annotation: Union[type, str]) -> str:
83
+ """Convert type annotation to string."""
84
+ if isinstance(annotation, str):
85
+ return annotation.replace("'", "")
86
+ if is_typing_type(annotation): # Is using typing types like List, Dict, etc.
87
+ return str(annotation).replace("typing.", "")
88
+ if hasattr(annotation, "__name__"):
89
+ return annotation.__name__
90
+ return str(annotation)
91
+
92
+
93
+ def get_hafnia_functions_from_module(python_module) -> Dict[str, FunctionType]:
94
+ def dataset_is_first_arg(func: Callable) -> bool:
95
+ """
96
+ Check if the function has 'HafniaDataset' as the first parameter.
97
+ """
98
+ func_signature = signature(func)
99
+ params = func_signature.parameters
100
+ if len(params) == 0:
101
+ return False
102
+ first_argument_type = list(params.values())[0]
103
+
104
+ annotation_as_str = annotation_as_string(first_argument_type.annotation)
105
+ return annotation_as_str == "HafniaDataset"
106
+
107
+ functions = {func[0]: func[1] for func in getmembers(python_module, isfunction) if dataset_is_first_arg(func[1])}
108
+ return functions
hafnia/http.py CHANGED
@@ -31,7 +31,7 @@ def fetch(endpoint: str, headers: Dict, params: Optional[Dict] = None) -> Dict:
31
31
  http.clear()
32
32
 
33
33
 
34
- def post(endpoint: str, headers: Dict, data: Union[Path, Dict, bytes], multipart: bool = False) -> Dict:
34
+ def post(endpoint: str, headers: Dict, data: Union[Path, Dict, bytes, str], multipart: bool = False) -> Dict:
35
35
  """Posts data to backend endpoint.
36
36
 
37
37
  Args:
@@ -64,9 +64,11 @@ def post(endpoint: str, headers: Dict, data: Union[Path, Dict, bytes], multipart
64
64
  with open(data, "rb") as f:
65
65
  body = f.read()
66
66
  response = http.request("POST", endpoint, body=body, headers=headers)
67
- elif isinstance(data, dict):
67
+ elif isinstance(data, (str, dict)):
68
+ if isinstance(data, dict):
69
+ data = json.dumps(data)
68
70
  headers["Content-Type"] = "application/json"
69
- response = http.request("POST", endpoint, body=json.dumps(data), headers=headers)
71
+ response = http.request("POST", endpoint, body=data, headers=headers)
70
72
  elif isinstance(data, bytes):
71
73
  response = http.request("POST", endpoint, body=data, headers=headers)
72
74
  else:
@@ -1,7 +1,7 @@
1
1
  from hafnia.platform.download import (
2
2
  download_resource,
3
3
  download_single_object,
4
- get_resource_creds,
4
+ get_resource_credentials,
5
5
  )
6
6
  from hafnia.platform.experiment import (
7
7
  create_experiment,
@@ -17,5 +17,5 @@ __all__ = [
17
17
  "create_experiment",
18
18
  "download_resource",
19
19
  "download_single_object",
20
- "get_resource_creds",
20
+ "get_resource_credentials",
21
21
  ]
@@ -0,0 +1,197 @@
1
+ import os
2
+ import shutil
3
+ import subprocess
4
+ import tempfile
5
+ from pathlib import Path
6
+ from typing import Any, Dict, List, Optional
7
+
8
+ import rich
9
+ from tqdm import tqdm
10
+
11
+ from cli.config import Config
12
+ from hafnia import utils
13
+ from hafnia.dataset.dataset_names import DATASET_FILENAMES_REQUIRED, ColumnName
14
+ from hafnia.dataset.dataset_recipe.dataset_recipe import (
15
+ DatasetRecipe,
16
+ get_dataset_path_from_recipe,
17
+ )
18
+ from hafnia.dataset.hafnia_dataset import HafniaDataset
19
+ from hafnia.http import fetch
20
+ from hafnia.log import sys_logger, user_logger
21
+ from hafnia.platform import get_dataset_id
22
+ from hafnia.platform.download import get_resource_credentials
23
+ from hafnia.utils import timed
24
+
25
+
26
+ @timed("Fetching dataset list.")
27
+ def dataset_list(cfg: Optional[Config] = None) -> List[Dict[str, str]]:
28
+ """List available datasets on the Hafnia platform."""
29
+ cfg = cfg or Config()
30
+ endpoint_dataset = cfg.get_platform_endpoint("datasets")
31
+ header = {"Authorization": cfg.api_key}
32
+ datasets: List[Dict[str, str]] = fetch(endpoint_dataset, headers=header) # type: ignore
33
+ if not datasets:
34
+ raise ValueError("No datasets found on the Hafnia platform.")
35
+
36
+ return datasets
37
+
38
+
39
+ def download_or_get_dataset_path(
40
+ dataset_name: str,
41
+ cfg: Optional[Config] = None,
42
+ path_datasets_folder: Optional[str] = None,
43
+ force_redownload: bool = False,
44
+ download_files: bool = True,
45
+ ) -> Path:
46
+ """Download or get the path of the dataset."""
47
+ recipe_explicit = DatasetRecipe.from_implicit_form(dataset_name)
48
+ path_dataset = get_dataset_path_from_recipe(recipe_explicit, path_datasets=path_datasets_folder)
49
+
50
+ is_dataset_valid = HafniaDataset.check_dataset_path(path_dataset, raise_error=False)
51
+ if is_dataset_valid and not force_redownload:
52
+ user_logger.info("Dataset found locally. Set 'force=True' or add `--force` flag with cli to re-download")
53
+ return path_dataset
54
+
55
+ cfg = cfg or Config()
56
+ api_key = cfg.api_key
57
+
58
+ shutil.rmtree(path_dataset, ignore_errors=True)
59
+
60
+ endpoint_dataset = cfg.get_platform_endpoint("datasets")
61
+ dataset_id = get_dataset_id(dataset_name=dataset_name, endpoint=endpoint_dataset, api_key=api_key)
62
+ if dataset_id is None:
63
+ sys_logger.error(f"Dataset '{dataset_name}' not found on the Hafnia platform.")
64
+ access_dataset_endpoint = f"{endpoint_dataset}/{dataset_id}/temporary-credentials"
65
+
66
+ download_dataset_from_access_endpoint(
67
+ endpoint=access_dataset_endpoint,
68
+ api_key=api_key,
69
+ path_dataset=path_dataset,
70
+ download_files=download_files,
71
+ )
72
+ return path_dataset
73
+
74
+
75
+ def download_dataset_from_access_endpoint(
76
+ endpoint: str,
77
+ api_key: str,
78
+ path_dataset: Path,
79
+ download_files: bool = True,
80
+ ) -> None:
81
+ resource_credentials = get_resource_credentials(endpoint, api_key)
82
+
83
+ local_dataset_paths = [str(path_dataset / filename) for filename in DATASET_FILENAMES_REQUIRED]
84
+ s3_uri = resource_credentials.s3_uri()
85
+ s3_dataset_files = [f"{s3_uri}/{filename}" for filename in DATASET_FILENAMES_REQUIRED]
86
+
87
+ envs = resource_credentials.aws_credentials()
88
+ fast_copy_files_s3(
89
+ src_paths=s3_dataset_files,
90
+ dst_paths=local_dataset_paths,
91
+ append_envs=envs,
92
+ description="Downloading annotations",
93
+ )
94
+
95
+ if not download_files:
96
+ return
97
+
98
+ dataset = HafniaDataset.from_path(path_dataset, check_for_images=False)
99
+ fast_copy_files_s3(
100
+ src_paths=dataset.samples[ColumnName.REMOTE_PATH].to_list(),
101
+ dst_paths=dataset.samples[ColumnName.FILE_NAME].to_list(),
102
+ append_envs=envs,
103
+ description="Downloading images",
104
+ )
105
+
106
+
107
+ def fast_copy_files_s3(
108
+ src_paths: List[str],
109
+ dst_paths: List[str],
110
+ append_envs: Optional[Dict[str, str]] = None,
111
+ description: str = "Copying files",
112
+ ) -> List[str]:
113
+ if len(src_paths) != len(dst_paths):
114
+ raise ValueError("Source and destination paths must have the same length.")
115
+
116
+ cmds = [f"cp {src} {dst}" for src, dst in zip(src_paths, dst_paths)]
117
+ lines = execute_s5cmd_commands(cmds, append_envs=append_envs, description=description)
118
+ return lines
119
+
120
+
121
+ def execute_s5cmd_commands(
122
+ commands: List[str],
123
+ append_envs: Optional[Dict[str, str]] = None,
124
+ description: str = "Executing s5cmd commands",
125
+ ) -> List[str]:
126
+ append_envs = append_envs or {}
127
+ with tempfile.NamedTemporaryFile(suffix=".txt") as tmp_file:
128
+ tmp_file_path = Path(tmp_file.name)
129
+ tmp_file_path.write_text("\n".join(commands))
130
+ run_cmds = [
131
+ "s5cmd",
132
+ "run",
133
+ str(tmp_file_path),
134
+ ]
135
+ envs = os.environ.copy()
136
+ envs.update(append_envs)
137
+
138
+ process = subprocess.Popen(
139
+ run_cmds,
140
+ stdout=subprocess.PIPE,
141
+ stderr=subprocess.STDOUT,
142
+ universal_newlines=True,
143
+ env=envs,
144
+ )
145
+
146
+ error_lines = []
147
+ lines = []
148
+ for line in tqdm(process.stdout, total=len(commands), desc=description):
149
+ if "ERROR" in line or "error" in line:
150
+ error_lines.append(line.strip())
151
+ lines.append(line.strip())
152
+
153
+ if len(error_lines) > 0:
154
+ show_n_lines = min(5, len(error_lines))
155
+ str_error_lines = "\n".join(error_lines[:show_n_lines])
156
+ user_logger.error(
157
+ f"Detected {len(error_lines)} errors occurred while executing a total of {len(commands)} "
158
+ f" commands with s5cmd. The first {show_n_lines} is printed below:\n{str_error_lines}"
159
+ )
160
+ raise RuntimeError("Errors occurred during s5cmd execution.")
161
+ return lines
162
+
163
+
164
+ TABLE_FIELDS = {
165
+ "ID": "id",
166
+ "Hidden\nSamples": "hidden.samples",
167
+ "Hidden\nSize": "hidden.size",
168
+ "Sample\nSamples": "sample.samples",
169
+ "Sample\nSize": "sample.size",
170
+ "Name": "name",
171
+ "Title": "title",
172
+ }
173
+
174
+
175
+ def create_rich_table_from_dataset(datasets: List[Dict[str, str]]) -> rich.table.Table:
176
+ datasets = extend_dataset_details(datasets)
177
+ datasets = sorted(datasets, key=lambda x: x["name"].lower())
178
+
179
+ table = rich.table.Table(title="Available Datasets")
180
+ for i_dataset, dataset in enumerate(datasets):
181
+ if i_dataset == 0:
182
+ for column_name, _ in TABLE_FIELDS.items():
183
+ table.add_column(column_name, justify="left", style="cyan", no_wrap=True)
184
+ row = [str(dataset.get(field, "")) for field in TABLE_FIELDS.values()]
185
+ table.add_row(*row)
186
+
187
+ return table
188
+
189
+
190
+ def extend_dataset_details(datasets: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
191
+ """Extends dataset details with number of samples and size"""
192
+ for dataset in datasets:
193
+ for variant in dataset["dataset_variants"]:
194
+ variant_type = variant["variant_type"]
195
+ dataset[f"{variant_type}.samples"] = variant["number_of_data_items"]
196
+ dataset[f"{variant_type}.size"] = utils.size_human_readable(variant["size_bytes"])
197
+ return datasets