hafnia 0.1.26__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. cli/__main__.py +2 -2
  2. cli/dataset_cmds.py +60 -0
  3. cli/runc_cmds.py +1 -1
  4. hafnia/data/__init__.py +2 -2
  5. hafnia/data/factory.py +9 -56
  6. hafnia/dataset/dataset_helpers.py +91 -0
  7. hafnia/dataset/dataset_names.py +71 -0
  8. hafnia/dataset/dataset_transformation.py +187 -0
  9. hafnia/dataset/dataset_upload_helper.py +468 -0
  10. hafnia/dataset/hafnia_dataset.py +453 -0
  11. hafnia/dataset/primitives/__init__.py +16 -0
  12. hafnia/dataset/primitives/bbox.py +137 -0
  13. hafnia/dataset/primitives/bitmask.py +182 -0
  14. hafnia/dataset/primitives/classification.py +56 -0
  15. hafnia/dataset/primitives/point.py +25 -0
  16. hafnia/dataset/primitives/polygon.py +100 -0
  17. hafnia/dataset/primitives/primitive.py +44 -0
  18. hafnia/dataset/primitives/segmentation.py +51 -0
  19. hafnia/dataset/primitives/utils.py +51 -0
  20. hafnia/dataset/table_transformations.py +183 -0
  21. hafnia/experiment/hafnia_logger.py +2 -2
  22. hafnia/helper_testing.py +63 -0
  23. hafnia/http.py +5 -3
  24. hafnia/platform/__init__.py +2 -2
  25. hafnia/platform/builder.py +25 -19
  26. hafnia/platform/datasets.py +184 -0
  27. hafnia/platform/download.py +85 -23
  28. hafnia/torch_helpers.py +180 -95
  29. hafnia/utils.py +1 -1
  30. hafnia/visualizations/colors.py +267 -0
  31. hafnia/visualizations/image_visualizations.py +202 -0
  32. {hafnia-0.1.26.dist-info → hafnia-0.2.0.dist-info}/METADATA +212 -99
  33. hafnia-0.2.0.dist-info/RECORD +46 -0
  34. cli/data_cmds.py +0 -53
  35. hafnia-0.1.26.dist-info/RECORD +0 -27
  36. {hafnia-0.1.26.dist-info → hafnia-0.2.0.dist-info}/WHEEL +0 -0
  37. {hafnia-0.1.26.dist-info → hafnia-0.2.0.dist-info}/entry_points.txt +0 -0
  38. {hafnia-0.1.26.dist-info → hafnia-0.2.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,183 @@
1
+ from pathlib import Path
2
+ from typing import List, Optional, Type
3
+
4
+ import polars as pl
5
+ from tqdm import tqdm
6
+
7
+ from hafnia.dataset import table_transformations
8
+ from hafnia.dataset.dataset_names import (
9
+ FILENAME_ANNOTATIONS_JSONL,
10
+ FILENAME_ANNOTATIONS_PARQUET,
11
+ FieldName,
12
+ )
13
+ from hafnia.dataset.primitives import PRIMITIVE_TYPES
14
+ from hafnia.dataset.primitives.classification import Classification
15
+ from hafnia.dataset.primitives.primitive import Primitive
16
+ from hafnia.log import user_logger
17
+
18
+
19
+ def create_primitive_table(
20
+ samples_table: pl.DataFrame, PrimitiveType: Type[Primitive], keep_sample_data: bool = False
21
+ ) -> Optional[pl.DataFrame]:
22
+ """
23
+ Returns a DataFrame with objects of the specified primitive type.
24
+ """
25
+ column_name = PrimitiveType.column_name()
26
+ has_primitive_column = (column_name in samples_table.columns) and (
27
+ samples_table[column_name].dtype == pl.List(pl.Struct)
28
+ )
29
+ if not has_primitive_column:
30
+ return None
31
+
32
+ # Remove frames without objects
33
+ remove_no_object_frames = samples_table.filter(pl.col(column_name).list.len() > 0)
34
+
35
+ if keep_sample_data:
36
+ # Drop other primitive columns to avoid conflicts
37
+ drop_columns = set(PRIMITIVE_TYPES) - {PrimitiveType, Classification}
38
+ remove_no_object_frames = remove_no_object_frames.drop(*[primitive.column_name() for primitive in drop_columns])
39
+ # Rename columns "height", "width" and "meta" for sample to avoid conflicts with object fields names
40
+ remove_no_object_frames = remove_no_object_frames.rename(
41
+ {"height": "image.height", "width": "image.width", "meta": "image.meta"}
42
+ )
43
+ objects_df = remove_no_object_frames.explode(column_name).unnest(column_name)
44
+ else:
45
+ objects_df = remove_no_object_frames.select(pl.col(column_name).explode().struct.unnest())
46
+ return objects_df
47
+
48
+
49
+ def filter_table_for_class_names(
50
+ samples_table: pl.DataFrame, class_names: List[str], PrimitiveType: Type[Primitive]
51
+ ) -> Optional[pl.DataFrame]:
52
+ table_with_selected_class_names = samples_table.filter(
53
+ pl.col(PrimitiveType.column_name())
54
+ .list.eval(pl.element().struct.field(FieldName.CLASS_NAME).is_in(class_names))
55
+ .list.any()
56
+ )
57
+
58
+ return table_with_selected_class_names
59
+
60
+
61
+ def split_primitive_columns_by_task_name(
62
+ samples_table: pl.DataFrame,
63
+ coordinate_types: Optional[List[Type[Primitive]]] = None,
64
+ ) -> pl.DataFrame:
65
+ """
66
+ Convert Primitive columns such as "objects" (Bbox) into a column for each task name.
67
+ For example, if the "objects" column (containing Bbox objects) has tasks "task1" and "task2".
68
+
69
+
70
+ This:
71
+ ─┬────────────┬─
72
+ ┆ objects ┆
73
+ ┆ --- ┆
74
+ ┆ list[struc ┆
75
+ ┆ t[11]] ┆
76
+ ═╪════════════╪═
77
+ becomes this:
78
+ ─┬────────────┬────────────┬─
79
+ ┆ objects. ┆ objects. ┆
80
+ ┆ task1 ┆ task2 ┆
81
+ ┆ --- ┆ --- ┆
82
+ ┆ list[struc ┆ list[struc ┆
83
+ ┆ t[11]] ┆ t[13]] ┆
84
+ ═╪════════════╪════════════╪═
85
+
86
+ """
87
+ coordinate_types = coordinate_types or PRIMITIVE_TYPES
88
+ for PrimitiveType in coordinate_types:
89
+ col_name = PrimitiveType.column_name()
90
+
91
+ if col_name not in samples_table.columns:
92
+ continue
93
+
94
+ if samples_table[col_name].dtype != pl.List(pl.Struct):
95
+ continue
96
+
97
+ task_names = samples_table[col_name].explode().struct.field(FieldName.TASK_NAME).unique().to_list()
98
+ samples_table = samples_table.with_columns(
99
+ [
100
+ pl.col(col_name)
101
+ .list.filter(pl.element().struct.field(FieldName.TASK_NAME).eq(task_name))
102
+ .alias(f"{col_name}.{task_name}")
103
+ for task_name in task_names
104
+ ]
105
+ )
106
+ samples_table = samples_table.drop(col_name)
107
+ return samples_table
108
+
109
+
110
+ def read_table_from_path(path: Path) -> pl.DataFrame:
111
+ path_annotations = path / FILENAME_ANNOTATIONS_PARQUET
112
+ if path_annotations.exists():
113
+ user_logger.info(f"Reading dataset annotations from Parquet file: {path_annotations}")
114
+ return pl.read_parquet(path_annotations)
115
+
116
+ path_annotations_jsonl = path / FILENAME_ANNOTATIONS_JSONL
117
+ if path_annotations_jsonl.exists():
118
+ user_logger.info(f"Reading dataset annotations from JSONL file: {path_annotations_jsonl}")
119
+ return pl.read_ndjson(path_annotations_jsonl)
120
+
121
+ raise FileNotFoundError(
122
+ f"Unable to read annotations. No json file '{path_annotations.name}' or Parquet file '{{path_annotations.name}} in in '{path}'."
123
+ )
124
+
125
+
126
+ def check_image_paths(table: pl.DataFrame) -> bool:
127
+ missing_files = []
128
+ for org_path in tqdm(table["file_name"].to_list(), desc="Check image paths"):
129
+ org_path = Path(org_path)
130
+ if not org_path.exists():
131
+ missing_files.append(org_path)
132
+
133
+ if len(missing_files) > 0:
134
+ user_logger.warning(f"Missing files: {len(missing_files)}. Show first 5:")
135
+ for missing_file in missing_files[:5]:
136
+ user_logger.warning(f" - {missing_file}")
137
+ raise FileNotFoundError(f"Some files are missing in the dataset: {len(missing_files)} files not found.")
138
+
139
+ return True
140
+
141
+
142
+ def unnest_classification_tasks(table: pl.DataFrame, strict: bool = True) -> pl.DataFrame:
143
+ """
144
+ Unnest classification tasks in table.
145
+ Classificiations tasks are all stored in the same column in the HafniaDataset table.
146
+ This function splits them into separate columns for each task name.
147
+
148
+ Type is converted from a list of structs (pl.List[pl.Struct]) to a struct (pl.Struct) column.
149
+
150
+ Converts classification column from this:
151
+ ─┬─────────────────┬─
152
+ ┆ classifications ┆
153
+ ┆ --- ┆
154
+ ┆ list[struct[6]] ┆
155
+ ═╪═════════════════╪═
156
+
157
+ For example, if the classification column has tasks "task1" and "task2",
158
+ ─┬──────────────────┬──────────────────┬─
159
+ ┆ classifications. ┆ classifications. ┆
160
+ ┆ task1 ┆ task2 ┆
161
+ ┆ --- ┆ --- ┆
162
+ ┆ struct[6] ┆ struct[6] ┆
163
+ ═╪══════════════════╪══════════════════╪═
164
+
165
+ """
166
+ coordinate_types = [Classification]
167
+ table_out = table_transformations.split_primitive_columns_by_task_name(table, coordinate_types=coordinate_types)
168
+
169
+ classification_columns = [c for c in table_out.columns if c.startswith(Classification.column_name() + ".")]
170
+ for classification_column in classification_columns:
171
+ has_multiple_items_per_sample = all(table_out[classification_column].list.len() > 1)
172
+ if has_multiple_items_per_sample:
173
+ if strict:
174
+ raise ValueError(
175
+ f"Column {classification_column} has multiple items per sample, but expected only one item."
176
+ )
177
+ else:
178
+ user_logger.warning(
179
+ f"Warning: Unnesting of column '{classification_column}' is skipped because it has multiple items per sample."
180
+ )
181
+
182
+ table_out = table_out.with_columns([pl.col(c).list.first() for c in classification_columns])
183
+ return table_out
@@ -9,10 +9,10 @@ from typing import Dict, Optional, Union
9
9
 
10
10
  import pyarrow as pa
11
11
  import pyarrow.parquet as pq
12
- from datasets import DatasetDict
13
12
  from pydantic import BaseModel, field_validator
14
13
 
15
14
  from hafnia.data.factory import load_dataset
15
+ from hafnia.dataset.hafnia_dataset import HafniaDataset
16
16
  from hafnia.log import sys_logger, user_logger
17
17
  from hafnia.utils import is_remote_job, now_as_str
18
18
 
@@ -92,7 +92,7 @@ class HafniaLogger:
92
92
  self.schema = Entity.create_schema()
93
93
  self.log_environment()
94
94
 
95
- def load_dataset(self, dataset_name: str) -> DatasetDict:
95
+ def load_dataset(self, dataset_name: str) -> HafniaDataset:
96
96
  """
97
97
  Load a dataset from the specified path.
98
98
  """
@@ -0,0 +1,63 @@
1
+ from pathlib import Path
2
+
3
+ from hafnia import utils
4
+ from hafnia.dataset.dataset_names import FILENAME_ANNOTATIONS_JSONL, DatasetVariant
5
+ from hafnia.dataset.hafnia_dataset import HafniaDataset, Sample
6
+
7
+ MICRO_DATASETS = {
8
+ "tiny-dataset": utils.PATH_DATASETS / "tiny-dataset",
9
+ "coco-2017": utils.PATH_DATASETS / "coco-2017",
10
+ }
11
+
12
+
13
+ def get_path_workspace() -> Path:
14
+ return Path(__file__).parents[2]
15
+
16
+
17
+ def get_path_expected_images() -> Path:
18
+ return get_path_workspace() / "tests" / "data" / "expected_images"
19
+
20
+
21
+ def get_path_test_data() -> Path:
22
+ return get_path_workspace() / "tests" / "data"
23
+
24
+
25
+ def get_path_micro_hafnia_dataset_no_check() -> Path:
26
+ return get_path_test_data() / "micro_test_datasets"
27
+
28
+
29
+ def get_path_micro_hafnia_dataset(dataset_name: str, force_update=False) -> Path:
30
+ import pytest
31
+
32
+ if dataset_name not in MICRO_DATASETS:
33
+ raise ValueError(f"Dataset name '{dataset_name}' is not recognized. Available options: {list(MICRO_DATASETS)}")
34
+ path_dataset = MICRO_DATASETS[dataset_name]
35
+
36
+ path_test_dataset = get_path_micro_hafnia_dataset_no_check() / dataset_name
37
+ path_test_dataset_annotations = path_test_dataset / FILENAME_ANNOTATIONS_JSONL
38
+ if path_test_dataset_annotations.exists() and not force_update:
39
+ return path_test_dataset
40
+
41
+ hafnia_dataset = HafniaDataset.read_from_path(path_dataset / DatasetVariant.SAMPLE.value)
42
+ hafnia_dataset = hafnia_dataset.sample(n_samples=3, seed=42)
43
+ hafnia_dataset.write(path_test_dataset)
44
+
45
+ if force_update:
46
+ pytest.fail(
47
+ "Sample image and metadata have been updated using 'force_update=True'. Set 'force_update=False' and rerun the test."
48
+ )
49
+ pytest.fail("Missing test sample image. Please rerun the test.")
50
+ return path_test_dataset
51
+
52
+
53
+ def get_sample_micro_hafnia_dataset(dataset_name: str, force_update=False) -> Sample:
54
+ micro_dataset = get_micro_hafnia_dataset(dataset_name=dataset_name, force_update=force_update)
55
+ sample_dict = micro_dataset[0]
56
+ sample = Sample(**sample_dict)
57
+ return sample
58
+
59
+
60
+ def get_micro_hafnia_dataset(dataset_name: str, force_update: bool = False) -> HafniaDataset:
61
+ path_dataset = get_path_micro_hafnia_dataset(dataset_name=dataset_name, force_update=force_update)
62
+ hafnia_dataset = HafniaDataset.read_from_path(path_dataset)
63
+ return hafnia_dataset
hafnia/http.py CHANGED
@@ -31,7 +31,7 @@ def fetch(endpoint: str, headers: Dict, params: Optional[Dict] = None) -> Dict:
31
31
  http.clear()
32
32
 
33
33
 
34
- def post(endpoint: str, headers: Dict, data: Union[Path, Dict, bytes], multipart: bool = False) -> Dict:
34
+ def post(endpoint: str, headers: Dict, data: Union[Path, Dict, bytes, str], multipart: bool = False) -> Dict:
35
35
  """Posts data to backend endpoint.
36
36
 
37
37
  Args:
@@ -64,9 +64,11 @@ def post(endpoint: str, headers: Dict, data: Union[Path, Dict, bytes], multipart
64
64
  with open(data, "rb") as f:
65
65
  body = f.read()
66
66
  response = http.request("POST", endpoint, body=body, headers=headers)
67
- elif isinstance(data, dict):
67
+ elif isinstance(data, (str, dict)):
68
+ if isinstance(data, dict):
69
+ data = json.dumps(data)
68
70
  headers["Content-Type"] = "application/json"
69
- response = http.request("POST", endpoint, body=json.dumps(data), headers=headers)
71
+ response = http.request("POST", endpoint, body=data, headers=headers)
70
72
  elif isinstance(data, bytes):
71
73
  response = http.request("POST", endpoint, body=data, headers=headers)
72
74
  else:
@@ -1,7 +1,7 @@
1
1
  from hafnia.platform.download import (
2
2
  download_resource,
3
3
  download_single_object,
4
- get_resource_creds,
4
+ get_resource_credentials,
5
5
  )
6
6
  from hafnia.platform.experiment import (
7
7
  create_experiment,
@@ -17,5 +17,5 @@ __all__ = [
17
17
  "create_experiment",
18
18
  "download_resource",
19
19
  "download_single_object",
20
- "get_resource_creds",
20
+ "get_resource_credentials",
21
21
  ]
@@ -1,7 +1,7 @@
1
1
  import json
2
2
  import os
3
+ import re
3
4
  import subprocess
4
- import tempfile
5
5
  import zipfile
6
6
  from hashlib import sha256
7
7
  from pathlib import Path
@@ -58,7 +58,7 @@ def buildx_available() -> bool:
58
58
  return False
59
59
 
60
60
 
61
- def build_dockerfile(dockerfile: str, docker_context: str, docker_tag: str, meta_file: str) -> None:
61
+ def build_dockerfile(dockerfile: str, docker_context: str, docker_tag: str) -> None:
62
62
  """
63
63
  Build a Docker image using the provided Dockerfile.
64
64
 
@@ -73,12 +73,12 @@ def build_dockerfile(dockerfile: str, docker_context: str, docker_tag: str, meta
73
73
 
74
74
  cmd = ["docker", "build", "--platform", "linux/amd64", "-t", docker_tag, "-f", dockerfile]
75
75
 
76
- remote_cache = os.getenv("REMOTE_CACHE_REPO")
76
+ remote_cache = os.getenv("EXPERIMENT_CACHE_ECR")
77
77
  cloud_mode = os.getenv("HAFNIA_CLOUD", "false").lower() in ["true", "1", "yes"]
78
78
 
79
79
  if buildx_available():
80
80
  cmd.insert(1, "buildx")
81
- cmd += ["--build-arg", "BUILDKIT_INLINE_CACHE=1", "--metadata-file", meta_file]
81
+ cmd += ["--build-arg", "BUILDKIT_INLINE_CACHE=1"]
82
82
  if cloud_mode:
83
83
  cmd += ["--push"]
84
84
  if remote_cache:
@@ -91,11 +91,27 @@ def build_dockerfile(dockerfile: str, docker_context: str, docker_tag: str, meta
91
91
  cmd.append(docker_context)
92
92
  sys_logger.debug("Build cmd: `{}`".format(" ".join(cmd)))
93
93
  sys_logger.info(f"Building and pushing Docker image with BuildKit (buildx); cache repo: {remote_cache or 'none'}")
94
+ result = None
95
+ output = ""
96
+ errors = []
94
97
  try:
95
- subprocess.run(cmd, check=True)
98
+ result = subprocess.run(cmd, check=True, capture_output=True, text=True)
99
+ output = (result.stdout or "") + (result.stderr or "")
96
100
  except subprocess.CalledProcessError as e:
97
- sys_logger.error(f"Docker build failed: {e}")
98
- raise RuntimeError(f"Docker build failed: {e}")
101
+ output = (e.stdout or "") + (e.stderr or "")
102
+ error_pattern = r"ERROR: (.+?)(?:\n|$)"
103
+ errors = re.findall(error_pattern, output)
104
+ if not errors:
105
+ raise RuntimeError(f"Docker build failed: {output}")
106
+ if re.search(r"image tag '([^']+)' already exists", errors[-1]):
107
+ sys_logger.warning("Image {} already exists in the registry.".format(docker_tag.rsplit("/")[-1]))
108
+ return
109
+ raise RuntimeError(f"Docker build failed: {output}")
110
+ finally:
111
+ stage_pattern = r"^.*\[\d+/\d+\][^\n]*"
112
+ stages = re.findall(stage_pattern, output, re.MULTILINE)
113
+ user_logger.info("\n".join(stages))
114
+ sys_logger.debug(output)
99
115
 
100
116
 
101
117
  def check_registry(docker_image: str) -> Optional[str]:
@@ -127,18 +143,8 @@ def build_image(metadata: Dict, registry_repo: str, state_file: str = "state.jso
127
143
  docker_image = f"{registry_repo}:{metadata['digest']}"
128
144
  image_exists = check_registry(docker_image) is not None
129
145
  if image_exists:
130
- sys_logger.info(f"Tag already in ECR – skipping build of {docker_image}.")
146
+ sys_logger.info("Image {} already exists in the registry.".format(docker_image.rsplit("/")[-1]))
131
147
  else:
132
- with tempfile.NamedTemporaryFile() as meta_tmp:
133
- meta_file = meta_tmp.name
134
- build_dockerfile(
135
- metadata["dockerfile"], Path(metadata["dockerfile"]).parent.as_posix(), docker_image, meta_file
136
- )
137
- with open(meta_file) as m:
138
- try:
139
- build_meta = json.load(m)
140
- metadata["local_digest"] = build_meta["containerimage.digest"]
141
- except Exception:
142
- metadata["local_digest"] = ""
148
+ build_dockerfile(metadata["dockerfile"], Path(metadata["dockerfile"]).parent.as_posix(), docker_image)
143
149
  metadata.update({"image_tag": docker_image, "image_exists": image_exists})
144
150
  Path(state_file).write_text(json.dumps(metadata, indent=2))
@@ -0,0 +1,184 @@
1
+ import os
2
+ import shutil
3
+ import subprocess
4
+ import tempfile
5
+ from pathlib import Path
6
+ from typing import Any, Dict, List, Optional
7
+
8
+ import rich
9
+ from tqdm import tqdm
10
+
11
+ from cli.config import Config
12
+ from hafnia import utils
13
+ from hafnia.dataset import dataset_names
14
+ from hafnia.dataset.hafnia_dataset import HafniaDataset
15
+ from hafnia.http import fetch
16
+ from hafnia.log import user_logger
17
+ from hafnia.platform import get_dataset_id
18
+ from hafnia.platform.download import get_resource_credentials
19
+ from hafnia.utils import timed
20
+
21
+
22
+ @timed("Fetching dataset list.")
23
+ def dataset_list(cfg: Optional[Config] = None) -> List[Dict[str, str]]:
24
+ """List available datasets on the Hafnia platform."""
25
+ cfg = cfg or Config()
26
+ endpoint_dataset = cfg.get_platform_endpoint("datasets")
27
+ header = {"Authorization": cfg.api_key}
28
+ datasets: List[Dict[str, str]] = fetch(endpoint_dataset, headers=header) # type: ignore
29
+ if not datasets:
30
+ raise ValueError("No datasets found on the Hafnia platform.")
31
+
32
+ return datasets
33
+
34
+
35
+ def download_or_get_dataset_path(
36
+ dataset_name: str,
37
+ cfg: Optional[Config] = None,
38
+ path_datasets_folder: Optional[str] = None,
39
+ force_redownload: bool = False,
40
+ ) -> Path:
41
+ """Download or get the path of the dataset."""
42
+ if utils.is_remote_job():
43
+ return Path(os.getenv("MDI_DATASET_DIR", "/opt/ml/input/data/training"))
44
+
45
+ path_datasets_folder = path_datasets_folder or str(utils.PATH_DATASETS)
46
+ path_dataset = Path(path_datasets_folder).absolute() / dataset_name
47
+
48
+ is_dataset_valid = HafniaDataset.check_dataset_path(path_dataset, raise_error=False)
49
+ if is_dataset_valid and not force_redownload:
50
+ user_logger.info("Dataset found locally. Set 'force=True' or add `--force` flag with cli to re-download")
51
+ return path_dataset
52
+
53
+ cfg = cfg or Config()
54
+ api_key = cfg.api_key
55
+
56
+ shutil.rmtree(path_dataset, ignore_errors=True)
57
+
58
+ endpoint_dataset = cfg.get_platform_endpoint("datasets")
59
+ dataset_id = get_dataset_id(dataset_name=dataset_name, endpoint=endpoint_dataset, api_key=api_key)
60
+ access_dataset_endpoint = f"{endpoint_dataset}/{dataset_id}/temporary-credentials"
61
+
62
+ download_dataset_from_access_endpoint(
63
+ endpoint=access_dataset_endpoint,
64
+ api_key=api_key,
65
+ path_dataset=path_dataset,
66
+ )
67
+ return path_dataset
68
+
69
+
70
+ def download_dataset_from_access_endpoint(endpoint: str, api_key: str, path_dataset: Path) -> None:
71
+ resource_credentials = get_resource_credentials(endpoint, api_key)
72
+
73
+ local_dataset_paths = [str(path_dataset / filename) for filename in dataset_names.DATASET_FILENAMES]
74
+ s3_uri = resource_credentials.s3_uri()
75
+ s3_dataset_files = [f"{s3_uri}/{filename}" for filename in dataset_names.DATASET_FILENAMES]
76
+
77
+ envs = resource_credentials.aws_credentials()
78
+ fast_copy_files_s3(
79
+ src_paths=s3_dataset_files,
80
+ dst_paths=local_dataset_paths,
81
+ append_envs=envs,
82
+ description="Downloading annotations",
83
+ )
84
+
85
+ dataset = HafniaDataset.read_from_path(path_dataset, check_for_images=False)
86
+ fast_copy_files_s3(
87
+ src_paths=dataset.samples[dataset_names.ColumnName.REMOTE_PATH].to_list(),
88
+ dst_paths=dataset.samples[dataset_names.ColumnName.FILE_NAME].to_list(),
89
+ append_envs=envs,
90
+ description="Downloading images",
91
+ )
92
+
93
+
94
+ def fast_copy_files_s3(
95
+ src_paths: List[str],
96
+ dst_paths: List[str],
97
+ append_envs: Optional[Dict[str, str]] = None,
98
+ description: str = "Copying files",
99
+ ) -> List[str]:
100
+ if len(src_paths) != len(dst_paths):
101
+ raise ValueError("Source and destination paths must have the same length.")
102
+
103
+ cmds = [f"cp {src} {dst}" for src, dst in zip(src_paths, dst_paths)]
104
+ lines = execute_s5cmd_commands(cmds, append_envs=append_envs, description=description)
105
+ return lines
106
+
107
+
108
+ def execute_s5cmd_commands(
109
+ commands: List[str],
110
+ append_envs: Optional[Dict[str, str]] = None,
111
+ description: str = "Executing s5cmd commands",
112
+ ) -> List[str]:
113
+ append_envs = append_envs or {}
114
+ with tempfile.NamedTemporaryFile(suffix=".txt") as tmp_file:
115
+ tmp_file_path = Path(tmp_file.name)
116
+ tmp_file_path.write_text("\n".join(commands))
117
+ run_cmds = [
118
+ "s5cmd",
119
+ "run",
120
+ str(tmp_file_path),
121
+ ]
122
+ envs = os.environ.copy()
123
+ envs.update(append_envs)
124
+
125
+ process = subprocess.Popen(
126
+ run_cmds,
127
+ stdout=subprocess.PIPE,
128
+ stderr=subprocess.STDOUT,
129
+ universal_newlines=True,
130
+ env=envs,
131
+ )
132
+
133
+ error_lines = []
134
+ lines = []
135
+ for line in tqdm(process.stdout, total=len(commands), desc=description):
136
+ if "ERROR" in line or "error" in line:
137
+ error_lines.append(line.strip())
138
+ lines.append(line.strip())
139
+
140
+ if len(error_lines) > 0:
141
+ show_n_lines = min(5, len(error_lines))
142
+ str_error_lines = "\n".join(error_lines[:show_n_lines])
143
+ user_logger.error(
144
+ f"Detected {len(error_lines)} errors occurred while executing a total of {len(commands)} "
145
+ f" commands with s5cmd. The first {show_n_lines} is printed below:\n{str_error_lines}"
146
+ )
147
+ raise RuntimeError("Errors occurred during s5cmd execution.")
148
+ return lines
149
+
150
+
151
+ TABLE_FIELDS = {
152
+ "ID": "id",
153
+ "Hidden\nSamples": "hidden.samples",
154
+ "Hidden\nSize": "hidden.size",
155
+ "Sample\nSamples": "sample.samples",
156
+ "Sample\nSize": "sample.size",
157
+ "Name": "name",
158
+ "Title": "title",
159
+ }
160
+
161
+
162
+ def create_rich_table_from_dataset(datasets: List[Dict[str, str]]) -> rich.table.Table:
163
+ datasets = extend_dataset_details(datasets)
164
+ datasets = sorted(datasets, key=lambda x: x["name"].lower())
165
+
166
+ table = rich.table.Table(title="Available Datasets")
167
+ for i_dataset, dataset in enumerate(datasets):
168
+ if i_dataset == 0:
169
+ for column_name, _ in TABLE_FIELDS.items():
170
+ table.add_column(column_name, justify="left", style="cyan", no_wrap=True)
171
+ row = [str(dataset.get(field, "")) for field in TABLE_FIELDS.values()]
172
+ table.add_row(*row)
173
+
174
+ return table
175
+
176
+
177
+ def extend_dataset_details(datasets: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
178
+ """Extends dataset details with number of samples and size"""
179
+ for dataset in datasets:
180
+ for variant in dataset["dataset_variants"]:
181
+ variant_type = variant["variant_type"]
182
+ dataset[f"{variant_type}.samples"] = variant["number_of_data_items"]
183
+ dataset[f"{variant_type}.size"] = utils.size_human_readable(variant["size_bytes"])
184
+ return datasets