hafnia 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/__main__.py +13 -2
- cli/config.py +2 -1
- cli/consts.py +1 -1
- cli/dataset_cmds.py +6 -14
- cli/dataset_recipe_cmds.py +78 -0
- cli/experiment_cmds.py +226 -43
- cli/profile_cmds.py +6 -5
- cli/runc_cmds.py +5 -5
- cli/trainer_package_cmds.py +65 -0
- hafnia/__init__.py +2 -0
- hafnia/data/factory.py +1 -2
- hafnia/dataset/dataset_helpers.py +0 -12
- hafnia/dataset/dataset_names.py +8 -4
- hafnia/dataset/dataset_recipe/dataset_recipe.py +119 -33
- hafnia/dataset/dataset_recipe/recipe_transforms.py +32 -4
- hafnia/dataset/dataset_recipe/recipe_types.py +1 -1
- hafnia/dataset/dataset_upload_helper.py +206 -53
- hafnia/dataset/hafnia_dataset.py +432 -194
- hafnia/dataset/license_types.py +63 -0
- hafnia/dataset/operations/dataset_stats.py +260 -3
- hafnia/dataset/operations/dataset_transformations.py +325 -4
- hafnia/dataset/operations/table_transformations.py +39 -2
- hafnia/dataset/primitives/__init__.py +8 -0
- hafnia/dataset/primitives/classification.py +1 -1
- hafnia/experiment/hafnia_logger.py +112 -0
- hafnia/http.py +16 -2
- hafnia/platform/__init__.py +9 -3
- hafnia/platform/builder.py +12 -10
- hafnia/platform/dataset_recipe.py +99 -0
- hafnia/platform/datasets.py +44 -6
- hafnia/platform/download.py +2 -1
- hafnia/platform/experiment.py +51 -56
- hafnia/platform/trainer_package.py +57 -0
- hafnia/utils.py +64 -13
- hafnia/visualizations/image_visualizations.py +3 -3
- {hafnia-0.2.4.dist-info → hafnia-0.3.0.dist-info}/METADATA +34 -30
- hafnia-0.3.0.dist-info/RECORD +53 -0
- cli/recipe_cmds.py +0 -45
- hafnia-0.2.4.dist-info/RECORD +0 -49
- {hafnia-0.2.4.dist-info → hafnia-0.3.0.dist-info}/WHEEL +0 -0
- {hafnia-0.2.4.dist-info → hafnia-0.3.0.dist-info}/entry_points.txt +0 -0
- {hafnia-0.2.4.dist-info → hafnia-0.3.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -7,6 +7,7 @@ from tqdm import tqdm
|
|
|
7
7
|
from hafnia.dataset.dataset_names import (
|
|
8
8
|
FILENAME_ANNOTATIONS_JSONL,
|
|
9
9
|
FILENAME_ANNOTATIONS_PARQUET,
|
|
10
|
+
ColumnName,
|
|
10
11
|
FieldName,
|
|
11
12
|
)
|
|
12
13
|
from hafnia.dataset.operations import table_transformations
|
|
@@ -34,8 +35,12 @@ def create_primitive_table(
|
|
|
34
35
|
|
|
35
36
|
if keep_sample_data:
|
|
36
37
|
# Drop other primitive columns to avoid conflicts
|
|
37
|
-
|
|
38
|
-
|
|
38
|
+
|
|
39
|
+
drop_columns_primitives = set(PRIMITIVE_TYPES) - {PrimitiveType, Classification}
|
|
40
|
+
drop_columns_names = [primitive.column_name() for primitive in drop_columns_primitives]
|
|
41
|
+
drop_columns_names = [c for c in drop_columns_names if c in remove_no_object_frames.columns]
|
|
42
|
+
|
|
43
|
+
remove_no_object_frames = remove_no_object_frames.drop(drop_columns_names)
|
|
39
44
|
# Rename columns "height", "width" and "meta" for sample to avoid conflicts with object fields names
|
|
40
45
|
remove_no_object_frames = remove_no_object_frames.rename(
|
|
41
46
|
{"height": "image.height", "width": "image.width", "meta": "image.meta"}
|
|
@@ -46,6 +51,38 @@ def create_primitive_table(
|
|
|
46
51
|
return objects_df
|
|
47
52
|
|
|
48
53
|
|
|
54
|
+
def merge_samples(samples0: pl.DataFrame, samples1: pl.DataFrame) -> pl.DataFrame:
|
|
55
|
+
has_same_schema = samples0.schema == samples1.schema
|
|
56
|
+
if not has_same_schema:
|
|
57
|
+
shared_columns = []
|
|
58
|
+
for column_name, column_type in samples0.schema.items():
|
|
59
|
+
if column_name not in samples1.schema:
|
|
60
|
+
continue
|
|
61
|
+
|
|
62
|
+
if column_type != samples1.schema[column_name]:
|
|
63
|
+
continue
|
|
64
|
+
shared_columns.append(column_name)
|
|
65
|
+
|
|
66
|
+
dropped_columns0 = [
|
|
67
|
+
f"{n}[{ctype._string_repr()}]" for n, ctype in samples0.schema.items() if n not in shared_columns
|
|
68
|
+
]
|
|
69
|
+
dropped_columns1 = [
|
|
70
|
+
f"{n}[{ctype._string_repr()}]" for n, ctype in samples1.schema.items() if n not in shared_columns
|
|
71
|
+
]
|
|
72
|
+
user_logger.warning(
|
|
73
|
+
"Datasets with different schemas are being merged. "
|
|
74
|
+
"Only the columns with the same name and type will be kept in the merged dataset.\n"
|
|
75
|
+
f"Dropped columns in samples0: {dropped_columns0}\n"
|
|
76
|
+
f"Dropped columns in samples1: {dropped_columns1}\n"
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
samples0 = samples0.select(list(shared_columns))
|
|
80
|
+
samples1 = samples1.select(list(shared_columns))
|
|
81
|
+
merged_samples = pl.concat([samples0, samples1], how="vertical")
|
|
82
|
+
merged_samples = merged_samples.drop(ColumnName.SAMPLE_INDEX).with_row_index(name=ColumnName.SAMPLE_INDEX)
|
|
83
|
+
return merged_samples
|
|
84
|
+
|
|
85
|
+
|
|
49
86
|
def filter_table_for_class_names(
|
|
50
87
|
samples_table: pl.DataFrame, class_names: List[str], PrimitiveType: Type[Primitive]
|
|
51
88
|
) -> Optional[pl.DataFrame]:
|
|
@@ -14,3 +14,11 @@ from .utils import class_color_by_name # noqa: F401
|
|
|
14
14
|
PRIMITIVE_TYPES: List[Type[Primitive]] = [Bbox, Classification, Polygon, Bitmask]
|
|
15
15
|
PRIMITIVE_NAME_TO_TYPE = {cls.__name__: cls for cls in PRIMITIVE_TYPES}
|
|
16
16
|
PRIMITIVE_COLUMN_NAMES: List[str] = [PrimitiveType.column_name() for PrimitiveType in PRIMITIVE_TYPES]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_primitive_type_from_string(name: str) -> Type[Primitive]:
|
|
20
|
+
if name not in PRIMITIVE_NAME_TO_TYPE:
|
|
21
|
+
raise ValueError(
|
|
22
|
+
f"Primitive '{name}' is not recognized. Available primitives: {list(PRIMITIVE_NAME_TO_TYPE.keys())}"
|
|
23
|
+
)
|
|
24
|
+
return PRIMITIVE_NAME_TO_TYPE[name]
|
|
@@ -38,7 +38,7 @@ class Classification(Primitive):
|
|
|
38
38
|
text = class_name
|
|
39
39
|
else:
|
|
40
40
|
text = f"{self.task_name}: {class_name}"
|
|
41
|
-
image = image_visualizations.append_text_below_frame(image, text=text)
|
|
41
|
+
image = image_visualizations.append_text_below_frame(image, text=text, text_size_ratio=0.05)
|
|
42
42
|
|
|
43
43
|
return image
|
|
44
44
|
|
|
@@ -2,6 +2,7 @@ import json
|
|
|
2
2
|
import os
|
|
3
3
|
import platform
|
|
4
4
|
import sys
|
|
5
|
+
import textwrap
|
|
5
6
|
from datetime import datetime
|
|
6
7
|
from enum import Enum
|
|
7
8
|
from pathlib import Path
|
|
@@ -16,6 +17,16 @@ from hafnia.dataset.hafnia_dataset import HafniaDataset
|
|
|
16
17
|
from hafnia.log import sys_logger, user_logger
|
|
17
18
|
from hafnia.utils import is_hafnia_cloud_job, now_as_str
|
|
18
19
|
|
|
20
|
+
try:
|
|
21
|
+
import mlflow
|
|
22
|
+
import mlflow.tracking
|
|
23
|
+
import sagemaker_mlflow # noqa: F401
|
|
24
|
+
|
|
25
|
+
MLFLOW_AVAILABLE = True
|
|
26
|
+
except ImportError:
|
|
27
|
+
user_logger.warning("MLFlow is not available")
|
|
28
|
+
MLFLOW_AVAILABLE = False
|
|
29
|
+
|
|
19
30
|
|
|
20
31
|
class EntityType(Enum):
|
|
21
32
|
"""Types of entities that can be logged."""
|
|
@@ -87,11 +98,44 @@ class HafniaLogger:
|
|
|
87
98
|
for path in create_paths:
|
|
88
99
|
path.mkdir(parents=True, exist_ok=True)
|
|
89
100
|
|
|
101
|
+
path_file = self.path_model() / "HOW_TO_STORE_YOUR_MODEL.txt"
|
|
102
|
+
path_file.write_text(get_instructions_how_to_store_model())
|
|
103
|
+
|
|
90
104
|
self.dataset_name: Optional[str] = None
|
|
91
105
|
self.log_file = self._path_artifacts() / self.EXPERIMENT_FILE
|
|
92
106
|
self.schema = Entity.create_schema()
|
|
107
|
+
|
|
108
|
+
# Initialize MLflow for remote jobs
|
|
109
|
+
self._mlflow_initialized = False
|
|
110
|
+
if is_hafnia_cloud_job() and MLFLOW_AVAILABLE:
|
|
111
|
+
self._init_mlflow()
|
|
112
|
+
|
|
93
113
|
self.log_environment()
|
|
94
114
|
|
|
115
|
+
def _init_mlflow(self):
|
|
116
|
+
"""Initialize MLflow tracking for remote jobs."""
|
|
117
|
+
try:
|
|
118
|
+
# Set MLflow tracking URI from environment variable
|
|
119
|
+
tracking_uri = os.getenv("MLFLOW_TRACKING_URI")
|
|
120
|
+
if tracking_uri:
|
|
121
|
+
mlflow.set_tracking_uri(tracking_uri)
|
|
122
|
+
user_logger.info(f"MLflow tracking URI set to: {tracking_uri}")
|
|
123
|
+
|
|
124
|
+
# Set experiment name from environment variable
|
|
125
|
+
experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME")
|
|
126
|
+
if experiment_name:
|
|
127
|
+
mlflow.set_experiment(experiment_name)
|
|
128
|
+
user_logger.info(f"MLflow experiment set to: {experiment_name}")
|
|
129
|
+
|
|
130
|
+
# Start MLflow run
|
|
131
|
+
run_name = os.getenv("MLFLOW_RUN_NAME", "undefined")
|
|
132
|
+
mlflow.start_run(run_name=run_name)
|
|
133
|
+
self._mlflow_initialized = True
|
|
134
|
+
user_logger.info("MLflow run started successfully")
|
|
135
|
+
|
|
136
|
+
except Exception as e:
|
|
137
|
+
user_logger.error(f"Failed to initialize MLflow: {e}")
|
|
138
|
+
|
|
95
139
|
def load_dataset(self, dataset_name: str) -> HafniaDataset:
|
|
96
140
|
"""
|
|
97
141
|
Load a dataset from the specified path.
|
|
@@ -153,6 +197,14 @@ class HafniaLogger:
|
|
|
153
197
|
)
|
|
154
198
|
self.write_entity(entity)
|
|
155
199
|
|
|
200
|
+
# Also log to MLflow if initialized
|
|
201
|
+
if not self._mlflow_initialized:
|
|
202
|
+
return
|
|
203
|
+
try:
|
|
204
|
+
mlflow.log_metric(name, value, step=step)
|
|
205
|
+
except Exception as e:
|
|
206
|
+
user_logger.error(f"Failed to log metric to MLflow: {e}")
|
|
207
|
+
|
|
156
208
|
def log_configuration(self, configurations: Dict):
|
|
157
209
|
self.log_hparams(configurations, "configuration.json")
|
|
158
210
|
|
|
@@ -166,6 +218,15 @@ class HafniaLogger:
|
|
|
166
218
|
existing_params.update(params)
|
|
167
219
|
file_path.write_text(json.dumps(existing_params, indent=2))
|
|
168
220
|
user_logger.info(f"Saved parameters to {file_path}")
|
|
221
|
+
|
|
222
|
+
# Also log to MLflow if initialized
|
|
223
|
+
if not self._mlflow_initialized:
|
|
224
|
+
return
|
|
225
|
+
try:
|
|
226
|
+
mlflow.log_params(params)
|
|
227
|
+
except Exception as e:
|
|
228
|
+
user_logger.error(f"Failed to log params to MLflow: {e}")
|
|
229
|
+
|
|
169
230
|
except Exception as e:
|
|
170
231
|
user_logger.error(f"Failed to save parameters to {file_path}: {e}")
|
|
171
232
|
|
|
@@ -202,3 +263,54 @@ class HafniaLogger:
|
|
|
202
263
|
pq.write_table(next_table, self.log_file)
|
|
203
264
|
except Exception as e:
|
|
204
265
|
sys_logger.error(f"Failed to flush logs: {e}")
|
|
266
|
+
|
|
267
|
+
def end_run(self) -> None:
|
|
268
|
+
"""End the MLflow run if initialized."""
|
|
269
|
+
if not self._mlflow_initialized:
|
|
270
|
+
return
|
|
271
|
+
try:
|
|
272
|
+
mlflow.end_run()
|
|
273
|
+
self._mlflow_initialized = False
|
|
274
|
+
user_logger.info("MLflow run ended successfully")
|
|
275
|
+
except Exception as e:
|
|
276
|
+
user_logger.error(f"Failed to end MLflow run: {e}")
|
|
277
|
+
|
|
278
|
+
def __del__(self):
|
|
279
|
+
"""Cleanup when logger is destroyed."""
|
|
280
|
+
self.end_run()
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def get_instructions_how_to_store_model() -> str:
|
|
284
|
+
instructions = textwrap.dedent(
|
|
285
|
+
"""\
|
|
286
|
+
If you, against your expectations, don't see any models in this folder,
|
|
287
|
+
we have provided a small guide to help.
|
|
288
|
+
|
|
289
|
+
The hafnia TaaS framework expects models to be stored in a folder generated
|
|
290
|
+
by the hafnia logger. You will need to store models in this folder
|
|
291
|
+
to ensure that they are properly stored and accessible after training.
|
|
292
|
+
|
|
293
|
+
Please check your recipe script and ensure that the models are being stored
|
|
294
|
+
as expected by the TaaS framework.
|
|
295
|
+
|
|
296
|
+
Below is also a small example to demonstrate:
|
|
297
|
+
|
|
298
|
+
```python
|
|
299
|
+
from hafnia.experiment import HafniaLogger
|
|
300
|
+
|
|
301
|
+
# Initiate Hafnia logger
|
|
302
|
+
logger = HafniaLogger()
|
|
303
|
+
|
|
304
|
+
# Folder path to store models - generated by the hafnia logger.
|
|
305
|
+
model_dir = logger.path_model()
|
|
306
|
+
|
|
307
|
+
# Example for storing a pytorch based model. Note: the model is stored in 'model_dir'
|
|
308
|
+
path_pytorch_model = model_dir / "model.pth"
|
|
309
|
+
|
|
310
|
+
# Finally save the model to the specified path
|
|
311
|
+
torch.save(model.state_dict(), path_pytorch_model)
|
|
312
|
+
```
|
|
313
|
+
"""
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
return instructions
|
hafnia/http.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
import json
|
|
2
2
|
from pathlib import Path
|
|
3
|
-
from typing import Dict, Optional, Union
|
|
3
|
+
from typing import Dict, List, Optional, Union
|
|
4
4
|
|
|
5
5
|
import urllib3
|
|
6
6
|
|
|
7
7
|
|
|
8
|
-
def fetch(endpoint: str, headers: Dict, params: Optional[Dict] = None) -> Dict:
|
|
8
|
+
def fetch(endpoint: str, headers: Dict, params: Optional[Dict] = None) -> Union[Dict, List]:
|
|
9
9
|
"""Fetches data from the API endpoint.
|
|
10
10
|
|
|
11
11
|
Args:
|
|
@@ -81,3 +81,17 @@ def post(endpoint: str, headers: Dict, data: Union[Path, Dict, bytes, str], mult
|
|
|
81
81
|
return json.loads(response.data.decode("utf-8"))
|
|
82
82
|
finally:
|
|
83
83
|
http.clear()
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def delete(endpoint: str, headers: Dict) -> Dict:
|
|
87
|
+
"""Sends a DELETE request to the specified endpoint."""
|
|
88
|
+
http = urllib3.PoolManager(retries=urllib3.Retry(3))
|
|
89
|
+
try:
|
|
90
|
+
response = http.request("DELETE", endpoint, headers=headers)
|
|
91
|
+
|
|
92
|
+
if response.status not in (200, 204):
|
|
93
|
+
error_details = response.data.decode("utf-8")
|
|
94
|
+
raise urllib3.exceptions.HTTPError(f"Request failed with status {response.status}: {error_details}")
|
|
95
|
+
return json.loads(response.data.decode("utf-8")) if response.data else {}
|
|
96
|
+
finally:
|
|
97
|
+
http.clear()
|
hafnia/platform/__init__.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from hafnia.platform.datasets import get_dataset_id
|
|
1
2
|
from hafnia.platform.download import (
|
|
2
3
|
download_resource,
|
|
3
4
|
download_single_object,
|
|
@@ -5,17 +6,22 @@ from hafnia.platform.download import (
|
|
|
5
6
|
)
|
|
6
7
|
from hafnia.platform.experiment import (
|
|
7
8
|
create_experiment,
|
|
8
|
-
|
|
9
|
-
get_dataset_id,
|
|
9
|
+
get_environments,
|
|
10
10
|
get_exp_environment_id,
|
|
11
|
+
pretty_print_training_environments,
|
|
11
12
|
)
|
|
13
|
+
from hafnia.platform.trainer_package import create_trainer_package, get_trainer_package_by_id, get_trainer_packages
|
|
12
14
|
|
|
13
15
|
__all__ = [
|
|
14
16
|
"get_dataset_id",
|
|
15
|
-
"
|
|
17
|
+
"create_trainer_package",
|
|
18
|
+
"get_trainer_packages",
|
|
19
|
+
"get_trainer_package_by_id",
|
|
16
20
|
"get_exp_environment_id",
|
|
17
21
|
"create_experiment",
|
|
18
22
|
"download_resource",
|
|
19
23
|
"download_single_object",
|
|
20
24
|
"get_resource_credentials",
|
|
25
|
+
"pretty_print_training_environments",
|
|
26
|
+
"get_environments",
|
|
21
27
|
]
|
hafnia/platform/builder.py
CHANGED
|
@@ -14,26 +14,28 @@ from hafnia.log import sys_logger, user_logger
|
|
|
14
14
|
from hafnia.platform import download_resource
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
def
|
|
18
|
-
"""Validate Hafnia
|
|
17
|
+
def validate_trainer_package_format(path: Path) -> None:
|
|
18
|
+
"""Validate Hafnia Trainer Package Format submission"""
|
|
19
19
|
hrf = zipfile.Path(path) if path.suffix == ".zip" else path
|
|
20
20
|
required = {"src", "scripts", "Dockerfile"}
|
|
21
21
|
errors = 0
|
|
22
22
|
for rp in required:
|
|
23
23
|
if not (hrf / rp).exists():
|
|
24
|
-
user_logger.error(f"Required path {rp} not found in
|
|
24
|
+
user_logger.error(f"Required path {rp} not found in trainer package.")
|
|
25
25
|
errors += 1
|
|
26
26
|
if errors > 0:
|
|
27
|
-
raise FileNotFoundError("Wrong
|
|
27
|
+
raise FileNotFoundError("Wrong trainer package structure")
|
|
28
28
|
|
|
29
29
|
|
|
30
|
-
def
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
30
|
+
def prepare_trainer_package(
|
|
31
|
+
trainer_url: str, output_dir: Path, api_key: str, state_file: Optional[Path] = None
|
|
32
|
+
) -> Dict:
|
|
33
|
+
resource = download_resource(trainer_url, output_dir.as_posix(), api_key)
|
|
34
|
+
trainer_path = Path(resource["downloaded_files"][0])
|
|
35
|
+
with zipfile.ZipFile(trainer_path, "r") as zip_ref:
|
|
34
36
|
zip_ref.extractall(output_dir)
|
|
35
37
|
|
|
36
|
-
|
|
38
|
+
validate_trainer_package_format(output_dir)
|
|
37
39
|
|
|
38
40
|
scripts_dir = output_dir / "scripts"
|
|
39
41
|
if not any(scripts_dir.iterdir()):
|
|
@@ -42,7 +44,7 @@ def prepare_recipe(recipe_url: str, output_dir: Path, api_key: str, state_file:
|
|
|
42
44
|
metadata = {
|
|
43
45
|
"user_data": (output_dir / "src").as_posix(),
|
|
44
46
|
"dockerfile": (output_dir / "Dockerfile").as_posix(),
|
|
45
|
-
"digest": sha256(
|
|
47
|
+
"digest": sha256(trainer_path.read_bytes()).hexdigest()[:8],
|
|
46
48
|
}
|
|
47
49
|
state_file = state_file if state_file else output_dir / "state.json"
|
|
48
50
|
with open(state_file, "w", encoding="utf-8") as f:
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Dict, List, Optional
|
|
4
|
+
|
|
5
|
+
from flatten_dict import flatten
|
|
6
|
+
|
|
7
|
+
from hafnia import http
|
|
8
|
+
from hafnia.log import user_logger
|
|
9
|
+
from hafnia.utils import pretty_print_list_as_table, timed
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@timed("Get or create dataset recipe")
|
|
13
|
+
def get_or_create_dataset_recipe(
|
|
14
|
+
recipe: dict, endpoint: str, api_key: str, name: Optional[str] = None
|
|
15
|
+
) -> Optional[Dict]:
|
|
16
|
+
headers = {"Authorization": api_key}
|
|
17
|
+
data = {"template": {"body": recipe}}
|
|
18
|
+
if name is not None:
|
|
19
|
+
data["name"] = name # type: ignore[assignment]
|
|
20
|
+
response = http.post(endpoint, headers=headers, data=data)
|
|
21
|
+
return response
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def get_or_create_dataset_recipe_by_dataset_name(dataset_name: str, endpoint: str, api_key: str) -> Dict:
|
|
25
|
+
return get_or_create_dataset_recipe(recipe=dataset_name, endpoint=endpoint, api_key=api_key)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_dataset_recipes(endpoint: str, api_key: str) -> List[Dict]:
|
|
29
|
+
headers = {"Authorization": api_key}
|
|
30
|
+
dataset_recipes: List[Dict] = http.fetch(endpoint, headers=headers) # type: ignore[assignment]
|
|
31
|
+
return dataset_recipes
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def get_dataset_recipe_by_id(dataset_recipe_id: str, endpoint: str, api_key: str) -> Dict:
|
|
35
|
+
headers = {"Authorization": api_key}
|
|
36
|
+
full_url = f"{endpoint}/{dataset_recipe_id}"
|
|
37
|
+
dataset_recipe_info: Dict = http.fetch(full_url, headers=headers) # type: ignore[assignment]
|
|
38
|
+
if not dataset_recipe_info:
|
|
39
|
+
raise ValueError(f"Dataset recipe with ID '{dataset_recipe_id}' was not found.")
|
|
40
|
+
return dataset_recipe_info
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def get_or_create_dataset_recipe_from_path(
|
|
44
|
+
path_recipe_json: Path, endpoint: str, api_key: str, name: Optional[str] = None
|
|
45
|
+
) -> Dict:
|
|
46
|
+
path_recipe_json = Path(path_recipe_json)
|
|
47
|
+
if not path_recipe_json.exists():
|
|
48
|
+
raise FileNotFoundError(f"Dataset recipe file '{path_recipe_json}' does not exist.")
|
|
49
|
+
json_dict = json.loads(path_recipe_json.read_text())
|
|
50
|
+
return get_or_create_dataset_recipe(json_dict, endpoint=endpoint, api_key=api_key, name=name)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def delete_dataset_recipe_by_id(id: str, endpoint: str, api_key: str) -> Dict:
|
|
54
|
+
headers = {"Authorization": api_key}
|
|
55
|
+
full_url = f"{endpoint}/{id}"
|
|
56
|
+
response = http.delete(endpoint=full_url, headers=headers)
|
|
57
|
+
return response
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@timed("Get dataset recipe")
|
|
61
|
+
def get_dataset_recipe_by_name(name: str, endpoint: str, api_key: str) -> Optional[Dict]:
|
|
62
|
+
headers = {"Authorization": api_key}
|
|
63
|
+
full_url = f"{endpoint}?name__iexact={name}"
|
|
64
|
+
dataset_recipes: List[Dict] = http.fetch(full_url, headers=headers) # type: ignore[assignment]
|
|
65
|
+
if len(dataset_recipes) == 0:
|
|
66
|
+
return None
|
|
67
|
+
|
|
68
|
+
if len(dataset_recipes) > 1:
|
|
69
|
+
user_logger.warning(f"Found {len(dataset_recipes)} dataset recipes called '{name}'. Using the first one.")
|
|
70
|
+
|
|
71
|
+
dataset_recipe = dataset_recipes[0]
|
|
72
|
+
return dataset_recipe
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def delete_dataset_recipe_by_name(name: str, endpoint: str, api_key: str) -> Optional[Dict]:
|
|
76
|
+
recipe_response = get_dataset_recipe_by_name(name, endpoint=endpoint, api_key=api_key)
|
|
77
|
+
|
|
78
|
+
if recipe_response:
|
|
79
|
+
return delete_dataset_recipe_by_id(recipe_response["id"], endpoint=endpoint, api_key=api_key)
|
|
80
|
+
return recipe_response
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def pretty_print_dataset_recipes(recipes: List[Dict]) -> None:
|
|
84
|
+
recipes = [flatten(recipe, reducer="dot", max_flatten_depth=2) for recipe in recipes] # noqa: F821
|
|
85
|
+
for recipe in recipes:
|
|
86
|
+
recipe["recipe_json"] = json.dumps(recipe["template.body"])[:20]
|
|
87
|
+
|
|
88
|
+
RECIPE_FIELDS = {
|
|
89
|
+
"ID": "id",
|
|
90
|
+
"Name": "name",
|
|
91
|
+
"Recipe": "recipe_json",
|
|
92
|
+
"Created": "created_at",
|
|
93
|
+
"IsDataset": "template.is_direct_dataset_reference",
|
|
94
|
+
}
|
|
95
|
+
pretty_print_list_as_table(
|
|
96
|
+
table_title="Available Dataset Recipes",
|
|
97
|
+
dict_items=recipes,
|
|
98
|
+
column_name_to_key_mapping=RECIPE_FIELDS,
|
|
99
|
+
)
|
hafnia/platform/datasets.py
CHANGED
|
@@ -8,10 +8,11 @@ from pathlib import Path
|
|
|
8
8
|
from typing import Any, Dict, List, Optional
|
|
9
9
|
|
|
10
10
|
import rich
|
|
11
|
+
from rich import print as rprint
|
|
11
12
|
from tqdm import tqdm
|
|
12
13
|
|
|
13
14
|
from cli.config import Config
|
|
14
|
-
from hafnia import utils
|
|
15
|
+
from hafnia import http, utils
|
|
15
16
|
from hafnia.dataset.dataset_names import DATASET_FILENAMES_REQUIRED, ColumnName
|
|
16
17
|
from hafnia.dataset.dataset_recipe.dataset_recipe import (
|
|
17
18
|
DatasetRecipe,
|
|
@@ -20,13 +21,12 @@ from hafnia.dataset.dataset_recipe.dataset_recipe import (
|
|
|
20
21
|
from hafnia.dataset.hafnia_dataset import HafniaDataset
|
|
21
22
|
from hafnia.http import fetch
|
|
22
23
|
from hafnia.log import sys_logger, user_logger
|
|
23
|
-
from hafnia.platform import get_dataset_id
|
|
24
24
|
from hafnia.platform.download import get_resource_credentials
|
|
25
25
|
from hafnia.utils import timed
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
@timed("Fetching dataset list.")
|
|
29
|
-
def
|
|
29
|
+
def get_datasets(cfg: Optional[Config] = None) -> List[Dict[str, str]]:
|
|
30
30
|
"""List available datasets on the Hafnia platform."""
|
|
31
31
|
cfg = cfg or Config()
|
|
32
32
|
endpoint_dataset = cfg.get_platform_endpoint("datasets")
|
|
@@ -38,6 +38,19 @@ def dataset_list(cfg: Optional[Config] = None) -> List[Dict[str, str]]:
|
|
|
38
38
|
return datasets
|
|
39
39
|
|
|
40
40
|
|
|
41
|
+
@timed("Fetching dataset info.")
|
|
42
|
+
def get_dataset_id(dataset_name: str, endpoint: str, api_key: str) -> str:
|
|
43
|
+
headers = {"Authorization": api_key}
|
|
44
|
+
full_url = f"{endpoint}?name__iexact={dataset_name}"
|
|
45
|
+
dataset_responses: List[Dict] = http.fetch(full_url, headers=headers) # type: ignore[assignment]
|
|
46
|
+
if not dataset_responses:
|
|
47
|
+
raise ValueError(f"Dataset '{dataset_name}' was not found in the dataset library.")
|
|
48
|
+
try:
|
|
49
|
+
return dataset_responses[0]["id"]
|
|
50
|
+
except (IndexError, KeyError) as e:
|
|
51
|
+
raise ValueError("Dataset information is missing or invalid") from e
|
|
52
|
+
|
|
53
|
+
|
|
41
54
|
def download_or_get_dataset_path(
|
|
42
55
|
dataset_name: str,
|
|
43
56
|
cfg: Optional[Config] = None,
|
|
@@ -131,6 +144,28 @@ def fast_copy_files_s3(
|
|
|
131
144
|
return lines
|
|
132
145
|
|
|
133
146
|
|
|
147
|
+
def find_s5cmd() -> Optional[str]:
|
|
148
|
+
"""Locate the s5cmd executable across different installation methods.
|
|
149
|
+
|
|
150
|
+
Searches for s5cmd in:
|
|
151
|
+
1. System PATH (via shutil.which)
|
|
152
|
+
2. Python bin directory (Unix-like systems)
|
|
153
|
+
3. Python executable directory (direct installs)
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
str: Absolute path to s5cmd executable if found, None otherwise.
|
|
157
|
+
"""
|
|
158
|
+
result = shutil.which("s5cmd")
|
|
159
|
+
if result:
|
|
160
|
+
return result
|
|
161
|
+
python_dir = Path(sys.executable).parent
|
|
162
|
+
locations = (python_dir / "Scripts" / "s5cmd.exe", python_dir / "bin" / "s5cmd", python_dir / "s5cmd")
|
|
163
|
+
for loc in locations:
|
|
164
|
+
if loc.exists():
|
|
165
|
+
return str(loc)
|
|
166
|
+
return None
|
|
167
|
+
|
|
168
|
+
|
|
134
169
|
def execute_s5cmd_commands(
|
|
135
170
|
commands: List[str],
|
|
136
171
|
append_envs: Optional[Dict[str, str]] = None,
|
|
@@ -142,7 +177,10 @@ def execute_s5cmd_commands(
|
|
|
142
177
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
143
178
|
tmp_file_path = Path(temp_dir, f"{uuid.uuid4().hex}.txt")
|
|
144
179
|
tmp_file_path.write_text("\n".join(commands))
|
|
145
|
-
|
|
180
|
+
|
|
181
|
+
s5cmd_bin = find_s5cmd()
|
|
182
|
+
if s5cmd_bin is None:
|
|
183
|
+
raise ValueError("Can not find s5cmd executable.")
|
|
146
184
|
run_cmds = [s5cmd_bin, "run", str(tmp_file_path)]
|
|
147
185
|
sys_logger.debug(run_cmds)
|
|
148
186
|
envs = os.environ.copy()
|
|
@@ -185,7 +223,7 @@ TABLE_FIELDS = {
|
|
|
185
223
|
}
|
|
186
224
|
|
|
187
225
|
|
|
188
|
-
def
|
|
226
|
+
def pretty_print_datasets(datasets: List[Dict[str, str]]) -> None:
|
|
189
227
|
datasets = extend_dataset_details(datasets)
|
|
190
228
|
datasets = sorted(datasets, key=lambda x: x["name"].lower())
|
|
191
229
|
|
|
@@ -197,7 +235,7 @@ def create_rich_table_from_dataset(datasets: List[Dict[str, str]]) -> rich.table
|
|
|
197
235
|
row = [str(dataset.get(field, "")) for field in TABLE_FIELDS.values()]
|
|
198
236
|
table.add_row(*row)
|
|
199
237
|
|
|
200
|
-
|
|
238
|
+
rprint(table)
|
|
201
239
|
|
|
202
240
|
|
|
203
241
|
def extend_dataset_details(datasets: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
hafnia/platform/download.py
CHANGED
|
@@ -96,7 +96,8 @@ def get_resource_credentials(endpoint: str, api_key: str) -> ResourceCredentials
|
|
|
96
96
|
RuntimeError: If the call to fetch the credentials fails for any reason.
|
|
97
97
|
"""
|
|
98
98
|
try:
|
|
99
|
-
|
|
99
|
+
headers = {"Authorization": api_key, "accept": "application/json"}
|
|
100
|
+
credentials_dict: Dict = fetch(endpoint, headers=headers) # type: ignore[assignment]
|
|
100
101
|
credentials = ResourceCredentials.fix_naming(credentials_dict)
|
|
101
102
|
sys_logger.debug("Successfully retrieved credentials from DIP endpoint.")
|
|
102
103
|
return credentials
|