hafnia 0.4.3__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hafnia/dataset/dataset_details_uploader.py +41 -54
- hafnia/dataset/dataset_helpers.py +60 -16
- hafnia/dataset/dataset_names.py +1 -94
- hafnia/dataset/dataset_recipe/dataset_recipe.py +48 -4
- hafnia/dataset/format_conversions/torchvision_datasets.py +8 -5
- hafnia/dataset/hafnia_dataset.py +261 -92
- hafnia/dataset/hafnia_dataset_types.py +145 -19
- hafnia/dataset/operations/dataset_s3_storage.py +216 -0
- hafnia/dataset/operations/table_transformations.py +2 -19
- hafnia/http.py +2 -1
- hafnia/platform/datasets.py +144 -153
- hafnia/platform/download.py +1 -1
- hafnia/platform/s5cmd_utils.py +266 -0
- hafnia/utils.py +4 -0
- {hafnia-0.4.3.dist-info → hafnia-0.5.1.dist-info}/METADATA +3 -3
- {hafnia-0.4.3.dist-info → hafnia-0.5.1.dist-info}/RECORD +22 -20
- {hafnia-0.4.3.dist-info → hafnia-0.5.1.dist-info}/WHEEL +1 -1
- hafnia_cli/dataset_cmds.py +36 -12
- hafnia_cli/profile_cmds.py +0 -1
- hafnia_cli/runc_cmds.py +7 -2
- {hafnia-0.4.3.dist-info → hafnia-0.5.1.dist-info}/entry_points.txt +0 -0
- {hafnia-0.4.3.dist-info → hafnia-0.5.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -4,7 +4,7 @@ import base64
|
|
|
4
4
|
from datetime import datetime
|
|
5
5
|
from enum import Enum
|
|
6
6
|
from pathlib import Path
|
|
7
|
-
from typing import Any, Dict, List, Optional,
|
|
7
|
+
from typing import Any, Dict, List, Optional, Type, Union
|
|
8
8
|
|
|
9
9
|
import boto3
|
|
10
10
|
import polars as pl
|
|
@@ -13,7 +13,6 @@ from pydantic import BaseModel, ConfigDict, field_validator
|
|
|
13
13
|
|
|
14
14
|
from hafnia.dataset.dataset_names import (
|
|
15
15
|
DatasetVariant,
|
|
16
|
-
DeploymentStage,
|
|
17
16
|
PrimitiveField,
|
|
18
17
|
SampleField,
|
|
19
18
|
SplitName,
|
|
@@ -29,26 +28,21 @@ from hafnia.dataset.primitives import (
|
|
|
29
28
|
Segmentation,
|
|
30
29
|
)
|
|
31
30
|
from hafnia.dataset.primitives.primitive import Primitive
|
|
32
|
-
from hafnia.
|
|
33
|
-
from hafnia.
|
|
34
|
-
from hafnia.platform.datasets import get_dataset_id
|
|
31
|
+
from hafnia.platform.datasets import upload_dataset_details
|
|
32
|
+
from hafnia.utils import get_path_dataset_gallery_images
|
|
35
33
|
from hafnia_cli.config import Config
|
|
36
34
|
|
|
37
35
|
|
|
38
|
-
def generate_bucket_name(dataset_name: str, deployment_stage: DeploymentStage) -> str:
|
|
39
|
-
# TODO: When moving to versioning we do NOT need 'staging' and 'production' specific buckets
|
|
40
|
-
# and the new name convention should be: f"hafnia-dataset-{dataset_name}"
|
|
41
|
-
return f"mdi-{deployment_stage.value}-{dataset_name}"
|
|
42
|
-
|
|
43
|
-
|
|
44
36
|
class DatasetDetails(BaseModel, validate_assignment=True): # type: ignore[call-arg]
|
|
45
37
|
model_config = ConfigDict(use_enum_values=True) # To parse Enum values as strings
|
|
46
38
|
name: str
|
|
39
|
+
title: Optional[str] = None
|
|
40
|
+
overview: Optional[str] = None
|
|
47
41
|
data_captured_start: Optional[datetime] = None
|
|
48
42
|
data_captured_end: Optional[datetime] = None
|
|
49
43
|
data_received_start: Optional[datetime] = None
|
|
50
44
|
data_received_end: Optional[datetime] = None
|
|
51
|
-
|
|
45
|
+
dataset_updated_at: Optional[datetime] = None
|
|
52
46
|
license_citation: Optional[str] = None
|
|
53
47
|
version: Optional[str] = None
|
|
54
48
|
s3_bucket_name: Optional[str] = None
|
|
@@ -281,26 +275,32 @@ def get_folder_size(path: Path) -> int:
|
|
|
281
275
|
return sum([path.stat().st_size for path in path.rglob("*")])
|
|
282
276
|
|
|
283
277
|
|
|
284
|
-
def
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
278
|
+
def upload_dataset_details_to_platform(
|
|
279
|
+
dataset: HafniaDataset,
|
|
280
|
+
path_gallery_images: Optional[Path] = None,
|
|
281
|
+
gallery_image_names: Optional[List[str]] = None,
|
|
282
|
+
distribution_task_names: Optional[List[str]] = None,
|
|
283
|
+
update_platform: bool = True,
|
|
284
|
+
cfg: Optional[Config] = None,
|
|
285
|
+
) -> dict:
|
|
286
|
+
cfg = cfg or Config()
|
|
287
|
+
dataset_details = dataset_details_from_hafnia_dataset(
|
|
288
|
+
dataset=dataset,
|
|
289
|
+
path_gallery_images=path_gallery_images,
|
|
290
|
+
gallery_image_names=gallery_image_names,
|
|
291
|
+
distribution_task_names=distribution_task_names,
|
|
292
|
+
)
|
|
297
293
|
|
|
298
|
-
|
|
299
|
-
|
|
294
|
+
if update_platform:
|
|
295
|
+
dataset_details_exclude_none = dataset_details.model_dump(exclude_none=True, mode="json")
|
|
296
|
+
upload_dataset_details(
|
|
297
|
+
cfg=cfg,
|
|
298
|
+
data=dataset_details_exclude_none,
|
|
299
|
+
dataset_name=dataset_details.name,
|
|
300
|
+
)
|
|
300
301
|
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
return response # type: ignore[return-value]
|
|
302
|
+
dataset_details_dict = dataset_details.model_dump(exclude_none=False, mode="json")
|
|
303
|
+
return dataset_details_dict
|
|
304
304
|
|
|
305
305
|
|
|
306
306
|
def get_resolutions(dataset: HafniaDataset, max_resolutions_selected: int = 8) -> List[DbResolution]:
|
|
@@ -360,9 +360,6 @@ def s3_based_fields(bucket_name: str, variant_type: DatasetVariant, session: bot
|
|
|
360
360
|
|
|
361
361
|
def dataset_details_from_hafnia_dataset(
|
|
362
362
|
dataset: HafniaDataset,
|
|
363
|
-
deployment_stage: DeploymentStage,
|
|
364
|
-
path_sample: Optional[Path],
|
|
365
|
-
path_hidden: Optional[Path],
|
|
366
363
|
path_gallery_images: Optional[Path] = None,
|
|
367
364
|
gallery_image_names: Optional[List[str]] = None,
|
|
368
365
|
distribution_task_names: Optional[List[str]] = None,
|
|
@@ -371,33 +368,24 @@ def dataset_details_from_hafnia_dataset(
|
|
|
371
368
|
dataset_reports = []
|
|
372
369
|
dataset_meta_info = dataset.info.meta or {}
|
|
373
370
|
|
|
374
|
-
path_and_variant
|
|
375
|
-
if path_sample is not None:
|
|
376
|
-
path_and_variant.append((path_sample, DatasetVariant.SAMPLE))
|
|
377
|
-
|
|
378
|
-
if path_hidden is not None:
|
|
379
|
-
path_and_variant.append((path_hidden, DatasetVariant.HIDDEN))
|
|
380
|
-
|
|
381
|
-
if len(path_and_variant) == 0:
|
|
382
|
-
raise ValueError("At least one path must be provided for sample or hidden dataset.")
|
|
383
|
-
|
|
371
|
+
path_and_variant = [DatasetVariant.SAMPLE, DatasetVariant.HIDDEN]
|
|
384
372
|
gallery_images = create_gallery_images(
|
|
385
373
|
dataset=dataset,
|
|
386
374
|
path_gallery_images=path_gallery_images,
|
|
387
375
|
gallery_image_names=gallery_image_names,
|
|
388
376
|
)
|
|
389
377
|
|
|
390
|
-
for
|
|
378
|
+
for variant_type in path_and_variant:
|
|
391
379
|
if variant_type == DatasetVariant.SAMPLE:
|
|
392
380
|
dataset_variant = dataset.create_sample_dataset()
|
|
393
381
|
else:
|
|
394
382
|
dataset_variant = dataset
|
|
395
383
|
|
|
396
|
-
|
|
384
|
+
files_paths = dataset_variant.samples[SampleField.FILE_PATH].to_list()
|
|
385
|
+
size_bytes = sum([Path(file_path).stat().st_size for file_path in files_paths])
|
|
397
386
|
dataset_variants.append(
|
|
398
387
|
DbDatasetVariant(
|
|
399
388
|
variant_type=VARIANT_TYPE_MAPPING[variant_type], # type: ignore[index]
|
|
400
|
-
# upload_date: Optional[datetime] = None
|
|
401
389
|
size_bytes=size_bytes,
|
|
402
390
|
data_type=DataTypeChoices.images,
|
|
403
391
|
number_of_data_items=len(dataset_variant),
|
|
@@ -405,7 +393,6 @@ def dataset_details_from_hafnia_dataset(
|
|
|
405
393
|
duration=dataset_meta_info.get("duration", None),
|
|
406
394
|
duration_average=dataset_meta_info.get("duration_average", None),
|
|
407
395
|
frame_rate=dataset_meta_info.get("frame_rate", None),
|
|
408
|
-
# bit_rate: Optional[float] = None
|
|
409
396
|
n_cameras=dataset_meta_info.get("n_cameras", None),
|
|
410
397
|
)
|
|
411
398
|
)
|
|
@@ -435,19 +422,19 @@ def dataset_details_from_hafnia_dataset(
|
|
|
435
422
|
object_reports = sorted(object_reports, key=lambda x: x.obj.name) # Sort object reports by name
|
|
436
423
|
report.annotated_object_reports = object_reports
|
|
437
424
|
|
|
438
|
-
|
|
439
|
-
|
|
425
|
+
if report.distribution_values is None:
|
|
426
|
+
report.distribution_values = []
|
|
440
427
|
|
|
441
|
-
|
|
428
|
+
dataset_reports.append(report)
|
|
442
429
|
dataset_name = dataset.info.dataset_name
|
|
443
|
-
bucket_sample = generate_bucket_name(dataset_name, deployment_stage=deployment_stage)
|
|
444
430
|
dataset_info = DatasetDetails(
|
|
445
431
|
name=dataset_name,
|
|
432
|
+
title=dataset.info.dataset_title,
|
|
433
|
+
overview=dataset.info.description,
|
|
446
434
|
version=dataset.info.version,
|
|
447
|
-
s3_bucket_name=bucket_sample,
|
|
448
435
|
dataset_variants=dataset_variants,
|
|
449
436
|
split_annotations_reports=dataset_reports,
|
|
450
|
-
|
|
437
|
+
dataset_updated_at=dataset.info.updated_at,
|
|
451
438
|
dataset_format_version=dataset.info.format_version,
|
|
452
439
|
license_citation=dataset.info.reference_bibtex,
|
|
453
440
|
data_captured_start=dataset_meta_info.get("data_captured_start", None),
|
|
@@ -565,7 +552,7 @@ def create_gallery_images(
|
|
|
565
552
|
gallery_images = None
|
|
566
553
|
if (gallery_image_names is not None) and (len(gallery_image_names) > 0):
|
|
567
554
|
if path_gallery_images is None:
|
|
568
|
-
|
|
555
|
+
path_gallery_images = get_path_dataset_gallery_images(dataset.info.dataset_name)
|
|
569
556
|
path_gallery_images.mkdir(parents=True, exist_ok=True)
|
|
570
557
|
COL_IMAGE_NAME = "image_name"
|
|
571
558
|
samples = dataset.samples.with_columns(
|
|
@@ -3,12 +3,70 @@ import math
|
|
|
3
3
|
import random
|
|
4
4
|
import shutil
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
from typing import Dict, List
|
|
6
|
+
from typing import Dict, List, Optional, Tuple
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import xxhash
|
|
10
|
+
from packaging.version import InvalidVersion, Version
|
|
10
11
|
from PIL import Image
|
|
11
12
|
|
|
13
|
+
from hafnia.log import user_logger
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def is_valid_version_string(version: Optional[str], allow_none: bool = False, allow_latest: bool = False) -> bool:
|
|
17
|
+
if allow_none and version is None:
|
|
18
|
+
return True
|
|
19
|
+
if allow_latest and version == "latest":
|
|
20
|
+
return True
|
|
21
|
+
return version_from_string(version, raise_error=False) is not None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def version_from_string(version: Optional[str], raise_error: bool = True) -> Optional[Version]:
|
|
25
|
+
if version is None:
|
|
26
|
+
if raise_error:
|
|
27
|
+
raise ValueError("Version is 'None'. A valid version string is required e.g '1.0.0'")
|
|
28
|
+
return None
|
|
29
|
+
|
|
30
|
+
try:
|
|
31
|
+
version_casted = Version(version)
|
|
32
|
+
except (InvalidVersion, TypeError) as e:
|
|
33
|
+
if raise_error:
|
|
34
|
+
raise ValueError(f"Invalid version string/type: {version}") from e
|
|
35
|
+
return None
|
|
36
|
+
|
|
37
|
+
# Check if version is semantic versioning (MAJOR.MINOR.PATCH)
|
|
38
|
+
if len(version_casted.release) < 3:
|
|
39
|
+
if raise_error:
|
|
40
|
+
raise ValueError(f"Version string '{version}' is not semantic versioning (MAJOR.MINOR.PATCH)")
|
|
41
|
+
return None
|
|
42
|
+
return version_casted
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def dataset_name_and_version_from_string(
|
|
46
|
+
string: str,
|
|
47
|
+
resolve_missing_version: bool = True,
|
|
48
|
+
) -> Tuple[str, Optional[str]]:
|
|
49
|
+
if not isinstance(string, str):
|
|
50
|
+
raise TypeError(f"'{type(string)}' for '{string}' is an unsupported type. Expected 'str' e.g 'mnist:1.0.0'")
|
|
51
|
+
|
|
52
|
+
parts = string.split(":")
|
|
53
|
+
if len(parts) == 1:
|
|
54
|
+
dataset_name = parts[0]
|
|
55
|
+
if resolve_missing_version:
|
|
56
|
+
version = "latest" # Default to 'latest' if version is missing. This will be resolved to a specific version later.
|
|
57
|
+
user_logger.info(f"Version is missing in dataset name: {string}. Defaulting to version='latest'.")
|
|
58
|
+
else:
|
|
59
|
+
raise ValueError(f"Version is missing in dataset name: {string}. Use 'name:version'")
|
|
60
|
+
elif len(parts) == 2:
|
|
61
|
+
dataset_name, version = parts
|
|
62
|
+
else:
|
|
63
|
+
raise ValueError(f"Invalid dataset name format: {string}. Use 'name' or 'name:version' ")
|
|
64
|
+
|
|
65
|
+
if not is_valid_version_string(version, allow_none=True, allow_latest=True):
|
|
66
|
+
raise ValueError(f"Invalid version string: {version}. Use semantic versioning e.g. '1.0.0' or 'latest'")
|
|
67
|
+
|
|
68
|
+
return dataset_name, version
|
|
69
|
+
|
|
12
70
|
|
|
13
71
|
def create_split_name_list_from_ratios(split_ratios: Dict[str, float], n_items: int, seed: int = 42) -> List[str]:
|
|
14
72
|
samples_per_split = split_sizes_from_ratios(split_ratios=split_ratios, n_items=n_items)
|
|
@@ -57,20 +115,6 @@ def save_pil_image_with_hash_name(image: Image.Image, path_folder: Path, allow_s
|
|
|
57
115
|
def copy_and_rename_file_to_hash_value(path_source: Path, path_dataset_root: Path) -> Path:
|
|
58
116
|
"""
|
|
59
117
|
Copies a file to a dataset root directory with a hash-based name and sub-directory structure.
|
|
60
|
-
|
|
61
|
-
E.g. for an "image.png" with hash "dfe8f3b1c2a4f5b6c7d8e9f0a1b2c3d4", the image will be copied to
|
|
62
|
-
'path_dataset_root / "data" / "dfe" / "dfe8f3b1c2a4f5b6c7d8e9f0a1b2c3d4.png"'
|
|
63
|
-
Notice that the hash is used for both the filename and the subfolder name.
|
|
64
|
-
|
|
65
|
-
Placing image/video files into multiple sub-folders (instead of one large folder) is seemingly
|
|
66
|
-
unnecessary, but it is actually a requirement when the dataset is later downloaded from S3.
|
|
67
|
-
|
|
68
|
-
The reason is that AWS has a rate limit of 3500 ops/sec per prefix (sub-folder) in S3 - meaning we can "only"
|
|
69
|
-
download 3500 files per second from a single folder (prefix) in S3.
|
|
70
|
-
|
|
71
|
-
For even a single user, we found that this limit was being reached when files are stored in single folder (prefix)
|
|
72
|
-
in S3. To support multiple users and concurrent experiments, we are required to separate files into
|
|
73
|
-
multiple sub-folders (prefixes) in S3 to not hit the rate limit.
|
|
74
118
|
"""
|
|
75
119
|
|
|
76
120
|
if not path_source.exists():
|
|
@@ -86,7 +130,7 @@ def copy_and_rename_file_to_hash_value(path_source: Path, path_dataset_root: Pat
|
|
|
86
130
|
|
|
87
131
|
|
|
88
132
|
def relative_path_from_hash(hash: str, suffix: str) -> Path:
|
|
89
|
-
path_file = Path("data") /
|
|
133
|
+
path_file = Path("data") / f"{hash}{suffix}"
|
|
90
134
|
return path_file
|
|
91
135
|
|
|
92
136
|
|
hafnia/dataset/dataset_names.py
CHANGED
|
@@ -1,8 +1,5 @@
|
|
|
1
1
|
from enum import Enum
|
|
2
|
-
from typing import
|
|
3
|
-
|
|
4
|
-
import boto3
|
|
5
|
-
from pydantic import BaseModel, field_validator
|
|
2
|
+
from typing import List
|
|
6
3
|
|
|
7
4
|
FILENAME_RECIPE_JSON = "recipe.json"
|
|
8
5
|
FILENAME_DATASET_INFO = "dataset_info.json"
|
|
@@ -124,93 +121,3 @@ class DatasetVariant(Enum):
|
|
|
124
121
|
DUMP = "dump"
|
|
125
122
|
SAMPLE = "sample"
|
|
126
123
|
HIDDEN = "hidden"
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
class AwsCredentials(BaseModel):
|
|
130
|
-
access_key: str
|
|
131
|
-
secret_key: str
|
|
132
|
-
session_token: str
|
|
133
|
-
region: Optional[str]
|
|
134
|
-
|
|
135
|
-
def aws_credentials(self) -> Dict[str, str]:
|
|
136
|
-
"""
|
|
137
|
-
Returns the AWS credentials as a dictionary.
|
|
138
|
-
"""
|
|
139
|
-
environment_vars = {
|
|
140
|
-
"AWS_ACCESS_KEY_ID": self.access_key,
|
|
141
|
-
"AWS_SECRET_ACCESS_KEY": self.secret_key,
|
|
142
|
-
"AWS_SESSION_TOKEN": self.session_token,
|
|
143
|
-
}
|
|
144
|
-
if self.region:
|
|
145
|
-
environment_vars["AWS_REGION"] = self.region
|
|
146
|
-
|
|
147
|
-
return environment_vars
|
|
148
|
-
|
|
149
|
-
@staticmethod
|
|
150
|
-
def from_session(session: boto3.Session) -> "AwsCredentials":
|
|
151
|
-
"""
|
|
152
|
-
Creates AwsCredentials from a Boto3 session.
|
|
153
|
-
"""
|
|
154
|
-
frozen_credentials = session.get_credentials().get_frozen_credentials()
|
|
155
|
-
return AwsCredentials(
|
|
156
|
-
access_key=frozen_credentials.access_key,
|
|
157
|
-
secret_key=frozen_credentials.secret_key,
|
|
158
|
-
session_token=frozen_credentials.token,
|
|
159
|
-
region=session.region_name,
|
|
160
|
-
)
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
ARN_PREFIX = "arn:aws:s3:::"
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
class ResourceCredentials(AwsCredentials):
|
|
167
|
-
s3_arn: str
|
|
168
|
-
|
|
169
|
-
@staticmethod
|
|
170
|
-
def fix_naming(payload: Dict[str, str]) -> "ResourceCredentials":
|
|
171
|
-
"""
|
|
172
|
-
The endpoint returns a payload with a key called 's3_path', but it
|
|
173
|
-
is actually an ARN path (starts with arn:aws:s3::). This method renames it to 's3_arn' for consistency.
|
|
174
|
-
"""
|
|
175
|
-
if "s3_path" in payload and payload["s3_path"].startswith(ARN_PREFIX):
|
|
176
|
-
payload["s3_arn"] = payload.pop("s3_path")
|
|
177
|
-
|
|
178
|
-
if "region" not in payload:
|
|
179
|
-
payload["region"] = "eu-west-1"
|
|
180
|
-
return ResourceCredentials(**payload)
|
|
181
|
-
|
|
182
|
-
@field_validator("s3_arn")
|
|
183
|
-
@classmethod
|
|
184
|
-
def validate_s3_arn(cls, value: str) -> str:
|
|
185
|
-
"""Validate s3_arn to ensure it starts with 'arn:aws:s3:::'"""
|
|
186
|
-
if not value.startswith("arn:aws:s3:::"):
|
|
187
|
-
raise ValueError(f"Invalid S3 ARN: {value}. It should start with 'arn:aws:s3:::'")
|
|
188
|
-
return value
|
|
189
|
-
|
|
190
|
-
def s3_path(self) -> str:
|
|
191
|
-
"""
|
|
192
|
-
Extracts the S3 path from the ARN.
|
|
193
|
-
Example: arn:aws:s3:::my-bucket/my-prefix -> my-bucket/my-prefix
|
|
194
|
-
"""
|
|
195
|
-
return self.s3_arn[len(ARN_PREFIX) :]
|
|
196
|
-
|
|
197
|
-
def s3_uri(self) -> str:
|
|
198
|
-
"""
|
|
199
|
-
Converts the S3 ARN to a URI format.
|
|
200
|
-
Example: arn:aws:s3:::my-bucket/my-prefix -> s3://my-bucket/my-prefix
|
|
201
|
-
"""
|
|
202
|
-
return f"s3://{self.s3_path()}"
|
|
203
|
-
|
|
204
|
-
def bucket_name(self) -> str:
|
|
205
|
-
"""
|
|
206
|
-
Extracts the bucket name from the S3 ARN.
|
|
207
|
-
Example: arn:aws:s3:::my-bucket/my-prefix -> my-bucket
|
|
208
|
-
"""
|
|
209
|
-
return self.s3_path().split("/")[0]
|
|
210
|
-
|
|
211
|
-
def object_key(self) -> str:
|
|
212
|
-
"""
|
|
213
|
-
Extracts the object key from the S3 ARN.
|
|
214
|
-
Example: arn:aws:s3:::my-bucket/my-prefix -> my-prefix
|
|
215
|
-
"""
|
|
216
|
-
return "/".join(self.s3_path().split("/")[1:])
|
|
@@ -11,14 +11,19 @@ from pydantic import (
|
|
|
11
11
|
)
|
|
12
12
|
|
|
13
13
|
from hafnia import utils
|
|
14
|
+
from hafnia.dataset.dataset_helpers import dataset_name_and_version_from_string
|
|
14
15
|
from hafnia.dataset.dataset_recipe import recipe_transforms
|
|
15
16
|
from hafnia.dataset.dataset_recipe.recipe_types import (
|
|
16
17
|
RecipeCreation,
|
|
17
18
|
RecipeTransform,
|
|
18
19
|
Serializable,
|
|
19
20
|
)
|
|
20
|
-
from hafnia.dataset.hafnia_dataset import
|
|
21
|
+
from hafnia.dataset.hafnia_dataset import (
|
|
22
|
+
HafniaDataset,
|
|
23
|
+
available_dataset_versions_from_name,
|
|
24
|
+
)
|
|
21
25
|
from hafnia.dataset.primitives.primitive import Primitive
|
|
26
|
+
from hafnia.log import user_logger
|
|
22
27
|
|
|
23
28
|
|
|
24
29
|
class DatasetRecipe(Serializable):
|
|
@@ -41,8 +46,31 @@ class DatasetRecipe(Serializable):
|
|
|
41
46
|
|
|
42
47
|
### Creation Methods (using the 'from_X' )###
|
|
43
48
|
@staticmethod
|
|
44
|
-
def from_name(
|
|
45
|
-
|
|
49
|
+
def from_name(
|
|
50
|
+
name: str,
|
|
51
|
+
version: Optional[str] = None,
|
|
52
|
+
force_redownload: bool = False,
|
|
53
|
+
download_files: bool = True,
|
|
54
|
+
) -> DatasetRecipe:
|
|
55
|
+
if version == "latest":
|
|
56
|
+
user_logger.info(
|
|
57
|
+
f"The dataset '{name}' in a dataset recipe uses 'latest' as version. For dataset recipes the "
|
|
58
|
+
"version is pinned to a specific version. Consider specifying a specific version to ensure "
|
|
59
|
+
"reproducibility of your experiments. "
|
|
60
|
+
)
|
|
61
|
+
available_versions = available_dataset_versions_from_name(name)
|
|
62
|
+
version = str(max(available_versions))
|
|
63
|
+
if version is None:
|
|
64
|
+
available_versions = available_dataset_versions_from_name(name)
|
|
65
|
+
str_versions = ", ".join([str(v) for v in available_versions])
|
|
66
|
+
raise ValueError(
|
|
67
|
+
f"Version must be specified when creating a DatasetRecipe from name. "
|
|
68
|
+
f"Available versions are: {str_versions}"
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
creation = FromName(
|
|
72
|
+
name=name, version=version, force_redownload=force_redownload, download_files=download_files
|
|
73
|
+
)
|
|
46
74
|
return DatasetRecipe(creation=creation)
|
|
47
75
|
|
|
48
76
|
@staticmethod
|
|
@@ -125,6 +153,21 @@ class DatasetRecipe(Serializable):
|
|
|
125
153
|
recipe_id = recipe["id"]
|
|
126
154
|
return DatasetRecipe.from_recipe_id(recipe_id)
|
|
127
155
|
|
|
156
|
+
@staticmethod
|
|
157
|
+
def from_name_and_version_string(string: str, resolve_missing_version: bool = False) -> "DatasetRecipe":
|
|
158
|
+
"""
|
|
159
|
+
Validates and converts a dataset name and version string (name:version) to a DatasetRecipe.from_name recipe.
|
|
160
|
+
If version is missing and 'resolve_missing_version' is True, it will default to 'latest'.
|
|
161
|
+
If resolve_missing_version is False, it will raise an error if version is missing.
|
|
162
|
+
"""
|
|
163
|
+
|
|
164
|
+
dataset_name, version = dataset_name_and_version_from_string(
|
|
165
|
+
string=string,
|
|
166
|
+
resolve_missing_version=resolve_missing_version,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
return DatasetRecipe.from_name(name=dataset_name, version=version)
|
|
170
|
+
|
|
128
171
|
@staticmethod
|
|
129
172
|
def from_implicit_form(recipe: Any) -> DatasetRecipe:
|
|
130
173
|
"""
|
|
@@ -180,7 +223,7 @@ class DatasetRecipe(Serializable):
|
|
|
180
223
|
return recipe
|
|
181
224
|
|
|
182
225
|
if isinstance(recipe, str): # str-type is convert to DatasetFromName
|
|
183
|
-
return DatasetRecipe.
|
|
226
|
+
return DatasetRecipe.from_name_and_version_string(string=recipe, resolve_missing_version=True)
|
|
184
227
|
|
|
185
228
|
if isinstance(recipe, Path): # Path-type is convert to DatasetFromPath
|
|
186
229
|
return DatasetRecipe.from_path(path_folder=recipe)
|
|
@@ -409,6 +452,7 @@ class FromPath(RecipeCreation):
|
|
|
409
452
|
|
|
410
453
|
class FromName(RecipeCreation):
|
|
411
454
|
name: str
|
|
455
|
+
version: Optional[str] = None
|
|
412
456
|
force_redownload: bool = False
|
|
413
457
|
download_files: bool = True
|
|
414
458
|
|
|
@@ -40,7 +40,7 @@ def mnist_as_hafnia_dataset(force_redownload=False, n_samples: Optional[int] = N
|
|
|
40
40
|
|
|
41
41
|
dataset_info = DatasetInfo(
|
|
42
42
|
dataset_name="mnist",
|
|
43
|
-
version="1.
|
|
43
|
+
version="1.0.0",
|
|
44
44
|
tasks=tasks,
|
|
45
45
|
reference_bibtex=textwrap.dedent("""\
|
|
46
46
|
@article{lecun2010mnist,
|
|
@@ -78,7 +78,7 @@ def caltech_101_as_hafnia_dataset(
|
|
|
78
78
|
n_samples=n_samples,
|
|
79
79
|
dataset_name=dataset_name,
|
|
80
80
|
)
|
|
81
|
-
hafnia_dataset.info.version = "1.
|
|
81
|
+
hafnia_dataset.info.version = "1.0.0"
|
|
82
82
|
hafnia_dataset.info.reference_bibtex = textwrap.dedent("""\
|
|
83
83
|
@article{FeiFei2004LearningGV,
|
|
84
84
|
title={Learning Generative Visual Models from Few Training Examples: An Incremental Bayesian
|
|
@@ -108,7 +108,7 @@ def caltech_256_as_hafnia_dataset(
|
|
|
108
108
|
n_samples=n_samples,
|
|
109
109
|
dataset_name=dataset_name,
|
|
110
110
|
)
|
|
111
|
-
hafnia_dataset.info.version = "1.
|
|
111
|
+
hafnia_dataset.info.version = "1.0.0"
|
|
112
112
|
hafnia_dataset.info.reference_bibtex = textwrap.dedent("""\
|
|
113
113
|
@misc{griffin_2023_5sv1j-ytw97,
|
|
114
114
|
author = {Griffin, Gregory and
|
|
@@ -150,7 +150,7 @@ def cifar_as_hafnia_dataset(
|
|
|
150
150
|
|
|
151
151
|
dataset_info = DatasetInfo(
|
|
152
152
|
dataset_name=dataset_name,
|
|
153
|
-
version="1.
|
|
153
|
+
version="1.0.0",
|
|
154
154
|
tasks=tasks,
|
|
155
155
|
reference_bibtex=textwrap.dedent("""\
|
|
156
156
|
@@TECHREPORT{Krizhevsky09learningmultiple,
|
|
@@ -268,7 +268,10 @@ def _download_and_extract_caltech_dataset(dataset_name: str, force_redownload: b
|
|
|
268
268
|
path_output_extracted = path_tmp_output / "caltech-101"
|
|
269
269
|
for gzip_file in os.listdir(path_output_extracted):
|
|
270
270
|
if gzip_file.endswith(".gz"):
|
|
271
|
-
extract_archive(
|
|
271
|
+
extract_archive(
|
|
272
|
+
from_path=os.path.join(path_output_extracted, gzip_file),
|
|
273
|
+
to_path=path_output_extracted,
|
|
274
|
+
)
|
|
272
275
|
path_org = path_output_extracted / "101_ObjectCategories"
|
|
273
276
|
|
|
274
277
|
elif dataset_name == "caltech-256":
|