hafnia 0.5.0__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,12 +3,70 @@ import math
3
3
  import random
4
4
  import shutil
5
5
  from pathlib import Path
6
- from typing import Dict, List
6
+ from typing import Dict, List, Optional, Tuple
7
7
 
8
8
  import numpy as np
9
9
  import xxhash
10
+ from packaging.version import InvalidVersion, Version
10
11
  from PIL import Image
11
12
 
13
+ from hafnia.log import user_logger
14
+
15
+
16
+ def is_valid_version_string(version: Optional[str], allow_none: bool = False, allow_latest: bool = False) -> bool:
17
+ if allow_none and version is None:
18
+ return True
19
+ if allow_latest and version == "latest":
20
+ return True
21
+ return version_from_string(version, raise_error=False) is not None
22
+
23
+
24
+ def version_from_string(version: Optional[str], raise_error: bool = True) -> Optional[Version]:
25
+ if version is None:
26
+ if raise_error:
27
+ raise ValueError("Version is 'None'. A valid version string is required e.g '1.0.0'")
28
+ return None
29
+
30
+ try:
31
+ version_casted = Version(version)
32
+ except (InvalidVersion, TypeError) as e:
33
+ if raise_error:
34
+ raise ValueError(f"Invalid version string/type: {version}") from e
35
+ return None
36
+
37
+ # Check if version is semantic versioning (MAJOR.MINOR.PATCH)
38
+ if len(version_casted.release) < 3:
39
+ if raise_error:
40
+ raise ValueError(f"Version string '{version}' is not semantic versioning (MAJOR.MINOR.PATCH)")
41
+ return None
42
+ return version_casted
43
+
44
+
45
+ def dataset_name_and_version_from_string(
46
+ string: str,
47
+ resolve_missing_version: bool = True,
48
+ ) -> Tuple[str, Optional[str]]:
49
+ if not isinstance(string, str):
50
+ raise TypeError(f"'{type(string)}' for '{string}' is an unsupported type. Expected 'str' e.g 'mnist:1.0.0'")
51
+
52
+ parts = string.split(":")
53
+ if len(parts) == 1:
54
+ dataset_name = parts[0]
55
+ if resolve_missing_version:
56
+ version = "latest" # Default to 'latest' if version is missing. This will be resolved to a specific version later.
57
+ user_logger.info(f"Version is missing in dataset name: {string}. Defaulting to version='latest'.")
58
+ else:
59
+ raise ValueError(f"Version is missing in dataset name: {string}. Use 'name:version'")
60
+ elif len(parts) == 2:
61
+ dataset_name, version = parts
62
+ else:
63
+ raise ValueError(f"Invalid dataset name format: {string}. Use 'name' or 'name:version' ")
64
+
65
+ if not is_valid_version_string(version, allow_none=True, allow_latest=True):
66
+ raise ValueError(f"Invalid version string: {version}. Use semantic versioning e.g. '1.0.0' or 'latest'")
67
+
68
+ return dataset_name, version
69
+
12
70
 
13
71
  def create_split_name_list_from_ratios(split_ratios: Dict[str, float], n_items: int, seed: int = 42) -> List[str]:
14
72
  samples_per_split = split_sizes_from_ratios(split_ratios=split_ratios, n_items=n_items)
@@ -1,9 +1,5 @@
1
1
  from enum import Enum
2
- from typing import Dict, List, Optional
3
-
4
- import boto3
5
- from botocore.exceptions import UnauthorizedSSOTokenError
6
- from pydantic import BaseModel, field_validator
2
+ from typing import List
7
3
 
8
4
  FILENAME_RECIPE_JSON = "recipe.json"
9
5
  FILENAME_DATASET_INFO = "dataset_info.json"
@@ -22,7 +18,6 @@ class DeploymentStage(Enum):
22
18
  PRODUCTION = "production"
23
19
 
24
20
 
25
- ARN_PREFIX = "arn:aws:s3:::"
26
21
  TAG_IS_SAMPLE = "sample"
27
22
 
28
23
  OPS_REMOVE_CLASS = "__REMOVE__"
@@ -126,105 +121,3 @@ class DatasetVariant(Enum):
126
121
  DUMP = "dump"
127
122
  SAMPLE = "sample"
128
123
  HIDDEN = "hidden"
129
-
130
-
131
- class AwsCredentials(BaseModel):
132
- access_key: str
133
- secret_key: str
134
- session_token: str
135
- region: Optional[str]
136
-
137
- def aws_credentials(self) -> Dict[str, str]:
138
- """
139
- Returns the AWS credentials as a dictionary.
140
- """
141
- environment_vars = {
142
- "AWS_ACCESS_KEY_ID": self.access_key,
143
- "AWS_SECRET_ACCESS_KEY": self.secret_key,
144
- "AWS_SESSION_TOKEN": self.session_token,
145
- }
146
- if self.region:
147
- environment_vars["AWS_REGION"] = self.region
148
-
149
- return environment_vars
150
-
151
- @staticmethod
152
- def from_session(session: boto3.Session) -> "AwsCredentials":
153
- """
154
- Creates AwsCredentials from a Boto3 session.
155
- """
156
- try:
157
- frozen_credentials = session.get_credentials().get_frozen_credentials()
158
- except UnauthorizedSSOTokenError as e:
159
- raise RuntimeError(
160
- f"Failed to get AWS credentials from the session for profile '{session.profile_name}'.\n"
161
- f"Ensure the profile exists in your AWS config in '~/.aws/config' and that you are logged in via AWS SSO.\n"
162
- f"\tUse 'aws sso login --profile {session.profile_name}' to log in."
163
- ) from e
164
- return AwsCredentials(
165
- access_key=frozen_credentials.access_key,
166
- secret_key=frozen_credentials.secret_key,
167
- session_token=frozen_credentials.token,
168
- region=session.region_name,
169
- )
170
-
171
- def to_resource_credentials(self, bucket_name: str) -> "ResourceCredentials":
172
- """
173
- Converts AwsCredentials to ResourceCredentials by adding the S3 ARN.
174
- """
175
- payload = self.model_dump()
176
- payload["s3_arn"] = f"{ARN_PREFIX}{bucket_name}"
177
- return ResourceCredentials(**payload)
178
-
179
-
180
- class ResourceCredentials(AwsCredentials):
181
- s3_arn: str
182
-
183
- @staticmethod
184
- def fix_naming(payload: Dict[str, str]) -> "ResourceCredentials":
185
- """
186
- The endpoint returns a payload with a key called 's3_path', but it
187
- is actually an ARN path (starts with arn:aws:s3::). This method renames it to 's3_arn' for consistency.
188
- """
189
- if "s3_path" in payload and payload["s3_path"].startswith(ARN_PREFIX):
190
- payload["s3_arn"] = payload.pop("s3_path")
191
-
192
- if "region" not in payload:
193
- payload["region"] = "eu-west-1"
194
- return ResourceCredentials(**payload)
195
-
196
- @field_validator("s3_arn")
197
- @classmethod
198
- def validate_s3_arn(cls, value: str) -> str:
199
- """Validate s3_arn to ensure it starts with 'arn:aws:s3:::'"""
200
- if not value.startswith("arn:aws:s3:::"):
201
- raise ValueError(f"Invalid S3 ARN: {value}. It should start with 'arn:aws:s3:::'")
202
- return value
203
-
204
- def s3_path(self) -> str:
205
- """
206
- Extracts the S3 path from the ARN.
207
- Example: arn:aws:s3:::my-bucket/my-prefix -> my-bucket/my-prefix
208
- """
209
- return self.s3_arn[len(ARN_PREFIX) :]
210
-
211
- def s3_uri(self) -> str:
212
- """
213
- Converts the S3 ARN to a URI format.
214
- Example: arn:aws:s3:::my-bucket/my-prefix -> s3://my-bucket/my-prefix
215
- """
216
- return f"s3://{self.s3_path()}"
217
-
218
- def bucket_name(self) -> str:
219
- """
220
- Extracts the bucket name from the S3 ARN.
221
- Example: arn:aws:s3:::my-bucket/my-prefix -> my-bucket
222
- """
223
- return self.s3_path().split("/")[0]
224
-
225
- def object_key(self) -> str:
226
- """
227
- Extracts the object key from the S3 ARN.
228
- Example: arn:aws:s3:::my-bucket/my-prefix -> my-prefix
229
- """
230
- return "/".join(self.s3_path().split("/")[1:])
@@ -11,14 +11,19 @@ from pydantic import (
11
11
  )
12
12
 
13
13
  from hafnia import utils
14
+ from hafnia.dataset.dataset_helpers import dataset_name_and_version_from_string
14
15
  from hafnia.dataset.dataset_recipe import recipe_transforms
15
16
  from hafnia.dataset.dataset_recipe.recipe_types import (
16
17
  RecipeCreation,
17
18
  RecipeTransform,
18
19
  Serializable,
19
20
  )
20
- from hafnia.dataset.hafnia_dataset import HafniaDataset
21
+ from hafnia.dataset.hafnia_dataset import (
22
+ HafniaDataset,
23
+ available_dataset_versions_from_name,
24
+ )
21
25
  from hafnia.dataset.primitives.primitive import Primitive
26
+ from hafnia.log import user_logger
22
27
 
23
28
 
24
29
  class DatasetRecipe(Serializable):
@@ -41,8 +46,31 @@ class DatasetRecipe(Serializable):
41
46
 
42
47
  ### Creation Methods (using the 'from_X' )###
43
48
  @staticmethod
44
- def from_name(name: str, force_redownload: bool = False, download_files: bool = True) -> DatasetRecipe:
45
- creation = FromName(name=name, force_redownload=force_redownload, download_files=download_files)
49
+ def from_name(
50
+ name: str,
51
+ version: Optional[str] = None,
52
+ force_redownload: bool = False,
53
+ download_files: bool = True,
54
+ ) -> DatasetRecipe:
55
+ if version == "latest":
56
+ user_logger.info(
57
+ f"The dataset '{name}' in a dataset recipe uses 'latest' as version. For dataset recipes the "
58
+ "version is pinned to a specific version. Consider specifying a specific version to ensure "
59
+ "reproducibility of your experiments. "
60
+ )
61
+ available_versions = available_dataset_versions_from_name(name)
62
+ version = str(max(available_versions))
63
+ if version is None:
64
+ available_versions = available_dataset_versions_from_name(name)
65
+ str_versions = ", ".join([str(v) for v in available_versions])
66
+ raise ValueError(
67
+ f"Version must be specified when creating a DatasetRecipe from name. "
68
+ f"Available versions are: {str_versions}"
69
+ )
70
+
71
+ creation = FromName(
72
+ name=name, version=version, force_redownload=force_redownload, download_files=download_files
73
+ )
46
74
  return DatasetRecipe(creation=creation)
47
75
 
48
76
  @staticmethod
@@ -125,6 +153,21 @@ class DatasetRecipe(Serializable):
125
153
  recipe_id = recipe["id"]
126
154
  return DatasetRecipe.from_recipe_id(recipe_id)
127
155
 
156
+ @staticmethod
157
+ def from_name_and_version_string(string: str, resolve_missing_version: bool = False) -> "DatasetRecipe":
158
+ """
159
+ Validates and converts a dataset name and version string (name:version) to a DatasetRecipe.from_name recipe.
160
+ If version is missing and 'resolve_missing_version' is True, it will default to 'latest'.
161
+ If resolve_missing_version is False, it will raise an error if version is missing.
162
+ """
163
+
164
+ dataset_name, version = dataset_name_and_version_from_string(
165
+ string=string,
166
+ resolve_missing_version=resolve_missing_version,
167
+ )
168
+
169
+ return DatasetRecipe.from_name(name=dataset_name, version=version)
170
+
128
171
  @staticmethod
129
172
  def from_implicit_form(recipe: Any) -> DatasetRecipe:
130
173
  """
@@ -180,7 +223,7 @@ class DatasetRecipe(Serializable):
180
223
  return recipe
181
224
 
182
225
  if isinstance(recipe, str): # str-type is convert to DatasetFromName
183
- return DatasetRecipe.from_name(name=recipe)
226
+ return DatasetRecipe.from_name_and_version_string(string=recipe, resolve_missing_version=True)
184
227
 
185
228
  if isinstance(recipe, Path): # Path-type is convert to DatasetFromPath
186
229
  return DatasetRecipe.from_path(path_folder=recipe)
@@ -409,6 +452,7 @@ class FromPath(RecipeCreation):
409
452
 
410
453
  class FromName(RecipeCreation):
411
454
  name: str
455
+ version: Optional[str] = None
412
456
  force_redownload: bool = False
413
457
  download_files: bool = True
414
458
 
@@ -78,7 +78,7 @@ def caltech_101_as_hafnia_dataset(
78
78
  n_samples=n_samples,
79
79
  dataset_name=dataset_name,
80
80
  )
81
- hafnia_dataset.info.version = "1.1.0"
81
+ hafnia_dataset.info.version = "1.0.0"
82
82
  hafnia_dataset.info.reference_bibtex = textwrap.dedent("""\
83
83
  @article{FeiFei2004LearningGV,
84
84
  title={Learning Generative Visual Models from Few Training Examples: An Incremental Bayesian
@@ -108,7 +108,7 @@ def caltech_256_as_hafnia_dataset(
108
108
  n_samples=n_samples,
109
109
  dataset_name=dataset_name,
110
110
  )
111
- hafnia_dataset.info.version = "1.1.0"
111
+ hafnia_dataset.info.version = "1.0.0"
112
112
  hafnia_dataset.info.reference_bibtex = textwrap.dedent("""\
113
113
  @misc{griffin_2023_5sv1j-ytw97,
114
114
  author = {Griffin, Gregory and
@@ -10,14 +10,12 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union
10
10
  import polars as pl
11
11
  from packaging.version import Version
12
12
 
13
+ from hafnia import utils
13
14
  from hafnia.dataset import dataset_helpers
15
+ from hafnia.dataset.dataset_helpers import is_valid_version_string, version_from_string
14
16
  from hafnia.dataset.dataset_names import (
15
- FILENAME_ANNOTATIONS_JSONL,
16
- FILENAME_ANNOTATIONS_PARQUET,
17
- FILENAME_DATASET_INFO,
18
17
  FILENAME_RECIPE_JSON,
19
18
  TAG_IS_SAMPLE,
20
- AwsCredentials,
21
19
  PrimitiveField,
22
20
  SampleField,
23
21
  SplitName,
@@ -28,7 +26,7 @@ from hafnia.dataset.format_conversions import (
28
26
  format_image_classification_folder,
29
27
  format_yolo,
30
28
  )
31
- from hafnia.dataset.hafnia_dataset_types import DatasetInfo, Sample
29
+ from hafnia.dataset.hafnia_dataset_types import DatasetInfo, DatasetMetadataFilePaths, Sample
32
30
  from hafnia.dataset.operations import (
33
31
  dataset_stats,
34
32
  dataset_transformations,
@@ -36,6 +34,9 @@ from hafnia.dataset.operations import (
36
34
  )
37
35
  from hafnia.dataset.primitives.primitive import Primitive
38
36
  from hafnia.log import user_logger
37
+ from hafnia.platform import s5cmd_utils
38
+ from hafnia.platform.datasets import get_read_credentials_by_name
39
+ from hafnia.platform.s5cmd_utils import AwsCredentials, ResourceCredentials
39
40
  from hafnia.utils import progress_bar
40
41
  from hafnia_cli.config import Config
41
42
 
@@ -89,10 +90,11 @@ class HafniaDataset:
89
90
  @staticmethod
90
91
  def from_path(path_folder: Path, check_for_images: bool = True) -> "HafniaDataset":
91
92
  path_folder = Path(path_folder)
92
- HafniaDataset.check_dataset_path(path_folder, raise_error=True)
93
+ metadata_file_paths = DatasetMetadataFilePaths.from_path(path_folder)
94
+ metadata_file_paths.exists(raise_error=True)
93
95
 
94
- dataset_info = DatasetInfo.from_json_file(path_folder / FILENAME_DATASET_INFO)
95
- samples = table_transformations.read_samples_from_path(path_folder)
96
+ dataset_info = DatasetInfo.from_json_file(Path(metadata_file_paths.dataset_info))
97
+ samples = metadata_file_paths.read_samples()
96
98
  samples, dataset_info = _dataset_corrections(samples, dataset_info)
97
99
 
98
100
  # Convert from relative paths to absolute paths
@@ -103,14 +105,24 @@ class HafniaDataset:
103
105
  return HafniaDataset(samples=samples, info=dataset_info)
104
106
 
105
107
  @staticmethod
106
- def from_name(name: str, force_redownload: bool = False, download_files: bool = True) -> "HafniaDataset":
108
+ def from_name(
109
+ name: str,
110
+ version: Optional[str] = None,
111
+ force_redownload: bool = False,
112
+ download_files: bool = True,
113
+ ) -> "HafniaDataset":
107
114
  """
108
115
  Load a dataset by its name. The dataset must be registered in the Hafnia platform.
109
116
  """
110
- from hafnia.platform.datasets import download_or_get_dataset_path
111
-
117
+ if ":" in name:
118
+ name, version = dataset_helpers.dataset_name_and_version_from_string(name)
119
+ raise ValueError(
120
+ "The 'from_name' does not support the 'name:version' format. Please provide the version separately.\n"
121
+ f"E.g., HafniaDataset.from_name(name='{name}', version='{version}')"
122
+ )
112
123
  dataset_path = download_or_get_dataset_path(
113
124
  dataset_name=name,
125
+ version=version,
114
126
  force_redownload=force_redownload,
115
127
  download_files=download_files,
116
128
  )
@@ -523,30 +535,6 @@ class HafniaDataset:
523
535
  table = dataset.samples if isinstance(dataset, HafniaDataset) else dataset
524
536
  return table_transformations.has_primitive(table, PrimitiveType)
525
537
 
526
- @staticmethod
527
- def check_dataset_path(path_dataset: Path, raise_error: bool = True) -> bool:
528
- """
529
- Checks if the dataset path exists and contains the required files.
530
- Returns True if the dataset is valid, otherwise raises an error or returns False.
531
- """
532
- if not path_dataset.exists():
533
- if raise_error:
534
- raise FileNotFoundError(f"Dataset path {path_dataset} does not exist.")
535
- return False
536
-
537
- required_files = [
538
- FILENAME_DATASET_INFO,
539
- FILENAME_ANNOTATIONS_JSONL,
540
- FILENAME_ANNOTATIONS_PARQUET,
541
- ]
542
- for filename in required_files:
543
- if not (path_dataset / filename).exists():
544
- if raise_error:
545
- raise FileNotFoundError(f"Required file {filename} not found in {path_dataset}.")
546
- return False
547
-
548
- return True
549
-
550
538
  def copy(self) -> "HafniaDataset":
551
539
  return HafniaDataset(info=self.info.model_copy(deep=True), samples=self.samples.clone())
552
540
 
@@ -584,16 +572,18 @@ class HafniaDataset:
584
572
  """
585
573
  Writes only the annotations files (JSONL and Parquet) to the specified folder.
586
574
  """
575
+
587
576
  user_logger.info(f"Writing dataset annotations to {path_folder}...")
588
- path_folder = path_folder.absolute()
589
- if not path_folder.exists():
590
- path_folder.mkdir(parents=True)
591
- dataset.info.write_json(path_folder / FILENAME_DATASET_INFO)
577
+ metadata_file_paths = DatasetMetadataFilePaths.from_path(path_folder)
578
+ path_dataset_info = Path(metadata_file_paths.dataset_info)
579
+ path_dataset_info.parent.mkdir(parents=True, exist_ok=True)
580
+ dataset.info.write_json(path_dataset_info)
592
581
 
593
582
  samples = dataset.samples
594
583
  if drop_null_cols: # Drops all unused/Null columns
595
584
  samples = samples.drop(pl.selectors.by_dtype(pl.Null))
596
585
 
586
+ path_folder = path_folder.absolute()
597
587
  # Store only relative paths in the annotations files
598
588
  if SampleField.FILE_PATH in samples.columns: # We drop column for remote datasets
599
589
  absolute_paths = samples[SampleField.FILE_PATH].to_list()
@@ -601,8 +591,11 @@ class HafniaDataset:
601
591
  samples = samples.with_columns(pl.Series(relative_paths).alias(SampleField.FILE_PATH))
602
592
  else:
603
593
  samples = samples.with_columns(pl.lit("").alias(SampleField.FILE_PATH))
604
- samples.write_ndjson(path_folder / FILENAME_ANNOTATIONS_JSONL) # Json for readability
605
- samples.write_parquet(path_folder / FILENAME_ANNOTATIONS_PARQUET) # Parquet for speed
594
+
595
+ if metadata_file_paths.annotations_jsonl:
596
+ samples.write_ndjson(Path(metadata_file_paths.annotations_jsonl)) # Json for readability
597
+ if metadata_file_paths.annotations_parquet:
598
+ samples.write_parquet(Path(metadata_file_paths.annotations_parquet)) # Parquet for speed
606
599
 
607
600
  def delete_on_platform(dataset: HafniaDataset, interactive: bool = True) -> None:
608
601
  """
@@ -707,6 +700,42 @@ class HafniaDataset:
707
700
  return True
708
701
 
709
702
 
703
+ def _dataset_corrections(samples: pl.DataFrame, dataset_info: DatasetInfo) -> Tuple[pl.DataFrame, DatasetInfo]:
704
+ format_version_of_dataset = Version(dataset_info.format_version)
705
+
706
+ ## Backwards compatibility fixes for older dataset versions
707
+ if format_version_of_dataset < Version("0.2.0"):
708
+ samples = table_transformations.add_dataset_name_if_missing(samples, dataset_info.dataset_name)
709
+
710
+ if "file_name" in samples.columns:
711
+ samples = samples.rename({"file_name": SampleField.FILE_PATH})
712
+
713
+ if SampleField.SAMPLE_INDEX not in samples.columns:
714
+ samples = table_transformations.add_sample_index(samples)
715
+
716
+ # Backwards compatibility: If tags-column doesn't exist, create it with empty lists
717
+ if SampleField.TAGS not in samples.columns:
718
+ tags_column: List[List[str]] = [[] for _ in range(len(samples))] # type: ignore[annotation-unchecked]
719
+ samples = samples.with_columns(pl.Series(tags_column, dtype=pl.List(pl.String)).alias(SampleField.TAGS))
720
+
721
+ if SampleField.STORAGE_FORMAT not in samples.columns:
722
+ samples = samples.with_columns(pl.lit(StorageFormat.IMAGE).alias(SampleField.STORAGE_FORMAT))
723
+
724
+ if SampleField.SAMPLE_INDEX in samples.columns and samples[SampleField.SAMPLE_INDEX].dtype != pl.UInt64:
725
+ samples = samples.cast({SampleField.SAMPLE_INDEX: pl.UInt64})
726
+
727
+ if format_version_of_dataset <= Version("0.2.0"):
728
+ if SampleField.BITMASKS in samples.columns and samples[SampleField.BITMASKS].dtype == pl.List(pl.Struct):
729
+ struct_schema = samples.schema[SampleField.BITMASKS].inner
730
+ struct_names = [f.name for f in struct_schema.fields]
731
+ if "rleString" in struct_names:
732
+ struct_names[struct_names.index("rleString")] = "rle_string"
733
+ samples = samples.with_columns(
734
+ pl.col(SampleField.BITMASKS).list.eval(pl.element().struct.rename_fields(struct_names))
735
+ )
736
+ return samples, dataset_info
737
+
738
+
710
739
  def check_hafnia_dataset_from_path(path_dataset: Path) -> None:
711
740
  dataset = HafniaDataset.from_path(path_dataset, check_for_images=True)
712
741
  dataset.check_dataset()
@@ -728,7 +757,8 @@ def get_or_create_dataset_path_from_recipe(
728
757
  if force_redownload:
729
758
  shutil.rmtree(path_dataset, ignore_errors=True)
730
759
 
731
- if HafniaDataset.check_dataset_path(path_dataset, raise_error=False):
760
+ dataset_metadata_files = DatasetMetadataFilePaths.from_path(path_dataset)
761
+ if dataset_metadata_files.exists(raise_error=False):
732
762
  return path_dataset
733
763
 
734
764
  path_dataset.mkdir(parents=True, exist_ok=True)
@@ -741,37 +771,101 @@ def get_or_create_dataset_path_from_recipe(
741
771
  return path_dataset
742
772
 
743
773
 
744
- def _dataset_corrections(samples: pl.DataFrame, dataset_info: DatasetInfo) -> Tuple[pl.DataFrame, DatasetInfo]:
745
- format_version_of_dataset = Version(dataset_info.format_version)
774
+ def available_dataset_versions_from_name(dataset_name: str) -> Dict[Version, "DatasetMetadataFilePaths"]:
775
+ credentials: ResourceCredentials = get_read_credentials_by_name(dataset_name=dataset_name)
776
+ return available_dataset_versions(credentials=credentials)
746
777
 
747
- ## Backwards compatibility fixes for older dataset versions
748
- if format_version_of_dataset < Version("0.2.0"):
749
- samples = table_transformations.add_dataset_name_if_missing(samples, dataset_info.dataset_name)
750
778
 
751
- if "file_name" in samples.columns:
752
- samples = samples.rename({"file_name": SampleField.FILE_PATH})
779
+ def available_dataset_versions(
780
+ credentials: ResourceCredentials,
781
+ ) -> Dict[Version, "DatasetMetadataFilePaths"]:
782
+ envs = credentials.aws_credentials()
783
+ bucket_prefix_sample_versions = f"{credentials.s3_uri()}/versions"
784
+ all_s3_annotation_files = s5cmd_utils.list_bucket(bucket_prefix=bucket_prefix_sample_versions, append_envs=envs)
785
+ available_versions = DatasetMetadataFilePaths.available_versions_from_files_list(all_s3_annotation_files)
786
+ return available_versions
753
787
 
754
- if SampleField.SAMPLE_INDEX not in samples.columns:
755
- samples = table_transformations.add_sample_index(samples)
756
788
 
757
- # Backwards compatibility: If tags-column doesn't exist, create it with empty lists
758
- if SampleField.TAGS not in samples.columns:
759
- tags_column: List[List[str]] = [[] for _ in range(len(samples))] # type: ignore[annotation-unchecked]
760
- samples = samples.with_columns(pl.Series(tags_column, dtype=pl.List(pl.String)).alias(SampleField.TAGS))
789
+ def select_version_from_available_versions(
790
+ available_versions: Dict[Version, "DatasetMetadataFilePaths"],
791
+ version: Optional[str],
792
+ ) -> "DatasetMetadataFilePaths":
793
+ if len(available_versions) == 0:
794
+ raise ValueError("No versions were found in the dataset.")
761
795
 
762
- if SampleField.STORAGE_FORMAT not in samples.columns:
763
- samples = samples.with_columns(pl.lit(StorageFormat.IMAGE).alias(SampleField.STORAGE_FORMAT))
796
+ if version is None:
797
+ str_versions = [str(v) for v in available_versions]
798
+ raise ValueError(f"Version must be specified. Available versions: {str_versions}")
799
+ elif version == "latest":
800
+ version_casted = max(available_versions)
801
+ user_logger.info(f"'latest' version '{version_casted}' has been selected")
802
+ else:
803
+ version_casted = version_from_string(version)
764
804
 
765
- if SampleField.SAMPLE_INDEX in samples.columns and samples[SampleField.SAMPLE_INDEX].dtype != pl.UInt64:
766
- samples = samples.cast({SampleField.SAMPLE_INDEX: pl.UInt64})
805
+ if version_casted not in available_versions:
806
+ raise ValueError(f"Selected version '{version}' not found in available versions: {available_versions}")
767
807
 
768
- if format_version_of_dataset <= Version("0.2.0"):
769
- if SampleField.BITMASKS in samples.columns and samples[SampleField.BITMASKS].dtype == pl.List(pl.Struct):
770
- struct_schema = samples.schema[SampleField.BITMASKS].inner
771
- struct_names = [f.name for f in struct_schema.fields]
772
- if "rleString" in struct_names:
773
- struct_names[struct_names.index("rleString")] = "rle_string"
774
- samples = samples.with_columns(
775
- pl.col(SampleField.BITMASKS).list.eval(pl.element().struct.rename_fields(struct_names))
776
- )
777
- return samples, dataset_info
808
+ return available_versions[version_casted]
809
+
810
+
811
+ def download_meta_dataset_files_from_version(
812
+ resource_credentials: ResourceCredentials, version: Optional[str], path_dataset: Path
813
+ ) -> list[str]:
814
+ envs = resource_credentials.aws_credentials()
815
+ available_versions = available_dataset_versions(credentials=resource_credentials)
816
+ metadata_files = select_version_from_available_versions(available_versions=available_versions, version=version)
817
+
818
+ s3_files = metadata_files.as_list()
819
+ path_dataset.mkdir(parents=True, exist_ok=True)
820
+ local_paths = [(path_dataset / filename.split("/")[-1]).as_posix() for filename in s3_files]
821
+ s5cmd_utils.fast_copy_files(
822
+ src_paths=s3_files,
823
+ dst_paths=local_paths,
824
+ append_envs=envs,
825
+ description="Downloading meta dataset files",
826
+ )
827
+
828
+ return local_paths
829
+
830
+
831
+ def download_or_get_dataset_path(
832
+ dataset_name: str,
833
+ version: Optional[str],
834
+ cfg: Optional[Config] = None,
835
+ path_datasets_folder: Optional[str] = None,
836
+ force_redownload: bool = False,
837
+ download_files: bool = True,
838
+ ) -> Path:
839
+ """Download or get the path of the dataset."""
840
+
841
+ path_datasets = path_datasets_folder or utils.PATH_DATASETS
842
+ path_dataset = Path(path_datasets) / dataset_name
843
+ if not is_valid_version_string(version, allow_none=True, allow_latest=True):
844
+ raise ValueError(
845
+ f"Invalid version string: {version}. Should be a valid version (e.g. '0.1.0'), 'latest' or None."
846
+ )
847
+
848
+ # Only valid versions (e.g. '0.1.0', '1.0.0') can use local cache. Using either "latest"/None will always redownload
849
+ if is_valid_version_string(version, allow_none=False, allow_latest=False):
850
+ dataset_metadata_files = DatasetMetadataFilePaths.from_path(path_dataset)
851
+ dataset_exists = dataset_metadata_files.exists(version=version, raise_error=False)
852
+ if dataset_exists and not force_redownload:
853
+ user_logger.info("Dataset found locally. Set 'force=True' or add `--force` flag with cli to re-download")
854
+ return path_dataset
855
+
856
+ cfg = cfg or Config()
857
+ resource_credentials = get_read_credentials_by_name(dataset_name=dataset_name, cfg=cfg)
858
+ if resource_credentials is None:
859
+ raise ValueError(f"Failed to get read credentials for dataset '{dataset_name}' from the platform.")
860
+
861
+ download_meta_dataset_files_from_version(
862
+ resource_credentials=resource_credentials, version=version, path_dataset=path_dataset
863
+ )
864
+
865
+ if not download_files:
866
+ return path_dataset
867
+
868
+ dataset = HafniaDataset.from_path(path_dataset, check_for_images=False)
869
+ dataset = dataset.download_files_aws(path_dataset, aws_credentials=resource_credentials, force_redownload=True)
870
+ dataset.write_annotations(path_folder=path_dataset) # Overwrite annotations as files have been re-downloaded
871
+ return path_dataset