hafnia 0.4.3__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,29 +1,74 @@
1
- import os
2
- import shutil
3
- import subprocess
4
- import sys
5
- import tempfile
6
- import uuid
7
- from pathlib import Path
8
1
  from typing import Any, Dict, List, Optional
9
2
 
10
3
  import rich
11
4
  from rich import print as rprint
12
5
 
13
6
  from hafnia import http, utils
14
- from hafnia.dataset.dataset_names import DATASET_FILENAMES_REQUIRED
15
- from hafnia.dataset.dataset_recipe.dataset_recipe import (
16
- DatasetRecipe,
17
- get_dataset_path_from_recipe,
18
- )
19
- from hafnia.dataset.hafnia_dataset import HafniaDataset
20
- from hafnia.http import fetch
21
- from hafnia.log import sys_logger, user_logger
7
+ from hafnia.http import fetch, post
8
+ from hafnia.log import user_logger
22
9
  from hafnia.platform.download import get_resource_credentials
23
- from hafnia.utils import progress_bar, timed
10
+ from hafnia.platform.s5cmd_utils import ResourceCredentials
11
+ from hafnia.utils import timed
24
12
  from hafnia_cli.config import Config
25
13
 
26
14
 
15
+ @timed("Fetching dataset by name.")
16
+ def get_dataset_by_name(dataset_name: str, cfg: Optional[Config] = None) -> Optional[Dict[str, Any]]:
17
+ """Get dataset details by name from the Hafnia platform."""
18
+ cfg = cfg or Config()
19
+ endpoint_dataset = cfg.get_platform_endpoint("datasets")
20
+ header = {"Authorization": cfg.api_key}
21
+ full_url = f"{endpoint_dataset}?name__iexact={dataset_name}"
22
+ datasets: List[Dict[str, Any]] = http.fetch(full_url, headers=header) # type: ignore[assignment]
23
+ if len(datasets) == 0:
24
+ return None
25
+
26
+ if len(datasets) > 1:
27
+ raise ValueError(f"Multiple datasets found with the name '{dataset_name}'.")
28
+
29
+ return datasets[0]
30
+
31
+
32
+ @timed("Fetching dataset by ID.")
33
+ def get_dataset_by_id(dataset_id: str, cfg: Optional[Config] = None) -> Optional[Dict[str, Any]]:
34
+ """Get dataset details by ID from the Hafnia platform."""
35
+ cfg = cfg or Config()
36
+ endpoint_dataset = cfg.get_platform_endpoint("datasets")
37
+ header = {"Authorization": cfg.api_key}
38
+ full_url = f"{endpoint_dataset}/{dataset_id}"
39
+ dataset: Dict[str, Any] = http.fetch(full_url, headers=header) # type: ignore[assignment]
40
+ if not dataset:
41
+ return None
42
+
43
+ return dataset
44
+
45
+
46
+ def get_or_create_dataset(dataset_name: str = "", cfg: Optional[Config] = None) -> Dict[str, Any]:
47
+ """Create a new dataset on the Hafnia platform."""
48
+ cfg = cfg or Config()
49
+ dataset = get_dataset_by_name(dataset_name, cfg)
50
+ if dataset is not None:
51
+ user_logger.info(f"Dataset '{dataset_name}' already exists on the Hafnia platform.")
52
+ return dataset
53
+
54
+ endpoint_dataset = cfg.get_platform_endpoint("datasets")
55
+ header = {"Authorization": cfg.api_key}
56
+ dataset_title = dataset_name.replace("-", " ").title() # convert dataset-name to title "Dataset Name"
57
+ payload = {
58
+ "title": dataset_title,
59
+ "name": dataset_name,
60
+ "overview": "No description provided.",
61
+ }
62
+
63
+ dataset = http.post(endpoint_dataset, headers=header, data=payload) # type: ignore[assignment]
64
+
65
+ # TODO: Handle issue when dataset creation fails because name is taken by another user from a different organization
66
+ if not dataset:
67
+ raise ValueError("Failed to create dataset on the Hafnia platform. ")
68
+
69
+ return dataset
70
+
71
+
27
72
  @timed("Fetching dataset list.")
28
73
  def get_datasets(cfg: Optional[Config] = None) -> List[Dict[str, str]]:
29
74
  """List available datasets on the Hafnia platform."""
@@ -50,161 +95,107 @@ def get_dataset_id(dataset_name: str, endpoint: str, api_key: str) -> str:
50
95
  raise ValueError("Dataset information is missing or invalid") from e
51
96
 
52
97
 
53
- def download_or_get_dataset_path(
54
- dataset_name: str,
55
- cfg: Optional[Config] = None,
56
- path_datasets_folder: Optional[str] = None,
57
- force_redownload: bool = False,
58
- download_files: bool = True,
59
- ) -> Path:
60
- """Download or get the path of the dataset."""
61
- recipe_explicit = DatasetRecipe.from_implicit_form(dataset_name)
62
- path_dataset = get_dataset_path_from_recipe(recipe_explicit, path_datasets=path_datasets_folder)
63
-
64
- is_dataset_valid = HafniaDataset.check_dataset_path(path_dataset, raise_error=False)
65
- if is_dataset_valid and not force_redownload:
66
- user_logger.info("Dataset found locally. Set 'force=True' or add `--force` flag with cli to re-download")
67
- return path_dataset
68
-
98
+ @timed("Get upload access credentials")
99
+ def get_upload_credentials(dataset_name: str, cfg: Optional[Config] = None) -> Optional[ResourceCredentials]:
100
+ """Get dataset details by name from the Hafnia platform."""
69
101
  cfg = cfg or Config()
70
- api_key = cfg.api_key
102
+ dataset_response = get_dataset_by_name(dataset_name=dataset_name, cfg=cfg)
103
+ if dataset_response is None:
104
+ return None
105
+
106
+ return get_upload_credentials_by_id(dataset_response["id"], cfg=cfg)
71
107
 
72
- shutil.rmtree(path_dataset, ignore_errors=True)
108
+
109
+ @timed("Get upload access credentials by ID")
110
+ def get_upload_credentials_by_id(dataset_id: str, cfg: Optional[Config] = None) -> Optional[ResourceCredentials]:
111
+ """Get dataset details by ID from the Hafnia platform."""
112
+ cfg = cfg or Config()
73
113
 
74
114
  endpoint_dataset = cfg.get_platform_endpoint("datasets")
75
- dataset_id = get_dataset_id(dataset_name=dataset_name, endpoint=endpoint_dataset, api_key=api_key)
76
- if dataset_id is None:
77
- sys_logger.error(f"Dataset '{dataset_name}' not found on the Hafnia platform.")
115
+ header = {"Authorization": cfg.api_key}
116
+ full_url = f"{endpoint_dataset}/{dataset_id}/temporary-credentials-upload"
117
+ credentials_response: Dict = http.fetch(full_url, headers=header) # type: ignore[assignment]
78
118
 
119
+ return ResourceCredentials.fix_naming(credentials_response)
120
+
121
+
122
+ @timed("Get read access credentials by ID")
123
+ def get_read_credentials_by_id(dataset_id: str, cfg: Optional[Config] = None) -> Optional[ResourceCredentials]:
124
+ """Get dataset read access credentials by ID from the Hafnia platform."""
125
+ cfg = cfg or Config()
126
+ endpoint_dataset = cfg.get_platform_endpoint("datasets")
79
127
  if utils.is_hafnia_cloud_job():
80
128
  credentials_endpoint_suffix = "temporary-credentials-hidden" # Access to hidden datasets
81
129
  else:
82
130
  credentials_endpoint_suffix = "temporary-credentials" # Access to sample dataset
83
131
  access_dataset_endpoint = f"{endpoint_dataset}/{dataset_id}/{credentials_endpoint_suffix}"
132
+ resource_credentials = get_resource_credentials(access_dataset_endpoint, cfg.api_key)
133
+ return resource_credentials
134
+
135
+
136
+ @timed("Get read access credentials by name")
137
+ def get_read_credentials_by_name(dataset_name: str, cfg: Optional[Config] = None) -> Optional[ResourceCredentials]:
138
+ """Get dataset read access credentials by name from the Hafnia platform."""
139
+ cfg = cfg or Config()
140
+ dataset_response = get_dataset_by_name(dataset_name=dataset_name, cfg=cfg)
141
+ if dataset_response is None:
142
+ return None
143
+
144
+ return get_read_credentials_by_id(dataset_response["id"], cfg=cfg)
145
+
146
+
147
+ @timed("Delete dataset by id")
148
+ def delete_dataset_by_id(dataset_id: str, cfg: Optional[Config] = None) -> Dict:
149
+ cfg = cfg or Config()
150
+ endpoint_dataset = cfg.get_platform_endpoint("datasets")
151
+ header = {"Authorization": cfg.api_key}
152
+ full_url = f"{endpoint_dataset}/{dataset_id}"
153
+ return http.delete(full_url, headers=header) # type: ignore
84
154
 
85
- download_dataset_from_access_endpoint(
86
- endpoint=access_dataset_endpoint,
87
- api_key=api_key,
88
- path_dataset=path_dataset,
89
- download_files=download_files,
90
- )
91
- return path_dataset
92
155
 
156
+ @timed("Delete dataset by name")
157
+ def delete_dataset_by_name(dataset_name: str, cfg: Optional[Config] = None) -> Dict:
158
+ cfg = cfg or Config()
159
+ dataset_response = get_dataset_by_name(dataset_name=dataset_name, cfg=cfg)
160
+ if dataset_response is None:
161
+ raise ValueError(f"Dataset '{dataset_name}' not found on the Hafnia platform.")
162
+
163
+ dataset_id = dataset_response["id"] # type: ignore[union-attr]
164
+ response = delete_dataset_by_id(dataset_id=dataset_id, cfg=cfg)
165
+ user_logger.info(f"Dataset '{dataset_name}' has been deleted from the Hafnia platform.")
166
+ return response
93
167
 
94
- def download_dataset_from_access_endpoint(
95
- endpoint: str,
96
- api_key: str,
97
- path_dataset: Path,
98
- download_files: bool = True,
168
+
169
+ def delete_dataset_completely_by_name(
170
+ dataset_name: str,
171
+ interactive: bool = True,
172
+ cfg: Optional[Config] = None,
99
173
  ) -> None:
100
- resource_credentials = get_resource_credentials(endpoint, api_key)
174
+ from hafnia.dataset.operations.dataset_s3_storage import delete_hafnia_dataset_files_on_platform
101
175
 
102
- local_dataset_paths = [(path_dataset / filename).as_posix() for filename in DATASET_FILENAMES_REQUIRED]
103
- s3_uri = resource_credentials.s3_uri()
104
- s3_dataset_files = [f"{s3_uri}/{filename}" for filename in DATASET_FILENAMES_REQUIRED]
176
+ cfg = cfg or Config()
105
177
 
106
- envs = resource_credentials.aws_credentials()
107
- try:
108
- fast_copy_files_s3(
109
- src_paths=s3_dataset_files,
110
- dst_paths=local_dataset_paths,
111
- append_envs=envs,
112
- description="Downloading annotations",
113
- )
114
- except ValueError as e:
115
- user_logger.error(f"Failed to download annotations: {e}")
178
+ is_deleted = delete_hafnia_dataset_files_on_platform(
179
+ dataset_name=dataset_name,
180
+ interactive=interactive,
181
+ cfg=cfg,
182
+ )
183
+ if not is_deleted:
116
184
  return
185
+ delete_dataset_by_name(dataset_name, cfg=cfg)
117
186
 
118
- if not download_files:
119
- return
120
- dataset = HafniaDataset.from_path(path_dataset, check_for_images=False)
121
- try:
122
- dataset = dataset.download_files_aws(path_dataset, aws_credentials=resource_credentials, force_redownload=True)
123
- except ValueError as e:
124
- user_logger.error(f"Failed to download images: {e}")
125
- return
126
- dataset.write_annotations(path_folder=path_dataset) # Overwrite annotations as files have been re-downloaded
127
-
128
-
129
- def fast_copy_files_s3(
130
- src_paths: List[str],
131
- dst_paths: List[str],
132
- append_envs: Optional[Dict[str, str]] = None,
133
- description: str = "Copying files",
134
- ) -> List[str]:
135
- if len(src_paths) != len(dst_paths):
136
- raise ValueError("Source and destination paths must have the same length.")
137
- cmds = [f"cp {src} {dst}" for src, dst in zip(src_paths, dst_paths)]
138
- lines = execute_s5cmd_commands(cmds, append_envs=append_envs, description=description)
139
- return lines
140
-
141
-
142
- def find_s5cmd() -> Optional[str]:
143
- """Locate the s5cmd executable across different installation methods.
144
-
145
- Searches for s5cmd in:
146
- 1. System PATH (via shutil.which)
147
- 2. Python bin directory (Unix-like systems)
148
- 3. Python executable directory (direct installs)
149
-
150
- Returns:
151
- str: Absolute path to s5cmd executable if found, None otherwise.
152
- """
153
- result = shutil.which("s5cmd")
154
- if result:
155
- return result
156
- python_dir = Path(sys.executable).parent
157
- locations = (python_dir / "Scripts" / "s5cmd.exe", python_dir / "bin" / "s5cmd", python_dir / "s5cmd")
158
- for loc in locations:
159
- if loc.exists():
160
- return str(loc)
161
- return None
162
-
163
-
164
- def execute_s5cmd_commands(
165
- commands: List[str],
166
- append_envs: Optional[Dict[str, str]] = None,
167
- description: str = "Executing s5cmd commands",
168
- ) -> List[str]:
169
- append_envs = append_envs or {}
170
- # In Windows default "Temp" directory can not be deleted that is why we need to create a
171
- # temporary directory.
172
- with tempfile.TemporaryDirectory() as temp_dir:
173
- tmp_file_path = Path(temp_dir, f"{uuid.uuid4().hex}.txt")
174
- tmp_file_path.write_text("\n".join(commands))
175
-
176
- s5cmd_bin = find_s5cmd()
177
- if s5cmd_bin is None:
178
- raise ValueError("Can not find s5cmd executable.")
179
- run_cmds = [s5cmd_bin, "run", str(tmp_file_path)]
180
- sys_logger.debug(run_cmds)
181
- envs = os.environ.copy()
182
- envs.update(append_envs)
183
-
184
- process = subprocess.Popen(
185
- run_cmds,
186
- stdout=subprocess.PIPE,
187
- stderr=subprocess.STDOUT,
188
- universal_newlines=True,
189
- env=envs,
190
- )
191
-
192
- error_lines = []
193
- lines = []
194
- for line in progress_bar(process.stdout, total=len(commands), description=description): # type: ignore[arg-type]
195
- if "ERROR" in line or "error" in line:
196
- error_lines.append(line.strip())
197
- lines.append(line.strip())
198
-
199
- if len(error_lines) > 0:
200
- show_n_lines = min(5, len(error_lines))
201
- str_error_lines = "\n".join(error_lines[:show_n_lines])
202
- user_logger.error(
203
- f"Detected {len(error_lines)} errors occurred while executing a total of {len(commands)} "
204
- f" commands with s5cmd. The first {show_n_lines} is printed below:\n{str_error_lines}"
205
- )
206
- raise RuntimeError("Errors occurred during s5cmd execution.")
207
- return lines
187
+
188
+ @timed("Import dataset details to platform")
189
+ def upload_dataset_details(cfg: Config, data: dict, dataset_name: str) -> dict:
190
+ dataset_endpoint = cfg.get_platform_endpoint("datasets")
191
+ dataset_id = get_dataset_id(dataset_name, dataset_endpoint, cfg.api_key)
192
+
193
+ import_endpoint = f"{dataset_endpoint}/{dataset_id}/import"
194
+ headers = {"Authorization": cfg.api_key}
195
+
196
+ user_logger.info("Exporting dataset details to platform. This may take up to 30 seconds...")
197
+ response = post(endpoint=import_endpoint, headers=headers, data=data) # type: ignore[assignment]
198
+ return response # type: ignore[return-value]
208
199
 
209
200
 
210
201
  TABLE_FIELDS = {
@@ -5,9 +5,9 @@ import boto3
5
5
  from botocore.exceptions import ClientError
6
6
  from rich.progress import Progress
7
7
 
8
- from hafnia.dataset.dataset_names import ResourceCredentials
9
8
  from hafnia.http import fetch
10
9
  from hafnia.log import sys_logger, user_logger
10
+ from hafnia.platform.s5cmd_utils import ResourceCredentials
11
11
 
12
12
 
13
13
  def get_resource_credentials(endpoint: str, api_key: str) -> ResourceCredentials:
@@ -0,0 +1,266 @@
1
+ import os
2
+ import shutil
3
+ import subprocess
4
+ import sys
5
+ import tempfile
6
+ import uuid
7
+ from pathlib import Path
8
+ from typing import Dict, List, Optional
9
+
10
+ import boto3
11
+ from botocore.exceptions import UnauthorizedSSOTokenError
12
+ from pydantic import BaseModel, field_validator
13
+
14
+ from hafnia.log import sys_logger, user_logger
15
+ from hafnia.utils import progress_bar
16
+
17
+
18
+ def find_s5cmd() -> Optional[str]:
19
+ """Locate the s5cmd executable across different installation methods.
20
+
21
+ Searches for s5cmd in:
22
+ 1. System PATH (via shutil.which)
23
+ 2. Python bin directory (Unix-like systems)
24
+ 3. Python executable directory (direct installs)
25
+
26
+ Returns:
27
+ str: Absolute path to s5cmd executable if found, None otherwise.
28
+ """
29
+ result = shutil.which("s5cmd")
30
+ if result:
31
+ return result
32
+ python_dir = Path(sys.executable).parent
33
+ locations = (
34
+ python_dir / "Scripts" / "s5cmd.exe",
35
+ python_dir / "bin" / "s5cmd",
36
+ python_dir / "s5cmd",
37
+ )
38
+ for loc in locations:
39
+ if loc.exists():
40
+ return str(loc)
41
+ return None
42
+
43
+
44
+ def execute_command(args: List[str], append_envs: Optional[Dict[str, str]] = None) -> subprocess.CompletedProcess:
45
+ s5cmd_bin = find_s5cmd()
46
+ cmds = [s5cmd_bin] + args
47
+ envs = os.environ.copy()
48
+ if append_envs:
49
+ envs.update(append_envs)
50
+
51
+ result = subprocess.run(
52
+ cmds, # type: ignore[arg-type]
53
+ stdout=subprocess.PIPE,
54
+ stderr=subprocess.PIPE,
55
+ universal_newlines=True,
56
+ env=envs,
57
+ )
58
+ return result
59
+
60
+
61
+ def execute_commands(
62
+ commands: List[str],
63
+ append_envs: Optional[Dict[str, str]] = None,
64
+ description: str = "Executing s5cmd commands",
65
+ ) -> List[str]:
66
+ append_envs = append_envs or {}
67
+
68
+ with tempfile.TemporaryDirectory() as temp_dir:
69
+ tmp_file_path = Path(temp_dir, f"{uuid.uuid4().hex}.txt")
70
+ tmp_file_path.write_text("\n".join(commands))
71
+
72
+ s5cmd_bin = find_s5cmd()
73
+ if s5cmd_bin is None:
74
+ raise ValueError("Can not find s5cmd executable.")
75
+ run_cmds = [s5cmd_bin, "run", str(tmp_file_path)]
76
+ sys_logger.debug(run_cmds)
77
+ envs = os.environ.copy()
78
+ envs.update(append_envs)
79
+
80
+ process = subprocess.Popen(
81
+ run_cmds,
82
+ stdout=subprocess.PIPE,
83
+ stderr=subprocess.STDOUT,
84
+ universal_newlines=True,
85
+ env=envs,
86
+ )
87
+
88
+ error_lines = []
89
+ lines = []
90
+ for line in progress_bar(process.stdout, total=len(commands), description=description): # type: ignore[arg-type]
91
+ if "ERROR" in line or "error" in line:
92
+ error_lines.append(line.strip())
93
+ lines.append(line.strip())
94
+
95
+ if len(error_lines) > 0:
96
+ show_n_lines = min(5, len(error_lines))
97
+ str_error_lines = "\n".join(error_lines[:show_n_lines])
98
+ user_logger.error(
99
+ f"Detected {len(error_lines)} errors occurred while executing a total of {len(commands)} "
100
+ f" commands with s5cmd. The first {show_n_lines} is printed below:\n{str_error_lines}"
101
+ )
102
+ raise RuntimeError("Errors occurred during s5cmd execution.")
103
+ return lines
104
+
105
+
106
+ def delete_bucket_content(
107
+ bucket_prefix: str,
108
+ remove_bucket: bool = True,
109
+ append_envs: Optional[Dict[str, str]] = None,
110
+ ) -> None:
111
+ # Remove all files in the bucket
112
+ returns = execute_command(["rm", f"{bucket_prefix}/*"], append_envs=append_envs)
113
+
114
+ if returns.returncode != 0:
115
+ bucket_content_is_already_deleted = "no object found" in returns.stderr.strip()
116
+ bucket_is_already_deleted = "NoSuchBucket" in returns.stderr.strip()
117
+ if bucket_content_is_already_deleted:
118
+ user_logger.info(f"No action was taken. S3 bucket '{bucket_prefix}' is already empty.")
119
+ elif bucket_is_already_deleted:
120
+ user_logger.info(f"No action was taken. S3 bucket '{bucket_prefix}' does not exist.")
121
+ return
122
+ else:
123
+ user_logger.error("Error during s5cmd rm command:")
124
+ user_logger.error(returns.stdout)
125
+ user_logger.error(returns.stderr)
126
+ raise RuntimeError(f"Failed to delete all files in S3 bucket '{bucket_prefix}'.")
127
+
128
+ if remove_bucket:
129
+ # Remove the bucket itself
130
+ returns = execute_command(["rb", bucket_prefix], append_envs=append_envs)
131
+ if returns.returncode != 0:
132
+ user_logger.error("Error during s5cmd rb command:")
133
+ user_logger.error(returns.stdout)
134
+ user_logger.error(returns.stderr)
135
+ raise RuntimeError(f"Failed to delete S3 bucket '{bucket_prefix}'.")
136
+ user_logger.info(f"S3 bucket '{bucket_prefix}' has been deleted.")
137
+
138
+
139
+ def list_bucket(bucket_prefix: str, append_envs: Optional[Dict[str, str]] = None) -> List[str]:
140
+ output = execute_command(["ls", f"{bucket_prefix}/*"], append_envs=append_envs)
141
+ has_missing_folder = "no object found" in output.stderr.strip()
142
+ if output.returncode != 0 and not has_missing_folder:
143
+ user_logger.error("Error during s5cmd ls command:")
144
+ user_logger.error(output.stderr)
145
+ raise RuntimeError(f"Failed to list dataset in S3 bucket '{bucket_prefix}'.")
146
+
147
+ files_in_s3 = [f"{bucket_prefix}/{line.split(' ')[-1]}" for line in output.stdout.splitlines()]
148
+ return files_in_s3
149
+
150
+
151
+ def fast_copy_files(
152
+ src_paths: List[str],
153
+ dst_paths: List[str],
154
+ append_envs: Optional[Dict[str, str]] = None,
155
+ description: str = "Copying files",
156
+ ) -> List[str]:
157
+ if len(src_paths) != len(dst_paths):
158
+ raise ValueError("Source and destination paths must have the same length.")
159
+ cmds = [f"cp {src} {dst}" for src, dst in zip(src_paths, dst_paths)]
160
+ lines = execute_commands(cmds, append_envs=append_envs, description=description)
161
+ return lines
162
+
163
+
164
+ ARN_PREFIX = "arn:aws:s3:::"
165
+
166
+
167
+ class AwsCredentials(BaseModel):
168
+ access_key: str
169
+ secret_key: str
170
+ session_token: str
171
+ region: Optional[str]
172
+
173
+ def aws_credentials(self) -> Dict[str, str]:
174
+ """
175
+ Returns the AWS credentials as a dictionary.
176
+ """
177
+ environment_vars = {
178
+ "AWS_ACCESS_KEY_ID": self.access_key,
179
+ "AWS_SECRET_ACCESS_KEY": self.secret_key,
180
+ "AWS_SESSION_TOKEN": self.session_token,
181
+ }
182
+ if self.region:
183
+ environment_vars["AWS_REGION"] = self.region
184
+
185
+ return environment_vars
186
+
187
+ @staticmethod
188
+ def from_session(session: boto3.Session) -> "AwsCredentials":
189
+ """
190
+ Creates AwsCredentials from a Boto3 session.
191
+ """
192
+ try:
193
+ frozen_credentials = session.get_credentials().get_frozen_credentials()
194
+ except UnauthorizedSSOTokenError as e:
195
+ raise RuntimeError(
196
+ f"Failed to get AWS credentials from the session for profile '{session.profile_name}'.\n"
197
+ f"Ensure the profile exists in your AWS config in '~/.aws/config' and that you are logged in via AWS SSO.\n"
198
+ f"\tUse 'aws sso login --profile {session.profile_name}' to log in."
199
+ ) from e
200
+ return AwsCredentials(
201
+ access_key=frozen_credentials.access_key,
202
+ secret_key=frozen_credentials.secret_key,
203
+ session_token=frozen_credentials.token,
204
+ region=session.region_name,
205
+ )
206
+
207
+ def to_resource_credentials(self, bucket_name: str) -> "ResourceCredentials":
208
+ """
209
+ Converts AwsCredentials to ResourceCredentials by adding the S3 ARN.
210
+ """
211
+ payload = self.model_dump()
212
+ payload["s3_arn"] = f"{ARN_PREFIX}{bucket_name}"
213
+ return ResourceCredentials(**payload)
214
+
215
+
216
+ class ResourceCredentials(AwsCredentials):
217
+ s3_arn: str
218
+
219
+ @staticmethod
220
+ def fix_naming(payload: Dict[str, str]) -> "ResourceCredentials":
221
+ """
222
+ The endpoint returns a payload with a key called 's3_path', but it
223
+ is actually an ARN path (starts with arn:aws:s3::). This method renames it to 's3_arn' for consistency.
224
+ """
225
+ if "s3_path" in payload and payload["s3_path"].startswith(ARN_PREFIX):
226
+ payload["s3_arn"] = payload.pop("s3_path")
227
+
228
+ if "region" not in payload:
229
+ payload["region"] = "eu-west-1"
230
+ return ResourceCredentials(**payload)
231
+
232
+ @field_validator("s3_arn")
233
+ @classmethod
234
+ def validate_s3_arn(cls, value: str) -> str:
235
+ """Validate s3_arn to ensure it starts with 'arn:aws:s3:::'"""
236
+ if not value.startswith("arn:aws:s3:::"):
237
+ raise ValueError(f"Invalid S3 ARN: {value}. It should start with 'arn:aws:s3:::'")
238
+ return value
239
+
240
+ def s3_path(self) -> str:
241
+ """
242
+ Extracts the S3 path from the ARN.
243
+ Example: arn:aws:s3:::my-bucket/my-prefix -> my-bucket/my-prefix
244
+ """
245
+ return self.s3_arn[len(ARN_PREFIX) :]
246
+
247
+ def s3_uri(self) -> str:
248
+ """
249
+ Converts the S3 ARN to a URI format.
250
+ Example: arn:aws:s3:::my-bucket/my-prefix -> s3://my-bucket/my-prefix
251
+ """
252
+ return f"s3://{self.s3_path()}"
253
+
254
+ def bucket_name(self) -> str:
255
+ """
256
+ Extracts the bucket name from the S3 ARN.
257
+ Example: arn:aws:s3:::my-bucket/my-prefix -> my-bucket
258
+ """
259
+ return self.s3_path().split("/")[0]
260
+
261
+ def object_key(self) -> str:
262
+ """
263
+ Extracts the object key from the S3 ARN.
264
+ Example: arn:aws:s3:::my-bucket/my-prefix -> my-prefix
265
+ """
266
+ return "/".join(self.s3_path().split("/")[1:])
hafnia/utils.py CHANGED
@@ -65,6 +65,10 @@ def timed(label: str):
65
65
  return decorator
66
66
 
67
67
 
68
+ def get_path_dataset_gallery_images(dataset_name: str) -> Path:
69
+ return PATH_DATASETS / dataset_name / "gallery_images"
70
+
71
+
68
72
  def get_path_hafnia_cache() -> Path:
69
73
  return Path.home() / "hafnia"
70
74
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hafnia
3
- Version: 0.4.3
3
+ Version: 0.5.1
4
4
  Summary: Python SDK for communication with Hafnia platform.
5
5
  Author-email: Milestone Systems <hafniaplatform@milestone.dk>
6
6
  License-File: LICENSE
@@ -10,7 +10,7 @@ Requires-Dist: click>=8.1.8
10
10
  Requires-Dist: emoji>=2.14.1
11
11
  Requires-Dist: flatten-dict>=0.4.2
12
12
  Requires-Dist: keyring>=25.6.0
13
- Requires-Dist: mcp==1.16.0
13
+ Requires-Dist: mcp>=1.23.0
14
14
  Requires-Dist: mlflow>=3.4.0
15
15
  Requires-Dist: more-itertools>=10.7.0
16
16
  Requires-Dist: opencv-python-headless>=4.11.0.86
@@ -209,7 +209,7 @@ DatasetInfo(
209
209
  ```
210
210
 
211
211
  You can iterate and access samples in the dataset using the `HafniaDataset` object.
212
- Each sample contain image and annotations information.
212
+ Each sample contain image and annotations information.
213
213
 
214
214
  ```python
215
215
  from hafnia.dataset.hafnia_dataset import HafniaDataset, Sample