hafnia 0.5.0__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hafnia/dataset/dataset_helpers.py +59 -1
- hafnia/dataset/dataset_names.py +1 -108
- hafnia/dataset/dataset_recipe/dataset_recipe.py +48 -4
- hafnia/dataset/format_conversions/torchvision_datasets.py +2 -2
- hafnia/dataset/hafnia_dataset.py +163 -69
- hafnia/dataset/hafnia_dataset_types.py +142 -18
- hafnia/dataset/operations/dataset_s3_storage.py +7 -2
- hafnia/dataset/operations/table_transformations.py +0 -18
- hafnia/experiment/command_builder.py +686 -0
- hafnia/platform/datasets.py +32 -132
- hafnia/platform/download.py +1 -1
- hafnia/platform/s5cmd_utils.py +122 -3
- {hafnia-0.5.0.dist-info → hafnia-0.5.2.dist-info}/METADATA +3 -2
- {hafnia-0.5.0.dist-info → hafnia-0.5.2.dist-info}/RECORD +19 -18
- hafnia_cli/dataset_cmds.py +19 -13
- hafnia_cli/runc_cmds.py +7 -2
- {hafnia-0.5.0.dist-info → hafnia-0.5.2.dist-info}/WHEEL +0 -0
- {hafnia-0.5.0.dist-info → hafnia-0.5.2.dist-info}/entry_points.txt +0 -0
- {hafnia-0.5.0.dist-info → hafnia-0.5.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -3,12 +3,70 @@ import math
|
|
|
3
3
|
import random
|
|
4
4
|
import shutil
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
from typing import Dict, List
|
|
6
|
+
from typing import Dict, List, Optional, Tuple
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import xxhash
|
|
10
|
+
from packaging.version import InvalidVersion, Version
|
|
10
11
|
from PIL import Image
|
|
11
12
|
|
|
13
|
+
from hafnia.log import user_logger
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def is_valid_version_string(version: Optional[str], allow_none: bool = False, allow_latest: bool = False) -> bool:
|
|
17
|
+
if allow_none and version is None:
|
|
18
|
+
return True
|
|
19
|
+
if allow_latest and version == "latest":
|
|
20
|
+
return True
|
|
21
|
+
return version_from_string(version, raise_error=False) is not None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def version_from_string(version: Optional[str], raise_error: bool = True) -> Optional[Version]:
|
|
25
|
+
if version is None:
|
|
26
|
+
if raise_error:
|
|
27
|
+
raise ValueError("Version is 'None'. A valid version string is required e.g '1.0.0'")
|
|
28
|
+
return None
|
|
29
|
+
|
|
30
|
+
try:
|
|
31
|
+
version_casted = Version(version)
|
|
32
|
+
except (InvalidVersion, TypeError) as e:
|
|
33
|
+
if raise_error:
|
|
34
|
+
raise ValueError(f"Invalid version string/type: {version}") from e
|
|
35
|
+
return None
|
|
36
|
+
|
|
37
|
+
# Check if version is semantic versioning (MAJOR.MINOR.PATCH)
|
|
38
|
+
if len(version_casted.release) < 3:
|
|
39
|
+
if raise_error:
|
|
40
|
+
raise ValueError(f"Version string '{version}' is not semantic versioning (MAJOR.MINOR.PATCH)")
|
|
41
|
+
return None
|
|
42
|
+
return version_casted
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def dataset_name_and_version_from_string(
|
|
46
|
+
string: str,
|
|
47
|
+
resolve_missing_version: bool = True,
|
|
48
|
+
) -> Tuple[str, Optional[str]]:
|
|
49
|
+
if not isinstance(string, str):
|
|
50
|
+
raise TypeError(f"'{type(string)}' for '{string}' is an unsupported type. Expected 'str' e.g 'mnist:1.0.0'")
|
|
51
|
+
|
|
52
|
+
parts = string.split(":")
|
|
53
|
+
if len(parts) == 1:
|
|
54
|
+
dataset_name = parts[0]
|
|
55
|
+
if resolve_missing_version:
|
|
56
|
+
version = "latest" # Default to 'latest' if version is missing. This will be resolved to a specific version later.
|
|
57
|
+
user_logger.info(f"Version is missing in dataset name: {string}. Defaulting to version='latest'.")
|
|
58
|
+
else:
|
|
59
|
+
raise ValueError(f"Version is missing in dataset name: {string}. Use 'name:version'")
|
|
60
|
+
elif len(parts) == 2:
|
|
61
|
+
dataset_name, version = parts
|
|
62
|
+
else:
|
|
63
|
+
raise ValueError(f"Invalid dataset name format: {string}. Use 'name' or 'name:version' ")
|
|
64
|
+
|
|
65
|
+
if not is_valid_version_string(version, allow_none=True, allow_latest=True):
|
|
66
|
+
raise ValueError(f"Invalid version string: {version}. Use semantic versioning e.g. '1.0.0' or 'latest'")
|
|
67
|
+
|
|
68
|
+
return dataset_name, version
|
|
69
|
+
|
|
12
70
|
|
|
13
71
|
def create_split_name_list_from_ratios(split_ratios: Dict[str, float], n_items: int, seed: int = 42) -> List[str]:
|
|
14
72
|
samples_per_split = split_sizes_from_ratios(split_ratios=split_ratios, n_items=n_items)
|
hafnia/dataset/dataset_names.py
CHANGED
|
@@ -1,9 +1,5 @@
|
|
|
1
1
|
from enum import Enum
|
|
2
|
-
from typing import
|
|
3
|
-
|
|
4
|
-
import boto3
|
|
5
|
-
from botocore.exceptions import UnauthorizedSSOTokenError
|
|
6
|
-
from pydantic import BaseModel, field_validator
|
|
2
|
+
from typing import List
|
|
7
3
|
|
|
8
4
|
FILENAME_RECIPE_JSON = "recipe.json"
|
|
9
5
|
FILENAME_DATASET_INFO = "dataset_info.json"
|
|
@@ -22,7 +18,6 @@ class DeploymentStage(Enum):
|
|
|
22
18
|
PRODUCTION = "production"
|
|
23
19
|
|
|
24
20
|
|
|
25
|
-
ARN_PREFIX = "arn:aws:s3:::"
|
|
26
21
|
TAG_IS_SAMPLE = "sample"
|
|
27
22
|
|
|
28
23
|
OPS_REMOVE_CLASS = "__REMOVE__"
|
|
@@ -126,105 +121,3 @@ class DatasetVariant(Enum):
|
|
|
126
121
|
DUMP = "dump"
|
|
127
122
|
SAMPLE = "sample"
|
|
128
123
|
HIDDEN = "hidden"
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
class AwsCredentials(BaseModel):
|
|
132
|
-
access_key: str
|
|
133
|
-
secret_key: str
|
|
134
|
-
session_token: str
|
|
135
|
-
region: Optional[str]
|
|
136
|
-
|
|
137
|
-
def aws_credentials(self) -> Dict[str, str]:
|
|
138
|
-
"""
|
|
139
|
-
Returns the AWS credentials as a dictionary.
|
|
140
|
-
"""
|
|
141
|
-
environment_vars = {
|
|
142
|
-
"AWS_ACCESS_KEY_ID": self.access_key,
|
|
143
|
-
"AWS_SECRET_ACCESS_KEY": self.secret_key,
|
|
144
|
-
"AWS_SESSION_TOKEN": self.session_token,
|
|
145
|
-
}
|
|
146
|
-
if self.region:
|
|
147
|
-
environment_vars["AWS_REGION"] = self.region
|
|
148
|
-
|
|
149
|
-
return environment_vars
|
|
150
|
-
|
|
151
|
-
@staticmethod
|
|
152
|
-
def from_session(session: boto3.Session) -> "AwsCredentials":
|
|
153
|
-
"""
|
|
154
|
-
Creates AwsCredentials from a Boto3 session.
|
|
155
|
-
"""
|
|
156
|
-
try:
|
|
157
|
-
frozen_credentials = session.get_credentials().get_frozen_credentials()
|
|
158
|
-
except UnauthorizedSSOTokenError as e:
|
|
159
|
-
raise RuntimeError(
|
|
160
|
-
f"Failed to get AWS credentials from the session for profile '{session.profile_name}'.\n"
|
|
161
|
-
f"Ensure the profile exists in your AWS config in '~/.aws/config' and that you are logged in via AWS SSO.\n"
|
|
162
|
-
f"\tUse 'aws sso login --profile {session.profile_name}' to log in."
|
|
163
|
-
) from e
|
|
164
|
-
return AwsCredentials(
|
|
165
|
-
access_key=frozen_credentials.access_key,
|
|
166
|
-
secret_key=frozen_credentials.secret_key,
|
|
167
|
-
session_token=frozen_credentials.token,
|
|
168
|
-
region=session.region_name,
|
|
169
|
-
)
|
|
170
|
-
|
|
171
|
-
def to_resource_credentials(self, bucket_name: str) -> "ResourceCredentials":
|
|
172
|
-
"""
|
|
173
|
-
Converts AwsCredentials to ResourceCredentials by adding the S3 ARN.
|
|
174
|
-
"""
|
|
175
|
-
payload = self.model_dump()
|
|
176
|
-
payload["s3_arn"] = f"{ARN_PREFIX}{bucket_name}"
|
|
177
|
-
return ResourceCredentials(**payload)
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
class ResourceCredentials(AwsCredentials):
|
|
181
|
-
s3_arn: str
|
|
182
|
-
|
|
183
|
-
@staticmethod
|
|
184
|
-
def fix_naming(payload: Dict[str, str]) -> "ResourceCredentials":
|
|
185
|
-
"""
|
|
186
|
-
The endpoint returns a payload with a key called 's3_path', but it
|
|
187
|
-
is actually an ARN path (starts with arn:aws:s3::). This method renames it to 's3_arn' for consistency.
|
|
188
|
-
"""
|
|
189
|
-
if "s3_path" in payload and payload["s3_path"].startswith(ARN_PREFIX):
|
|
190
|
-
payload["s3_arn"] = payload.pop("s3_path")
|
|
191
|
-
|
|
192
|
-
if "region" not in payload:
|
|
193
|
-
payload["region"] = "eu-west-1"
|
|
194
|
-
return ResourceCredentials(**payload)
|
|
195
|
-
|
|
196
|
-
@field_validator("s3_arn")
|
|
197
|
-
@classmethod
|
|
198
|
-
def validate_s3_arn(cls, value: str) -> str:
|
|
199
|
-
"""Validate s3_arn to ensure it starts with 'arn:aws:s3:::'"""
|
|
200
|
-
if not value.startswith("arn:aws:s3:::"):
|
|
201
|
-
raise ValueError(f"Invalid S3 ARN: {value}. It should start with 'arn:aws:s3:::'")
|
|
202
|
-
return value
|
|
203
|
-
|
|
204
|
-
def s3_path(self) -> str:
|
|
205
|
-
"""
|
|
206
|
-
Extracts the S3 path from the ARN.
|
|
207
|
-
Example: arn:aws:s3:::my-bucket/my-prefix -> my-bucket/my-prefix
|
|
208
|
-
"""
|
|
209
|
-
return self.s3_arn[len(ARN_PREFIX) :]
|
|
210
|
-
|
|
211
|
-
def s3_uri(self) -> str:
|
|
212
|
-
"""
|
|
213
|
-
Converts the S3 ARN to a URI format.
|
|
214
|
-
Example: arn:aws:s3:::my-bucket/my-prefix -> s3://my-bucket/my-prefix
|
|
215
|
-
"""
|
|
216
|
-
return f"s3://{self.s3_path()}"
|
|
217
|
-
|
|
218
|
-
def bucket_name(self) -> str:
|
|
219
|
-
"""
|
|
220
|
-
Extracts the bucket name from the S3 ARN.
|
|
221
|
-
Example: arn:aws:s3:::my-bucket/my-prefix -> my-bucket
|
|
222
|
-
"""
|
|
223
|
-
return self.s3_path().split("/")[0]
|
|
224
|
-
|
|
225
|
-
def object_key(self) -> str:
|
|
226
|
-
"""
|
|
227
|
-
Extracts the object key from the S3 ARN.
|
|
228
|
-
Example: arn:aws:s3:::my-bucket/my-prefix -> my-prefix
|
|
229
|
-
"""
|
|
230
|
-
return "/".join(self.s3_path().split("/")[1:])
|
|
@@ -11,14 +11,19 @@ from pydantic import (
|
|
|
11
11
|
)
|
|
12
12
|
|
|
13
13
|
from hafnia import utils
|
|
14
|
+
from hafnia.dataset.dataset_helpers import dataset_name_and_version_from_string
|
|
14
15
|
from hafnia.dataset.dataset_recipe import recipe_transforms
|
|
15
16
|
from hafnia.dataset.dataset_recipe.recipe_types import (
|
|
16
17
|
RecipeCreation,
|
|
17
18
|
RecipeTransform,
|
|
18
19
|
Serializable,
|
|
19
20
|
)
|
|
20
|
-
from hafnia.dataset.hafnia_dataset import
|
|
21
|
+
from hafnia.dataset.hafnia_dataset import (
|
|
22
|
+
HafniaDataset,
|
|
23
|
+
available_dataset_versions_from_name,
|
|
24
|
+
)
|
|
21
25
|
from hafnia.dataset.primitives.primitive import Primitive
|
|
26
|
+
from hafnia.log import user_logger
|
|
22
27
|
|
|
23
28
|
|
|
24
29
|
class DatasetRecipe(Serializable):
|
|
@@ -41,8 +46,31 @@ class DatasetRecipe(Serializable):
|
|
|
41
46
|
|
|
42
47
|
### Creation Methods (using the 'from_X' )###
|
|
43
48
|
@staticmethod
|
|
44
|
-
def from_name(
|
|
45
|
-
|
|
49
|
+
def from_name(
|
|
50
|
+
name: str,
|
|
51
|
+
version: Optional[str] = None,
|
|
52
|
+
force_redownload: bool = False,
|
|
53
|
+
download_files: bool = True,
|
|
54
|
+
) -> DatasetRecipe:
|
|
55
|
+
if version == "latest":
|
|
56
|
+
user_logger.info(
|
|
57
|
+
f"The dataset '{name}' in a dataset recipe uses 'latest' as version. For dataset recipes the "
|
|
58
|
+
"version is pinned to a specific version. Consider specifying a specific version to ensure "
|
|
59
|
+
"reproducibility of your experiments. "
|
|
60
|
+
)
|
|
61
|
+
available_versions = available_dataset_versions_from_name(name)
|
|
62
|
+
version = str(max(available_versions))
|
|
63
|
+
if version is None:
|
|
64
|
+
available_versions = available_dataset_versions_from_name(name)
|
|
65
|
+
str_versions = ", ".join([str(v) for v in available_versions])
|
|
66
|
+
raise ValueError(
|
|
67
|
+
f"Version must be specified when creating a DatasetRecipe from name. "
|
|
68
|
+
f"Available versions are: {str_versions}"
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
creation = FromName(
|
|
72
|
+
name=name, version=version, force_redownload=force_redownload, download_files=download_files
|
|
73
|
+
)
|
|
46
74
|
return DatasetRecipe(creation=creation)
|
|
47
75
|
|
|
48
76
|
@staticmethod
|
|
@@ -125,6 +153,21 @@ class DatasetRecipe(Serializable):
|
|
|
125
153
|
recipe_id = recipe["id"]
|
|
126
154
|
return DatasetRecipe.from_recipe_id(recipe_id)
|
|
127
155
|
|
|
156
|
+
@staticmethod
|
|
157
|
+
def from_name_and_version_string(string: str, resolve_missing_version: bool = False) -> "DatasetRecipe":
|
|
158
|
+
"""
|
|
159
|
+
Validates and converts a dataset name and version string (name:version) to a DatasetRecipe.from_name recipe.
|
|
160
|
+
If version is missing and 'resolve_missing_version' is True, it will default to 'latest'.
|
|
161
|
+
If resolve_missing_version is False, it will raise an error if version is missing.
|
|
162
|
+
"""
|
|
163
|
+
|
|
164
|
+
dataset_name, version = dataset_name_and_version_from_string(
|
|
165
|
+
string=string,
|
|
166
|
+
resolve_missing_version=resolve_missing_version,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
return DatasetRecipe.from_name(name=dataset_name, version=version)
|
|
170
|
+
|
|
128
171
|
@staticmethod
|
|
129
172
|
def from_implicit_form(recipe: Any) -> DatasetRecipe:
|
|
130
173
|
"""
|
|
@@ -180,7 +223,7 @@ class DatasetRecipe(Serializable):
|
|
|
180
223
|
return recipe
|
|
181
224
|
|
|
182
225
|
if isinstance(recipe, str): # str-type is convert to DatasetFromName
|
|
183
|
-
return DatasetRecipe.
|
|
226
|
+
return DatasetRecipe.from_name_and_version_string(string=recipe, resolve_missing_version=True)
|
|
184
227
|
|
|
185
228
|
if isinstance(recipe, Path): # Path-type is convert to DatasetFromPath
|
|
186
229
|
return DatasetRecipe.from_path(path_folder=recipe)
|
|
@@ -409,6 +452,7 @@ class FromPath(RecipeCreation):
|
|
|
409
452
|
|
|
410
453
|
class FromName(RecipeCreation):
|
|
411
454
|
name: str
|
|
455
|
+
version: Optional[str] = None
|
|
412
456
|
force_redownload: bool = False
|
|
413
457
|
download_files: bool = True
|
|
414
458
|
|
|
@@ -78,7 +78,7 @@ def caltech_101_as_hafnia_dataset(
|
|
|
78
78
|
n_samples=n_samples,
|
|
79
79
|
dataset_name=dataset_name,
|
|
80
80
|
)
|
|
81
|
-
hafnia_dataset.info.version = "1.
|
|
81
|
+
hafnia_dataset.info.version = "1.0.0"
|
|
82
82
|
hafnia_dataset.info.reference_bibtex = textwrap.dedent("""\
|
|
83
83
|
@article{FeiFei2004LearningGV,
|
|
84
84
|
title={Learning Generative Visual Models from Few Training Examples: An Incremental Bayesian
|
|
@@ -108,7 +108,7 @@ def caltech_256_as_hafnia_dataset(
|
|
|
108
108
|
n_samples=n_samples,
|
|
109
109
|
dataset_name=dataset_name,
|
|
110
110
|
)
|
|
111
|
-
hafnia_dataset.info.version = "1.
|
|
111
|
+
hafnia_dataset.info.version = "1.0.0"
|
|
112
112
|
hafnia_dataset.info.reference_bibtex = textwrap.dedent("""\
|
|
113
113
|
@misc{griffin_2023_5sv1j-ytw97,
|
|
114
114
|
author = {Griffin, Gregory and
|
hafnia/dataset/hafnia_dataset.py
CHANGED
|
@@ -10,14 +10,12 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
|
|
10
10
|
import polars as pl
|
|
11
11
|
from packaging.version import Version
|
|
12
12
|
|
|
13
|
+
from hafnia import utils
|
|
13
14
|
from hafnia.dataset import dataset_helpers
|
|
15
|
+
from hafnia.dataset.dataset_helpers import is_valid_version_string, version_from_string
|
|
14
16
|
from hafnia.dataset.dataset_names import (
|
|
15
|
-
FILENAME_ANNOTATIONS_JSONL,
|
|
16
|
-
FILENAME_ANNOTATIONS_PARQUET,
|
|
17
|
-
FILENAME_DATASET_INFO,
|
|
18
17
|
FILENAME_RECIPE_JSON,
|
|
19
18
|
TAG_IS_SAMPLE,
|
|
20
|
-
AwsCredentials,
|
|
21
19
|
PrimitiveField,
|
|
22
20
|
SampleField,
|
|
23
21
|
SplitName,
|
|
@@ -28,7 +26,7 @@ from hafnia.dataset.format_conversions import (
|
|
|
28
26
|
format_image_classification_folder,
|
|
29
27
|
format_yolo,
|
|
30
28
|
)
|
|
31
|
-
from hafnia.dataset.hafnia_dataset_types import DatasetInfo, Sample
|
|
29
|
+
from hafnia.dataset.hafnia_dataset_types import DatasetInfo, DatasetMetadataFilePaths, Sample
|
|
32
30
|
from hafnia.dataset.operations import (
|
|
33
31
|
dataset_stats,
|
|
34
32
|
dataset_transformations,
|
|
@@ -36,6 +34,9 @@ from hafnia.dataset.operations import (
|
|
|
36
34
|
)
|
|
37
35
|
from hafnia.dataset.primitives.primitive import Primitive
|
|
38
36
|
from hafnia.log import user_logger
|
|
37
|
+
from hafnia.platform import s5cmd_utils
|
|
38
|
+
from hafnia.platform.datasets import get_read_credentials_by_name
|
|
39
|
+
from hafnia.platform.s5cmd_utils import AwsCredentials, ResourceCredentials
|
|
39
40
|
from hafnia.utils import progress_bar
|
|
40
41
|
from hafnia_cli.config import Config
|
|
41
42
|
|
|
@@ -89,10 +90,11 @@ class HafniaDataset:
|
|
|
89
90
|
@staticmethod
|
|
90
91
|
def from_path(path_folder: Path, check_for_images: bool = True) -> "HafniaDataset":
|
|
91
92
|
path_folder = Path(path_folder)
|
|
92
|
-
|
|
93
|
+
metadata_file_paths = DatasetMetadataFilePaths.from_path(path_folder)
|
|
94
|
+
metadata_file_paths.exists(raise_error=True)
|
|
93
95
|
|
|
94
|
-
dataset_info = DatasetInfo.from_json_file(
|
|
95
|
-
samples =
|
|
96
|
+
dataset_info = DatasetInfo.from_json_file(Path(metadata_file_paths.dataset_info))
|
|
97
|
+
samples = metadata_file_paths.read_samples()
|
|
96
98
|
samples, dataset_info = _dataset_corrections(samples, dataset_info)
|
|
97
99
|
|
|
98
100
|
# Convert from relative paths to absolute paths
|
|
@@ -103,14 +105,24 @@ class HafniaDataset:
|
|
|
103
105
|
return HafniaDataset(samples=samples, info=dataset_info)
|
|
104
106
|
|
|
105
107
|
@staticmethod
|
|
106
|
-
def from_name(
|
|
108
|
+
def from_name(
|
|
109
|
+
name: str,
|
|
110
|
+
version: Optional[str] = None,
|
|
111
|
+
force_redownload: bool = False,
|
|
112
|
+
download_files: bool = True,
|
|
113
|
+
) -> "HafniaDataset":
|
|
107
114
|
"""
|
|
108
115
|
Load a dataset by its name. The dataset must be registered in the Hafnia platform.
|
|
109
116
|
"""
|
|
110
|
-
|
|
111
|
-
|
|
117
|
+
if ":" in name:
|
|
118
|
+
name, version = dataset_helpers.dataset_name_and_version_from_string(name)
|
|
119
|
+
raise ValueError(
|
|
120
|
+
"The 'from_name' does not support the 'name:version' format. Please provide the version separately.\n"
|
|
121
|
+
f"E.g., HafniaDataset.from_name(name='{name}', version='{version}')"
|
|
122
|
+
)
|
|
112
123
|
dataset_path = download_or_get_dataset_path(
|
|
113
124
|
dataset_name=name,
|
|
125
|
+
version=version,
|
|
114
126
|
force_redownload=force_redownload,
|
|
115
127
|
download_files=download_files,
|
|
116
128
|
)
|
|
@@ -523,30 +535,6 @@ class HafniaDataset:
|
|
|
523
535
|
table = dataset.samples if isinstance(dataset, HafniaDataset) else dataset
|
|
524
536
|
return table_transformations.has_primitive(table, PrimitiveType)
|
|
525
537
|
|
|
526
|
-
@staticmethod
|
|
527
|
-
def check_dataset_path(path_dataset: Path, raise_error: bool = True) -> bool:
|
|
528
|
-
"""
|
|
529
|
-
Checks if the dataset path exists and contains the required files.
|
|
530
|
-
Returns True if the dataset is valid, otherwise raises an error or returns False.
|
|
531
|
-
"""
|
|
532
|
-
if not path_dataset.exists():
|
|
533
|
-
if raise_error:
|
|
534
|
-
raise FileNotFoundError(f"Dataset path {path_dataset} does not exist.")
|
|
535
|
-
return False
|
|
536
|
-
|
|
537
|
-
required_files = [
|
|
538
|
-
FILENAME_DATASET_INFO,
|
|
539
|
-
FILENAME_ANNOTATIONS_JSONL,
|
|
540
|
-
FILENAME_ANNOTATIONS_PARQUET,
|
|
541
|
-
]
|
|
542
|
-
for filename in required_files:
|
|
543
|
-
if not (path_dataset / filename).exists():
|
|
544
|
-
if raise_error:
|
|
545
|
-
raise FileNotFoundError(f"Required file {filename} not found in {path_dataset}.")
|
|
546
|
-
return False
|
|
547
|
-
|
|
548
|
-
return True
|
|
549
|
-
|
|
550
538
|
def copy(self) -> "HafniaDataset":
|
|
551
539
|
return HafniaDataset(info=self.info.model_copy(deep=True), samples=self.samples.clone())
|
|
552
540
|
|
|
@@ -584,16 +572,18 @@ class HafniaDataset:
|
|
|
584
572
|
"""
|
|
585
573
|
Writes only the annotations files (JSONL and Parquet) to the specified folder.
|
|
586
574
|
"""
|
|
575
|
+
|
|
587
576
|
user_logger.info(f"Writing dataset annotations to {path_folder}...")
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
dataset.info.write_json(
|
|
577
|
+
metadata_file_paths = DatasetMetadataFilePaths.from_path(path_folder)
|
|
578
|
+
path_dataset_info = Path(metadata_file_paths.dataset_info)
|
|
579
|
+
path_dataset_info.parent.mkdir(parents=True, exist_ok=True)
|
|
580
|
+
dataset.info.write_json(path_dataset_info)
|
|
592
581
|
|
|
593
582
|
samples = dataset.samples
|
|
594
583
|
if drop_null_cols: # Drops all unused/Null columns
|
|
595
584
|
samples = samples.drop(pl.selectors.by_dtype(pl.Null))
|
|
596
585
|
|
|
586
|
+
path_folder = path_folder.absolute()
|
|
597
587
|
# Store only relative paths in the annotations files
|
|
598
588
|
if SampleField.FILE_PATH in samples.columns: # We drop column for remote datasets
|
|
599
589
|
absolute_paths = samples[SampleField.FILE_PATH].to_list()
|
|
@@ -601,8 +591,11 @@ class HafniaDataset:
|
|
|
601
591
|
samples = samples.with_columns(pl.Series(relative_paths).alias(SampleField.FILE_PATH))
|
|
602
592
|
else:
|
|
603
593
|
samples = samples.with_columns(pl.lit("").alias(SampleField.FILE_PATH))
|
|
604
|
-
|
|
605
|
-
|
|
594
|
+
|
|
595
|
+
if metadata_file_paths.annotations_jsonl:
|
|
596
|
+
samples.write_ndjson(Path(metadata_file_paths.annotations_jsonl)) # Json for readability
|
|
597
|
+
if metadata_file_paths.annotations_parquet:
|
|
598
|
+
samples.write_parquet(Path(metadata_file_paths.annotations_parquet)) # Parquet for speed
|
|
606
599
|
|
|
607
600
|
def delete_on_platform(dataset: HafniaDataset, interactive: bool = True) -> None:
|
|
608
601
|
"""
|
|
@@ -707,6 +700,42 @@ class HafniaDataset:
|
|
|
707
700
|
return True
|
|
708
701
|
|
|
709
702
|
|
|
703
|
+
def _dataset_corrections(samples: pl.DataFrame, dataset_info: DatasetInfo) -> Tuple[pl.DataFrame, DatasetInfo]:
|
|
704
|
+
format_version_of_dataset = Version(dataset_info.format_version)
|
|
705
|
+
|
|
706
|
+
## Backwards compatibility fixes for older dataset versions
|
|
707
|
+
if format_version_of_dataset < Version("0.2.0"):
|
|
708
|
+
samples = table_transformations.add_dataset_name_if_missing(samples, dataset_info.dataset_name)
|
|
709
|
+
|
|
710
|
+
if "file_name" in samples.columns:
|
|
711
|
+
samples = samples.rename({"file_name": SampleField.FILE_PATH})
|
|
712
|
+
|
|
713
|
+
if SampleField.SAMPLE_INDEX not in samples.columns:
|
|
714
|
+
samples = table_transformations.add_sample_index(samples)
|
|
715
|
+
|
|
716
|
+
# Backwards compatibility: If tags-column doesn't exist, create it with empty lists
|
|
717
|
+
if SampleField.TAGS not in samples.columns:
|
|
718
|
+
tags_column: List[List[str]] = [[] for _ in range(len(samples))] # type: ignore[annotation-unchecked]
|
|
719
|
+
samples = samples.with_columns(pl.Series(tags_column, dtype=pl.List(pl.String)).alias(SampleField.TAGS))
|
|
720
|
+
|
|
721
|
+
if SampleField.STORAGE_FORMAT not in samples.columns:
|
|
722
|
+
samples = samples.with_columns(pl.lit(StorageFormat.IMAGE).alias(SampleField.STORAGE_FORMAT))
|
|
723
|
+
|
|
724
|
+
if SampleField.SAMPLE_INDEX in samples.columns and samples[SampleField.SAMPLE_INDEX].dtype != pl.UInt64:
|
|
725
|
+
samples = samples.cast({SampleField.SAMPLE_INDEX: pl.UInt64})
|
|
726
|
+
|
|
727
|
+
if format_version_of_dataset <= Version("0.2.0"):
|
|
728
|
+
if SampleField.BITMASKS in samples.columns and samples[SampleField.BITMASKS].dtype == pl.List(pl.Struct):
|
|
729
|
+
struct_schema = samples.schema[SampleField.BITMASKS].inner
|
|
730
|
+
struct_names = [f.name for f in struct_schema.fields]
|
|
731
|
+
if "rleString" in struct_names:
|
|
732
|
+
struct_names[struct_names.index("rleString")] = "rle_string"
|
|
733
|
+
samples = samples.with_columns(
|
|
734
|
+
pl.col(SampleField.BITMASKS).list.eval(pl.element().struct.rename_fields(struct_names))
|
|
735
|
+
)
|
|
736
|
+
return samples, dataset_info
|
|
737
|
+
|
|
738
|
+
|
|
710
739
|
def check_hafnia_dataset_from_path(path_dataset: Path) -> None:
|
|
711
740
|
dataset = HafniaDataset.from_path(path_dataset, check_for_images=True)
|
|
712
741
|
dataset.check_dataset()
|
|
@@ -728,7 +757,8 @@ def get_or_create_dataset_path_from_recipe(
|
|
|
728
757
|
if force_redownload:
|
|
729
758
|
shutil.rmtree(path_dataset, ignore_errors=True)
|
|
730
759
|
|
|
731
|
-
|
|
760
|
+
dataset_metadata_files = DatasetMetadataFilePaths.from_path(path_dataset)
|
|
761
|
+
if dataset_metadata_files.exists(raise_error=False):
|
|
732
762
|
return path_dataset
|
|
733
763
|
|
|
734
764
|
path_dataset.mkdir(parents=True, exist_ok=True)
|
|
@@ -741,37 +771,101 @@ def get_or_create_dataset_path_from_recipe(
|
|
|
741
771
|
return path_dataset
|
|
742
772
|
|
|
743
773
|
|
|
744
|
-
def
|
|
745
|
-
|
|
774
|
+
def available_dataset_versions_from_name(dataset_name: str) -> Dict[Version, "DatasetMetadataFilePaths"]:
|
|
775
|
+
credentials: ResourceCredentials = get_read_credentials_by_name(dataset_name=dataset_name)
|
|
776
|
+
return available_dataset_versions(credentials=credentials)
|
|
746
777
|
|
|
747
|
-
## Backwards compatibility fixes for older dataset versions
|
|
748
|
-
if format_version_of_dataset < Version("0.2.0"):
|
|
749
|
-
samples = table_transformations.add_dataset_name_if_missing(samples, dataset_info.dataset_name)
|
|
750
778
|
|
|
751
|
-
|
|
752
|
-
|
|
779
|
+
def available_dataset_versions(
|
|
780
|
+
credentials: ResourceCredentials,
|
|
781
|
+
) -> Dict[Version, "DatasetMetadataFilePaths"]:
|
|
782
|
+
envs = credentials.aws_credentials()
|
|
783
|
+
bucket_prefix_sample_versions = f"{credentials.s3_uri()}/versions"
|
|
784
|
+
all_s3_annotation_files = s5cmd_utils.list_bucket(bucket_prefix=bucket_prefix_sample_versions, append_envs=envs)
|
|
785
|
+
available_versions = DatasetMetadataFilePaths.available_versions_from_files_list(all_s3_annotation_files)
|
|
786
|
+
return available_versions
|
|
753
787
|
|
|
754
|
-
if SampleField.SAMPLE_INDEX not in samples.columns:
|
|
755
|
-
samples = table_transformations.add_sample_index(samples)
|
|
756
788
|
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
789
|
+
def select_version_from_available_versions(
|
|
790
|
+
available_versions: Dict[Version, "DatasetMetadataFilePaths"],
|
|
791
|
+
version: Optional[str],
|
|
792
|
+
) -> "DatasetMetadataFilePaths":
|
|
793
|
+
if len(available_versions) == 0:
|
|
794
|
+
raise ValueError("No versions were found in the dataset.")
|
|
761
795
|
|
|
762
|
-
|
|
763
|
-
|
|
796
|
+
if version is None:
|
|
797
|
+
str_versions = [str(v) for v in available_versions]
|
|
798
|
+
raise ValueError(f"Version must be specified. Available versions: {str_versions}")
|
|
799
|
+
elif version == "latest":
|
|
800
|
+
version_casted = max(available_versions)
|
|
801
|
+
user_logger.info(f"'latest' version '{version_casted}' has been selected")
|
|
802
|
+
else:
|
|
803
|
+
version_casted = version_from_string(version)
|
|
764
804
|
|
|
765
|
-
|
|
766
|
-
|
|
805
|
+
if version_casted not in available_versions:
|
|
806
|
+
raise ValueError(f"Selected version '{version}' not found in available versions: {available_versions}")
|
|
767
807
|
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
808
|
+
return available_versions[version_casted]
|
|
809
|
+
|
|
810
|
+
|
|
811
|
+
def download_meta_dataset_files_from_version(
|
|
812
|
+
resource_credentials: ResourceCredentials, version: Optional[str], path_dataset: Path
|
|
813
|
+
) -> list[str]:
|
|
814
|
+
envs = resource_credentials.aws_credentials()
|
|
815
|
+
available_versions = available_dataset_versions(credentials=resource_credentials)
|
|
816
|
+
metadata_files = select_version_from_available_versions(available_versions=available_versions, version=version)
|
|
817
|
+
|
|
818
|
+
s3_files = metadata_files.as_list()
|
|
819
|
+
path_dataset.mkdir(parents=True, exist_ok=True)
|
|
820
|
+
local_paths = [(path_dataset / filename.split("/")[-1]).as_posix() for filename in s3_files]
|
|
821
|
+
s5cmd_utils.fast_copy_files(
|
|
822
|
+
src_paths=s3_files,
|
|
823
|
+
dst_paths=local_paths,
|
|
824
|
+
append_envs=envs,
|
|
825
|
+
description="Downloading meta dataset files",
|
|
826
|
+
)
|
|
827
|
+
|
|
828
|
+
return local_paths
|
|
829
|
+
|
|
830
|
+
|
|
831
|
+
def download_or_get_dataset_path(
|
|
832
|
+
dataset_name: str,
|
|
833
|
+
version: Optional[str],
|
|
834
|
+
cfg: Optional[Config] = None,
|
|
835
|
+
path_datasets_folder: Optional[str] = None,
|
|
836
|
+
force_redownload: bool = False,
|
|
837
|
+
download_files: bool = True,
|
|
838
|
+
) -> Path:
|
|
839
|
+
"""Download or get the path of the dataset."""
|
|
840
|
+
|
|
841
|
+
path_datasets = path_datasets_folder or utils.PATH_DATASETS
|
|
842
|
+
path_dataset = Path(path_datasets) / dataset_name
|
|
843
|
+
if not is_valid_version_string(version, allow_none=True, allow_latest=True):
|
|
844
|
+
raise ValueError(
|
|
845
|
+
f"Invalid version string: {version}. Should be a valid version (e.g. '0.1.0'), 'latest' or None."
|
|
846
|
+
)
|
|
847
|
+
|
|
848
|
+
# Only valid versions (e.g. '0.1.0', '1.0.0') can use local cache. Using either "latest"/None will always redownload
|
|
849
|
+
if is_valid_version_string(version, allow_none=False, allow_latest=False):
|
|
850
|
+
dataset_metadata_files = DatasetMetadataFilePaths.from_path(path_dataset)
|
|
851
|
+
dataset_exists = dataset_metadata_files.exists(version=version, raise_error=False)
|
|
852
|
+
if dataset_exists and not force_redownload:
|
|
853
|
+
user_logger.info("Dataset found locally. Set 'force=True' or add `--force` flag with cli to re-download")
|
|
854
|
+
return path_dataset
|
|
855
|
+
|
|
856
|
+
cfg = cfg or Config()
|
|
857
|
+
resource_credentials = get_read_credentials_by_name(dataset_name=dataset_name, cfg=cfg)
|
|
858
|
+
if resource_credentials is None:
|
|
859
|
+
raise ValueError(f"Failed to get read credentials for dataset '{dataset_name}' from the platform.")
|
|
860
|
+
|
|
861
|
+
download_meta_dataset_files_from_version(
|
|
862
|
+
resource_credentials=resource_credentials, version=version, path_dataset=path_dataset
|
|
863
|
+
)
|
|
864
|
+
|
|
865
|
+
if not download_files:
|
|
866
|
+
return path_dataset
|
|
867
|
+
|
|
868
|
+
dataset = HafniaDataset.from_path(path_dataset, check_for_images=False)
|
|
869
|
+
dataset = dataset.download_files_aws(path_dataset, aws_credentials=resource_credentials, force_redownload=True)
|
|
870
|
+
dataset.write_annotations(path_folder=path_dataset) # Overwrite annotations as files have been re-downloaded
|
|
871
|
+
return path_dataset
|