hafnia 0.2.4__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. cli/__main__.py +16 -3
  2. cli/config.py +45 -4
  3. cli/consts.py +1 -1
  4. cli/dataset_cmds.py +6 -14
  5. cli/dataset_recipe_cmds.py +78 -0
  6. cli/experiment_cmds.py +226 -43
  7. cli/keychain.py +88 -0
  8. cli/profile_cmds.py +10 -6
  9. cli/runc_cmds.py +5 -5
  10. cli/trainer_package_cmds.py +65 -0
  11. hafnia/__init__.py +2 -0
  12. hafnia/data/factory.py +1 -2
  13. hafnia/dataset/dataset_helpers.py +9 -14
  14. hafnia/dataset/dataset_names.py +10 -5
  15. hafnia/dataset/dataset_recipe/dataset_recipe.py +165 -67
  16. hafnia/dataset/dataset_recipe/recipe_transforms.py +48 -4
  17. hafnia/dataset/dataset_recipe/recipe_types.py +1 -1
  18. hafnia/dataset/dataset_upload_helper.py +265 -56
  19. hafnia/dataset/format_conversions/image_classification_from_directory.py +106 -0
  20. hafnia/dataset/format_conversions/torchvision_datasets.py +281 -0
  21. hafnia/dataset/hafnia_dataset.py +577 -213
  22. hafnia/dataset/license_types.py +63 -0
  23. hafnia/dataset/operations/dataset_stats.py +259 -3
  24. hafnia/dataset/operations/dataset_transformations.py +332 -7
  25. hafnia/dataset/operations/table_transformations.py +43 -5
  26. hafnia/dataset/primitives/__init__.py +8 -0
  27. hafnia/dataset/primitives/bbox.py +25 -12
  28. hafnia/dataset/primitives/bitmask.py +26 -14
  29. hafnia/dataset/primitives/classification.py +16 -8
  30. hafnia/dataset/primitives/point.py +7 -3
  31. hafnia/dataset/primitives/polygon.py +16 -9
  32. hafnia/dataset/primitives/segmentation.py +10 -7
  33. hafnia/experiment/hafnia_logger.py +111 -8
  34. hafnia/http.py +16 -2
  35. hafnia/platform/__init__.py +9 -3
  36. hafnia/platform/builder.py +12 -10
  37. hafnia/platform/dataset_recipe.py +104 -0
  38. hafnia/platform/datasets.py +47 -9
  39. hafnia/platform/download.py +25 -19
  40. hafnia/platform/experiment.py +51 -56
  41. hafnia/platform/trainer_package.py +57 -0
  42. hafnia/utils.py +81 -13
  43. hafnia/visualizations/image_visualizations.py +4 -4
  44. {hafnia-0.2.4.dist-info → hafnia-0.4.0.dist-info}/METADATA +40 -34
  45. hafnia-0.4.0.dist-info/RECORD +56 -0
  46. cli/recipe_cmds.py +0 -45
  47. hafnia-0.2.4.dist-info/RECORD +0 -49
  48. {hafnia-0.2.4.dist-info → hafnia-0.4.0.dist-info}/WHEEL +0 -0
  49. {hafnia-0.2.4.dist-info → hafnia-0.4.0.dist-info}/entry_points.txt +0 -0
  50. {hafnia-0.2.4.dist-info → hafnia-0.4.0.dist-info}/licenses/LICENSE +0 -0
@@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple
2
2
 
3
3
  import cv2
4
4
  import numpy as np
5
+ from pydantic import Field
5
6
 
6
7
  from hafnia.dataset.primitives.bitmask import Bitmask
7
8
  from hafnia.dataset.primitives.point import Point
@@ -11,15 +12,21 @@ from hafnia.dataset.primitives.utils import class_color_by_name, get_class_name
11
12
 
12
13
  class Polygon(Primitive):
13
14
  # Names should match names in FieldName
14
- points: List[Point]
15
- class_name: Optional[str] = None # This should match the string in 'FieldName.CLASS_NAME'
16
- class_idx: Optional[int] = None # This should match the string in 'FieldName.CLASS_IDX'
17
- object_id: Optional[str] = None # This should match the string in 'FieldName.OBJECT_ID'
18
- confidence: Optional[float] = None # Confidence score (0-1.0) for the primitive, e.g. 0.95 for Bbox
19
- ground_truth: bool = True # Whether this is ground truth or a prediction
20
-
21
- task_name: str = "" # Task name to support multiple Polygon tasks in the same dataset. "" defaults to "polygon"
22
- meta: Optional[Dict[str, Any]] = None # This can be used to store additional information about the bitmask
15
+ points: List[Point] = Field(description="List of points defining the polygon")
16
+ class_name: Optional[str] = Field(default=None, description="Class name of the polygon")
17
+ class_idx: Optional[int] = Field(default=None, description="Class index of the polygon")
18
+ object_id: Optional[str] = Field(default=None, description="Object ID of the polygon")
19
+ confidence: Optional[float] = Field(
20
+ default=None, description="Confidence score (0-1.0) for the primitive, e.g. 0.95 for Bbox"
21
+ )
22
+ ground_truth: bool = Field(default=True, description="Whether this is ground truth or a prediction")
23
+
24
+ task_name: str = Field(
25
+ default="", description="Task name to support multiple Polygon tasks in the same dataset. Defaults to 'polygon'"
26
+ )
27
+ meta: Optional[Dict[str, Any]] = Field(
28
+ default=None, description="This can be used to store additional information about the polygon"
29
+ )
23
30
 
24
31
  @staticmethod
25
32
  def from_list_of_points(
@@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional, Tuple
2
2
 
3
3
  import cv2
4
4
  import numpy as np
5
+ from pydantic import Field
5
6
 
6
7
  from hafnia.dataset.primitives.primitive import Primitive
7
8
  from hafnia.dataset.primitives.utils import get_class_name
@@ -9,15 +10,17 @@ from hafnia.visualizations.colors import get_n_colors
9
10
 
10
11
 
11
12
  class Segmentation(Primitive):
12
- # mask: np.ndarray
13
- class_names: Optional[List[str]] = None # This should match the string in 'FieldName.CLASS_NAME'
14
- ground_truth: bool = True # Whether this is ground truth or a prediction
13
+ # WARNING: Segmentation masks have not been fully implemented yet
14
+ class_names: Optional[List[str]] = Field(default=None, description="Class names of the segmentation")
15
+ ground_truth: bool = Field(default=True, description="Whether this is ground truth or a prediction")
15
16
 
16
- # confidence: Optional[float] = None # Confidence score (0-1.0) for the primitive, e.g. 0.95 for Classification
17
- task_name: str = (
18
- "" # Task name to support multiple Segmentation tasks in the same dataset. "" defaults to "segmentation"
17
+ task_name: str = Field(
18
+ default="",
19
+ description="Task name to support multiple Segmentation tasks in the same dataset. Defaults to 'segmentation'",
20
+ )
21
+ meta: Optional[Dict[str, Any]] = Field(
22
+ default=None, description="This can be used to store additional information about the segmentation"
19
23
  )
20
- meta: Optional[Dict[str, Any]] = None # This can be used to store additional information about the bitmask
21
24
 
22
25
  @staticmethod
23
26
  def default_task_name() -> str:
@@ -2,6 +2,7 @@ import json
2
2
  import os
3
3
  import platform
4
4
  import sys
5
+ import textwrap
5
6
  from datetime import datetime
6
7
  from enum import Enum
7
8
  from pathlib import Path
@@ -11,11 +12,19 @@ import pyarrow as pa
11
12
  import pyarrow.parquet as pq
12
13
  from pydantic import BaseModel, field_validator
13
14
 
14
- from hafnia.data.factory import load_dataset
15
- from hafnia.dataset.hafnia_dataset import HafniaDataset
16
15
  from hafnia.log import sys_logger, user_logger
17
16
  from hafnia.utils import is_hafnia_cloud_job, now_as_str
18
17
 
18
+ try:
19
+ import mlflow
20
+ import mlflow.tracking
21
+ import sagemaker_mlflow # noqa: F401
22
+
23
+ MLFLOW_AVAILABLE = True
24
+ except ImportError:
25
+ user_logger.warning("MLFlow is not available")
26
+ MLFLOW_AVAILABLE = False
27
+
19
28
 
20
29
  class EntityType(Enum):
21
30
  """Types of entities that can be logged."""
@@ -87,17 +96,43 @@ class HafniaLogger:
87
96
  for path in create_paths:
88
97
  path.mkdir(parents=True, exist_ok=True)
89
98
 
99
+ path_file = self.path_model() / "HOW_TO_STORE_YOUR_MODEL.txt"
100
+ path_file.write_text(get_instructions_how_to_store_model())
101
+
90
102
  self.dataset_name: Optional[str] = None
91
103
  self.log_file = self._path_artifacts() / self.EXPERIMENT_FILE
92
104
  self.schema = Entity.create_schema()
105
+
106
+ # Initialize MLflow for remote jobs
107
+ self._mlflow_initialized = False
108
+ if is_hafnia_cloud_job() and MLFLOW_AVAILABLE:
109
+ self._init_mlflow()
110
+
93
111
  self.log_environment()
94
112
 
95
- def load_dataset(self, dataset_name: str) -> HafniaDataset:
96
- """
97
- Load a dataset from the specified path.
98
- """
99
- self.dataset_name = dataset_name
100
- return load_dataset(dataset_name)
113
+ def _init_mlflow(self):
114
+ """Initialize MLflow tracking for remote jobs."""
115
+ try:
116
+ # Set MLflow tracking URI from environment variable
117
+ tracking_uri = os.getenv("MLFLOW_TRACKING_URI")
118
+ if tracking_uri:
119
+ mlflow.set_tracking_uri(tracking_uri)
120
+ user_logger.info(f"MLflow tracking URI set to: {tracking_uri}")
121
+
122
+ # Set experiment name from environment variable
123
+ experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME")
124
+ if experiment_name:
125
+ mlflow.set_experiment(experiment_name)
126
+ user_logger.info(f"MLflow experiment set to: {experiment_name}")
127
+
128
+ # Start MLflow run
129
+ run_name = os.getenv("MLFLOW_RUN_NAME", "undefined")
130
+ mlflow.start_run(run_name=run_name)
131
+ self._mlflow_initialized = True
132
+ user_logger.info("MLflow run started successfully")
133
+
134
+ except Exception as e:
135
+ user_logger.error(f"Failed to initialize MLflow: {e}")
101
136
 
102
137
  def path_local_experiment(self) -> Path:
103
138
  """Get the path for local experiment."""
@@ -153,6 +188,14 @@ class HafniaLogger:
153
188
  )
154
189
  self.write_entity(entity)
155
190
 
191
+ # Also log to MLflow if initialized
192
+ if not self._mlflow_initialized:
193
+ return
194
+ try:
195
+ mlflow.log_metric(name, value, step=step)
196
+ except Exception as e:
197
+ user_logger.error(f"Failed to log metric to MLflow: {e}")
198
+
156
199
  def log_configuration(self, configurations: Dict):
157
200
  self.log_hparams(configurations, "configuration.json")
158
201
 
@@ -166,6 +209,15 @@ class HafniaLogger:
166
209
  existing_params.update(params)
167
210
  file_path.write_text(json.dumps(existing_params, indent=2))
168
211
  user_logger.info(f"Saved parameters to {file_path}")
212
+
213
+ # Also log to MLflow if initialized
214
+ if not self._mlflow_initialized:
215
+ return
216
+ try:
217
+ mlflow.log_params(params)
218
+ except Exception as e:
219
+ user_logger.error(f"Failed to log params to MLflow: {e}")
220
+
169
221
  except Exception as e:
170
222
  user_logger.error(f"Failed to save parameters to {file_path}: {e}")
171
223
 
@@ -202,3 +254,54 @@ class HafniaLogger:
202
254
  pq.write_table(next_table, self.log_file)
203
255
  except Exception as e:
204
256
  sys_logger.error(f"Failed to flush logs: {e}")
257
+
258
+ def end_run(self) -> None:
259
+ """End the MLflow run if initialized."""
260
+ if not self._mlflow_initialized:
261
+ return
262
+ try:
263
+ mlflow.end_run()
264
+ self._mlflow_initialized = False
265
+ user_logger.info("MLflow run ended successfully")
266
+ except Exception as e:
267
+ user_logger.error(f"Failed to end MLflow run: {e}")
268
+
269
+ def __del__(self):
270
+ """Cleanup when logger is destroyed."""
271
+ self.end_run()
272
+
273
+
274
+ def get_instructions_how_to_store_model() -> str:
275
+ instructions = textwrap.dedent(
276
+ """\
277
+ If you, against your expectations, don't see any models in this folder,
278
+ we have provided a small guide to help.
279
+
280
+ The hafnia TaaS framework expects models to be stored in a folder generated
281
+ by the hafnia logger. You will need to store models in this folder
282
+ to ensure that they are properly stored and accessible after training.
283
+
284
+ Please check your recipe script and ensure that the models are being stored
285
+ as expected by the TaaS framework.
286
+
287
+ Below is also a small example to demonstrate:
288
+
289
+ ```python
290
+ from hafnia.experiment import HafniaLogger
291
+
292
+ # Initiate Hafnia logger
293
+ logger = HafniaLogger()
294
+
295
+ # Folder path to store models - generated by the hafnia logger.
296
+ model_dir = logger.path_model()
297
+
298
+ # Example for storing a pytorch based model. Note: the model is stored in 'model_dir'
299
+ path_pytorch_model = model_dir / "model.pth"
300
+
301
+ # Finally save the model to the specified path
302
+ torch.save(model.state_dict(), path_pytorch_model)
303
+ ```
304
+ """
305
+ )
306
+
307
+ return instructions
hafnia/http.py CHANGED
@@ -1,11 +1,11 @@
1
1
  import json
2
2
  from pathlib import Path
3
- from typing import Dict, Optional, Union
3
+ from typing import Dict, List, Optional, Union
4
4
 
5
5
  import urllib3
6
6
 
7
7
 
8
- def fetch(endpoint: str, headers: Dict, params: Optional[Dict] = None) -> Dict:
8
+ def fetch(endpoint: str, headers: Dict, params: Optional[Dict] = None) -> Union[Dict, List]:
9
9
  """Fetches data from the API endpoint.
10
10
 
11
11
  Args:
@@ -81,3 +81,17 @@ def post(endpoint: str, headers: Dict, data: Union[Path, Dict, bytes, str], mult
81
81
  return json.loads(response.data.decode("utf-8"))
82
82
  finally:
83
83
  http.clear()
84
+
85
+
86
+ def delete(endpoint: str, headers: Dict) -> Dict:
87
+ """Sends a DELETE request to the specified endpoint."""
88
+ http = urllib3.PoolManager(retries=urllib3.Retry(3))
89
+ try:
90
+ response = http.request("DELETE", endpoint, headers=headers)
91
+
92
+ if response.status not in (200, 204):
93
+ error_details = response.data.decode("utf-8")
94
+ raise urllib3.exceptions.HTTPError(f"Request failed with status {response.status}: {error_details}")
95
+ return json.loads(response.data.decode("utf-8")) if response.data else {}
96
+ finally:
97
+ http.clear()
@@ -1,3 +1,4 @@
1
+ from hafnia.platform.datasets import get_dataset_id
1
2
  from hafnia.platform.download import (
2
3
  download_resource,
3
4
  download_single_object,
@@ -5,17 +6,22 @@ from hafnia.platform.download import (
5
6
  )
6
7
  from hafnia.platform.experiment import (
7
8
  create_experiment,
8
- create_recipe,
9
- get_dataset_id,
9
+ get_environments,
10
10
  get_exp_environment_id,
11
+ pretty_print_training_environments,
11
12
  )
13
+ from hafnia.platform.trainer_package import create_trainer_package, get_trainer_package_by_id, get_trainer_packages
12
14
 
13
15
  __all__ = [
14
16
  "get_dataset_id",
15
- "create_recipe",
17
+ "create_trainer_package",
18
+ "get_trainer_packages",
19
+ "get_trainer_package_by_id",
16
20
  "get_exp_environment_id",
17
21
  "create_experiment",
18
22
  "download_resource",
19
23
  "download_single_object",
20
24
  "get_resource_credentials",
25
+ "pretty_print_training_environments",
26
+ "get_environments",
21
27
  ]
@@ -14,26 +14,28 @@ from hafnia.log import sys_logger, user_logger
14
14
  from hafnia.platform import download_resource
15
15
 
16
16
 
17
- def validate_recipe_format(path: Path) -> None:
18
- """Validate Hafnia Recipe Format submition"""
17
+ def validate_trainer_package_format(path: Path) -> None:
18
+ """Validate Hafnia Trainer Package Format submission"""
19
19
  hrf = zipfile.Path(path) if path.suffix == ".zip" else path
20
20
  required = {"src", "scripts", "Dockerfile"}
21
21
  errors = 0
22
22
  for rp in required:
23
23
  if not (hrf / rp).exists():
24
- user_logger.error(f"Required path {rp} not found in recipe.")
24
+ user_logger.error(f"Required path {rp} not found in trainer package.")
25
25
  errors += 1
26
26
  if errors > 0:
27
- raise FileNotFoundError("Wrong recipe structure")
27
+ raise FileNotFoundError("Wrong trainer package structure")
28
28
 
29
29
 
30
- def prepare_recipe(recipe_url: str, output_dir: Path, api_key: str, state_file: Optional[Path] = None) -> Dict:
31
- resource = download_resource(recipe_url, output_dir.as_posix(), api_key)
32
- recipe_path = Path(resource["downloaded_files"][0])
33
- with zipfile.ZipFile(recipe_path, "r") as zip_ref:
30
+ def prepare_trainer_package(
31
+ trainer_url: str, output_dir: Path, api_key: str, state_file: Optional[Path] = None
32
+ ) -> Dict:
33
+ resource = download_resource(trainer_url, output_dir.as_posix(), api_key)
34
+ trainer_path = Path(resource["downloaded_files"][0])
35
+ with zipfile.ZipFile(trainer_path, "r") as zip_ref:
34
36
  zip_ref.extractall(output_dir)
35
37
 
36
- validate_recipe_format(output_dir)
38
+ validate_trainer_package_format(output_dir)
37
39
 
38
40
  scripts_dir = output_dir / "scripts"
39
41
  if not any(scripts_dir.iterdir()):
@@ -42,7 +44,7 @@ def prepare_recipe(recipe_url: str, output_dir: Path, api_key: str, state_file:
42
44
  metadata = {
43
45
  "user_data": (output_dir / "src").as_posix(),
44
46
  "dockerfile": (output_dir / "Dockerfile").as_posix(),
45
- "digest": sha256(recipe_path.read_bytes()).hexdigest()[:8],
47
+ "digest": sha256(trainer_path.read_bytes()).hexdigest()[:8],
46
48
  }
47
49
  state_file = state_file if state_file else output_dir / "state.json"
48
50
  with open(state_file, "w", encoding="utf-8") as f:
@@ -0,0 +1,104 @@
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Dict, List, Optional
4
+
5
+ from flatten_dict import flatten
6
+
7
+ from hafnia import http
8
+ from hafnia.log import user_logger
9
+ from hafnia.utils import pretty_print_list_as_table, timed
10
+
11
+
12
+ @timed("Get or create dataset recipe")
13
+ def get_or_create_dataset_recipe(
14
+ recipe: dict,
15
+ endpoint: str,
16
+ api_key: str,
17
+ name: Optional[str] = None,
18
+ overwrite: bool = False,
19
+ ) -> Optional[Dict]:
20
+ headers = {"Authorization": api_key}
21
+ data = {"template": {"body": recipe}, "overwrite": overwrite}
22
+ if name is not None:
23
+ data["name"] = name # type: ignore[assignment]
24
+
25
+ response = http.post(endpoint, headers=headers, data=data)
26
+ return response
27
+
28
+
29
+ def get_or_create_dataset_recipe_by_dataset_name(dataset_name: str, endpoint: str, api_key: str) -> Dict:
30
+ return get_or_create_dataset_recipe(recipe=dataset_name, endpoint=endpoint, api_key=api_key)
31
+
32
+
33
+ def get_dataset_recipes(endpoint: str, api_key: str) -> List[Dict]:
34
+ headers = {"Authorization": api_key}
35
+ dataset_recipes: List[Dict] = http.fetch(endpoint, headers=headers) # type: ignore[assignment]
36
+ return dataset_recipes
37
+
38
+
39
+ def get_dataset_recipe_by_id(dataset_recipe_id: str, endpoint: str, api_key: str) -> Dict:
40
+ headers = {"Authorization": api_key}
41
+ full_url = f"{endpoint}/{dataset_recipe_id}"
42
+ dataset_recipe_info: Dict = http.fetch(full_url, headers=headers) # type: ignore[assignment]
43
+ if not dataset_recipe_info:
44
+ raise ValueError(f"Dataset recipe with ID '{dataset_recipe_id}' was not found.")
45
+ return dataset_recipe_info
46
+
47
+
48
+ def get_or_create_dataset_recipe_from_path(
49
+ path_recipe_json: Path, endpoint: str, api_key: str, name: Optional[str] = None
50
+ ) -> Dict:
51
+ path_recipe_json = Path(path_recipe_json)
52
+ if not path_recipe_json.exists():
53
+ raise FileNotFoundError(f"Dataset recipe file '{path_recipe_json}' does not exist.")
54
+ json_dict = json.loads(path_recipe_json.read_text())
55
+ return get_or_create_dataset_recipe(json_dict, endpoint=endpoint, api_key=api_key, name=name)
56
+
57
+
58
+ def delete_dataset_recipe_by_id(id: str, endpoint: str, api_key: str) -> Dict:
59
+ headers = {"Authorization": api_key}
60
+ full_url = f"{endpoint}/{id}"
61
+ response = http.delete(endpoint=full_url, headers=headers)
62
+ return response
63
+
64
+
65
+ @timed("Get dataset recipe")
66
+ def get_dataset_recipe_by_name(name: str, endpoint: str, api_key: str) -> Optional[Dict]:
67
+ headers = {"Authorization": api_key}
68
+ full_url = f"{endpoint}?name__iexact={name}"
69
+ dataset_recipes: List[Dict] = http.fetch(full_url, headers=headers) # type: ignore[assignment]
70
+ if len(dataset_recipes) == 0:
71
+ return None
72
+
73
+ if len(dataset_recipes) > 1:
74
+ user_logger.warning(f"Found {len(dataset_recipes)} dataset recipes called '{name}'. Using the first one.")
75
+
76
+ dataset_recipe = dataset_recipes[0]
77
+ return dataset_recipe
78
+
79
+
80
+ def delete_dataset_recipe_by_name(name: str, endpoint: str, api_key: str) -> Optional[Dict]:
81
+ recipe_response = get_dataset_recipe_by_name(name, endpoint=endpoint, api_key=api_key)
82
+
83
+ if recipe_response:
84
+ return delete_dataset_recipe_by_id(recipe_response["id"], endpoint=endpoint, api_key=api_key)
85
+ return recipe_response
86
+
87
+
88
+ def pretty_print_dataset_recipes(recipes: List[Dict]) -> None:
89
+ recipes = [flatten(recipe, reducer="dot", max_flatten_depth=2) for recipe in recipes] # noqa: F821
90
+ for recipe in recipes:
91
+ recipe["recipe_json"] = json.dumps(recipe["template.body"])[:20]
92
+
93
+ RECIPE_FIELDS = {
94
+ "ID": "id",
95
+ "Name": "name",
96
+ "Recipe": "recipe_json",
97
+ "Created": "created_at",
98
+ "IsDataset": "template.is_direct_dataset_reference",
99
+ }
100
+ pretty_print_list_as_table(
101
+ table_title="Available Dataset Recipes",
102
+ dict_items=recipes,
103
+ column_name_to_key_mapping=RECIPE_FIELDS,
104
+ )
@@ -8,10 +8,11 @@ from pathlib import Path
8
8
  from typing import Any, Dict, List, Optional
9
9
 
10
10
  import rich
11
- from tqdm import tqdm
11
+ from rich import print as rprint
12
+ from rich.progress import track
12
13
 
13
14
  from cli.config import Config
14
- from hafnia import utils
15
+ from hafnia import http, utils
15
16
  from hafnia.dataset.dataset_names import DATASET_FILENAMES_REQUIRED, ColumnName
16
17
  from hafnia.dataset.dataset_recipe.dataset_recipe import (
17
18
  DatasetRecipe,
@@ -20,13 +21,12 @@ from hafnia.dataset.dataset_recipe.dataset_recipe import (
20
21
  from hafnia.dataset.hafnia_dataset import HafniaDataset
21
22
  from hafnia.http import fetch
22
23
  from hafnia.log import sys_logger, user_logger
23
- from hafnia.platform import get_dataset_id
24
24
  from hafnia.platform.download import get_resource_credentials
25
25
  from hafnia.utils import timed
26
26
 
27
27
 
28
28
  @timed("Fetching dataset list.")
29
- def dataset_list(cfg: Optional[Config] = None) -> List[Dict[str, str]]:
29
+ def get_datasets(cfg: Optional[Config] = None) -> List[Dict[str, str]]:
30
30
  """List available datasets on the Hafnia platform."""
31
31
  cfg = cfg or Config()
32
32
  endpoint_dataset = cfg.get_platform_endpoint("datasets")
@@ -38,6 +38,19 @@ def dataset_list(cfg: Optional[Config] = None) -> List[Dict[str, str]]:
38
38
  return datasets
39
39
 
40
40
 
41
+ @timed("Fetching dataset info.")
42
+ def get_dataset_id(dataset_name: str, endpoint: str, api_key: str) -> str:
43
+ headers = {"Authorization": api_key}
44
+ full_url = f"{endpoint}?name__iexact={dataset_name}"
45
+ dataset_responses: List[Dict] = http.fetch(full_url, headers=headers) # type: ignore[assignment]
46
+ if not dataset_responses:
47
+ raise ValueError(f"Dataset '{dataset_name}' was not found in the dataset library.")
48
+ try:
49
+ return dataset_responses[0]["id"]
50
+ except (IndexError, KeyError) as e:
51
+ raise ValueError("Dataset information is missing or invalid") from e
52
+
53
+
41
54
  def download_or_get_dataset_path(
42
55
  dataset_name: str,
43
56
  cfg: Optional[Config] = None,
@@ -109,7 +122,7 @@ def download_dataset_from_access_endpoint(
109
122
  try:
110
123
  fast_copy_files_s3(
111
124
  src_paths=dataset.samples[ColumnName.REMOTE_PATH].to_list(),
112
- dst_paths=dataset.samples[ColumnName.FILE_NAME].to_list(),
125
+ dst_paths=dataset.samples[ColumnName.FILE_PATH].to_list(),
113
126
  append_envs=envs,
114
127
  description="Downloading images",
115
128
  )
@@ -131,6 +144,28 @@ def fast_copy_files_s3(
131
144
  return lines
132
145
 
133
146
 
147
+ def find_s5cmd() -> Optional[str]:
148
+ """Locate the s5cmd executable across different installation methods.
149
+
150
+ Searches for s5cmd in:
151
+ 1. System PATH (via shutil.which)
152
+ 2. Python bin directory (Unix-like systems)
153
+ 3. Python executable directory (direct installs)
154
+
155
+ Returns:
156
+ str: Absolute path to s5cmd executable if found, None otherwise.
157
+ """
158
+ result = shutil.which("s5cmd")
159
+ if result:
160
+ return result
161
+ python_dir = Path(sys.executable).parent
162
+ locations = (python_dir / "Scripts" / "s5cmd.exe", python_dir / "bin" / "s5cmd", python_dir / "s5cmd")
163
+ for loc in locations:
164
+ if loc.exists():
165
+ return str(loc)
166
+ return None
167
+
168
+
134
169
  def execute_s5cmd_commands(
135
170
  commands: List[str],
136
171
  append_envs: Optional[Dict[str, str]] = None,
@@ -142,7 +177,10 @@ def execute_s5cmd_commands(
142
177
  with tempfile.TemporaryDirectory() as temp_dir:
143
178
  tmp_file_path = Path(temp_dir, f"{uuid.uuid4().hex}.txt")
144
179
  tmp_file_path.write_text("\n".join(commands))
145
- s5cmd_bin = (Path(sys.executable).parent / "s5cmd").absolute().as_posix()
180
+
181
+ s5cmd_bin = find_s5cmd()
182
+ if s5cmd_bin is None:
183
+ raise ValueError("Can not find s5cmd executable.")
146
184
  run_cmds = [s5cmd_bin, "run", str(tmp_file_path)]
147
185
  sys_logger.debug(run_cmds)
148
186
  envs = os.environ.copy()
@@ -158,7 +196,7 @@ def execute_s5cmd_commands(
158
196
 
159
197
  error_lines = []
160
198
  lines = []
161
- for line in tqdm(process.stdout, total=len(commands), desc=description):
199
+ for line in track(process.stdout, total=len(commands), description=description):
162
200
  if "ERROR" in line or "error" in line:
163
201
  error_lines.append(line.strip())
164
202
  lines.append(line.strip())
@@ -185,7 +223,7 @@ TABLE_FIELDS = {
185
223
  }
186
224
 
187
225
 
188
- def create_rich_table_from_dataset(datasets: List[Dict[str, str]]) -> rich.table.Table:
226
+ def pretty_print_datasets(datasets: List[Dict[str, str]]) -> None:
189
227
  datasets = extend_dataset_details(datasets)
190
228
  datasets = sorted(datasets, key=lambda x: x["name"].lower())
191
229
 
@@ -197,7 +235,7 @@ def create_rich_table_from_dataset(datasets: List[Dict[str, str]]) -> rich.table
197
235
  row = [str(dataset.get(field, "")) for field in TABLE_FIELDS.values()]
198
236
  table.add_row(*row)
199
237
 
200
- return table
238
+ rprint(table)
201
239
 
202
240
 
203
241
  def extend_dataset_details(datasets: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
@@ -1,10 +1,10 @@
1
1
  from pathlib import Path
2
- from typing import Dict
2
+ from typing import Dict, Optional
3
3
 
4
4
  import boto3
5
5
  from botocore.exceptions import ClientError
6
6
  from pydantic import BaseModel, field_validator
7
- from tqdm import tqdm
7
+ from rich.progress import Progress
8
8
 
9
9
  from hafnia.http import fetch
10
10
  from hafnia.log import sys_logger, user_logger
@@ -96,7 +96,8 @@ def get_resource_credentials(endpoint: str, api_key: str) -> ResourceCredentials
96
96
  RuntimeError: If the call to fetch the credentials fails for any reason.
97
97
  """
98
98
  try:
99
- credentials_dict = fetch(endpoint, headers={"Authorization": api_key, "accept": "application/json"})
99
+ headers = {"Authorization": api_key, "accept": "application/json"}
100
+ credentials_dict: Dict = fetch(endpoint, headers=headers) # type: ignore[assignment]
100
101
  credentials = ResourceCredentials.fix_naming(credentials_dict)
101
102
  sys_logger.debug("Successfully retrieved credentials from DIP endpoint.")
102
103
  return credentials
@@ -124,13 +125,15 @@ def download_single_object(s3_client, bucket: str, object_key: str, output_dir:
124
125
  return local_path
125
126
 
126
127
 
127
- def download_resource(resource_url: str, destination: str, api_key: str) -> Dict:
128
+ def download_resource(resource_url: str, destination: str, api_key: str, prefix: Optional[str] = None) -> Dict:
128
129
  """
129
130
  Downloads either a single file from S3 or all objects under a prefix.
130
131
 
131
132
  Args:
132
133
  resource_url (str): The URL or identifier used to fetch S3 credentials.
133
134
  destination (str): Path to local directory where files will be stored.
135
+ api_key (str): API key for authentication when fetching credentials.
136
+ prefix (Optional[str]): If provided, only download objects under this prefix.
134
137
 
135
138
  Returns:
136
139
  Dict[str, Any]: A dictionary containing download info, e.g.:
@@ -146,7 +149,7 @@ def download_resource(resource_url: str, destination: str, api_key: str) -> Dict
146
149
  res_credentials = get_resource_credentials(resource_url, api_key)
147
150
 
148
151
  bucket_name = res_credentials.bucket_name()
149
- key = res_credentials.object_key()
152
+ prefix = prefix or res_credentials.object_key()
150
153
 
151
154
  output_path = Path(destination)
152
155
  output_path.mkdir(parents=True, exist_ok=True)
@@ -158,29 +161,32 @@ def download_resource(resource_url: str, destination: str, api_key: str) -> Dict
158
161
  )
159
162
  downloaded_files = []
160
163
  try:
161
- s3_client.head_object(Bucket=bucket_name, Key=key)
162
- local_file = download_single_object(s3_client, bucket_name, key, output_path)
164
+ s3_client.head_object(Bucket=bucket_name, Key=prefix)
165
+ local_file = download_single_object(s3_client, bucket_name, prefix, output_path)
163
166
  downloaded_files.append(str(local_file))
164
167
  user_logger.info(f"Downloaded single file: {local_file}")
165
168
 
166
169
  except ClientError as e:
167
170
  error_code = e.response.get("Error", {}).get("Code")
168
171
  if error_code == "404":
169
- sys_logger.debug(f"Object '{key}' not found; trying as a prefix.")
170
- response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=key)
172
+ sys_logger.debug(f"Object '{prefix}' not found; trying as a prefix.")
173
+ response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=prefix)
171
174
  contents = response.get("Contents", [])
172
175
 
173
176
  if not contents:
174
- raise ValueError(f"No objects found for prefix '{key}' in bucket '{bucket_name}'")
175
- pbar = tqdm(contents)
176
- for obj in pbar:
177
- sub_key = obj["Key"]
178
- size_mb = obj.get("Size", 0) / 1024 / 1024
179
- pbar.set_description(f"{sub_key} ({size_mb:.2f} MB)")
180
- local_file = download_single_object(s3_client, bucket_name, sub_key, output_path)
181
- downloaded_files.append(local_file.as_posix())
182
-
183
- user_logger.info(f"Downloaded folder/prefix '{key}' with {len(downloaded_files)} object(s).")
177
+ raise ValueError(f"No objects found for prefix '{prefix}' in bucket '{bucket_name}'")
178
+
179
+ with Progress() as progress:
180
+ task = progress.add_task("Downloading files", total=len(contents))
181
+ for obj in contents:
182
+ sub_key = obj["Key"]
183
+ size_mb = obj.get("Size", 0) / 1024 / 1024
184
+ progress.update(task, description=f"Downloading {sub_key} ({size_mb:.2f} MB)")
185
+ local_file = download_single_object(s3_client, bucket_name, sub_key, output_path)
186
+ downloaded_files.append(local_file.as_posix())
187
+ progress.advance(task)
188
+
189
+ user_logger.info(f"Downloaded folder/prefix '{prefix}' with {len(downloaded_files)} object(s).")
184
190
  else:
185
191
  user_logger.error(f"Error checking object or prefix: {e}")
186
192
  raise RuntimeError(f"Failed to check or download S3 resource: {e}") from e