hafnia 0.1.27__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/__main__.py +2 -2
- cli/dataset_cmds.py +60 -0
- cli/runc_cmds.py +1 -1
- hafnia/data/__init__.py +2 -2
- hafnia/data/factory.py +9 -56
- hafnia/dataset/dataset_helpers.py +91 -0
- hafnia/dataset/dataset_names.py +71 -0
- hafnia/dataset/dataset_transformation.py +187 -0
- hafnia/dataset/dataset_upload_helper.py +468 -0
- hafnia/dataset/hafnia_dataset.py +453 -0
- hafnia/dataset/primitives/__init__.py +16 -0
- hafnia/dataset/primitives/bbox.py +137 -0
- hafnia/dataset/primitives/bitmask.py +182 -0
- hafnia/dataset/primitives/classification.py +56 -0
- hafnia/dataset/primitives/point.py +25 -0
- hafnia/dataset/primitives/polygon.py +100 -0
- hafnia/dataset/primitives/primitive.py +44 -0
- hafnia/dataset/primitives/segmentation.py +51 -0
- hafnia/dataset/primitives/utils.py +51 -0
- hafnia/dataset/table_transformations.py +183 -0
- hafnia/experiment/hafnia_logger.py +2 -2
- hafnia/helper_testing.py +63 -0
- hafnia/http.py +5 -3
- hafnia/platform/__init__.py +2 -2
- hafnia/platform/datasets.py +184 -0
- hafnia/platform/download.py +85 -23
- hafnia/torch_helpers.py +180 -95
- hafnia/utils.py +1 -1
- hafnia/visualizations/colors.py +267 -0
- hafnia/visualizations/image_visualizations.py +202 -0
- {hafnia-0.1.27.dist-info → hafnia-0.2.0.dist-info}/METADATA +212 -99
- hafnia-0.2.0.dist-info/RECORD +46 -0
- cli/data_cmds.py +0 -53
- hafnia-0.1.27.dist-info/RECORD +0 -27
- {hafnia-0.1.27.dist-info → hafnia-0.2.0.dist-info}/WHEEL +0 -0
- {hafnia-0.1.27.dist-info → hafnia-0.2.0.dist-info}/entry_points.txt +0 -0
- {hafnia-0.1.27.dist-info → hafnia-0.2.0.dist-info}/licenses/LICENSE +0 -0
cli/__main__.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
#!/usr/bin/env python
|
|
2
2
|
import click
|
|
3
3
|
|
|
4
|
-
from cli import consts,
|
|
4
|
+
from cli import consts, dataset_cmds, experiment_cmds, profile_cmds, recipe_cmds, runc_cmds
|
|
5
5
|
from cli.config import Config, ConfigSchema
|
|
6
6
|
|
|
7
7
|
|
|
@@ -46,7 +46,7 @@ def clear(cfg: Config) -> None:
|
|
|
46
46
|
|
|
47
47
|
|
|
48
48
|
main.add_command(profile_cmds.profile)
|
|
49
|
-
main.add_command(
|
|
49
|
+
main.add_command(dataset_cmds.dataset)
|
|
50
50
|
main.add_command(runc_cmds.runc)
|
|
51
51
|
main.add_command(experiment_cmds.experiment)
|
|
52
52
|
main.add_command(recipe_cmds.recipe)
|
cli/dataset_cmds.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import click
|
|
5
|
+
from rich import print as rprint
|
|
6
|
+
|
|
7
|
+
import cli.consts as consts
|
|
8
|
+
from cli.config import Config
|
|
9
|
+
from hafnia import utils
|
|
10
|
+
from hafnia.platform.datasets import create_rich_table_from_dataset
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@click.group()
|
|
14
|
+
def dataset():
|
|
15
|
+
"""Manage dataset interaction"""
|
|
16
|
+
pass
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataset.command("ls")
|
|
20
|
+
@click.pass_obj
|
|
21
|
+
def dataset_list(cfg: Config) -> None:
|
|
22
|
+
"""List available datasets on Hafnia platform"""
|
|
23
|
+
|
|
24
|
+
from hafnia.platform.datasets import dataset_list
|
|
25
|
+
|
|
26
|
+
try:
|
|
27
|
+
datasets = dataset_list(cfg=cfg)
|
|
28
|
+
except Exception:
|
|
29
|
+
raise click.ClickException(consts.ERROR_GET_RESOURCE)
|
|
30
|
+
|
|
31
|
+
table = create_rich_table_from_dataset(datasets)
|
|
32
|
+
rprint(table)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataset.command("download")
|
|
36
|
+
@click.argument("dataset_name")
|
|
37
|
+
@click.option(
|
|
38
|
+
"--destination",
|
|
39
|
+
"-d",
|
|
40
|
+
default=None,
|
|
41
|
+
required=False,
|
|
42
|
+
help=f"Destination folder to save the dataset. Defaults to '{utils.PATH_DATASETS}/<dataset_name>'",
|
|
43
|
+
)
|
|
44
|
+
@click.option("--force", "-f", is_flag=True, default=False, help="Flag to enable force redownload")
|
|
45
|
+
@click.pass_obj
|
|
46
|
+
def data_download(cfg: Config, dataset_name: str, destination: Optional[click.Path], force: bool) -> Path:
|
|
47
|
+
"""Download dataset from Hafnia platform"""
|
|
48
|
+
|
|
49
|
+
from hafnia.platform import datasets
|
|
50
|
+
|
|
51
|
+
try:
|
|
52
|
+
path_dataset = datasets.download_or_get_dataset_path(
|
|
53
|
+
dataset_name=dataset_name,
|
|
54
|
+
cfg=cfg,
|
|
55
|
+
path_datasets_folder=destination,
|
|
56
|
+
force_redownload=force,
|
|
57
|
+
)
|
|
58
|
+
except Exception:
|
|
59
|
+
raise click.ClickException(consts.ERROR_GET_RESOURCE)
|
|
60
|
+
return path_dataset
|
cli/runc_cmds.py
CHANGED
|
@@ -38,7 +38,7 @@ def runc():
|
|
|
38
38
|
@click.pass_obj
|
|
39
39
|
def launch_local(cfg: Config, exec_cmd: str, dataset: str, image_name: str) -> None:
|
|
40
40
|
"""Launch a job within the image."""
|
|
41
|
-
from hafnia.
|
|
41
|
+
from hafnia.platform.datasets import download_or_get_dataset_path
|
|
42
42
|
|
|
43
43
|
is_local_dataset = "/" in dataset
|
|
44
44
|
if is_local_dataset:
|
hafnia/data/__init__.py
CHANGED
|
@@ -1,3 +1,3 @@
|
|
|
1
|
-
from hafnia.data.factory import load_dataset
|
|
1
|
+
from hafnia.data.factory import get_dataset_path, load_dataset
|
|
2
2
|
|
|
3
|
-
__all__ = ["load_dataset"]
|
|
3
|
+
__all__ = ["load_dataset", "get_dataset_path"]
|
hafnia/data/factory.py
CHANGED
|
@@ -1,67 +1,20 @@
|
|
|
1
|
-
import os
|
|
2
|
-
import shutil
|
|
3
1
|
from pathlib import Path
|
|
4
|
-
from typing import Optional, Union
|
|
5
2
|
|
|
6
|
-
from
|
|
3
|
+
from hafnia.dataset.hafnia_dataset import HafniaDataset
|
|
4
|
+
from hafnia.platform.datasets import download_or_get_dataset_path
|
|
7
5
|
|
|
8
|
-
from cli.config import Config
|
|
9
|
-
from hafnia import utils
|
|
10
|
-
from hafnia.log import user_logger
|
|
11
|
-
from hafnia.platform import download_resource, get_dataset_id
|
|
12
6
|
|
|
13
|
-
|
|
14
|
-
def load_local(dataset_path: Path) -> Union[Dataset, DatasetDict]:
|
|
15
|
-
"""Load a Hugging Face dataset from a local directory path."""
|
|
16
|
-
if not dataset_path.exists():
|
|
17
|
-
raise ValueError(f"Can not load dataset, directory does not exist -- {dataset_path}")
|
|
18
|
-
user_logger.info(f"Loading data from {dataset_path.as_posix()}")
|
|
19
|
-
return load_from_disk(dataset_path.as_posix())
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
def download_or_get_dataset_path(
|
|
23
|
-
dataset_name: str,
|
|
24
|
-
cfg: Optional[Config] = None,
|
|
25
|
-
output_dir: Optional[str] = None,
|
|
26
|
-
force_redownload: bool = False,
|
|
27
|
-
) -> Path:
|
|
28
|
-
"""Download or get the path of the dataset."""
|
|
29
|
-
|
|
30
|
-
cfg = cfg or Config()
|
|
31
|
-
endpoint_dataset = cfg.get_platform_endpoint("datasets")
|
|
32
|
-
api_key = cfg.api_key
|
|
33
|
-
|
|
34
|
-
output_dir = output_dir or str(utils.PATH_DATASET)
|
|
35
|
-
dataset_path_base = Path(output_dir).absolute() / dataset_name
|
|
36
|
-
dataset_path_base.mkdir(exist_ok=True, parents=True)
|
|
37
|
-
dataset_path_sample = dataset_path_base / "sample"
|
|
38
|
-
|
|
39
|
-
if dataset_path_sample.exists() and not force_redownload:
|
|
40
|
-
user_logger.info("Dataset found locally. Set 'force=True' or add `--force` flag with cli to re-download")
|
|
41
|
-
return dataset_path_sample
|
|
42
|
-
|
|
43
|
-
dataset_id = get_dataset_id(dataset_name, endpoint_dataset, api_key)
|
|
44
|
-
dataset_access_info_url = f"{endpoint_dataset}/{dataset_id}/temporary-credentials"
|
|
45
|
-
|
|
46
|
-
if force_redownload and dataset_path_sample.exists():
|
|
47
|
-
# Remove old files to avoid old files conflicting with new files
|
|
48
|
-
shutil.rmtree(dataset_path_sample, ignore_errors=True)
|
|
49
|
-
status = download_resource(dataset_access_info_url, str(dataset_path_base), api_key)
|
|
50
|
-
if status:
|
|
51
|
-
return dataset_path_sample
|
|
52
|
-
raise RuntimeError("Failed to download dataset")
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
def load_dataset(dataset_name: str, force_redownload: bool = False) -> Union[Dataset, DatasetDict]:
|
|
7
|
+
def load_dataset(dataset_name: str, force_redownload: bool = False) -> HafniaDataset:
|
|
56
8
|
"""Load a dataset either from a local path or from the Hafnia platform."""
|
|
57
9
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
10
|
+
path_dataset = get_dataset_path(dataset_name, force_redownload=force_redownload)
|
|
11
|
+
dataset = HafniaDataset.read_from_path(path_dataset)
|
|
12
|
+
return dataset
|
|
13
|
+
|
|
61
14
|
|
|
15
|
+
def get_dataset_path(dataset_name: str, force_redownload: bool = False) -> Path:
|
|
62
16
|
path_dataset = download_or_get_dataset_path(
|
|
63
17
|
dataset_name=dataset_name,
|
|
64
18
|
force_redownload=force_redownload,
|
|
65
19
|
)
|
|
66
|
-
|
|
67
|
-
return dataset
|
|
20
|
+
return path_dataset
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
import io
|
|
2
|
+
import math
|
|
3
|
+
import random
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Dict, List
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import xxhash
|
|
9
|
+
from PIL import Image
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def create_split_name_list_from_ratios(split_ratios: Dict[str, float], n_items: int, seed: int = 42) -> List[str]:
|
|
13
|
+
samples_per_split = split_sizes_from_ratios(split_ratios=split_ratios, n_items=n_items)
|
|
14
|
+
|
|
15
|
+
split_name_column = []
|
|
16
|
+
for split_name, n_split_samples in samples_per_split.items():
|
|
17
|
+
split_name_column.extend([split_name] * n_split_samples)
|
|
18
|
+
random.Random(seed).shuffle(split_name_column) # Shuffle the split names
|
|
19
|
+
|
|
20
|
+
return split_name_column
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def hash_file_xxhash(path: Path, chunk_size: int = 262144) -> str:
|
|
24
|
+
hasher = xxhash.xxh3_64()
|
|
25
|
+
|
|
26
|
+
with open(path, "rb") as f:
|
|
27
|
+
for chunk in iter(lambda: f.read(chunk_size), b""): # 8192, 16384, 32768, 65536
|
|
28
|
+
hasher.update(chunk)
|
|
29
|
+
return hasher.hexdigest()
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def hash_from_bytes(data: bytes) -> str:
|
|
33
|
+
hasher = xxhash.xxh3_64()
|
|
34
|
+
hasher.update(data)
|
|
35
|
+
return hasher.hexdigest()
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def save_image_with_hash_name(image: np.ndarray, path_folder: Path) -> Path:
|
|
39
|
+
pil_image = Image.fromarray(image)
|
|
40
|
+
buffer = io.BytesIO()
|
|
41
|
+
pil_image.save(buffer, format="PNG")
|
|
42
|
+
hash_value = hash_from_bytes(buffer.getvalue())
|
|
43
|
+
path_image = Path(path_folder) / f"{hash_value}.png"
|
|
44
|
+
pil_image.save(path_image)
|
|
45
|
+
return path_image
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def filename_as_hash_from_path(path_image: Path) -> str:
|
|
49
|
+
hash = hash_file_xxhash(path_image)
|
|
50
|
+
return f"{hash}{path_image.suffix}"
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def split_sizes_from_ratios(n_items: int, split_ratios: Dict[str, float]) -> Dict[str, int]:
|
|
54
|
+
summed_ratios = sum(split_ratios.values())
|
|
55
|
+
abs_tols = 0.0011 # Allow some tolerance for floating point errors {"test": 0.333, "val": 0.333, "train": 0.333}
|
|
56
|
+
if not math.isclose(summed_ratios, 1.0, abs_tol=abs_tols): # Allow tolerance to allow e.g. (0.333, 0.333, 0.333)
|
|
57
|
+
raise ValueError(f"Split ratios must sum to 1.0. The summed values of {split_ratios} is {summed_ratios}")
|
|
58
|
+
|
|
59
|
+
# recaculate split sizes
|
|
60
|
+
split_ratios = {split_name: split_ratio / summed_ratios for split_name, split_ratio in split_ratios.items()}
|
|
61
|
+
split_sizes = {split_name: int(n_items * split_ratio) for split_name, split_ratio in split_ratios.items()}
|
|
62
|
+
|
|
63
|
+
remaining_items = n_items - sum(split_sizes.values())
|
|
64
|
+
if remaining_items > 0: # Distribute remaining items evenly across splits
|
|
65
|
+
for _ in range(remaining_items):
|
|
66
|
+
# Select name by the largest error from the expected distribution
|
|
67
|
+
total_size = sum(split_sizes.values())
|
|
68
|
+
distribution_error = {
|
|
69
|
+
split_name: abs(split_ratios[split_name] - (size / total_size))
|
|
70
|
+
for split_name, size in split_sizes.items()
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
split_with_largest_error = sorted(distribution_error.items(), key=lambda x: x[1], reverse=True)[0][0]
|
|
74
|
+
split_sizes[split_with_largest_error] += 1
|
|
75
|
+
|
|
76
|
+
if sum(split_sizes.values()) != n_items:
|
|
77
|
+
raise ValueError("Something is wrong. The split sizes do not match the number of items.")
|
|
78
|
+
|
|
79
|
+
return split_sizes
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def select_evenly_across_list(lst: list, num_samples: int):
|
|
83
|
+
if num_samples >= len(lst):
|
|
84
|
+
return lst # No need to sample
|
|
85
|
+
step = (len(lst) - 1) / (num_samples - 1)
|
|
86
|
+
indices = [int(round(step * i)) for i in range(num_samples)] # noqa: RUF046
|
|
87
|
+
return [lst[index] for index in indices]
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def prefix_dict(d: dict, prefix: str) -> dict:
|
|
91
|
+
return {f"{prefix}.{k}": v for k, v in d.items()}
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
FILENAME_DATASET_INFO = "dataset_info.json"
|
|
5
|
+
FILENAME_ANNOTATIONS_JSONL = "annotations.jsonl"
|
|
6
|
+
FILENAME_ANNOTATIONS_PARQUET = "annotations.parquet"
|
|
7
|
+
|
|
8
|
+
DATASET_FILENAMES = [
|
|
9
|
+
FILENAME_DATASET_INFO,
|
|
10
|
+
FILENAME_ANNOTATIONS_JSONL,
|
|
11
|
+
FILENAME_ANNOTATIONS_PARQUET,
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class DeploymentStage(Enum):
|
|
16
|
+
STAGING = "staging"
|
|
17
|
+
PRODUCTION = "production"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class FieldName:
|
|
21
|
+
CLASS_NAME: str = "class_name" # Name of the class this primitive is associated with, e.g. "car" for Bbox
|
|
22
|
+
CLASS_IDX: str = (
|
|
23
|
+
"class_idx" # Index of the class this primitive is associated with, e.g. 0 for "car" if it is the first class
|
|
24
|
+
)
|
|
25
|
+
OBJECT_ID: str = "object_id" # Unique identifier for the object, e.g. "12345123"
|
|
26
|
+
CONFIDENCE: str = "confidence" # Confidence score (0-1.0) for the primitive, e.g. 0.95 for Bbox
|
|
27
|
+
|
|
28
|
+
META: str = "meta" # Contains metadata about each primitive, e.g. attributes color, occluded, iscrowd, etc.
|
|
29
|
+
TASK_NAME: str = "task_name" # Name of the task this primitive is associated with, e.g. "bboxes" for Bbox
|
|
30
|
+
|
|
31
|
+
@staticmethod
|
|
32
|
+
def fields() -> List[str]:
|
|
33
|
+
"""
|
|
34
|
+
Returns a list of expected field names for primitives.
|
|
35
|
+
"""
|
|
36
|
+
return [
|
|
37
|
+
FieldName.CLASS_NAME,
|
|
38
|
+
FieldName.CLASS_IDX,
|
|
39
|
+
FieldName.OBJECT_ID,
|
|
40
|
+
FieldName.CONFIDENCE,
|
|
41
|
+
FieldName.META,
|
|
42
|
+
FieldName.TASK_NAME,
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class ColumnName:
|
|
47
|
+
SAMPLE_INDEX: str = "sample_index"
|
|
48
|
+
FILE_NAME: str = "file_name"
|
|
49
|
+
HEIGHT: str = "height"
|
|
50
|
+
WIDTH: str = "width"
|
|
51
|
+
SPLIT: str = "split"
|
|
52
|
+
IS_SAMPLE: str = "is_sample"
|
|
53
|
+
REMOTE_PATH: str = "remote_path" # Path to the file in remote storage, e.g. S3
|
|
54
|
+
META: str = "meta"
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class SplitName:
|
|
58
|
+
TRAIN = "train"
|
|
59
|
+
VAL = "validation"
|
|
60
|
+
TEST = "test"
|
|
61
|
+
UNDEFINED = "UNDEFINED"
|
|
62
|
+
|
|
63
|
+
@staticmethod
|
|
64
|
+
def valid_splits() -> List[str]:
|
|
65
|
+
return [SplitName.TRAIN, SplitName.VAL, SplitName.TEST]
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class DatasetVariant(Enum):
|
|
69
|
+
DUMP = "dump"
|
|
70
|
+
SAMPLE = "sample"
|
|
71
|
+
HIDDEN = "hidden"
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
import shutil
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from random import Random
|
|
5
|
+
from typing import TYPE_CHECKING, Callable, Dict
|
|
6
|
+
|
|
7
|
+
import cv2
|
|
8
|
+
import numpy as np
|
|
9
|
+
import polars as pl
|
|
10
|
+
from PIL import Image
|
|
11
|
+
from tqdm import tqdm
|
|
12
|
+
|
|
13
|
+
from hafnia.dataset import dataset_helpers
|
|
14
|
+
from hafnia.dataset.dataset_names import ColumnName
|
|
15
|
+
from hafnia.log import user_logger
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from hafnia.dataset.hafnia_dataset import HafniaDataset
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
### Image transformations ###
|
|
22
|
+
class AnonymizeByPixelation:
|
|
23
|
+
def __init__(self, resize_factor: float = 0.10):
|
|
24
|
+
self.resize_factor = resize_factor
|
|
25
|
+
|
|
26
|
+
def __call__(self, frame: np.ndarray) -> np.ndarray:
|
|
27
|
+
org_size = frame.shape[:2]
|
|
28
|
+
frame = cv2.resize(frame, (0, 0), fx=self.resize_factor, fy=self.resize_factor)
|
|
29
|
+
frame = cv2.resize(frame, org_size[::-1], interpolation=cv2.INTER_NEAREST)
|
|
30
|
+
return frame
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def splits_by_ratios(dataset: "HafniaDataset", split_ratios: Dict[str, float], seed: int = 42) -> "HafniaDataset":
|
|
34
|
+
"""
|
|
35
|
+
Divides the dataset into splits based on the provided ratios.
|
|
36
|
+
|
|
37
|
+
Example: Defining split ratios and applying the transformation
|
|
38
|
+
|
|
39
|
+
>>> dataset = HafniaDataset.read_from_path(Path("path/to/dataset"))
|
|
40
|
+
>>> split_ratios = {SplitName.TRAIN: 0.8, SplitName.VAL: 0.1, SplitName.TEST: 0.1}
|
|
41
|
+
>>> dataset_with_splits = splits_by_ratios(dataset, split_ratios, seed=42)
|
|
42
|
+
Or use the function as a
|
|
43
|
+
>>> dataset_with_splits = dataset.splits_by_ratios(split_ratios, seed=42)
|
|
44
|
+
"""
|
|
45
|
+
n_items = len(dataset)
|
|
46
|
+
split_name_column = dataset_helpers.create_split_name_list_from_ratios(
|
|
47
|
+
split_ratios=split_ratios, n_items=n_items, seed=seed
|
|
48
|
+
)
|
|
49
|
+
table = dataset.samples.with_columns(pl.Series(split_name_column).alias("split"))
|
|
50
|
+
return dataset.update_table(table)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def divide_split_into_multiple_splits(
|
|
54
|
+
dataset: "HafniaDataset",
|
|
55
|
+
divide_split_name: str,
|
|
56
|
+
split_ratios: Dict[str, float],
|
|
57
|
+
) -> "HafniaDataset":
|
|
58
|
+
"""
|
|
59
|
+
Divides a dataset split ('divide_split_name') into multiple splits based on the provided split
|
|
60
|
+
ratios ('split_ratios'). This is especially useful for some open datasets where they have only provide
|
|
61
|
+
two splits or only provide annotations for two splits. This function allows you to create additional
|
|
62
|
+
splits based on the provided ratios.
|
|
63
|
+
|
|
64
|
+
Example: Defining split ratios and applying the transformation
|
|
65
|
+
>>> dataset = HafniaDataset.read_from_path(Path("path/to/dataset"))
|
|
66
|
+
>>> divide_split_name = SplitName.TEST
|
|
67
|
+
>>> split_ratios = {SplitName.TEST: 0.8, SplitName.VAL: 0.2}
|
|
68
|
+
>>> dataset_with_splits = divide_split_into_multiple_splits(dataset, divide_split_name, split_ratios)
|
|
69
|
+
"""
|
|
70
|
+
dataset_split_to_be_divided = dataset.create_split_dataset(split_name=divide_split_name)
|
|
71
|
+
if len(dataset_split_to_be_divided) == 0:
|
|
72
|
+
split_counts = dict(dataset.samples.select(pl.col(ColumnName.SPLIT).value_counts()).iter_rows())
|
|
73
|
+
raise ValueError(
|
|
74
|
+
f"No samples in the '{divide_split_name}' split to divide into multiple splits. {split_counts=}"
|
|
75
|
+
)
|
|
76
|
+
assert len(dataset_split_to_be_divided) > 0, f"No samples in the '{divide_split_name}' split!"
|
|
77
|
+
dataset_split_to_be_divided = dataset_split_to_be_divided.split_by_ratios(split_ratios=split_ratios, seed=42)
|
|
78
|
+
|
|
79
|
+
remaining_data = dataset.samples.filter(pl.col(ColumnName.SPLIT).is_in([divide_split_name]).not_())
|
|
80
|
+
new_table = pl.concat([remaining_data, dataset_split_to_be_divided.samples], how="vertical")
|
|
81
|
+
dataset_new = dataset.update_table(new_table)
|
|
82
|
+
return dataset_new
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def shuffle_dataset(dataset: "HafniaDataset", seed: int = 42) -> "HafniaDataset":
|
|
86
|
+
table = dataset.samples.sample(n=len(dataset), with_replacement=False, seed=seed, shuffle=True)
|
|
87
|
+
return dataset.update_table(table)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def sample(dataset: "HafniaDataset", n_samples: int, shuffle: bool = True, seed: int = 42) -> "HafniaDataset":
|
|
91
|
+
table = dataset.samples.sample(n=n_samples, with_replacement=False, seed=seed, shuffle=shuffle)
|
|
92
|
+
return dataset.update_table(table)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def define_sample_set_by_size(dataset: "HafniaDataset", n_samples: int, seed: int = 42) -> "HafniaDataset":
|
|
96
|
+
is_sample_indices = Random(seed).sample(range(len(dataset)), n_samples)
|
|
97
|
+
is_sample_column = [False for _ in range(len(dataset))]
|
|
98
|
+
for idx in is_sample_indices:
|
|
99
|
+
is_sample_column[idx] = True
|
|
100
|
+
|
|
101
|
+
table = dataset.samples.with_columns(pl.Series(is_sample_column).alias("is_sample"))
|
|
102
|
+
return dataset.update_table(table)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def transform_images(
|
|
106
|
+
dataset: "HafniaDataset",
|
|
107
|
+
transform: Callable[[np.ndarray], np.ndarray],
|
|
108
|
+
path_output: Path,
|
|
109
|
+
) -> "HafniaDataset":
|
|
110
|
+
new_paths = []
|
|
111
|
+
path_image_folder = path_output / "data"
|
|
112
|
+
path_image_folder.mkdir(parents=True, exist_ok=True)
|
|
113
|
+
|
|
114
|
+
for org_path in tqdm(dataset.samples["file_name"].to_list(), desc="Transform images"):
|
|
115
|
+
org_path = Path(org_path)
|
|
116
|
+
if not org_path.exists():
|
|
117
|
+
raise FileNotFoundError(f"File {org_path} does not exist in the dataset.")
|
|
118
|
+
|
|
119
|
+
image = np.array(Image.open(org_path))
|
|
120
|
+
image_transformed = transform(image)
|
|
121
|
+
new_path = dataset_helpers.save_image_with_hash_name(image_transformed, path_image_folder)
|
|
122
|
+
|
|
123
|
+
if not new_path.exists():
|
|
124
|
+
raise FileNotFoundError(f"Transformed file {new_path} does not exist in the dataset.")
|
|
125
|
+
new_paths.append(str(new_path))
|
|
126
|
+
|
|
127
|
+
table = dataset.samples.with_columns(pl.Series(new_paths).alias("file_name"))
|
|
128
|
+
return dataset.update_table(table)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def rename_to_unique_image_names(dataset: "HafniaDataset", path_output: Path) -> "HafniaDataset":
|
|
132
|
+
user_logger.info(f"Copy images to have unique filenames. New path is '{path_output}'")
|
|
133
|
+
shutil.rmtree(path_output, ignore_errors=True) # Remove the output folder if it exists
|
|
134
|
+
new_paths = []
|
|
135
|
+
for org_path in tqdm(dataset.samples["file_name"].to_list(), desc="- Rename/copy images"):
|
|
136
|
+
org_path = Path(org_path)
|
|
137
|
+
if not org_path.exists():
|
|
138
|
+
raise FileNotFoundError(f"File {org_path} does not exist in the dataset.")
|
|
139
|
+
|
|
140
|
+
hash_name = hashlib.md5(str(org_path).encode()).hexdigest()[
|
|
141
|
+
:6
|
|
142
|
+
] # Generate a unique name based on the original file name
|
|
143
|
+
new_path = path_output / "data" / f"{hash_name}_{org_path.name}"
|
|
144
|
+
if not new_path.parent.exists():
|
|
145
|
+
new_path.parent.mkdir(parents=True, exist_ok=True)
|
|
146
|
+
|
|
147
|
+
shutil.copyfile(org_path, new_path) # Copy the original file to the new path
|
|
148
|
+
new_paths.append(str(new_path))
|
|
149
|
+
|
|
150
|
+
table = dataset.samples.with_columns(pl.Series(new_paths).alias("file_name"))
|
|
151
|
+
return dataset.update_table(table)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
### Hafnia Dataset Transformations ###
|
|
155
|
+
class SplitsByRatios:
|
|
156
|
+
def __init__(self, split_ratios: dict, seed: int = 42):
|
|
157
|
+
self.split_ratios = split_ratios
|
|
158
|
+
self.seed = seed
|
|
159
|
+
|
|
160
|
+
def __call__(self, dataset: "HafniaDataset") -> "HafniaDataset":
|
|
161
|
+
return splits_by_ratios(dataset, self.split_ratios, self.seed)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class ShuffleDataset:
|
|
165
|
+
def __init__(self, seed: int = 42):
|
|
166
|
+
self.seed = seed
|
|
167
|
+
|
|
168
|
+
def __call__(self, dataset: "HafniaDataset") -> "HafniaDataset":
|
|
169
|
+
return shuffle_dataset(dataset, self.seed)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class SampleSetBySize:
|
|
173
|
+
def __init__(self, n_samples: int, seed: int = 42):
|
|
174
|
+
self.n_samples = n_samples
|
|
175
|
+
self.seed = seed
|
|
176
|
+
|
|
177
|
+
def __call__(self, dataset: "HafniaDataset") -> "HafniaDataset":
|
|
178
|
+
return define_sample_set_by_size(dataset, self.n_samples, self.seed)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
class TransformImages:
|
|
182
|
+
def __init__(self, transform: Callable[[np.ndarray], np.ndarray], path_output: Path):
|
|
183
|
+
self.transform = transform
|
|
184
|
+
self.path_output = path_output
|
|
185
|
+
|
|
186
|
+
def __call__(self, dataset: "HafniaDataset") -> "HafniaDataset":
|
|
187
|
+
return transform_images(dataset, self.transform, self.path_output)
|