hafnia 0.1.27__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. cli/__main__.py +2 -2
  2. cli/dataset_cmds.py +60 -0
  3. cli/runc_cmds.py +1 -1
  4. hafnia/data/__init__.py +2 -2
  5. hafnia/data/factory.py +9 -56
  6. hafnia/dataset/dataset_helpers.py +91 -0
  7. hafnia/dataset/dataset_names.py +71 -0
  8. hafnia/dataset/dataset_transformation.py +187 -0
  9. hafnia/dataset/dataset_upload_helper.py +468 -0
  10. hafnia/dataset/hafnia_dataset.py +453 -0
  11. hafnia/dataset/primitives/__init__.py +16 -0
  12. hafnia/dataset/primitives/bbox.py +137 -0
  13. hafnia/dataset/primitives/bitmask.py +182 -0
  14. hafnia/dataset/primitives/classification.py +56 -0
  15. hafnia/dataset/primitives/point.py +25 -0
  16. hafnia/dataset/primitives/polygon.py +100 -0
  17. hafnia/dataset/primitives/primitive.py +44 -0
  18. hafnia/dataset/primitives/segmentation.py +51 -0
  19. hafnia/dataset/primitives/utils.py +51 -0
  20. hafnia/dataset/table_transformations.py +183 -0
  21. hafnia/experiment/hafnia_logger.py +2 -2
  22. hafnia/helper_testing.py +63 -0
  23. hafnia/http.py +5 -3
  24. hafnia/platform/__init__.py +2 -2
  25. hafnia/platform/datasets.py +184 -0
  26. hafnia/platform/download.py +85 -23
  27. hafnia/torch_helpers.py +180 -95
  28. hafnia/utils.py +1 -1
  29. hafnia/visualizations/colors.py +267 -0
  30. hafnia/visualizations/image_visualizations.py +202 -0
  31. {hafnia-0.1.27.dist-info → hafnia-0.2.0.dist-info}/METADATA +212 -99
  32. hafnia-0.2.0.dist-info/RECORD +46 -0
  33. cli/data_cmds.py +0 -53
  34. hafnia-0.1.27.dist-info/RECORD +0 -27
  35. {hafnia-0.1.27.dist-info → hafnia-0.2.0.dist-info}/WHEEL +0 -0
  36. {hafnia-0.1.27.dist-info → hafnia-0.2.0.dist-info}/entry_points.txt +0 -0
  37. {hafnia-0.1.27.dist-info → hafnia-0.2.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,183 @@
1
+ from pathlib import Path
2
+ from typing import List, Optional, Type
3
+
4
+ import polars as pl
5
+ from tqdm import tqdm
6
+
7
+ from hafnia.dataset import table_transformations
8
+ from hafnia.dataset.dataset_names import (
9
+ FILENAME_ANNOTATIONS_JSONL,
10
+ FILENAME_ANNOTATIONS_PARQUET,
11
+ FieldName,
12
+ )
13
+ from hafnia.dataset.primitives import PRIMITIVE_TYPES
14
+ from hafnia.dataset.primitives.classification import Classification
15
+ from hafnia.dataset.primitives.primitive import Primitive
16
+ from hafnia.log import user_logger
17
+
18
+
19
+ def create_primitive_table(
20
+ samples_table: pl.DataFrame, PrimitiveType: Type[Primitive], keep_sample_data: bool = False
21
+ ) -> Optional[pl.DataFrame]:
22
+ """
23
+ Returns a DataFrame with objects of the specified primitive type.
24
+ """
25
+ column_name = PrimitiveType.column_name()
26
+ has_primitive_column = (column_name in samples_table.columns) and (
27
+ samples_table[column_name].dtype == pl.List(pl.Struct)
28
+ )
29
+ if not has_primitive_column:
30
+ return None
31
+
32
+ # Remove frames without objects
33
+ remove_no_object_frames = samples_table.filter(pl.col(column_name).list.len() > 0)
34
+
35
+ if keep_sample_data:
36
+ # Drop other primitive columns to avoid conflicts
37
+ drop_columns = set(PRIMITIVE_TYPES) - {PrimitiveType, Classification}
38
+ remove_no_object_frames = remove_no_object_frames.drop(*[primitive.column_name() for primitive in drop_columns])
39
+ # Rename columns "height", "width" and "meta" for sample to avoid conflicts with object fields names
40
+ remove_no_object_frames = remove_no_object_frames.rename(
41
+ {"height": "image.height", "width": "image.width", "meta": "image.meta"}
42
+ )
43
+ objects_df = remove_no_object_frames.explode(column_name).unnest(column_name)
44
+ else:
45
+ objects_df = remove_no_object_frames.select(pl.col(column_name).explode().struct.unnest())
46
+ return objects_df
47
+
48
+
49
+ def filter_table_for_class_names(
50
+ samples_table: pl.DataFrame, class_names: List[str], PrimitiveType: Type[Primitive]
51
+ ) -> Optional[pl.DataFrame]:
52
+ table_with_selected_class_names = samples_table.filter(
53
+ pl.col(PrimitiveType.column_name())
54
+ .list.eval(pl.element().struct.field(FieldName.CLASS_NAME).is_in(class_names))
55
+ .list.any()
56
+ )
57
+
58
+ return table_with_selected_class_names
59
+
60
+
61
+ def split_primitive_columns_by_task_name(
62
+ samples_table: pl.DataFrame,
63
+ coordinate_types: Optional[List[Type[Primitive]]] = None,
64
+ ) -> pl.DataFrame:
65
+ """
66
+ Convert Primitive columns such as "objects" (Bbox) into a column for each task name.
67
+ For example, if the "objects" column (containing Bbox objects) has tasks "task1" and "task2".
68
+
69
+
70
+ This:
71
+ ─┬────────────┬─
72
+ ┆ objects ┆
73
+ ┆ --- ┆
74
+ ┆ list[struc ┆
75
+ ┆ t[11]] ┆
76
+ ═╪════════════╪═
77
+ becomes this:
78
+ ─┬────────────┬────────────┬─
79
+ ┆ objects. ┆ objects. ┆
80
+ ┆ task1 ┆ task2 ┆
81
+ ┆ --- ┆ --- ┆
82
+ ┆ list[struc ┆ list[struc ┆
83
+ ┆ t[11]] ┆ t[13]] ┆
84
+ ═╪════════════╪════════════╪═
85
+
86
+ """
87
+ coordinate_types = coordinate_types or PRIMITIVE_TYPES
88
+ for PrimitiveType in coordinate_types:
89
+ col_name = PrimitiveType.column_name()
90
+
91
+ if col_name not in samples_table.columns:
92
+ continue
93
+
94
+ if samples_table[col_name].dtype != pl.List(pl.Struct):
95
+ continue
96
+
97
+ task_names = samples_table[col_name].explode().struct.field(FieldName.TASK_NAME).unique().to_list()
98
+ samples_table = samples_table.with_columns(
99
+ [
100
+ pl.col(col_name)
101
+ .list.filter(pl.element().struct.field(FieldName.TASK_NAME).eq(task_name))
102
+ .alias(f"{col_name}.{task_name}")
103
+ for task_name in task_names
104
+ ]
105
+ )
106
+ samples_table = samples_table.drop(col_name)
107
+ return samples_table
108
+
109
+
110
+ def read_table_from_path(path: Path) -> pl.DataFrame:
111
+ path_annotations = path / FILENAME_ANNOTATIONS_PARQUET
112
+ if path_annotations.exists():
113
+ user_logger.info(f"Reading dataset annotations from Parquet file: {path_annotations}")
114
+ return pl.read_parquet(path_annotations)
115
+
116
+ path_annotations_jsonl = path / FILENAME_ANNOTATIONS_JSONL
117
+ if path_annotations_jsonl.exists():
118
+ user_logger.info(f"Reading dataset annotations from JSONL file: {path_annotations_jsonl}")
119
+ return pl.read_ndjson(path_annotations_jsonl)
120
+
121
+ raise FileNotFoundError(
122
+ f"Unable to read annotations. No json file '{path_annotations.name}' or Parquet file '{{path_annotations.name}} in in '{path}'."
123
+ )
124
+
125
+
126
+ def check_image_paths(table: pl.DataFrame) -> bool:
127
+ missing_files = []
128
+ for org_path in tqdm(table["file_name"].to_list(), desc="Check image paths"):
129
+ org_path = Path(org_path)
130
+ if not org_path.exists():
131
+ missing_files.append(org_path)
132
+
133
+ if len(missing_files) > 0:
134
+ user_logger.warning(f"Missing files: {len(missing_files)}. Show first 5:")
135
+ for missing_file in missing_files[:5]:
136
+ user_logger.warning(f" - {missing_file}")
137
+ raise FileNotFoundError(f"Some files are missing in the dataset: {len(missing_files)} files not found.")
138
+
139
+ return True
140
+
141
+
142
+ def unnest_classification_tasks(table: pl.DataFrame, strict: bool = True) -> pl.DataFrame:
143
+ """
144
+ Unnest classification tasks in table.
145
+ Classificiations tasks are all stored in the same column in the HafniaDataset table.
146
+ This function splits them into separate columns for each task name.
147
+
148
+ Type is converted from a list of structs (pl.List[pl.Struct]) to a struct (pl.Struct) column.
149
+
150
+ Converts classification column from this:
151
+ ─┬─────────────────┬─
152
+ ┆ classifications ┆
153
+ ┆ --- ┆
154
+ ┆ list[struct[6]] ┆
155
+ ═╪═════════════════╪═
156
+
157
+ For example, if the classification column has tasks "task1" and "task2",
158
+ ─┬──────────────────┬──────────────────┬─
159
+ ┆ classifications. ┆ classifications. ┆
160
+ ┆ task1 ┆ task2 ┆
161
+ ┆ --- ┆ --- ┆
162
+ ┆ struct[6] ┆ struct[6] ┆
163
+ ═╪══════════════════╪══════════════════╪═
164
+
165
+ """
166
+ coordinate_types = [Classification]
167
+ table_out = table_transformations.split_primitive_columns_by_task_name(table, coordinate_types=coordinate_types)
168
+
169
+ classification_columns = [c for c in table_out.columns if c.startswith(Classification.column_name() + ".")]
170
+ for classification_column in classification_columns:
171
+ has_multiple_items_per_sample = all(table_out[classification_column].list.len() > 1)
172
+ if has_multiple_items_per_sample:
173
+ if strict:
174
+ raise ValueError(
175
+ f"Column {classification_column} has multiple items per sample, but expected only one item."
176
+ )
177
+ else:
178
+ user_logger.warning(
179
+ f"Warning: Unnesting of column '{classification_column}' is skipped because it has multiple items per sample."
180
+ )
181
+
182
+ table_out = table_out.with_columns([pl.col(c).list.first() for c in classification_columns])
183
+ return table_out
@@ -9,10 +9,10 @@ from typing import Dict, Optional, Union
9
9
 
10
10
  import pyarrow as pa
11
11
  import pyarrow.parquet as pq
12
- from datasets import DatasetDict
13
12
  from pydantic import BaseModel, field_validator
14
13
 
15
14
  from hafnia.data.factory import load_dataset
15
+ from hafnia.dataset.hafnia_dataset import HafniaDataset
16
16
  from hafnia.log import sys_logger, user_logger
17
17
  from hafnia.utils import is_remote_job, now_as_str
18
18
 
@@ -92,7 +92,7 @@ class HafniaLogger:
92
92
  self.schema = Entity.create_schema()
93
93
  self.log_environment()
94
94
 
95
- def load_dataset(self, dataset_name: str) -> DatasetDict:
95
+ def load_dataset(self, dataset_name: str) -> HafniaDataset:
96
96
  """
97
97
  Load a dataset from the specified path.
98
98
  """
@@ -0,0 +1,63 @@
1
+ from pathlib import Path
2
+
3
+ from hafnia import utils
4
+ from hafnia.dataset.dataset_names import FILENAME_ANNOTATIONS_JSONL, DatasetVariant
5
+ from hafnia.dataset.hafnia_dataset import HafniaDataset, Sample
6
+
7
+ MICRO_DATASETS = {
8
+ "tiny-dataset": utils.PATH_DATASETS / "tiny-dataset",
9
+ "coco-2017": utils.PATH_DATASETS / "coco-2017",
10
+ }
11
+
12
+
13
+ def get_path_workspace() -> Path:
14
+ return Path(__file__).parents[2]
15
+
16
+
17
+ def get_path_expected_images() -> Path:
18
+ return get_path_workspace() / "tests" / "data" / "expected_images"
19
+
20
+
21
+ def get_path_test_data() -> Path:
22
+ return get_path_workspace() / "tests" / "data"
23
+
24
+
25
+ def get_path_micro_hafnia_dataset_no_check() -> Path:
26
+ return get_path_test_data() / "micro_test_datasets"
27
+
28
+
29
+ def get_path_micro_hafnia_dataset(dataset_name: str, force_update=False) -> Path:
30
+ import pytest
31
+
32
+ if dataset_name not in MICRO_DATASETS:
33
+ raise ValueError(f"Dataset name '{dataset_name}' is not recognized. Available options: {list(MICRO_DATASETS)}")
34
+ path_dataset = MICRO_DATASETS[dataset_name]
35
+
36
+ path_test_dataset = get_path_micro_hafnia_dataset_no_check() / dataset_name
37
+ path_test_dataset_annotations = path_test_dataset / FILENAME_ANNOTATIONS_JSONL
38
+ if path_test_dataset_annotations.exists() and not force_update:
39
+ return path_test_dataset
40
+
41
+ hafnia_dataset = HafniaDataset.read_from_path(path_dataset / DatasetVariant.SAMPLE.value)
42
+ hafnia_dataset = hafnia_dataset.sample(n_samples=3, seed=42)
43
+ hafnia_dataset.write(path_test_dataset)
44
+
45
+ if force_update:
46
+ pytest.fail(
47
+ "Sample image and metadata have been updated using 'force_update=True'. Set 'force_update=False' and rerun the test."
48
+ )
49
+ pytest.fail("Missing test sample image. Please rerun the test.")
50
+ return path_test_dataset
51
+
52
+
53
+ def get_sample_micro_hafnia_dataset(dataset_name: str, force_update=False) -> Sample:
54
+ micro_dataset = get_micro_hafnia_dataset(dataset_name=dataset_name, force_update=force_update)
55
+ sample_dict = micro_dataset[0]
56
+ sample = Sample(**sample_dict)
57
+ return sample
58
+
59
+
60
+ def get_micro_hafnia_dataset(dataset_name: str, force_update: bool = False) -> HafniaDataset:
61
+ path_dataset = get_path_micro_hafnia_dataset(dataset_name=dataset_name, force_update=force_update)
62
+ hafnia_dataset = HafniaDataset.read_from_path(path_dataset)
63
+ return hafnia_dataset
hafnia/http.py CHANGED
@@ -31,7 +31,7 @@ def fetch(endpoint: str, headers: Dict, params: Optional[Dict] = None) -> Dict:
31
31
  http.clear()
32
32
 
33
33
 
34
- def post(endpoint: str, headers: Dict, data: Union[Path, Dict, bytes], multipart: bool = False) -> Dict:
34
+ def post(endpoint: str, headers: Dict, data: Union[Path, Dict, bytes, str], multipart: bool = False) -> Dict:
35
35
  """Posts data to backend endpoint.
36
36
 
37
37
  Args:
@@ -64,9 +64,11 @@ def post(endpoint: str, headers: Dict, data: Union[Path, Dict, bytes], multipart
64
64
  with open(data, "rb") as f:
65
65
  body = f.read()
66
66
  response = http.request("POST", endpoint, body=body, headers=headers)
67
- elif isinstance(data, dict):
67
+ elif isinstance(data, (str, dict)):
68
+ if isinstance(data, dict):
69
+ data = json.dumps(data)
68
70
  headers["Content-Type"] = "application/json"
69
- response = http.request("POST", endpoint, body=json.dumps(data), headers=headers)
71
+ response = http.request("POST", endpoint, body=data, headers=headers)
70
72
  elif isinstance(data, bytes):
71
73
  response = http.request("POST", endpoint, body=data, headers=headers)
72
74
  else:
@@ -1,7 +1,7 @@
1
1
  from hafnia.platform.download import (
2
2
  download_resource,
3
3
  download_single_object,
4
- get_resource_creds,
4
+ get_resource_credentials,
5
5
  )
6
6
  from hafnia.platform.experiment import (
7
7
  create_experiment,
@@ -17,5 +17,5 @@ __all__ = [
17
17
  "create_experiment",
18
18
  "download_resource",
19
19
  "download_single_object",
20
- "get_resource_creds",
20
+ "get_resource_credentials",
21
21
  ]
@@ -0,0 +1,184 @@
1
+ import os
2
+ import shutil
3
+ import subprocess
4
+ import tempfile
5
+ from pathlib import Path
6
+ from typing import Any, Dict, List, Optional
7
+
8
+ import rich
9
+ from tqdm import tqdm
10
+
11
+ from cli.config import Config
12
+ from hafnia import utils
13
+ from hafnia.dataset import dataset_names
14
+ from hafnia.dataset.hafnia_dataset import HafniaDataset
15
+ from hafnia.http import fetch
16
+ from hafnia.log import user_logger
17
+ from hafnia.platform import get_dataset_id
18
+ from hafnia.platform.download import get_resource_credentials
19
+ from hafnia.utils import timed
20
+
21
+
22
+ @timed("Fetching dataset list.")
23
+ def dataset_list(cfg: Optional[Config] = None) -> List[Dict[str, str]]:
24
+ """List available datasets on the Hafnia platform."""
25
+ cfg = cfg or Config()
26
+ endpoint_dataset = cfg.get_platform_endpoint("datasets")
27
+ header = {"Authorization": cfg.api_key}
28
+ datasets: List[Dict[str, str]] = fetch(endpoint_dataset, headers=header) # type: ignore
29
+ if not datasets:
30
+ raise ValueError("No datasets found on the Hafnia platform.")
31
+
32
+ return datasets
33
+
34
+
35
+ def download_or_get_dataset_path(
36
+ dataset_name: str,
37
+ cfg: Optional[Config] = None,
38
+ path_datasets_folder: Optional[str] = None,
39
+ force_redownload: bool = False,
40
+ ) -> Path:
41
+ """Download or get the path of the dataset."""
42
+ if utils.is_remote_job():
43
+ return Path(os.getenv("MDI_DATASET_DIR", "/opt/ml/input/data/training"))
44
+
45
+ path_datasets_folder = path_datasets_folder or str(utils.PATH_DATASETS)
46
+ path_dataset = Path(path_datasets_folder).absolute() / dataset_name
47
+
48
+ is_dataset_valid = HafniaDataset.check_dataset_path(path_dataset, raise_error=False)
49
+ if is_dataset_valid and not force_redownload:
50
+ user_logger.info("Dataset found locally. Set 'force=True' or add `--force` flag with cli to re-download")
51
+ return path_dataset
52
+
53
+ cfg = cfg or Config()
54
+ api_key = cfg.api_key
55
+
56
+ shutil.rmtree(path_dataset, ignore_errors=True)
57
+
58
+ endpoint_dataset = cfg.get_platform_endpoint("datasets")
59
+ dataset_id = get_dataset_id(dataset_name=dataset_name, endpoint=endpoint_dataset, api_key=api_key)
60
+ access_dataset_endpoint = f"{endpoint_dataset}/{dataset_id}/temporary-credentials"
61
+
62
+ download_dataset_from_access_endpoint(
63
+ endpoint=access_dataset_endpoint,
64
+ api_key=api_key,
65
+ path_dataset=path_dataset,
66
+ )
67
+ return path_dataset
68
+
69
+
70
+ def download_dataset_from_access_endpoint(endpoint: str, api_key: str, path_dataset: Path) -> None:
71
+ resource_credentials = get_resource_credentials(endpoint, api_key)
72
+
73
+ local_dataset_paths = [str(path_dataset / filename) for filename in dataset_names.DATASET_FILENAMES]
74
+ s3_uri = resource_credentials.s3_uri()
75
+ s3_dataset_files = [f"{s3_uri}/{filename}" for filename in dataset_names.DATASET_FILENAMES]
76
+
77
+ envs = resource_credentials.aws_credentials()
78
+ fast_copy_files_s3(
79
+ src_paths=s3_dataset_files,
80
+ dst_paths=local_dataset_paths,
81
+ append_envs=envs,
82
+ description="Downloading annotations",
83
+ )
84
+
85
+ dataset = HafniaDataset.read_from_path(path_dataset, check_for_images=False)
86
+ fast_copy_files_s3(
87
+ src_paths=dataset.samples[dataset_names.ColumnName.REMOTE_PATH].to_list(),
88
+ dst_paths=dataset.samples[dataset_names.ColumnName.FILE_NAME].to_list(),
89
+ append_envs=envs,
90
+ description="Downloading images",
91
+ )
92
+
93
+
94
+ def fast_copy_files_s3(
95
+ src_paths: List[str],
96
+ dst_paths: List[str],
97
+ append_envs: Optional[Dict[str, str]] = None,
98
+ description: str = "Copying files",
99
+ ) -> List[str]:
100
+ if len(src_paths) != len(dst_paths):
101
+ raise ValueError("Source and destination paths must have the same length.")
102
+
103
+ cmds = [f"cp {src} {dst}" for src, dst in zip(src_paths, dst_paths)]
104
+ lines = execute_s5cmd_commands(cmds, append_envs=append_envs, description=description)
105
+ return lines
106
+
107
+
108
+ def execute_s5cmd_commands(
109
+ commands: List[str],
110
+ append_envs: Optional[Dict[str, str]] = None,
111
+ description: str = "Executing s5cmd commands",
112
+ ) -> List[str]:
113
+ append_envs = append_envs or {}
114
+ with tempfile.NamedTemporaryFile(suffix=".txt") as tmp_file:
115
+ tmp_file_path = Path(tmp_file.name)
116
+ tmp_file_path.write_text("\n".join(commands))
117
+ run_cmds = [
118
+ "s5cmd",
119
+ "run",
120
+ str(tmp_file_path),
121
+ ]
122
+ envs = os.environ.copy()
123
+ envs.update(append_envs)
124
+
125
+ process = subprocess.Popen(
126
+ run_cmds,
127
+ stdout=subprocess.PIPE,
128
+ stderr=subprocess.STDOUT,
129
+ universal_newlines=True,
130
+ env=envs,
131
+ )
132
+
133
+ error_lines = []
134
+ lines = []
135
+ for line in tqdm(process.stdout, total=len(commands), desc=description):
136
+ if "ERROR" in line or "error" in line:
137
+ error_lines.append(line.strip())
138
+ lines.append(line.strip())
139
+
140
+ if len(error_lines) > 0:
141
+ show_n_lines = min(5, len(error_lines))
142
+ str_error_lines = "\n".join(error_lines[:show_n_lines])
143
+ user_logger.error(
144
+ f"Detected {len(error_lines)} errors occurred while executing a total of {len(commands)} "
145
+ f" commands with s5cmd. The first {show_n_lines} is printed below:\n{str_error_lines}"
146
+ )
147
+ raise RuntimeError("Errors occurred during s5cmd execution.")
148
+ return lines
149
+
150
+
151
+ TABLE_FIELDS = {
152
+ "ID": "id",
153
+ "Hidden\nSamples": "hidden.samples",
154
+ "Hidden\nSize": "hidden.size",
155
+ "Sample\nSamples": "sample.samples",
156
+ "Sample\nSize": "sample.size",
157
+ "Name": "name",
158
+ "Title": "title",
159
+ }
160
+
161
+
162
+ def create_rich_table_from_dataset(datasets: List[Dict[str, str]]) -> rich.table.Table:
163
+ datasets = extend_dataset_details(datasets)
164
+ datasets = sorted(datasets, key=lambda x: x["name"].lower())
165
+
166
+ table = rich.table.Table(title="Available Datasets")
167
+ for i_dataset, dataset in enumerate(datasets):
168
+ if i_dataset == 0:
169
+ for column_name, _ in TABLE_FIELDS.items():
170
+ table.add_column(column_name, justify="left", style="cyan", no_wrap=True)
171
+ row = [str(dataset.get(field, "")) for field in TABLE_FIELDS.values()]
172
+ table.add_row(*row)
173
+
174
+ return table
175
+
176
+
177
+ def extend_dataset_details(datasets: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
178
+ """Extends dataset details with number of samples and size"""
179
+ for dataset in datasets:
180
+ for variant in dataset["dataset_variants"]:
181
+ variant_type = variant["variant_type"]
182
+ dataset[f"{variant_type}.samples"] = variant["number_of_data_items"]
183
+ dataset[f"{variant_type}.size"] = utils.size_human_readable(variant["size_bytes"])
184
+ return datasets
@@ -1,15 +1,87 @@
1
1
  from pathlib import Path
2
- from typing import Any, Dict
2
+ from typing import Dict
3
3
 
4
4
  import boto3
5
5
  from botocore.exceptions import ClientError
6
+ from pydantic import BaseModel, field_validator
6
7
  from tqdm import tqdm
7
8
 
8
9
  from hafnia.http import fetch
9
10
  from hafnia.log import sys_logger, user_logger
10
11
 
11
-
12
- def get_resource_creds(endpoint: str, api_key: str) -> Dict[str, Any]:
12
+ ARN_PREFIX = "arn:aws:s3:::"
13
+
14
+
15
+ class ResourceCredentials(BaseModel):
16
+ access_key: str
17
+ secret_key: str
18
+ session_token: str
19
+ s3_arn: str
20
+ region: str
21
+
22
+ @staticmethod
23
+ def fix_naming(payload: Dict[str, str]) -> "ResourceCredentials":
24
+ """
25
+ The endpoint returns a payload with a key called 's3_path', but it
26
+ is actually an ARN path (starts with arn:aws:s3::). This method renames it to 's3_arn' for consistency.
27
+ """
28
+ if "s3_path" in payload and payload["s3_path"].startswith(ARN_PREFIX):
29
+ payload["s3_arn"] = payload.pop("s3_path")
30
+
31
+ if "region" not in payload:
32
+ payload["region"] = "eu-west-1"
33
+ return ResourceCredentials(**payload)
34
+
35
+ @field_validator("s3_arn")
36
+ @classmethod
37
+ def validate_s3_arn(cls, value: str) -> str:
38
+ """Validate s3_arn to ensure it starts with 'arn:aws:s3:::'"""
39
+ if not value.startswith("arn:aws:s3:::"):
40
+ raise ValueError(f"Invalid S3 ARN: {value}. It should start with 'arn:aws:s3:::'")
41
+ return value
42
+
43
+ def s3_path(self) -> str:
44
+ """
45
+ Extracts the S3 path from the ARN.
46
+ Example: arn:aws:s3:::my-bucket/my-prefix -> my-bucket/my-prefix
47
+ """
48
+ return self.s3_arn[len(ARN_PREFIX) :]
49
+
50
+ def s3_uri(self) -> str:
51
+ """
52
+ Converts the S3 ARN to a URI format.
53
+ Example: arn:aws:s3:::my-bucket/my-prefix -> s3://my-bucket/my-prefix
54
+ """
55
+ return f"s3://{self.s3_path()}"
56
+
57
+ def bucket_name(self) -> str:
58
+ """
59
+ Extracts the bucket name from the S3 ARN.
60
+ Example: arn:aws:s3:::my-bucket/my-prefix -> my-bucket
61
+ """
62
+ return self.s3_path().split("/")[0]
63
+
64
+ def object_key(self) -> str:
65
+ """
66
+ Extracts the object key from the S3 ARN.
67
+ Example: arn:aws:s3:::my-bucket/my-prefix -> my-prefix
68
+ """
69
+ return "/".join(self.s3_path().split("/")[1:])
70
+
71
+ def aws_credentials(self) -> Dict[str, str]:
72
+ """
73
+ Returns the AWS credentials as a dictionary.
74
+ """
75
+ environment_vars = {
76
+ "AWS_ACCESS_KEY_ID": self.access_key,
77
+ "AWS_SECRET_ACCESS_KEY": self.secret_key,
78
+ "AWS_SESSION_TOKEN": self.session_token,
79
+ "AWS_REGION": self.region,
80
+ }
81
+ return environment_vars
82
+
83
+
84
+ def get_resource_credentials(endpoint: str, api_key: str) -> ResourceCredentials:
13
85
  """
14
86
  Retrieve credentials for accessing the recipe stored in S3 (or another resource)
15
87
  by calling a DIP endpoint with the API key.
@@ -18,21 +90,16 @@ def get_resource_creds(endpoint: str, api_key: str) -> Dict[str, Any]:
18
90
  endpoint (str): The endpoint URL to fetch credentials from.
19
91
 
20
92
  Returns:
21
- Dict[str, Any]: Dictionary containing the credentials, for example:
22
- {
23
- "access_key": str,
24
- "secret_key": str,
25
- "session_token": str,
26
- "s3_path": str
27
- }
93
+ ResourceCredentials
28
94
 
29
95
  Raises:
30
96
  RuntimeError: If the call to fetch the credentials fails for any reason.
31
97
  """
32
98
  try:
33
- creds = fetch(endpoint, headers={"Authorization": api_key, "accept": "application/json"})
99
+ credentials_dict = fetch(endpoint, headers={"Authorization": api_key, "accept": "application/json"})
100
+ credentials = ResourceCredentials.fix_naming(credentials_dict)
34
101
  sys_logger.debug("Successfully retrieved credentials from DIP endpoint.")
35
- return creds
102
+ return credentials
36
103
  except Exception as e:
37
104
  sys_logger.error(f"Failed to fetch credentials from endpoint: {e}")
38
105
  raise RuntimeError(f"Failed to retrieve credentials: {e}") from e
@@ -76,23 +143,18 @@ def download_resource(resource_url: str, destination: str, api_key: str) -> Dict
76
143
  ValueError: If the S3 ARN is invalid or no objects found under prefix.
77
144
  RuntimeError: If S3 calls fail with an unexpected error.
78
145
  """
79
- res_creds = get_resource_creds(resource_url, api_key)
80
- s3_arn = res_creds["s3_path"]
81
- arn_prefix = "arn:aws:s3:::"
82
- if not s3_arn.startswith(arn_prefix):
83
- raise ValueError(f"Invalid S3 ARN: {s3_arn}")
146
+ res_credentials = get_resource_credentials(resource_url, api_key)
84
147
 
85
- s3_path = s3_arn[len(arn_prefix) :]
86
- bucket_name, *key_parts = s3_path.split("/")
87
- key = "/".join(key_parts)
148
+ bucket_name = res_credentials.bucket_name()
149
+ key = res_credentials.object_key()
88
150
 
89
151
  output_path = Path(destination)
90
152
  output_path.mkdir(parents=True, exist_ok=True)
91
153
  s3_client = boto3.client(
92
154
  "s3",
93
- aws_access_key_id=res_creds["access_key"],
94
- aws_secret_access_key=res_creds["secret_key"],
95
- aws_session_token=res_creds["session_token"],
155
+ aws_access_key_id=res_credentials.access_key,
156
+ aws_secret_access_key=res_credentials.secret_key,
157
+ aws_session_token=res_credentials.session_token,
96
158
  )
97
159
  downloaded_files = []
98
160
  try: