hafnia 0.2.4__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/__main__.py +16 -3
- cli/config.py +45 -4
- cli/consts.py +1 -1
- cli/dataset_cmds.py +6 -14
- cli/dataset_recipe_cmds.py +78 -0
- cli/experiment_cmds.py +226 -43
- cli/keychain.py +88 -0
- cli/profile_cmds.py +10 -6
- cli/runc_cmds.py +5 -5
- cli/trainer_package_cmds.py +65 -0
- hafnia/__init__.py +2 -0
- hafnia/data/factory.py +1 -2
- hafnia/dataset/dataset_helpers.py +9 -14
- hafnia/dataset/dataset_names.py +10 -5
- hafnia/dataset/dataset_recipe/dataset_recipe.py +165 -67
- hafnia/dataset/dataset_recipe/recipe_transforms.py +48 -4
- hafnia/dataset/dataset_recipe/recipe_types.py +1 -1
- hafnia/dataset/dataset_upload_helper.py +265 -56
- hafnia/dataset/format_conversions/image_classification_from_directory.py +106 -0
- hafnia/dataset/format_conversions/torchvision_datasets.py +281 -0
- hafnia/dataset/hafnia_dataset.py +577 -213
- hafnia/dataset/license_types.py +63 -0
- hafnia/dataset/operations/dataset_stats.py +259 -3
- hafnia/dataset/operations/dataset_transformations.py +332 -7
- hafnia/dataset/operations/table_transformations.py +43 -5
- hafnia/dataset/primitives/__init__.py +8 -0
- hafnia/dataset/primitives/bbox.py +25 -12
- hafnia/dataset/primitives/bitmask.py +26 -14
- hafnia/dataset/primitives/classification.py +16 -8
- hafnia/dataset/primitives/point.py +7 -3
- hafnia/dataset/primitives/polygon.py +16 -9
- hafnia/dataset/primitives/segmentation.py +10 -7
- hafnia/experiment/hafnia_logger.py +111 -8
- hafnia/http.py +16 -2
- hafnia/platform/__init__.py +9 -3
- hafnia/platform/builder.py +12 -10
- hafnia/platform/dataset_recipe.py +104 -0
- hafnia/platform/datasets.py +47 -9
- hafnia/platform/download.py +25 -19
- hafnia/platform/experiment.py +51 -56
- hafnia/platform/trainer_package.py +57 -0
- hafnia/utils.py +81 -13
- hafnia/visualizations/image_visualizations.py +4 -4
- {hafnia-0.2.4.dist-info → hafnia-0.4.0.dist-info}/METADATA +40 -34
- hafnia-0.4.0.dist-info/RECORD +56 -0
- cli/recipe_cmds.py +0 -45
- hafnia-0.2.4.dist-info/RECORD +0 -49
- {hafnia-0.2.4.dist-info → hafnia-0.4.0.dist-info}/WHEEL +0 -0
- {hafnia-0.2.4.dist-info → hafnia-0.4.0.dist-info}/entry_points.txt +0 -0
- {hafnia-0.2.4.dist-info → hafnia-0.4.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple
|
|
|
2
2
|
|
|
3
3
|
import cv2
|
|
4
4
|
import numpy as np
|
|
5
|
+
from pydantic import Field
|
|
5
6
|
|
|
6
7
|
from hafnia.dataset.primitives.bitmask import Bitmask
|
|
7
8
|
from hafnia.dataset.primitives.point import Point
|
|
@@ -11,15 +12,21 @@ from hafnia.dataset.primitives.utils import class_color_by_name, get_class_name
|
|
|
11
12
|
|
|
12
13
|
class Polygon(Primitive):
|
|
13
14
|
# Names should match names in FieldName
|
|
14
|
-
points: List[Point]
|
|
15
|
-
class_name: Optional[str] = None
|
|
16
|
-
class_idx: Optional[int] = None
|
|
17
|
-
object_id: Optional[str] = None
|
|
18
|
-
confidence: Optional[float] =
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
15
|
+
points: List[Point] = Field(description="List of points defining the polygon")
|
|
16
|
+
class_name: Optional[str] = Field(default=None, description="Class name of the polygon")
|
|
17
|
+
class_idx: Optional[int] = Field(default=None, description="Class index of the polygon")
|
|
18
|
+
object_id: Optional[str] = Field(default=None, description="Object ID of the polygon")
|
|
19
|
+
confidence: Optional[float] = Field(
|
|
20
|
+
default=None, description="Confidence score (0-1.0) for the primitive, e.g. 0.95 for Bbox"
|
|
21
|
+
)
|
|
22
|
+
ground_truth: bool = Field(default=True, description="Whether this is ground truth or a prediction")
|
|
23
|
+
|
|
24
|
+
task_name: str = Field(
|
|
25
|
+
default="", description="Task name to support multiple Polygon tasks in the same dataset. Defaults to 'polygon'"
|
|
26
|
+
)
|
|
27
|
+
meta: Optional[Dict[str, Any]] = Field(
|
|
28
|
+
default=None, description="This can be used to store additional information about the polygon"
|
|
29
|
+
)
|
|
23
30
|
|
|
24
31
|
@staticmethod
|
|
25
32
|
def from_list_of_points(
|
|
@@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional, Tuple
|
|
|
2
2
|
|
|
3
3
|
import cv2
|
|
4
4
|
import numpy as np
|
|
5
|
+
from pydantic import Field
|
|
5
6
|
|
|
6
7
|
from hafnia.dataset.primitives.primitive import Primitive
|
|
7
8
|
from hafnia.dataset.primitives.utils import get_class_name
|
|
@@ -9,15 +10,17 @@ from hafnia.visualizations.colors import get_n_colors
|
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
class Segmentation(Primitive):
|
|
12
|
-
#
|
|
13
|
-
class_names: Optional[List[str]] = None
|
|
14
|
-
ground_truth: bool = True
|
|
13
|
+
# WARNING: Segmentation masks have not been fully implemented yet
|
|
14
|
+
class_names: Optional[List[str]] = Field(default=None, description="Class names of the segmentation")
|
|
15
|
+
ground_truth: bool = Field(default=True, description="Whether this is ground truth or a prediction")
|
|
15
16
|
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
"
|
|
17
|
+
task_name: str = Field(
|
|
18
|
+
default="",
|
|
19
|
+
description="Task name to support multiple Segmentation tasks in the same dataset. Defaults to 'segmentation'",
|
|
20
|
+
)
|
|
21
|
+
meta: Optional[Dict[str, Any]] = Field(
|
|
22
|
+
default=None, description="This can be used to store additional information about the segmentation"
|
|
19
23
|
)
|
|
20
|
-
meta: Optional[Dict[str, Any]] = None # This can be used to store additional information about the bitmask
|
|
21
24
|
|
|
22
25
|
@staticmethod
|
|
23
26
|
def default_task_name() -> str:
|
|
@@ -2,6 +2,7 @@ import json
|
|
|
2
2
|
import os
|
|
3
3
|
import platform
|
|
4
4
|
import sys
|
|
5
|
+
import textwrap
|
|
5
6
|
from datetime import datetime
|
|
6
7
|
from enum import Enum
|
|
7
8
|
from pathlib import Path
|
|
@@ -11,11 +12,19 @@ import pyarrow as pa
|
|
|
11
12
|
import pyarrow.parquet as pq
|
|
12
13
|
from pydantic import BaseModel, field_validator
|
|
13
14
|
|
|
14
|
-
from hafnia.data.factory import load_dataset
|
|
15
|
-
from hafnia.dataset.hafnia_dataset import HafniaDataset
|
|
16
15
|
from hafnia.log import sys_logger, user_logger
|
|
17
16
|
from hafnia.utils import is_hafnia_cloud_job, now_as_str
|
|
18
17
|
|
|
18
|
+
try:
|
|
19
|
+
import mlflow
|
|
20
|
+
import mlflow.tracking
|
|
21
|
+
import sagemaker_mlflow # noqa: F401
|
|
22
|
+
|
|
23
|
+
MLFLOW_AVAILABLE = True
|
|
24
|
+
except ImportError:
|
|
25
|
+
user_logger.warning("MLFlow is not available")
|
|
26
|
+
MLFLOW_AVAILABLE = False
|
|
27
|
+
|
|
19
28
|
|
|
20
29
|
class EntityType(Enum):
|
|
21
30
|
"""Types of entities that can be logged."""
|
|
@@ -87,17 +96,43 @@ class HafniaLogger:
|
|
|
87
96
|
for path in create_paths:
|
|
88
97
|
path.mkdir(parents=True, exist_ok=True)
|
|
89
98
|
|
|
99
|
+
path_file = self.path_model() / "HOW_TO_STORE_YOUR_MODEL.txt"
|
|
100
|
+
path_file.write_text(get_instructions_how_to_store_model())
|
|
101
|
+
|
|
90
102
|
self.dataset_name: Optional[str] = None
|
|
91
103
|
self.log_file = self._path_artifacts() / self.EXPERIMENT_FILE
|
|
92
104
|
self.schema = Entity.create_schema()
|
|
105
|
+
|
|
106
|
+
# Initialize MLflow for remote jobs
|
|
107
|
+
self._mlflow_initialized = False
|
|
108
|
+
if is_hafnia_cloud_job() and MLFLOW_AVAILABLE:
|
|
109
|
+
self._init_mlflow()
|
|
110
|
+
|
|
93
111
|
self.log_environment()
|
|
94
112
|
|
|
95
|
-
def
|
|
96
|
-
"""
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
113
|
+
def _init_mlflow(self):
|
|
114
|
+
"""Initialize MLflow tracking for remote jobs."""
|
|
115
|
+
try:
|
|
116
|
+
# Set MLflow tracking URI from environment variable
|
|
117
|
+
tracking_uri = os.getenv("MLFLOW_TRACKING_URI")
|
|
118
|
+
if tracking_uri:
|
|
119
|
+
mlflow.set_tracking_uri(tracking_uri)
|
|
120
|
+
user_logger.info(f"MLflow tracking URI set to: {tracking_uri}")
|
|
121
|
+
|
|
122
|
+
# Set experiment name from environment variable
|
|
123
|
+
experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME")
|
|
124
|
+
if experiment_name:
|
|
125
|
+
mlflow.set_experiment(experiment_name)
|
|
126
|
+
user_logger.info(f"MLflow experiment set to: {experiment_name}")
|
|
127
|
+
|
|
128
|
+
# Start MLflow run
|
|
129
|
+
run_name = os.getenv("MLFLOW_RUN_NAME", "undefined")
|
|
130
|
+
mlflow.start_run(run_name=run_name)
|
|
131
|
+
self._mlflow_initialized = True
|
|
132
|
+
user_logger.info("MLflow run started successfully")
|
|
133
|
+
|
|
134
|
+
except Exception as e:
|
|
135
|
+
user_logger.error(f"Failed to initialize MLflow: {e}")
|
|
101
136
|
|
|
102
137
|
def path_local_experiment(self) -> Path:
|
|
103
138
|
"""Get the path for local experiment."""
|
|
@@ -153,6 +188,14 @@ class HafniaLogger:
|
|
|
153
188
|
)
|
|
154
189
|
self.write_entity(entity)
|
|
155
190
|
|
|
191
|
+
# Also log to MLflow if initialized
|
|
192
|
+
if not self._mlflow_initialized:
|
|
193
|
+
return
|
|
194
|
+
try:
|
|
195
|
+
mlflow.log_metric(name, value, step=step)
|
|
196
|
+
except Exception as e:
|
|
197
|
+
user_logger.error(f"Failed to log metric to MLflow: {e}")
|
|
198
|
+
|
|
156
199
|
def log_configuration(self, configurations: Dict):
|
|
157
200
|
self.log_hparams(configurations, "configuration.json")
|
|
158
201
|
|
|
@@ -166,6 +209,15 @@ class HafniaLogger:
|
|
|
166
209
|
existing_params.update(params)
|
|
167
210
|
file_path.write_text(json.dumps(existing_params, indent=2))
|
|
168
211
|
user_logger.info(f"Saved parameters to {file_path}")
|
|
212
|
+
|
|
213
|
+
# Also log to MLflow if initialized
|
|
214
|
+
if not self._mlflow_initialized:
|
|
215
|
+
return
|
|
216
|
+
try:
|
|
217
|
+
mlflow.log_params(params)
|
|
218
|
+
except Exception as e:
|
|
219
|
+
user_logger.error(f"Failed to log params to MLflow: {e}")
|
|
220
|
+
|
|
169
221
|
except Exception as e:
|
|
170
222
|
user_logger.error(f"Failed to save parameters to {file_path}: {e}")
|
|
171
223
|
|
|
@@ -202,3 +254,54 @@ class HafniaLogger:
|
|
|
202
254
|
pq.write_table(next_table, self.log_file)
|
|
203
255
|
except Exception as e:
|
|
204
256
|
sys_logger.error(f"Failed to flush logs: {e}")
|
|
257
|
+
|
|
258
|
+
def end_run(self) -> None:
|
|
259
|
+
"""End the MLflow run if initialized."""
|
|
260
|
+
if not self._mlflow_initialized:
|
|
261
|
+
return
|
|
262
|
+
try:
|
|
263
|
+
mlflow.end_run()
|
|
264
|
+
self._mlflow_initialized = False
|
|
265
|
+
user_logger.info("MLflow run ended successfully")
|
|
266
|
+
except Exception as e:
|
|
267
|
+
user_logger.error(f"Failed to end MLflow run: {e}")
|
|
268
|
+
|
|
269
|
+
def __del__(self):
|
|
270
|
+
"""Cleanup when logger is destroyed."""
|
|
271
|
+
self.end_run()
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def get_instructions_how_to_store_model() -> str:
|
|
275
|
+
instructions = textwrap.dedent(
|
|
276
|
+
"""\
|
|
277
|
+
If you, against your expectations, don't see any models in this folder,
|
|
278
|
+
we have provided a small guide to help.
|
|
279
|
+
|
|
280
|
+
The hafnia TaaS framework expects models to be stored in a folder generated
|
|
281
|
+
by the hafnia logger. You will need to store models in this folder
|
|
282
|
+
to ensure that they are properly stored and accessible after training.
|
|
283
|
+
|
|
284
|
+
Please check your recipe script and ensure that the models are being stored
|
|
285
|
+
as expected by the TaaS framework.
|
|
286
|
+
|
|
287
|
+
Below is also a small example to demonstrate:
|
|
288
|
+
|
|
289
|
+
```python
|
|
290
|
+
from hafnia.experiment import HafniaLogger
|
|
291
|
+
|
|
292
|
+
# Initiate Hafnia logger
|
|
293
|
+
logger = HafniaLogger()
|
|
294
|
+
|
|
295
|
+
# Folder path to store models - generated by the hafnia logger.
|
|
296
|
+
model_dir = logger.path_model()
|
|
297
|
+
|
|
298
|
+
# Example for storing a pytorch based model. Note: the model is stored in 'model_dir'
|
|
299
|
+
path_pytorch_model = model_dir / "model.pth"
|
|
300
|
+
|
|
301
|
+
# Finally save the model to the specified path
|
|
302
|
+
torch.save(model.state_dict(), path_pytorch_model)
|
|
303
|
+
```
|
|
304
|
+
"""
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
return instructions
|
hafnia/http.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
import json
|
|
2
2
|
from pathlib import Path
|
|
3
|
-
from typing import Dict, Optional, Union
|
|
3
|
+
from typing import Dict, List, Optional, Union
|
|
4
4
|
|
|
5
5
|
import urllib3
|
|
6
6
|
|
|
7
7
|
|
|
8
|
-
def fetch(endpoint: str, headers: Dict, params: Optional[Dict] = None) -> Dict:
|
|
8
|
+
def fetch(endpoint: str, headers: Dict, params: Optional[Dict] = None) -> Union[Dict, List]:
|
|
9
9
|
"""Fetches data from the API endpoint.
|
|
10
10
|
|
|
11
11
|
Args:
|
|
@@ -81,3 +81,17 @@ def post(endpoint: str, headers: Dict, data: Union[Path, Dict, bytes, str], mult
|
|
|
81
81
|
return json.loads(response.data.decode("utf-8"))
|
|
82
82
|
finally:
|
|
83
83
|
http.clear()
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def delete(endpoint: str, headers: Dict) -> Dict:
|
|
87
|
+
"""Sends a DELETE request to the specified endpoint."""
|
|
88
|
+
http = urllib3.PoolManager(retries=urllib3.Retry(3))
|
|
89
|
+
try:
|
|
90
|
+
response = http.request("DELETE", endpoint, headers=headers)
|
|
91
|
+
|
|
92
|
+
if response.status not in (200, 204):
|
|
93
|
+
error_details = response.data.decode("utf-8")
|
|
94
|
+
raise urllib3.exceptions.HTTPError(f"Request failed with status {response.status}: {error_details}")
|
|
95
|
+
return json.loads(response.data.decode("utf-8")) if response.data else {}
|
|
96
|
+
finally:
|
|
97
|
+
http.clear()
|
hafnia/platform/__init__.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from hafnia.platform.datasets import get_dataset_id
|
|
1
2
|
from hafnia.platform.download import (
|
|
2
3
|
download_resource,
|
|
3
4
|
download_single_object,
|
|
@@ -5,17 +6,22 @@ from hafnia.platform.download import (
|
|
|
5
6
|
)
|
|
6
7
|
from hafnia.platform.experiment import (
|
|
7
8
|
create_experiment,
|
|
8
|
-
|
|
9
|
-
get_dataset_id,
|
|
9
|
+
get_environments,
|
|
10
10
|
get_exp_environment_id,
|
|
11
|
+
pretty_print_training_environments,
|
|
11
12
|
)
|
|
13
|
+
from hafnia.platform.trainer_package import create_trainer_package, get_trainer_package_by_id, get_trainer_packages
|
|
12
14
|
|
|
13
15
|
__all__ = [
|
|
14
16
|
"get_dataset_id",
|
|
15
|
-
"
|
|
17
|
+
"create_trainer_package",
|
|
18
|
+
"get_trainer_packages",
|
|
19
|
+
"get_trainer_package_by_id",
|
|
16
20
|
"get_exp_environment_id",
|
|
17
21
|
"create_experiment",
|
|
18
22
|
"download_resource",
|
|
19
23
|
"download_single_object",
|
|
20
24
|
"get_resource_credentials",
|
|
25
|
+
"pretty_print_training_environments",
|
|
26
|
+
"get_environments",
|
|
21
27
|
]
|
hafnia/platform/builder.py
CHANGED
|
@@ -14,26 +14,28 @@ from hafnia.log import sys_logger, user_logger
|
|
|
14
14
|
from hafnia.platform import download_resource
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
def
|
|
18
|
-
"""Validate Hafnia
|
|
17
|
+
def validate_trainer_package_format(path: Path) -> None:
|
|
18
|
+
"""Validate Hafnia Trainer Package Format submission"""
|
|
19
19
|
hrf = zipfile.Path(path) if path.suffix == ".zip" else path
|
|
20
20
|
required = {"src", "scripts", "Dockerfile"}
|
|
21
21
|
errors = 0
|
|
22
22
|
for rp in required:
|
|
23
23
|
if not (hrf / rp).exists():
|
|
24
|
-
user_logger.error(f"Required path {rp} not found in
|
|
24
|
+
user_logger.error(f"Required path {rp} not found in trainer package.")
|
|
25
25
|
errors += 1
|
|
26
26
|
if errors > 0:
|
|
27
|
-
raise FileNotFoundError("Wrong
|
|
27
|
+
raise FileNotFoundError("Wrong trainer package structure")
|
|
28
28
|
|
|
29
29
|
|
|
30
|
-
def
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
30
|
+
def prepare_trainer_package(
|
|
31
|
+
trainer_url: str, output_dir: Path, api_key: str, state_file: Optional[Path] = None
|
|
32
|
+
) -> Dict:
|
|
33
|
+
resource = download_resource(trainer_url, output_dir.as_posix(), api_key)
|
|
34
|
+
trainer_path = Path(resource["downloaded_files"][0])
|
|
35
|
+
with zipfile.ZipFile(trainer_path, "r") as zip_ref:
|
|
34
36
|
zip_ref.extractall(output_dir)
|
|
35
37
|
|
|
36
|
-
|
|
38
|
+
validate_trainer_package_format(output_dir)
|
|
37
39
|
|
|
38
40
|
scripts_dir = output_dir / "scripts"
|
|
39
41
|
if not any(scripts_dir.iterdir()):
|
|
@@ -42,7 +44,7 @@ def prepare_recipe(recipe_url: str, output_dir: Path, api_key: str, state_file:
|
|
|
42
44
|
metadata = {
|
|
43
45
|
"user_data": (output_dir / "src").as_posix(),
|
|
44
46
|
"dockerfile": (output_dir / "Dockerfile").as_posix(),
|
|
45
|
-
"digest": sha256(
|
|
47
|
+
"digest": sha256(trainer_path.read_bytes()).hexdigest()[:8],
|
|
46
48
|
}
|
|
47
49
|
state_file = state_file if state_file else output_dir / "state.json"
|
|
48
50
|
with open(state_file, "w", encoding="utf-8") as f:
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Dict, List, Optional
|
|
4
|
+
|
|
5
|
+
from flatten_dict import flatten
|
|
6
|
+
|
|
7
|
+
from hafnia import http
|
|
8
|
+
from hafnia.log import user_logger
|
|
9
|
+
from hafnia.utils import pretty_print_list_as_table, timed
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@timed("Get or create dataset recipe")
|
|
13
|
+
def get_or_create_dataset_recipe(
|
|
14
|
+
recipe: dict,
|
|
15
|
+
endpoint: str,
|
|
16
|
+
api_key: str,
|
|
17
|
+
name: Optional[str] = None,
|
|
18
|
+
overwrite: bool = False,
|
|
19
|
+
) -> Optional[Dict]:
|
|
20
|
+
headers = {"Authorization": api_key}
|
|
21
|
+
data = {"template": {"body": recipe}, "overwrite": overwrite}
|
|
22
|
+
if name is not None:
|
|
23
|
+
data["name"] = name # type: ignore[assignment]
|
|
24
|
+
|
|
25
|
+
response = http.post(endpoint, headers=headers, data=data)
|
|
26
|
+
return response
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def get_or_create_dataset_recipe_by_dataset_name(dataset_name: str, endpoint: str, api_key: str) -> Dict:
|
|
30
|
+
return get_or_create_dataset_recipe(recipe=dataset_name, endpoint=endpoint, api_key=api_key)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def get_dataset_recipes(endpoint: str, api_key: str) -> List[Dict]:
|
|
34
|
+
headers = {"Authorization": api_key}
|
|
35
|
+
dataset_recipes: List[Dict] = http.fetch(endpoint, headers=headers) # type: ignore[assignment]
|
|
36
|
+
return dataset_recipes
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def get_dataset_recipe_by_id(dataset_recipe_id: str, endpoint: str, api_key: str) -> Dict:
|
|
40
|
+
headers = {"Authorization": api_key}
|
|
41
|
+
full_url = f"{endpoint}/{dataset_recipe_id}"
|
|
42
|
+
dataset_recipe_info: Dict = http.fetch(full_url, headers=headers) # type: ignore[assignment]
|
|
43
|
+
if not dataset_recipe_info:
|
|
44
|
+
raise ValueError(f"Dataset recipe with ID '{dataset_recipe_id}' was not found.")
|
|
45
|
+
return dataset_recipe_info
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def get_or_create_dataset_recipe_from_path(
|
|
49
|
+
path_recipe_json: Path, endpoint: str, api_key: str, name: Optional[str] = None
|
|
50
|
+
) -> Dict:
|
|
51
|
+
path_recipe_json = Path(path_recipe_json)
|
|
52
|
+
if not path_recipe_json.exists():
|
|
53
|
+
raise FileNotFoundError(f"Dataset recipe file '{path_recipe_json}' does not exist.")
|
|
54
|
+
json_dict = json.loads(path_recipe_json.read_text())
|
|
55
|
+
return get_or_create_dataset_recipe(json_dict, endpoint=endpoint, api_key=api_key, name=name)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def delete_dataset_recipe_by_id(id: str, endpoint: str, api_key: str) -> Dict:
|
|
59
|
+
headers = {"Authorization": api_key}
|
|
60
|
+
full_url = f"{endpoint}/{id}"
|
|
61
|
+
response = http.delete(endpoint=full_url, headers=headers)
|
|
62
|
+
return response
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@timed("Get dataset recipe")
|
|
66
|
+
def get_dataset_recipe_by_name(name: str, endpoint: str, api_key: str) -> Optional[Dict]:
|
|
67
|
+
headers = {"Authorization": api_key}
|
|
68
|
+
full_url = f"{endpoint}?name__iexact={name}"
|
|
69
|
+
dataset_recipes: List[Dict] = http.fetch(full_url, headers=headers) # type: ignore[assignment]
|
|
70
|
+
if len(dataset_recipes) == 0:
|
|
71
|
+
return None
|
|
72
|
+
|
|
73
|
+
if len(dataset_recipes) > 1:
|
|
74
|
+
user_logger.warning(f"Found {len(dataset_recipes)} dataset recipes called '{name}'. Using the first one.")
|
|
75
|
+
|
|
76
|
+
dataset_recipe = dataset_recipes[0]
|
|
77
|
+
return dataset_recipe
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def delete_dataset_recipe_by_name(name: str, endpoint: str, api_key: str) -> Optional[Dict]:
|
|
81
|
+
recipe_response = get_dataset_recipe_by_name(name, endpoint=endpoint, api_key=api_key)
|
|
82
|
+
|
|
83
|
+
if recipe_response:
|
|
84
|
+
return delete_dataset_recipe_by_id(recipe_response["id"], endpoint=endpoint, api_key=api_key)
|
|
85
|
+
return recipe_response
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def pretty_print_dataset_recipes(recipes: List[Dict]) -> None:
|
|
89
|
+
recipes = [flatten(recipe, reducer="dot", max_flatten_depth=2) for recipe in recipes] # noqa: F821
|
|
90
|
+
for recipe in recipes:
|
|
91
|
+
recipe["recipe_json"] = json.dumps(recipe["template.body"])[:20]
|
|
92
|
+
|
|
93
|
+
RECIPE_FIELDS = {
|
|
94
|
+
"ID": "id",
|
|
95
|
+
"Name": "name",
|
|
96
|
+
"Recipe": "recipe_json",
|
|
97
|
+
"Created": "created_at",
|
|
98
|
+
"IsDataset": "template.is_direct_dataset_reference",
|
|
99
|
+
}
|
|
100
|
+
pretty_print_list_as_table(
|
|
101
|
+
table_title="Available Dataset Recipes",
|
|
102
|
+
dict_items=recipes,
|
|
103
|
+
column_name_to_key_mapping=RECIPE_FIELDS,
|
|
104
|
+
)
|
hafnia/platform/datasets.py
CHANGED
|
@@ -8,10 +8,11 @@ from pathlib import Path
|
|
|
8
8
|
from typing import Any, Dict, List, Optional
|
|
9
9
|
|
|
10
10
|
import rich
|
|
11
|
-
from
|
|
11
|
+
from rich import print as rprint
|
|
12
|
+
from rich.progress import track
|
|
12
13
|
|
|
13
14
|
from cli.config import Config
|
|
14
|
-
from hafnia import utils
|
|
15
|
+
from hafnia import http, utils
|
|
15
16
|
from hafnia.dataset.dataset_names import DATASET_FILENAMES_REQUIRED, ColumnName
|
|
16
17
|
from hafnia.dataset.dataset_recipe.dataset_recipe import (
|
|
17
18
|
DatasetRecipe,
|
|
@@ -20,13 +21,12 @@ from hafnia.dataset.dataset_recipe.dataset_recipe import (
|
|
|
20
21
|
from hafnia.dataset.hafnia_dataset import HafniaDataset
|
|
21
22
|
from hafnia.http import fetch
|
|
22
23
|
from hafnia.log import sys_logger, user_logger
|
|
23
|
-
from hafnia.platform import get_dataset_id
|
|
24
24
|
from hafnia.platform.download import get_resource_credentials
|
|
25
25
|
from hafnia.utils import timed
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
@timed("Fetching dataset list.")
|
|
29
|
-
def
|
|
29
|
+
def get_datasets(cfg: Optional[Config] = None) -> List[Dict[str, str]]:
|
|
30
30
|
"""List available datasets on the Hafnia platform."""
|
|
31
31
|
cfg = cfg or Config()
|
|
32
32
|
endpoint_dataset = cfg.get_platform_endpoint("datasets")
|
|
@@ -38,6 +38,19 @@ def dataset_list(cfg: Optional[Config] = None) -> List[Dict[str, str]]:
|
|
|
38
38
|
return datasets
|
|
39
39
|
|
|
40
40
|
|
|
41
|
+
@timed("Fetching dataset info.")
|
|
42
|
+
def get_dataset_id(dataset_name: str, endpoint: str, api_key: str) -> str:
|
|
43
|
+
headers = {"Authorization": api_key}
|
|
44
|
+
full_url = f"{endpoint}?name__iexact={dataset_name}"
|
|
45
|
+
dataset_responses: List[Dict] = http.fetch(full_url, headers=headers) # type: ignore[assignment]
|
|
46
|
+
if not dataset_responses:
|
|
47
|
+
raise ValueError(f"Dataset '{dataset_name}' was not found in the dataset library.")
|
|
48
|
+
try:
|
|
49
|
+
return dataset_responses[0]["id"]
|
|
50
|
+
except (IndexError, KeyError) as e:
|
|
51
|
+
raise ValueError("Dataset information is missing or invalid") from e
|
|
52
|
+
|
|
53
|
+
|
|
41
54
|
def download_or_get_dataset_path(
|
|
42
55
|
dataset_name: str,
|
|
43
56
|
cfg: Optional[Config] = None,
|
|
@@ -109,7 +122,7 @@ def download_dataset_from_access_endpoint(
|
|
|
109
122
|
try:
|
|
110
123
|
fast_copy_files_s3(
|
|
111
124
|
src_paths=dataset.samples[ColumnName.REMOTE_PATH].to_list(),
|
|
112
|
-
dst_paths=dataset.samples[ColumnName.
|
|
125
|
+
dst_paths=dataset.samples[ColumnName.FILE_PATH].to_list(),
|
|
113
126
|
append_envs=envs,
|
|
114
127
|
description="Downloading images",
|
|
115
128
|
)
|
|
@@ -131,6 +144,28 @@ def fast_copy_files_s3(
|
|
|
131
144
|
return lines
|
|
132
145
|
|
|
133
146
|
|
|
147
|
+
def find_s5cmd() -> Optional[str]:
|
|
148
|
+
"""Locate the s5cmd executable across different installation methods.
|
|
149
|
+
|
|
150
|
+
Searches for s5cmd in:
|
|
151
|
+
1. System PATH (via shutil.which)
|
|
152
|
+
2. Python bin directory (Unix-like systems)
|
|
153
|
+
3. Python executable directory (direct installs)
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
str: Absolute path to s5cmd executable if found, None otherwise.
|
|
157
|
+
"""
|
|
158
|
+
result = shutil.which("s5cmd")
|
|
159
|
+
if result:
|
|
160
|
+
return result
|
|
161
|
+
python_dir = Path(sys.executable).parent
|
|
162
|
+
locations = (python_dir / "Scripts" / "s5cmd.exe", python_dir / "bin" / "s5cmd", python_dir / "s5cmd")
|
|
163
|
+
for loc in locations:
|
|
164
|
+
if loc.exists():
|
|
165
|
+
return str(loc)
|
|
166
|
+
return None
|
|
167
|
+
|
|
168
|
+
|
|
134
169
|
def execute_s5cmd_commands(
|
|
135
170
|
commands: List[str],
|
|
136
171
|
append_envs: Optional[Dict[str, str]] = None,
|
|
@@ -142,7 +177,10 @@ def execute_s5cmd_commands(
|
|
|
142
177
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
143
178
|
tmp_file_path = Path(temp_dir, f"{uuid.uuid4().hex}.txt")
|
|
144
179
|
tmp_file_path.write_text("\n".join(commands))
|
|
145
|
-
|
|
180
|
+
|
|
181
|
+
s5cmd_bin = find_s5cmd()
|
|
182
|
+
if s5cmd_bin is None:
|
|
183
|
+
raise ValueError("Can not find s5cmd executable.")
|
|
146
184
|
run_cmds = [s5cmd_bin, "run", str(tmp_file_path)]
|
|
147
185
|
sys_logger.debug(run_cmds)
|
|
148
186
|
envs = os.environ.copy()
|
|
@@ -158,7 +196,7 @@ def execute_s5cmd_commands(
|
|
|
158
196
|
|
|
159
197
|
error_lines = []
|
|
160
198
|
lines = []
|
|
161
|
-
for line in
|
|
199
|
+
for line in track(process.stdout, total=len(commands), description=description):
|
|
162
200
|
if "ERROR" in line or "error" in line:
|
|
163
201
|
error_lines.append(line.strip())
|
|
164
202
|
lines.append(line.strip())
|
|
@@ -185,7 +223,7 @@ TABLE_FIELDS = {
|
|
|
185
223
|
}
|
|
186
224
|
|
|
187
225
|
|
|
188
|
-
def
|
|
226
|
+
def pretty_print_datasets(datasets: List[Dict[str, str]]) -> None:
|
|
189
227
|
datasets = extend_dataset_details(datasets)
|
|
190
228
|
datasets = sorted(datasets, key=lambda x: x["name"].lower())
|
|
191
229
|
|
|
@@ -197,7 +235,7 @@ def create_rich_table_from_dataset(datasets: List[Dict[str, str]]) -> rich.table
|
|
|
197
235
|
row = [str(dataset.get(field, "")) for field in TABLE_FIELDS.values()]
|
|
198
236
|
table.add_row(*row)
|
|
199
237
|
|
|
200
|
-
|
|
238
|
+
rprint(table)
|
|
201
239
|
|
|
202
240
|
|
|
203
241
|
def extend_dataset_details(datasets: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
hafnia/platform/download.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
from pathlib import Path
|
|
2
|
-
from typing import Dict
|
|
2
|
+
from typing import Dict, Optional
|
|
3
3
|
|
|
4
4
|
import boto3
|
|
5
5
|
from botocore.exceptions import ClientError
|
|
6
6
|
from pydantic import BaseModel, field_validator
|
|
7
|
-
from
|
|
7
|
+
from rich.progress import Progress
|
|
8
8
|
|
|
9
9
|
from hafnia.http import fetch
|
|
10
10
|
from hafnia.log import sys_logger, user_logger
|
|
@@ -96,7 +96,8 @@ def get_resource_credentials(endpoint: str, api_key: str) -> ResourceCredentials
|
|
|
96
96
|
RuntimeError: If the call to fetch the credentials fails for any reason.
|
|
97
97
|
"""
|
|
98
98
|
try:
|
|
99
|
-
|
|
99
|
+
headers = {"Authorization": api_key, "accept": "application/json"}
|
|
100
|
+
credentials_dict: Dict = fetch(endpoint, headers=headers) # type: ignore[assignment]
|
|
100
101
|
credentials = ResourceCredentials.fix_naming(credentials_dict)
|
|
101
102
|
sys_logger.debug("Successfully retrieved credentials from DIP endpoint.")
|
|
102
103
|
return credentials
|
|
@@ -124,13 +125,15 @@ def download_single_object(s3_client, bucket: str, object_key: str, output_dir:
|
|
|
124
125
|
return local_path
|
|
125
126
|
|
|
126
127
|
|
|
127
|
-
def download_resource(resource_url: str, destination: str, api_key: str) -> Dict:
|
|
128
|
+
def download_resource(resource_url: str, destination: str, api_key: str, prefix: Optional[str] = None) -> Dict:
|
|
128
129
|
"""
|
|
129
130
|
Downloads either a single file from S3 or all objects under a prefix.
|
|
130
131
|
|
|
131
132
|
Args:
|
|
132
133
|
resource_url (str): The URL or identifier used to fetch S3 credentials.
|
|
133
134
|
destination (str): Path to local directory where files will be stored.
|
|
135
|
+
api_key (str): API key for authentication when fetching credentials.
|
|
136
|
+
prefix (Optional[str]): If provided, only download objects under this prefix.
|
|
134
137
|
|
|
135
138
|
Returns:
|
|
136
139
|
Dict[str, Any]: A dictionary containing download info, e.g.:
|
|
@@ -146,7 +149,7 @@ def download_resource(resource_url: str, destination: str, api_key: str) -> Dict
|
|
|
146
149
|
res_credentials = get_resource_credentials(resource_url, api_key)
|
|
147
150
|
|
|
148
151
|
bucket_name = res_credentials.bucket_name()
|
|
149
|
-
|
|
152
|
+
prefix = prefix or res_credentials.object_key()
|
|
150
153
|
|
|
151
154
|
output_path = Path(destination)
|
|
152
155
|
output_path.mkdir(parents=True, exist_ok=True)
|
|
@@ -158,29 +161,32 @@ def download_resource(resource_url: str, destination: str, api_key: str) -> Dict
|
|
|
158
161
|
)
|
|
159
162
|
downloaded_files = []
|
|
160
163
|
try:
|
|
161
|
-
s3_client.head_object(Bucket=bucket_name, Key=
|
|
162
|
-
local_file = download_single_object(s3_client, bucket_name,
|
|
164
|
+
s3_client.head_object(Bucket=bucket_name, Key=prefix)
|
|
165
|
+
local_file = download_single_object(s3_client, bucket_name, prefix, output_path)
|
|
163
166
|
downloaded_files.append(str(local_file))
|
|
164
167
|
user_logger.info(f"Downloaded single file: {local_file}")
|
|
165
168
|
|
|
166
169
|
except ClientError as e:
|
|
167
170
|
error_code = e.response.get("Error", {}).get("Code")
|
|
168
171
|
if error_code == "404":
|
|
169
|
-
sys_logger.debug(f"Object '{
|
|
170
|
-
response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=
|
|
172
|
+
sys_logger.debug(f"Object '{prefix}' not found; trying as a prefix.")
|
|
173
|
+
response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=prefix)
|
|
171
174
|
contents = response.get("Contents", [])
|
|
172
175
|
|
|
173
176
|
if not contents:
|
|
174
|
-
raise ValueError(f"No objects found for prefix '{
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
177
|
+
raise ValueError(f"No objects found for prefix '{prefix}' in bucket '{bucket_name}'")
|
|
178
|
+
|
|
179
|
+
with Progress() as progress:
|
|
180
|
+
task = progress.add_task("Downloading files", total=len(contents))
|
|
181
|
+
for obj in contents:
|
|
182
|
+
sub_key = obj["Key"]
|
|
183
|
+
size_mb = obj.get("Size", 0) / 1024 / 1024
|
|
184
|
+
progress.update(task, description=f"Downloading {sub_key} ({size_mb:.2f} MB)")
|
|
185
|
+
local_file = download_single_object(s3_client, bucket_name, sub_key, output_path)
|
|
186
|
+
downloaded_files.append(local_file.as_posix())
|
|
187
|
+
progress.advance(task)
|
|
188
|
+
|
|
189
|
+
user_logger.info(f"Downloaded folder/prefix '{prefix}' with {len(downloaded_files)} object(s).")
|
|
184
190
|
else:
|
|
185
191
|
user_logger.error(f"Error checking object or prefix: {e}")
|
|
186
192
|
raise RuntimeError(f"Failed to check or download S3 resource: {e}") from e
|