hafnia 0.1.27__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. cli/__main__.py +2 -2
  2. cli/dataset_cmds.py +60 -0
  3. cli/runc_cmds.py +1 -1
  4. hafnia/data/__init__.py +2 -2
  5. hafnia/data/factory.py +9 -56
  6. hafnia/dataset/dataset_helpers.py +91 -0
  7. hafnia/dataset/dataset_names.py +71 -0
  8. hafnia/dataset/dataset_transformation.py +187 -0
  9. hafnia/dataset/dataset_upload_helper.py +468 -0
  10. hafnia/dataset/hafnia_dataset.py +453 -0
  11. hafnia/dataset/primitives/__init__.py +16 -0
  12. hafnia/dataset/primitives/bbox.py +137 -0
  13. hafnia/dataset/primitives/bitmask.py +182 -0
  14. hafnia/dataset/primitives/classification.py +56 -0
  15. hafnia/dataset/primitives/point.py +25 -0
  16. hafnia/dataset/primitives/polygon.py +100 -0
  17. hafnia/dataset/primitives/primitive.py +44 -0
  18. hafnia/dataset/primitives/segmentation.py +51 -0
  19. hafnia/dataset/primitives/utils.py +51 -0
  20. hafnia/dataset/table_transformations.py +183 -0
  21. hafnia/experiment/hafnia_logger.py +2 -2
  22. hafnia/helper_testing.py +63 -0
  23. hafnia/http.py +5 -3
  24. hafnia/platform/__init__.py +2 -2
  25. hafnia/platform/datasets.py +184 -0
  26. hafnia/platform/download.py +85 -23
  27. hafnia/torch_helpers.py +180 -95
  28. hafnia/utils.py +1 -1
  29. hafnia/visualizations/colors.py +267 -0
  30. hafnia/visualizations/image_visualizations.py +202 -0
  31. {hafnia-0.1.27.dist-info → hafnia-0.2.0.dist-info}/METADATA +212 -99
  32. hafnia-0.2.0.dist-info/RECORD +46 -0
  33. cli/data_cmds.py +0 -53
  34. hafnia-0.1.27.dist-info/RECORD +0 -27
  35. {hafnia-0.1.27.dist-info → hafnia-0.2.0.dist-info}/WHEEL +0 -0
  36. {hafnia-0.1.27.dist-info → hafnia-0.2.0.dist-info}/entry_points.txt +0 -0
  37. {hafnia-0.1.27.dist-info → hafnia-0.2.0.dist-info}/licenses/LICENSE +0 -0
cli/__main__.py CHANGED
@@ -1,7 +1,7 @@
1
1
  #!/usr/bin/env python
2
2
  import click
3
3
 
4
- from cli import consts, data_cmds, experiment_cmds, profile_cmds, recipe_cmds, runc_cmds
4
+ from cli import consts, dataset_cmds, experiment_cmds, profile_cmds, recipe_cmds, runc_cmds
5
5
  from cli.config import Config, ConfigSchema
6
6
 
7
7
 
@@ -46,7 +46,7 @@ def clear(cfg: Config) -> None:
46
46
 
47
47
 
48
48
  main.add_command(profile_cmds.profile)
49
- main.add_command(data_cmds.data)
49
+ main.add_command(dataset_cmds.dataset)
50
50
  main.add_command(runc_cmds.runc)
51
51
  main.add_command(experiment_cmds.experiment)
52
52
  main.add_command(recipe_cmds.recipe)
cli/dataset_cmds.py ADDED
@@ -0,0 +1,60 @@
1
+ from pathlib import Path
2
+ from typing import Optional
3
+
4
+ import click
5
+ from rich import print as rprint
6
+
7
+ import cli.consts as consts
8
+ from cli.config import Config
9
+ from hafnia import utils
10
+ from hafnia.platform.datasets import create_rich_table_from_dataset
11
+
12
+
13
+ @click.group()
14
+ def dataset():
15
+ """Manage dataset interaction"""
16
+ pass
17
+
18
+
19
+ @dataset.command("ls")
20
+ @click.pass_obj
21
+ def dataset_list(cfg: Config) -> None:
22
+ """List available datasets on Hafnia platform"""
23
+
24
+ from hafnia.platform.datasets import dataset_list
25
+
26
+ try:
27
+ datasets = dataset_list(cfg=cfg)
28
+ except Exception:
29
+ raise click.ClickException(consts.ERROR_GET_RESOURCE)
30
+
31
+ table = create_rich_table_from_dataset(datasets)
32
+ rprint(table)
33
+
34
+
35
+ @dataset.command("download")
36
+ @click.argument("dataset_name")
37
+ @click.option(
38
+ "--destination",
39
+ "-d",
40
+ default=None,
41
+ required=False,
42
+ help=f"Destination folder to save the dataset. Defaults to '{utils.PATH_DATASETS}/<dataset_name>'",
43
+ )
44
+ @click.option("--force", "-f", is_flag=True, default=False, help="Flag to enable force redownload")
45
+ @click.pass_obj
46
+ def data_download(cfg: Config, dataset_name: str, destination: Optional[click.Path], force: bool) -> Path:
47
+ """Download dataset from Hafnia platform"""
48
+
49
+ from hafnia.platform import datasets
50
+
51
+ try:
52
+ path_dataset = datasets.download_or_get_dataset_path(
53
+ dataset_name=dataset_name,
54
+ cfg=cfg,
55
+ path_datasets_folder=destination,
56
+ force_redownload=force,
57
+ )
58
+ except Exception:
59
+ raise click.ClickException(consts.ERROR_GET_RESOURCE)
60
+ return path_dataset
cli/runc_cmds.py CHANGED
@@ -38,7 +38,7 @@ def runc():
38
38
  @click.pass_obj
39
39
  def launch_local(cfg: Config, exec_cmd: str, dataset: str, image_name: str) -> None:
40
40
  """Launch a job within the image."""
41
- from hafnia.data.factory import download_or_get_dataset_path
41
+ from hafnia.platform.datasets import download_or_get_dataset_path
42
42
 
43
43
  is_local_dataset = "/" in dataset
44
44
  if is_local_dataset:
hafnia/data/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
- from hafnia.data.factory import load_dataset
1
+ from hafnia.data.factory import get_dataset_path, load_dataset
2
2
 
3
- __all__ = ["load_dataset"]
3
+ __all__ = ["load_dataset", "get_dataset_path"]
hafnia/data/factory.py CHANGED
@@ -1,67 +1,20 @@
1
- import os
2
- import shutil
3
1
  from pathlib import Path
4
- from typing import Optional, Union
5
2
 
6
- from datasets import Dataset, DatasetDict, load_from_disk
3
+ from hafnia.dataset.hafnia_dataset import HafniaDataset
4
+ from hafnia.platform.datasets import download_or_get_dataset_path
7
5
 
8
- from cli.config import Config
9
- from hafnia import utils
10
- from hafnia.log import user_logger
11
- from hafnia.platform import download_resource, get_dataset_id
12
6
 
13
-
14
- def load_local(dataset_path: Path) -> Union[Dataset, DatasetDict]:
15
- """Load a Hugging Face dataset from a local directory path."""
16
- if not dataset_path.exists():
17
- raise ValueError(f"Can not load dataset, directory does not exist -- {dataset_path}")
18
- user_logger.info(f"Loading data from {dataset_path.as_posix()}")
19
- return load_from_disk(dataset_path.as_posix())
20
-
21
-
22
- def download_or_get_dataset_path(
23
- dataset_name: str,
24
- cfg: Optional[Config] = None,
25
- output_dir: Optional[str] = None,
26
- force_redownload: bool = False,
27
- ) -> Path:
28
- """Download or get the path of the dataset."""
29
-
30
- cfg = cfg or Config()
31
- endpoint_dataset = cfg.get_platform_endpoint("datasets")
32
- api_key = cfg.api_key
33
-
34
- output_dir = output_dir or str(utils.PATH_DATASET)
35
- dataset_path_base = Path(output_dir).absolute() / dataset_name
36
- dataset_path_base.mkdir(exist_ok=True, parents=True)
37
- dataset_path_sample = dataset_path_base / "sample"
38
-
39
- if dataset_path_sample.exists() and not force_redownload:
40
- user_logger.info("Dataset found locally. Set 'force=True' or add `--force` flag with cli to re-download")
41
- return dataset_path_sample
42
-
43
- dataset_id = get_dataset_id(dataset_name, endpoint_dataset, api_key)
44
- dataset_access_info_url = f"{endpoint_dataset}/{dataset_id}/temporary-credentials"
45
-
46
- if force_redownload and dataset_path_sample.exists():
47
- # Remove old files to avoid old files conflicting with new files
48
- shutil.rmtree(dataset_path_sample, ignore_errors=True)
49
- status = download_resource(dataset_access_info_url, str(dataset_path_base), api_key)
50
- if status:
51
- return dataset_path_sample
52
- raise RuntimeError("Failed to download dataset")
53
-
54
-
55
- def load_dataset(dataset_name: str, force_redownload: bool = False) -> Union[Dataset, DatasetDict]:
7
+ def load_dataset(dataset_name: str, force_redownload: bool = False) -> HafniaDataset:
56
8
  """Load a dataset either from a local path or from the Hafnia platform."""
57
9
 
58
- if utils.is_remote_job():
59
- path_dataset = Path(os.getenv("MDI_DATASET_DIR", "/opt/ml/input/data/training"))
60
- return load_local(path_dataset)
10
+ path_dataset = get_dataset_path(dataset_name, force_redownload=force_redownload)
11
+ dataset = HafniaDataset.read_from_path(path_dataset)
12
+ return dataset
13
+
61
14
 
15
+ def get_dataset_path(dataset_name: str, force_redownload: bool = False) -> Path:
62
16
  path_dataset = download_or_get_dataset_path(
63
17
  dataset_name=dataset_name,
64
18
  force_redownload=force_redownload,
65
19
  )
66
- dataset = load_local(path_dataset)
67
- return dataset
20
+ return path_dataset
@@ -0,0 +1,91 @@
1
+ import io
2
+ import math
3
+ import random
4
+ from pathlib import Path
5
+ from typing import Dict, List
6
+
7
+ import numpy as np
8
+ import xxhash
9
+ from PIL import Image
10
+
11
+
12
+ def create_split_name_list_from_ratios(split_ratios: Dict[str, float], n_items: int, seed: int = 42) -> List[str]:
13
+ samples_per_split = split_sizes_from_ratios(split_ratios=split_ratios, n_items=n_items)
14
+
15
+ split_name_column = []
16
+ for split_name, n_split_samples in samples_per_split.items():
17
+ split_name_column.extend([split_name] * n_split_samples)
18
+ random.Random(seed).shuffle(split_name_column) # Shuffle the split names
19
+
20
+ return split_name_column
21
+
22
+
23
+ def hash_file_xxhash(path: Path, chunk_size: int = 262144) -> str:
24
+ hasher = xxhash.xxh3_64()
25
+
26
+ with open(path, "rb") as f:
27
+ for chunk in iter(lambda: f.read(chunk_size), b""): # 8192, 16384, 32768, 65536
28
+ hasher.update(chunk)
29
+ return hasher.hexdigest()
30
+
31
+
32
+ def hash_from_bytes(data: bytes) -> str:
33
+ hasher = xxhash.xxh3_64()
34
+ hasher.update(data)
35
+ return hasher.hexdigest()
36
+
37
+
38
+ def save_image_with_hash_name(image: np.ndarray, path_folder: Path) -> Path:
39
+ pil_image = Image.fromarray(image)
40
+ buffer = io.BytesIO()
41
+ pil_image.save(buffer, format="PNG")
42
+ hash_value = hash_from_bytes(buffer.getvalue())
43
+ path_image = Path(path_folder) / f"{hash_value}.png"
44
+ pil_image.save(path_image)
45
+ return path_image
46
+
47
+
48
+ def filename_as_hash_from_path(path_image: Path) -> str:
49
+ hash = hash_file_xxhash(path_image)
50
+ return f"{hash}{path_image.suffix}"
51
+
52
+
53
+ def split_sizes_from_ratios(n_items: int, split_ratios: Dict[str, float]) -> Dict[str, int]:
54
+ summed_ratios = sum(split_ratios.values())
55
+ abs_tols = 0.0011 # Allow some tolerance for floating point errors {"test": 0.333, "val": 0.333, "train": 0.333}
56
+ if not math.isclose(summed_ratios, 1.0, abs_tol=abs_tols): # Allow tolerance to allow e.g. (0.333, 0.333, 0.333)
57
+ raise ValueError(f"Split ratios must sum to 1.0. The summed values of {split_ratios} is {summed_ratios}")
58
+
59
+ # recaculate split sizes
60
+ split_ratios = {split_name: split_ratio / summed_ratios for split_name, split_ratio in split_ratios.items()}
61
+ split_sizes = {split_name: int(n_items * split_ratio) for split_name, split_ratio in split_ratios.items()}
62
+
63
+ remaining_items = n_items - sum(split_sizes.values())
64
+ if remaining_items > 0: # Distribute remaining items evenly across splits
65
+ for _ in range(remaining_items):
66
+ # Select name by the largest error from the expected distribution
67
+ total_size = sum(split_sizes.values())
68
+ distribution_error = {
69
+ split_name: abs(split_ratios[split_name] - (size / total_size))
70
+ for split_name, size in split_sizes.items()
71
+ }
72
+
73
+ split_with_largest_error = sorted(distribution_error.items(), key=lambda x: x[1], reverse=True)[0][0]
74
+ split_sizes[split_with_largest_error] += 1
75
+
76
+ if sum(split_sizes.values()) != n_items:
77
+ raise ValueError("Something is wrong. The split sizes do not match the number of items.")
78
+
79
+ return split_sizes
80
+
81
+
82
+ def select_evenly_across_list(lst: list, num_samples: int):
83
+ if num_samples >= len(lst):
84
+ return lst # No need to sample
85
+ step = (len(lst) - 1) / (num_samples - 1)
86
+ indices = [int(round(step * i)) for i in range(num_samples)] # noqa: RUF046
87
+ return [lst[index] for index in indices]
88
+
89
+
90
+ def prefix_dict(d: dict, prefix: str) -> dict:
91
+ return {f"{prefix}.{k}": v for k, v in d.items()}
@@ -0,0 +1,71 @@
1
+ from enum import Enum
2
+ from typing import List
3
+
4
+ FILENAME_DATASET_INFO = "dataset_info.json"
5
+ FILENAME_ANNOTATIONS_JSONL = "annotations.jsonl"
6
+ FILENAME_ANNOTATIONS_PARQUET = "annotations.parquet"
7
+
8
+ DATASET_FILENAMES = [
9
+ FILENAME_DATASET_INFO,
10
+ FILENAME_ANNOTATIONS_JSONL,
11
+ FILENAME_ANNOTATIONS_PARQUET,
12
+ ]
13
+
14
+
15
+ class DeploymentStage(Enum):
16
+ STAGING = "staging"
17
+ PRODUCTION = "production"
18
+
19
+
20
+ class FieldName:
21
+ CLASS_NAME: str = "class_name" # Name of the class this primitive is associated with, e.g. "car" for Bbox
22
+ CLASS_IDX: str = (
23
+ "class_idx" # Index of the class this primitive is associated with, e.g. 0 for "car" if it is the first class
24
+ )
25
+ OBJECT_ID: str = "object_id" # Unique identifier for the object, e.g. "12345123"
26
+ CONFIDENCE: str = "confidence" # Confidence score (0-1.0) for the primitive, e.g. 0.95 for Bbox
27
+
28
+ META: str = "meta" # Contains metadata about each primitive, e.g. attributes color, occluded, iscrowd, etc.
29
+ TASK_NAME: str = "task_name" # Name of the task this primitive is associated with, e.g. "bboxes" for Bbox
30
+
31
+ @staticmethod
32
+ def fields() -> List[str]:
33
+ """
34
+ Returns a list of expected field names for primitives.
35
+ """
36
+ return [
37
+ FieldName.CLASS_NAME,
38
+ FieldName.CLASS_IDX,
39
+ FieldName.OBJECT_ID,
40
+ FieldName.CONFIDENCE,
41
+ FieldName.META,
42
+ FieldName.TASK_NAME,
43
+ ]
44
+
45
+
46
+ class ColumnName:
47
+ SAMPLE_INDEX: str = "sample_index"
48
+ FILE_NAME: str = "file_name"
49
+ HEIGHT: str = "height"
50
+ WIDTH: str = "width"
51
+ SPLIT: str = "split"
52
+ IS_SAMPLE: str = "is_sample"
53
+ REMOTE_PATH: str = "remote_path" # Path to the file in remote storage, e.g. S3
54
+ META: str = "meta"
55
+
56
+
57
+ class SplitName:
58
+ TRAIN = "train"
59
+ VAL = "validation"
60
+ TEST = "test"
61
+ UNDEFINED = "UNDEFINED"
62
+
63
+ @staticmethod
64
+ def valid_splits() -> List[str]:
65
+ return [SplitName.TRAIN, SplitName.VAL, SplitName.TEST]
66
+
67
+
68
+ class DatasetVariant(Enum):
69
+ DUMP = "dump"
70
+ SAMPLE = "sample"
71
+ HIDDEN = "hidden"
@@ -0,0 +1,187 @@
1
+ import hashlib
2
+ import shutil
3
+ from pathlib import Path
4
+ from random import Random
5
+ from typing import TYPE_CHECKING, Callable, Dict
6
+
7
+ import cv2
8
+ import numpy as np
9
+ import polars as pl
10
+ from PIL import Image
11
+ from tqdm import tqdm
12
+
13
+ from hafnia.dataset import dataset_helpers
14
+ from hafnia.dataset.dataset_names import ColumnName
15
+ from hafnia.log import user_logger
16
+
17
+ if TYPE_CHECKING:
18
+ from hafnia.dataset.hafnia_dataset import HafniaDataset
19
+
20
+
21
+ ### Image transformations ###
22
+ class AnonymizeByPixelation:
23
+ def __init__(self, resize_factor: float = 0.10):
24
+ self.resize_factor = resize_factor
25
+
26
+ def __call__(self, frame: np.ndarray) -> np.ndarray:
27
+ org_size = frame.shape[:2]
28
+ frame = cv2.resize(frame, (0, 0), fx=self.resize_factor, fy=self.resize_factor)
29
+ frame = cv2.resize(frame, org_size[::-1], interpolation=cv2.INTER_NEAREST)
30
+ return frame
31
+
32
+
33
+ def splits_by_ratios(dataset: "HafniaDataset", split_ratios: Dict[str, float], seed: int = 42) -> "HafniaDataset":
34
+ """
35
+ Divides the dataset into splits based on the provided ratios.
36
+
37
+ Example: Defining split ratios and applying the transformation
38
+
39
+ >>> dataset = HafniaDataset.read_from_path(Path("path/to/dataset"))
40
+ >>> split_ratios = {SplitName.TRAIN: 0.8, SplitName.VAL: 0.1, SplitName.TEST: 0.1}
41
+ >>> dataset_with_splits = splits_by_ratios(dataset, split_ratios, seed=42)
42
+ Or use the function as a
43
+ >>> dataset_with_splits = dataset.splits_by_ratios(split_ratios, seed=42)
44
+ """
45
+ n_items = len(dataset)
46
+ split_name_column = dataset_helpers.create_split_name_list_from_ratios(
47
+ split_ratios=split_ratios, n_items=n_items, seed=seed
48
+ )
49
+ table = dataset.samples.with_columns(pl.Series(split_name_column).alias("split"))
50
+ return dataset.update_table(table)
51
+
52
+
53
+ def divide_split_into_multiple_splits(
54
+ dataset: "HafniaDataset",
55
+ divide_split_name: str,
56
+ split_ratios: Dict[str, float],
57
+ ) -> "HafniaDataset":
58
+ """
59
+ Divides a dataset split ('divide_split_name') into multiple splits based on the provided split
60
+ ratios ('split_ratios'). This is especially useful for some open datasets where they have only provide
61
+ two splits or only provide annotations for two splits. This function allows you to create additional
62
+ splits based on the provided ratios.
63
+
64
+ Example: Defining split ratios and applying the transformation
65
+ >>> dataset = HafniaDataset.read_from_path(Path("path/to/dataset"))
66
+ >>> divide_split_name = SplitName.TEST
67
+ >>> split_ratios = {SplitName.TEST: 0.8, SplitName.VAL: 0.2}
68
+ >>> dataset_with_splits = divide_split_into_multiple_splits(dataset, divide_split_name, split_ratios)
69
+ """
70
+ dataset_split_to_be_divided = dataset.create_split_dataset(split_name=divide_split_name)
71
+ if len(dataset_split_to_be_divided) == 0:
72
+ split_counts = dict(dataset.samples.select(pl.col(ColumnName.SPLIT).value_counts()).iter_rows())
73
+ raise ValueError(
74
+ f"No samples in the '{divide_split_name}' split to divide into multiple splits. {split_counts=}"
75
+ )
76
+ assert len(dataset_split_to_be_divided) > 0, f"No samples in the '{divide_split_name}' split!"
77
+ dataset_split_to_be_divided = dataset_split_to_be_divided.split_by_ratios(split_ratios=split_ratios, seed=42)
78
+
79
+ remaining_data = dataset.samples.filter(pl.col(ColumnName.SPLIT).is_in([divide_split_name]).not_())
80
+ new_table = pl.concat([remaining_data, dataset_split_to_be_divided.samples], how="vertical")
81
+ dataset_new = dataset.update_table(new_table)
82
+ return dataset_new
83
+
84
+
85
+ def shuffle_dataset(dataset: "HafniaDataset", seed: int = 42) -> "HafniaDataset":
86
+ table = dataset.samples.sample(n=len(dataset), with_replacement=False, seed=seed, shuffle=True)
87
+ return dataset.update_table(table)
88
+
89
+
90
+ def sample(dataset: "HafniaDataset", n_samples: int, shuffle: bool = True, seed: int = 42) -> "HafniaDataset":
91
+ table = dataset.samples.sample(n=n_samples, with_replacement=False, seed=seed, shuffle=shuffle)
92
+ return dataset.update_table(table)
93
+
94
+
95
+ def define_sample_set_by_size(dataset: "HafniaDataset", n_samples: int, seed: int = 42) -> "HafniaDataset":
96
+ is_sample_indices = Random(seed).sample(range(len(dataset)), n_samples)
97
+ is_sample_column = [False for _ in range(len(dataset))]
98
+ for idx in is_sample_indices:
99
+ is_sample_column[idx] = True
100
+
101
+ table = dataset.samples.with_columns(pl.Series(is_sample_column).alias("is_sample"))
102
+ return dataset.update_table(table)
103
+
104
+
105
+ def transform_images(
106
+ dataset: "HafniaDataset",
107
+ transform: Callable[[np.ndarray], np.ndarray],
108
+ path_output: Path,
109
+ ) -> "HafniaDataset":
110
+ new_paths = []
111
+ path_image_folder = path_output / "data"
112
+ path_image_folder.mkdir(parents=True, exist_ok=True)
113
+
114
+ for org_path in tqdm(dataset.samples["file_name"].to_list(), desc="Transform images"):
115
+ org_path = Path(org_path)
116
+ if not org_path.exists():
117
+ raise FileNotFoundError(f"File {org_path} does not exist in the dataset.")
118
+
119
+ image = np.array(Image.open(org_path))
120
+ image_transformed = transform(image)
121
+ new_path = dataset_helpers.save_image_with_hash_name(image_transformed, path_image_folder)
122
+
123
+ if not new_path.exists():
124
+ raise FileNotFoundError(f"Transformed file {new_path} does not exist in the dataset.")
125
+ new_paths.append(str(new_path))
126
+
127
+ table = dataset.samples.with_columns(pl.Series(new_paths).alias("file_name"))
128
+ return dataset.update_table(table)
129
+
130
+
131
+ def rename_to_unique_image_names(dataset: "HafniaDataset", path_output: Path) -> "HafniaDataset":
132
+ user_logger.info(f"Copy images to have unique filenames. New path is '{path_output}'")
133
+ shutil.rmtree(path_output, ignore_errors=True) # Remove the output folder if it exists
134
+ new_paths = []
135
+ for org_path in tqdm(dataset.samples["file_name"].to_list(), desc="- Rename/copy images"):
136
+ org_path = Path(org_path)
137
+ if not org_path.exists():
138
+ raise FileNotFoundError(f"File {org_path} does not exist in the dataset.")
139
+
140
+ hash_name = hashlib.md5(str(org_path).encode()).hexdigest()[
141
+ :6
142
+ ] # Generate a unique name based on the original file name
143
+ new_path = path_output / "data" / f"{hash_name}_{org_path.name}"
144
+ if not new_path.parent.exists():
145
+ new_path.parent.mkdir(parents=True, exist_ok=True)
146
+
147
+ shutil.copyfile(org_path, new_path) # Copy the original file to the new path
148
+ new_paths.append(str(new_path))
149
+
150
+ table = dataset.samples.with_columns(pl.Series(new_paths).alias("file_name"))
151
+ return dataset.update_table(table)
152
+
153
+
154
+ ### Hafnia Dataset Transformations ###
155
+ class SplitsByRatios:
156
+ def __init__(self, split_ratios: dict, seed: int = 42):
157
+ self.split_ratios = split_ratios
158
+ self.seed = seed
159
+
160
+ def __call__(self, dataset: "HafniaDataset") -> "HafniaDataset":
161
+ return splits_by_ratios(dataset, self.split_ratios, self.seed)
162
+
163
+
164
+ class ShuffleDataset:
165
+ def __init__(self, seed: int = 42):
166
+ self.seed = seed
167
+
168
+ def __call__(self, dataset: "HafniaDataset") -> "HafniaDataset":
169
+ return shuffle_dataset(dataset, self.seed)
170
+
171
+
172
+ class SampleSetBySize:
173
+ def __init__(self, n_samples: int, seed: int = 42):
174
+ self.n_samples = n_samples
175
+ self.seed = seed
176
+
177
+ def __call__(self, dataset: "HafniaDataset") -> "HafniaDataset":
178
+ return define_sample_set_by_size(dataset, self.n_samples, self.seed)
179
+
180
+
181
+ class TransformImages:
182
+ def __init__(self, transform: Callable[[np.ndarray], np.ndarray], path_output: Path):
183
+ self.transform = transform
184
+ self.path_output = path_output
185
+
186
+ def __call__(self, dataset: "HafniaDataset") -> "HafniaDataset":
187
+ return transform_images(dataset, self.transform, self.path_output)