hafnia 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. cli/__main__.py +13 -2
  2. cli/config.py +2 -1
  3. cli/consts.py +1 -1
  4. cli/dataset_cmds.py +6 -14
  5. cli/dataset_recipe_cmds.py +78 -0
  6. cli/experiment_cmds.py +226 -43
  7. cli/profile_cmds.py +6 -5
  8. cli/runc_cmds.py +5 -5
  9. cli/trainer_package_cmds.py +65 -0
  10. hafnia/__init__.py +2 -0
  11. hafnia/data/factory.py +1 -2
  12. hafnia/dataset/dataset_helpers.py +0 -12
  13. hafnia/dataset/dataset_names.py +8 -4
  14. hafnia/dataset/dataset_recipe/dataset_recipe.py +119 -33
  15. hafnia/dataset/dataset_recipe/recipe_transforms.py +32 -4
  16. hafnia/dataset/dataset_recipe/recipe_types.py +1 -1
  17. hafnia/dataset/dataset_upload_helper.py +206 -53
  18. hafnia/dataset/hafnia_dataset.py +432 -194
  19. hafnia/dataset/license_types.py +63 -0
  20. hafnia/dataset/operations/dataset_stats.py +260 -3
  21. hafnia/dataset/operations/dataset_transformations.py +325 -4
  22. hafnia/dataset/operations/table_transformations.py +39 -2
  23. hafnia/dataset/primitives/__init__.py +8 -0
  24. hafnia/dataset/primitives/classification.py +1 -1
  25. hafnia/experiment/hafnia_logger.py +112 -0
  26. hafnia/http.py +16 -2
  27. hafnia/platform/__init__.py +9 -3
  28. hafnia/platform/builder.py +12 -10
  29. hafnia/platform/dataset_recipe.py +99 -0
  30. hafnia/platform/datasets.py +44 -6
  31. hafnia/platform/download.py +2 -1
  32. hafnia/platform/experiment.py +51 -56
  33. hafnia/platform/trainer_package.py +57 -0
  34. hafnia/utils.py +64 -13
  35. hafnia/visualizations/image_visualizations.py +3 -3
  36. {hafnia-0.2.4.dist-info → hafnia-0.3.0.dist-info}/METADATA +34 -30
  37. hafnia-0.3.0.dist-info/RECORD +53 -0
  38. cli/recipe_cmds.py +0 -45
  39. hafnia-0.2.4.dist-info/RECORD +0 -49
  40. {hafnia-0.2.4.dist-info → hafnia-0.3.0.dist-info}/WHEEL +0 -0
  41. {hafnia-0.2.4.dist-info → hafnia-0.3.0.dist-info}/entry_points.txt +0 -0
  42. {hafnia-0.2.4.dist-info → hafnia-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -7,6 +7,7 @@ from tqdm import tqdm
7
7
  from hafnia.dataset.dataset_names import (
8
8
  FILENAME_ANNOTATIONS_JSONL,
9
9
  FILENAME_ANNOTATIONS_PARQUET,
10
+ ColumnName,
10
11
  FieldName,
11
12
  )
12
13
  from hafnia.dataset.operations import table_transformations
@@ -34,8 +35,12 @@ def create_primitive_table(
34
35
 
35
36
  if keep_sample_data:
36
37
  # Drop other primitive columns to avoid conflicts
37
- drop_columns = set(PRIMITIVE_TYPES) - {PrimitiveType, Classification}
38
- remove_no_object_frames = remove_no_object_frames.drop(*[primitive.column_name() for primitive in drop_columns])
38
+
39
+ drop_columns_primitives = set(PRIMITIVE_TYPES) - {PrimitiveType, Classification}
40
+ drop_columns_names = [primitive.column_name() for primitive in drop_columns_primitives]
41
+ drop_columns_names = [c for c in drop_columns_names if c in remove_no_object_frames.columns]
42
+
43
+ remove_no_object_frames = remove_no_object_frames.drop(drop_columns_names)
39
44
  # Rename columns "height", "width" and "meta" for sample to avoid conflicts with object fields names
40
45
  remove_no_object_frames = remove_no_object_frames.rename(
41
46
  {"height": "image.height", "width": "image.width", "meta": "image.meta"}
@@ -46,6 +51,38 @@ def create_primitive_table(
46
51
  return objects_df
47
52
 
48
53
 
54
+ def merge_samples(samples0: pl.DataFrame, samples1: pl.DataFrame) -> pl.DataFrame:
55
+ has_same_schema = samples0.schema == samples1.schema
56
+ if not has_same_schema:
57
+ shared_columns = []
58
+ for column_name, column_type in samples0.schema.items():
59
+ if column_name not in samples1.schema:
60
+ continue
61
+
62
+ if column_type != samples1.schema[column_name]:
63
+ continue
64
+ shared_columns.append(column_name)
65
+
66
+ dropped_columns0 = [
67
+ f"{n}[{ctype._string_repr()}]" for n, ctype in samples0.schema.items() if n not in shared_columns
68
+ ]
69
+ dropped_columns1 = [
70
+ f"{n}[{ctype._string_repr()}]" for n, ctype in samples1.schema.items() if n not in shared_columns
71
+ ]
72
+ user_logger.warning(
73
+ "Datasets with different schemas are being merged. "
74
+ "Only the columns with the same name and type will be kept in the merged dataset.\n"
75
+ f"Dropped columns in samples0: {dropped_columns0}\n"
76
+ f"Dropped columns in samples1: {dropped_columns1}\n"
77
+ )
78
+
79
+ samples0 = samples0.select(list(shared_columns))
80
+ samples1 = samples1.select(list(shared_columns))
81
+ merged_samples = pl.concat([samples0, samples1], how="vertical")
82
+ merged_samples = merged_samples.drop(ColumnName.SAMPLE_INDEX).with_row_index(name=ColumnName.SAMPLE_INDEX)
83
+ return merged_samples
84
+
85
+
49
86
  def filter_table_for_class_names(
50
87
  samples_table: pl.DataFrame, class_names: List[str], PrimitiveType: Type[Primitive]
51
88
  ) -> Optional[pl.DataFrame]:
@@ -14,3 +14,11 @@ from .utils import class_color_by_name # noqa: F401
14
14
  PRIMITIVE_TYPES: List[Type[Primitive]] = [Bbox, Classification, Polygon, Bitmask]
15
15
  PRIMITIVE_NAME_TO_TYPE = {cls.__name__: cls for cls in PRIMITIVE_TYPES}
16
16
  PRIMITIVE_COLUMN_NAMES: List[str] = [PrimitiveType.column_name() for PrimitiveType in PRIMITIVE_TYPES]
17
+
18
+
19
+ def get_primitive_type_from_string(name: str) -> Type[Primitive]:
20
+ if name not in PRIMITIVE_NAME_TO_TYPE:
21
+ raise ValueError(
22
+ f"Primitive '{name}' is not recognized. Available primitives: {list(PRIMITIVE_NAME_TO_TYPE.keys())}"
23
+ )
24
+ return PRIMITIVE_NAME_TO_TYPE[name]
@@ -38,7 +38,7 @@ class Classification(Primitive):
38
38
  text = class_name
39
39
  else:
40
40
  text = f"{self.task_name}: {class_name}"
41
- image = image_visualizations.append_text_below_frame(image, text=text)
41
+ image = image_visualizations.append_text_below_frame(image, text=text, text_size_ratio=0.05)
42
42
 
43
43
  return image
44
44
 
@@ -2,6 +2,7 @@ import json
2
2
  import os
3
3
  import platform
4
4
  import sys
5
+ import textwrap
5
6
  from datetime import datetime
6
7
  from enum import Enum
7
8
  from pathlib import Path
@@ -16,6 +17,16 @@ from hafnia.dataset.hafnia_dataset import HafniaDataset
16
17
  from hafnia.log import sys_logger, user_logger
17
18
  from hafnia.utils import is_hafnia_cloud_job, now_as_str
18
19
 
20
+ try:
21
+ import mlflow
22
+ import mlflow.tracking
23
+ import sagemaker_mlflow # noqa: F401
24
+
25
+ MLFLOW_AVAILABLE = True
26
+ except ImportError:
27
+ user_logger.warning("MLFlow is not available")
28
+ MLFLOW_AVAILABLE = False
29
+
19
30
 
20
31
  class EntityType(Enum):
21
32
  """Types of entities that can be logged."""
@@ -87,11 +98,44 @@ class HafniaLogger:
87
98
  for path in create_paths:
88
99
  path.mkdir(parents=True, exist_ok=True)
89
100
 
101
+ path_file = self.path_model() / "HOW_TO_STORE_YOUR_MODEL.txt"
102
+ path_file.write_text(get_instructions_how_to_store_model())
103
+
90
104
  self.dataset_name: Optional[str] = None
91
105
  self.log_file = self._path_artifacts() / self.EXPERIMENT_FILE
92
106
  self.schema = Entity.create_schema()
107
+
108
+ # Initialize MLflow for remote jobs
109
+ self._mlflow_initialized = False
110
+ if is_hafnia_cloud_job() and MLFLOW_AVAILABLE:
111
+ self._init_mlflow()
112
+
93
113
  self.log_environment()
94
114
 
115
+ def _init_mlflow(self):
116
+ """Initialize MLflow tracking for remote jobs."""
117
+ try:
118
+ # Set MLflow tracking URI from environment variable
119
+ tracking_uri = os.getenv("MLFLOW_TRACKING_URI")
120
+ if tracking_uri:
121
+ mlflow.set_tracking_uri(tracking_uri)
122
+ user_logger.info(f"MLflow tracking URI set to: {tracking_uri}")
123
+
124
+ # Set experiment name from environment variable
125
+ experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME")
126
+ if experiment_name:
127
+ mlflow.set_experiment(experiment_name)
128
+ user_logger.info(f"MLflow experiment set to: {experiment_name}")
129
+
130
+ # Start MLflow run
131
+ run_name = os.getenv("MLFLOW_RUN_NAME", "undefined")
132
+ mlflow.start_run(run_name=run_name)
133
+ self._mlflow_initialized = True
134
+ user_logger.info("MLflow run started successfully")
135
+
136
+ except Exception as e:
137
+ user_logger.error(f"Failed to initialize MLflow: {e}")
138
+
95
139
  def load_dataset(self, dataset_name: str) -> HafniaDataset:
96
140
  """
97
141
  Load a dataset from the specified path.
@@ -153,6 +197,14 @@ class HafniaLogger:
153
197
  )
154
198
  self.write_entity(entity)
155
199
 
200
+ # Also log to MLflow if initialized
201
+ if not self._mlflow_initialized:
202
+ return
203
+ try:
204
+ mlflow.log_metric(name, value, step=step)
205
+ except Exception as e:
206
+ user_logger.error(f"Failed to log metric to MLflow: {e}")
207
+
156
208
  def log_configuration(self, configurations: Dict):
157
209
  self.log_hparams(configurations, "configuration.json")
158
210
 
@@ -166,6 +218,15 @@ class HafniaLogger:
166
218
  existing_params.update(params)
167
219
  file_path.write_text(json.dumps(existing_params, indent=2))
168
220
  user_logger.info(f"Saved parameters to {file_path}")
221
+
222
+ # Also log to MLflow if initialized
223
+ if not self._mlflow_initialized:
224
+ return
225
+ try:
226
+ mlflow.log_params(params)
227
+ except Exception as e:
228
+ user_logger.error(f"Failed to log params to MLflow: {e}")
229
+
169
230
  except Exception as e:
170
231
  user_logger.error(f"Failed to save parameters to {file_path}: {e}")
171
232
 
@@ -202,3 +263,54 @@ class HafniaLogger:
202
263
  pq.write_table(next_table, self.log_file)
203
264
  except Exception as e:
204
265
  sys_logger.error(f"Failed to flush logs: {e}")
266
+
267
+ def end_run(self) -> None:
268
+ """End the MLflow run if initialized."""
269
+ if not self._mlflow_initialized:
270
+ return
271
+ try:
272
+ mlflow.end_run()
273
+ self._mlflow_initialized = False
274
+ user_logger.info("MLflow run ended successfully")
275
+ except Exception as e:
276
+ user_logger.error(f"Failed to end MLflow run: {e}")
277
+
278
+ def __del__(self):
279
+ """Cleanup when logger is destroyed."""
280
+ self.end_run()
281
+
282
+
283
+ def get_instructions_how_to_store_model() -> str:
284
+ instructions = textwrap.dedent(
285
+ """\
286
+ If you, against your expectations, don't see any models in this folder,
287
+ we have provided a small guide to help.
288
+
289
+ The hafnia TaaS framework expects models to be stored in a folder generated
290
+ by the hafnia logger. You will need to store models in this folder
291
+ to ensure that they are properly stored and accessible after training.
292
+
293
+ Please check your recipe script and ensure that the models are being stored
294
+ as expected by the TaaS framework.
295
+
296
+ Below is also a small example to demonstrate:
297
+
298
+ ```python
299
+ from hafnia.experiment import HafniaLogger
300
+
301
+ # Initiate Hafnia logger
302
+ logger = HafniaLogger()
303
+
304
+ # Folder path to store models - generated by the hafnia logger.
305
+ model_dir = logger.path_model()
306
+
307
+ # Example for storing a pytorch based model. Note: the model is stored in 'model_dir'
308
+ path_pytorch_model = model_dir / "model.pth"
309
+
310
+ # Finally save the model to the specified path
311
+ torch.save(model.state_dict(), path_pytorch_model)
312
+ ```
313
+ """
314
+ )
315
+
316
+ return instructions
hafnia/http.py CHANGED
@@ -1,11 +1,11 @@
1
1
  import json
2
2
  from pathlib import Path
3
- from typing import Dict, Optional, Union
3
+ from typing import Dict, List, Optional, Union
4
4
 
5
5
  import urllib3
6
6
 
7
7
 
8
- def fetch(endpoint: str, headers: Dict, params: Optional[Dict] = None) -> Dict:
8
+ def fetch(endpoint: str, headers: Dict, params: Optional[Dict] = None) -> Union[Dict, List]:
9
9
  """Fetches data from the API endpoint.
10
10
 
11
11
  Args:
@@ -81,3 +81,17 @@ def post(endpoint: str, headers: Dict, data: Union[Path, Dict, bytes, str], mult
81
81
  return json.loads(response.data.decode("utf-8"))
82
82
  finally:
83
83
  http.clear()
84
+
85
+
86
+ def delete(endpoint: str, headers: Dict) -> Dict:
87
+ """Sends a DELETE request to the specified endpoint."""
88
+ http = urllib3.PoolManager(retries=urllib3.Retry(3))
89
+ try:
90
+ response = http.request("DELETE", endpoint, headers=headers)
91
+
92
+ if response.status not in (200, 204):
93
+ error_details = response.data.decode("utf-8")
94
+ raise urllib3.exceptions.HTTPError(f"Request failed with status {response.status}: {error_details}")
95
+ return json.loads(response.data.decode("utf-8")) if response.data else {}
96
+ finally:
97
+ http.clear()
@@ -1,3 +1,4 @@
1
+ from hafnia.platform.datasets import get_dataset_id
1
2
  from hafnia.platform.download import (
2
3
  download_resource,
3
4
  download_single_object,
@@ -5,17 +6,22 @@ from hafnia.platform.download import (
5
6
  )
6
7
  from hafnia.platform.experiment import (
7
8
  create_experiment,
8
- create_recipe,
9
- get_dataset_id,
9
+ get_environments,
10
10
  get_exp_environment_id,
11
+ pretty_print_training_environments,
11
12
  )
13
+ from hafnia.platform.trainer_package import create_trainer_package, get_trainer_package_by_id, get_trainer_packages
12
14
 
13
15
  __all__ = [
14
16
  "get_dataset_id",
15
- "create_recipe",
17
+ "create_trainer_package",
18
+ "get_trainer_packages",
19
+ "get_trainer_package_by_id",
16
20
  "get_exp_environment_id",
17
21
  "create_experiment",
18
22
  "download_resource",
19
23
  "download_single_object",
20
24
  "get_resource_credentials",
25
+ "pretty_print_training_environments",
26
+ "get_environments",
21
27
  ]
@@ -14,26 +14,28 @@ from hafnia.log import sys_logger, user_logger
14
14
  from hafnia.platform import download_resource
15
15
 
16
16
 
17
- def validate_recipe_format(path: Path) -> None:
18
- """Validate Hafnia Recipe Format submition"""
17
+ def validate_trainer_package_format(path: Path) -> None:
18
+ """Validate Hafnia Trainer Package Format submission"""
19
19
  hrf = zipfile.Path(path) if path.suffix == ".zip" else path
20
20
  required = {"src", "scripts", "Dockerfile"}
21
21
  errors = 0
22
22
  for rp in required:
23
23
  if not (hrf / rp).exists():
24
- user_logger.error(f"Required path {rp} not found in recipe.")
24
+ user_logger.error(f"Required path {rp} not found in trainer package.")
25
25
  errors += 1
26
26
  if errors > 0:
27
- raise FileNotFoundError("Wrong recipe structure")
27
+ raise FileNotFoundError("Wrong trainer package structure")
28
28
 
29
29
 
30
- def prepare_recipe(recipe_url: str, output_dir: Path, api_key: str, state_file: Optional[Path] = None) -> Dict:
31
- resource = download_resource(recipe_url, output_dir.as_posix(), api_key)
32
- recipe_path = Path(resource["downloaded_files"][0])
33
- with zipfile.ZipFile(recipe_path, "r") as zip_ref:
30
+ def prepare_trainer_package(
31
+ trainer_url: str, output_dir: Path, api_key: str, state_file: Optional[Path] = None
32
+ ) -> Dict:
33
+ resource = download_resource(trainer_url, output_dir.as_posix(), api_key)
34
+ trainer_path = Path(resource["downloaded_files"][0])
35
+ with zipfile.ZipFile(trainer_path, "r") as zip_ref:
34
36
  zip_ref.extractall(output_dir)
35
37
 
36
- validate_recipe_format(output_dir)
38
+ validate_trainer_package_format(output_dir)
37
39
 
38
40
  scripts_dir = output_dir / "scripts"
39
41
  if not any(scripts_dir.iterdir()):
@@ -42,7 +44,7 @@ def prepare_recipe(recipe_url: str, output_dir: Path, api_key: str, state_file:
42
44
  metadata = {
43
45
  "user_data": (output_dir / "src").as_posix(),
44
46
  "dockerfile": (output_dir / "Dockerfile").as_posix(),
45
- "digest": sha256(recipe_path.read_bytes()).hexdigest()[:8],
47
+ "digest": sha256(trainer_path.read_bytes()).hexdigest()[:8],
46
48
  }
47
49
  state_file = state_file if state_file else output_dir / "state.json"
48
50
  with open(state_file, "w", encoding="utf-8") as f:
@@ -0,0 +1,99 @@
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Dict, List, Optional
4
+
5
+ from flatten_dict import flatten
6
+
7
+ from hafnia import http
8
+ from hafnia.log import user_logger
9
+ from hafnia.utils import pretty_print_list_as_table, timed
10
+
11
+
12
+ @timed("Get or create dataset recipe")
13
+ def get_or_create_dataset_recipe(
14
+ recipe: dict, endpoint: str, api_key: str, name: Optional[str] = None
15
+ ) -> Optional[Dict]:
16
+ headers = {"Authorization": api_key}
17
+ data = {"template": {"body": recipe}}
18
+ if name is not None:
19
+ data["name"] = name # type: ignore[assignment]
20
+ response = http.post(endpoint, headers=headers, data=data)
21
+ return response
22
+
23
+
24
+ def get_or_create_dataset_recipe_by_dataset_name(dataset_name: str, endpoint: str, api_key: str) -> Dict:
25
+ return get_or_create_dataset_recipe(recipe=dataset_name, endpoint=endpoint, api_key=api_key)
26
+
27
+
28
+ def get_dataset_recipes(endpoint: str, api_key: str) -> List[Dict]:
29
+ headers = {"Authorization": api_key}
30
+ dataset_recipes: List[Dict] = http.fetch(endpoint, headers=headers) # type: ignore[assignment]
31
+ return dataset_recipes
32
+
33
+
34
+ def get_dataset_recipe_by_id(dataset_recipe_id: str, endpoint: str, api_key: str) -> Dict:
35
+ headers = {"Authorization": api_key}
36
+ full_url = f"{endpoint}/{dataset_recipe_id}"
37
+ dataset_recipe_info: Dict = http.fetch(full_url, headers=headers) # type: ignore[assignment]
38
+ if not dataset_recipe_info:
39
+ raise ValueError(f"Dataset recipe with ID '{dataset_recipe_id}' was not found.")
40
+ return dataset_recipe_info
41
+
42
+
43
+ def get_or_create_dataset_recipe_from_path(
44
+ path_recipe_json: Path, endpoint: str, api_key: str, name: Optional[str] = None
45
+ ) -> Dict:
46
+ path_recipe_json = Path(path_recipe_json)
47
+ if not path_recipe_json.exists():
48
+ raise FileNotFoundError(f"Dataset recipe file '{path_recipe_json}' does not exist.")
49
+ json_dict = json.loads(path_recipe_json.read_text())
50
+ return get_or_create_dataset_recipe(json_dict, endpoint=endpoint, api_key=api_key, name=name)
51
+
52
+
53
+ def delete_dataset_recipe_by_id(id: str, endpoint: str, api_key: str) -> Dict:
54
+ headers = {"Authorization": api_key}
55
+ full_url = f"{endpoint}/{id}"
56
+ response = http.delete(endpoint=full_url, headers=headers)
57
+ return response
58
+
59
+
60
+ @timed("Get dataset recipe")
61
+ def get_dataset_recipe_by_name(name: str, endpoint: str, api_key: str) -> Optional[Dict]:
62
+ headers = {"Authorization": api_key}
63
+ full_url = f"{endpoint}?name__iexact={name}"
64
+ dataset_recipes: List[Dict] = http.fetch(full_url, headers=headers) # type: ignore[assignment]
65
+ if len(dataset_recipes) == 0:
66
+ return None
67
+
68
+ if len(dataset_recipes) > 1:
69
+ user_logger.warning(f"Found {len(dataset_recipes)} dataset recipes called '{name}'. Using the first one.")
70
+
71
+ dataset_recipe = dataset_recipes[0]
72
+ return dataset_recipe
73
+
74
+
75
+ def delete_dataset_recipe_by_name(name: str, endpoint: str, api_key: str) -> Optional[Dict]:
76
+ recipe_response = get_dataset_recipe_by_name(name, endpoint=endpoint, api_key=api_key)
77
+
78
+ if recipe_response:
79
+ return delete_dataset_recipe_by_id(recipe_response["id"], endpoint=endpoint, api_key=api_key)
80
+ return recipe_response
81
+
82
+
83
+ def pretty_print_dataset_recipes(recipes: List[Dict]) -> None:
84
+ recipes = [flatten(recipe, reducer="dot", max_flatten_depth=2) for recipe in recipes] # noqa: F821
85
+ for recipe in recipes:
86
+ recipe["recipe_json"] = json.dumps(recipe["template.body"])[:20]
87
+
88
+ RECIPE_FIELDS = {
89
+ "ID": "id",
90
+ "Name": "name",
91
+ "Recipe": "recipe_json",
92
+ "Created": "created_at",
93
+ "IsDataset": "template.is_direct_dataset_reference",
94
+ }
95
+ pretty_print_list_as_table(
96
+ table_title="Available Dataset Recipes",
97
+ dict_items=recipes,
98
+ column_name_to_key_mapping=RECIPE_FIELDS,
99
+ )
@@ -8,10 +8,11 @@ from pathlib import Path
8
8
  from typing import Any, Dict, List, Optional
9
9
 
10
10
  import rich
11
+ from rich import print as rprint
11
12
  from tqdm import tqdm
12
13
 
13
14
  from cli.config import Config
14
- from hafnia import utils
15
+ from hafnia import http, utils
15
16
  from hafnia.dataset.dataset_names import DATASET_FILENAMES_REQUIRED, ColumnName
16
17
  from hafnia.dataset.dataset_recipe.dataset_recipe import (
17
18
  DatasetRecipe,
@@ -20,13 +21,12 @@ from hafnia.dataset.dataset_recipe.dataset_recipe import (
20
21
  from hafnia.dataset.hafnia_dataset import HafniaDataset
21
22
  from hafnia.http import fetch
22
23
  from hafnia.log import sys_logger, user_logger
23
- from hafnia.platform import get_dataset_id
24
24
  from hafnia.platform.download import get_resource_credentials
25
25
  from hafnia.utils import timed
26
26
 
27
27
 
28
28
  @timed("Fetching dataset list.")
29
- def dataset_list(cfg: Optional[Config] = None) -> List[Dict[str, str]]:
29
+ def get_datasets(cfg: Optional[Config] = None) -> List[Dict[str, str]]:
30
30
  """List available datasets on the Hafnia platform."""
31
31
  cfg = cfg or Config()
32
32
  endpoint_dataset = cfg.get_platform_endpoint("datasets")
@@ -38,6 +38,19 @@ def dataset_list(cfg: Optional[Config] = None) -> List[Dict[str, str]]:
38
38
  return datasets
39
39
 
40
40
 
41
+ @timed("Fetching dataset info.")
42
+ def get_dataset_id(dataset_name: str, endpoint: str, api_key: str) -> str:
43
+ headers = {"Authorization": api_key}
44
+ full_url = f"{endpoint}?name__iexact={dataset_name}"
45
+ dataset_responses: List[Dict] = http.fetch(full_url, headers=headers) # type: ignore[assignment]
46
+ if not dataset_responses:
47
+ raise ValueError(f"Dataset '{dataset_name}' was not found in the dataset library.")
48
+ try:
49
+ return dataset_responses[0]["id"]
50
+ except (IndexError, KeyError) as e:
51
+ raise ValueError("Dataset information is missing or invalid") from e
52
+
53
+
41
54
  def download_or_get_dataset_path(
42
55
  dataset_name: str,
43
56
  cfg: Optional[Config] = None,
@@ -131,6 +144,28 @@ def fast_copy_files_s3(
131
144
  return lines
132
145
 
133
146
 
147
+ def find_s5cmd() -> Optional[str]:
148
+ """Locate the s5cmd executable across different installation methods.
149
+
150
+ Searches for s5cmd in:
151
+ 1. System PATH (via shutil.which)
152
+ 2. Python bin directory (Unix-like systems)
153
+ 3. Python executable directory (direct installs)
154
+
155
+ Returns:
156
+ str: Absolute path to s5cmd executable if found, None otherwise.
157
+ """
158
+ result = shutil.which("s5cmd")
159
+ if result:
160
+ return result
161
+ python_dir = Path(sys.executable).parent
162
+ locations = (python_dir / "Scripts" / "s5cmd.exe", python_dir / "bin" / "s5cmd", python_dir / "s5cmd")
163
+ for loc in locations:
164
+ if loc.exists():
165
+ return str(loc)
166
+ return None
167
+
168
+
134
169
  def execute_s5cmd_commands(
135
170
  commands: List[str],
136
171
  append_envs: Optional[Dict[str, str]] = None,
@@ -142,7 +177,10 @@ def execute_s5cmd_commands(
142
177
  with tempfile.TemporaryDirectory() as temp_dir:
143
178
  tmp_file_path = Path(temp_dir, f"{uuid.uuid4().hex}.txt")
144
179
  tmp_file_path.write_text("\n".join(commands))
145
- s5cmd_bin = (Path(sys.executable).parent / "s5cmd").absolute().as_posix()
180
+
181
+ s5cmd_bin = find_s5cmd()
182
+ if s5cmd_bin is None:
183
+ raise ValueError("Can not find s5cmd executable.")
146
184
  run_cmds = [s5cmd_bin, "run", str(tmp_file_path)]
147
185
  sys_logger.debug(run_cmds)
148
186
  envs = os.environ.copy()
@@ -185,7 +223,7 @@ TABLE_FIELDS = {
185
223
  }
186
224
 
187
225
 
188
- def create_rich_table_from_dataset(datasets: List[Dict[str, str]]) -> rich.table.Table:
226
+ def pretty_print_datasets(datasets: List[Dict[str, str]]) -> None:
189
227
  datasets = extend_dataset_details(datasets)
190
228
  datasets = sorted(datasets, key=lambda x: x["name"].lower())
191
229
 
@@ -197,7 +235,7 @@ def create_rich_table_from_dataset(datasets: List[Dict[str, str]]) -> rich.table
197
235
  row = [str(dataset.get(field, "")) for field in TABLE_FIELDS.values()]
198
236
  table.add_row(*row)
199
237
 
200
- return table
238
+ rprint(table)
201
239
 
202
240
 
203
241
  def extend_dataset_details(datasets: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
@@ -96,7 +96,8 @@ def get_resource_credentials(endpoint: str, api_key: str) -> ResourceCredentials
96
96
  RuntimeError: If the call to fetch the credentials fails for any reason.
97
97
  """
98
98
  try:
99
- credentials_dict = fetch(endpoint, headers={"Authorization": api_key, "accept": "application/json"})
99
+ headers = {"Authorization": api_key, "accept": "application/json"}
100
+ credentials_dict: Dict = fetch(endpoint, headers=headers) # type: ignore[assignment]
100
101
  credentials = ResourceCredentials.fix_naming(credentials_dict)
101
102
  sys_logger.debug("Successfully retrieved credentials from DIP endpoint.")
102
103
  return credentials