hafnia 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/config.py +17 -4
- hafnia/data/factory.py +13 -10
- hafnia/dataset/dataset_names.py +2 -1
- hafnia/dataset/dataset_recipe/dataset_recipe.py +327 -0
- hafnia/dataset/dataset_recipe/recipe_transforms.py +53 -0
- hafnia/dataset/dataset_recipe/recipe_types.py +140 -0
- hafnia/dataset/hafnia_dataset.py +202 -31
- hafnia/dataset/operations/dataset_stats.py +15 -0
- hafnia/dataset/operations/dataset_transformations.py +82 -0
- hafnia/dataset/{table_transformations.py → operations/table_transformations.py} +1 -1
- hafnia/experiment/hafnia_logger.py +5 -5
- hafnia/helper_testing.py +48 -3
- hafnia/platform/datasets.py +26 -13
- hafnia/utils.py +20 -1
- hafnia/visualizations/image_visualizations.py +1 -1
- {hafnia-0.2.0.dist-info → hafnia-0.2.1.dist-info}/METADATA +17 -20
- {hafnia-0.2.0.dist-info → hafnia-0.2.1.dist-info}/RECORD +20 -16
- hafnia/dataset/dataset_transformation.py +0 -187
- {hafnia-0.2.0.dist-info → hafnia-0.2.1.dist-info}/WHEEL +0 -0
- {hafnia-0.2.0.dist-info → hafnia-0.2.1.dist-info}/entry_points.txt +0 -0
- {hafnia-0.2.0.dist-info → hafnia-0.2.1.dist-info}/licenses/LICENSE +0 -0
cli/config.py
CHANGED
|
@@ -80,7 +80,7 @@ class Config:
|
|
|
80
80
|
def __init__(self, config_path: Optional[Path] = None) -> None:
|
|
81
81
|
self.config_path = self.resolve_config_path(config_path)
|
|
82
82
|
self.config_path.parent.mkdir(parents=True, exist_ok=True)
|
|
83
|
-
self.config_data =
|
|
83
|
+
self.config_data = Config.load_config(self.config_path)
|
|
84
84
|
|
|
85
85
|
def resolve_config_path(self, path: Optional[Path] = None) -> Path:
|
|
86
86
|
if path:
|
|
@@ -111,12 +111,25 @@ class Config:
|
|
|
111
111
|
endpoint = self.config.platform_url + PLATFORM_API_MAPPING[method]
|
|
112
112
|
return endpoint
|
|
113
113
|
|
|
114
|
-
|
|
114
|
+
@staticmethod
|
|
115
|
+
def load_config(config_path: Path) -> ConfigFileSchema:
|
|
115
116
|
"""Load configuration from file."""
|
|
116
|
-
|
|
117
|
+
|
|
118
|
+
# Environment variables has higher priority than config file
|
|
119
|
+
HAFNIA_API_KEY = os.getenv("HAFNIA_API_KEY")
|
|
120
|
+
HAFNIA_PLATFORM_URL = os.getenv("HAFNIA_PLATFORM_URL")
|
|
121
|
+
if HAFNIA_API_KEY and HAFNIA_PLATFORM_URL:
|
|
122
|
+
HAFNIA_PROFILE_NAME = os.getenv("HAFNIA_PROFILE_NAME", "default").strip()
|
|
123
|
+
cfg = ConfigFileSchema(
|
|
124
|
+
active_profile=HAFNIA_PROFILE_NAME,
|
|
125
|
+
profiles={HAFNIA_PROFILE_NAME: ConfigSchema(platform_url=HAFNIA_PLATFORM_URL, api_key=HAFNIA_API_KEY)},
|
|
126
|
+
)
|
|
127
|
+
return cfg
|
|
128
|
+
|
|
129
|
+
if not config_path.exists():
|
|
117
130
|
return ConfigFileSchema()
|
|
118
131
|
try:
|
|
119
|
-
with open(
|
|
132
|
+
with open(config_path.as_posix(), "r") as f:
|
|
120
133
|
data = json.load(f)
|
|
121
134
|
return ConfigFileSchema(**data)
|
|
122
135
|
except json.JSONDecodeError:
|
hafnia/data/factory.py
CHANGED
|
@@ -1,20 +1,23 @@
|
|
|
1
|
+
import os
|
|
1
2
|
from pathlib import Path
|
|
3
|
+
from typing import Any
|
|
2
4
|
|
|
3
|
-
from hafnia
|
|
4
|
-
from hafnia.
|
|
5
|
+
from hafnia import utils
|
|
6
|
+
from hafnia.dataset.hafnia_dataset import HafniaDataset, get_or_create_dataset_path_from_recipe
|
|
5
7
|
|
|
6
8
|
|
|
7
|
-
def load_dataset(
|
|
9
|
+
def load_dataset(recipe: Any, force_redownload: bool = False) -> HafniaDataset:
|
|
8
10
|
"""Load a dataset either from a local path or from the Hafnia platform."""
|
|
9
11
|
|
|
10
|
-
path_dataset = get_dataset_path(
|
|
11
|
-
dataset = HafniaDataset.
|
|
12
|
+
path_dataset = get_dataset_path(recipe, force_redownload=force_redownload)
|
|
13
|
+
dataset = HafniaDataset.from_path(path_dataset)
|
|
12
14
|
return dataset
|
|
13
15
|
|
|
14
16
|
|
|
15
|
-
def get_dataset_path(
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
)
|
|
17
|
+
def get_dataset_path(recipe: Any, force_redownload: bool = False) -> Path:
|
|
18
|
+
if utils.is_hafnia_cloud_job():
|
|
19
|
+
return Path(os.getenv("MDI_DATASET_DIR", "/opt/ml/input/data/training"))
|
|
20
|
+
|
|
21
|
+
path_dataset = get_or_create_dataset_path_from_recipe(recipe, force_redownload=force_redownload)
|
|
22
|
+
|
|
20
23
|
return path_dataset
|
hafnia/dataset/dataset_names.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
from enum import Enum
|
|
2
2
|
from typing import List
|
|
3
3
|
|
|
4
|
+
FILENAME_RECIPE_JSON = "recipe.json"
|
|
4
5
|
FILENAME_DATASET_INFO = "dataset_info.json"
|
|
5
6
|
FILENAME_ANNOTATIONS_JSONL = "annotations.jsonl"
|
|
6
7
|
FILENAME_ANNOTATIONS_PARQUET = "annotations.parquet"
|
|
7
8
|
|
|
8
|
-
|
|
9
|
+
DATASET_FILENAMES_REQUIRED = [
|
|
9
10
|
FILENAME_DATASET_INFO,
|
|
10
11
|
FILENAME_ANNOTATIONS_JSONL,
|
|
11
12
|
FILENAME_ANNOTATIONS_PARQUET,
|
|
@@ -0,0 +1,327 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
|
7
|
+
|
|
8
|
+
from pydantic import (
|
|
9
|
+
field_serializer,
|
|
10
|
+
field_validator,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
from hafnia import utils
|
|
14
|
+
from hafnia.dataset.dataset_recipe import recipe_transforms
|
|
15
|
+
from hafnia.dataset.dataset_recipe.recipe_types import RecipeCreation, RecipeTransform, Serializable
|
|
16
|
+
from hafnia.dataset.hafnia_dataset import HafniaDataset
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from hafnia.dataset.hafnia_dataset import HafniaDataset
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class DatasetRecipe(Serializable):
|
|
23
|
+
creation: RecipeCreation
|
|
24
|
+
operations: Optional[List[RecipeTransform]] = None
|
|
25
|
+
|
|
26
|
+
def build(self) -> HafniaDataset:
|
|
27
|
+
dataset = self.creation.build()
|
|
28
|
+
if self.operations:
|
|
29
|
+
for operation in self.operations:
|
|
30
|
+
dataset = operation.build(dataset)
|
|
31
|
+
return dataset
|
|
32
|
+
|
|
33
|
+
def append_operation(self, operation: RecipeTransform) -> DatasetRecipe:
|
|
34
|
+
"""Append an operation to the dataset recipe."""
|
|
35
|
+
if self.operations is None:
|
|
36
|
+
self.operations = []
|
|
37
|
+
self.operations.append(operation)
|
|
38
|
+
return self
|
|
39
|
+
|
|
40
|
+
### Creation Methods (using the 'from_X' )###
|
|
41
|
+
@staticmethod
|
|
42
|
+
def from_name(name: str, force_redownload: bool = False, download_files: bool = True) -> DatasetRecipe:
|
|
43
|
+
creation = FromName(name=name, force_redownload=force_redownload, download_files=download_files)
|
|
44
|
+
return DatasetRecipe(creation=creation)
|
|
45
|
+
|
|
46
|
+
@staticmethod
|
|
47
|
+
def from_path(path_folder: Path, check_for_images: bool = True) -> DatasetRecipe:
|
|
48
|
+
creation = FromPath(path_folder=path_folder, check_for_images=check_for_images)
|
|
49
|
+
return DatasetRecipe(creation=creation)
|
|
50
|
+
|
|
51
|
+
@staticmethod
|
|
52
|
+
def from_merge(recipe0: DatasetRecipe, recipe1: DatasetRecipe) -> DatasetRecipe:
|
|
53
|
+
return DatasetRecipe(creation=FromMerge(recipe0=recipe0, recipe1=recipe1))
|
|
54
|
+
|
|
55
|
+
@staticmethod
|
|
56
|
+
def from_merger(recipes: List[DatasetRecipe]) -> DatasetRecipe:
|
|
57
|
+
"""Create a DatasetRecipe from a list of DatasetRecipes."""
|
|
58
|
+
if not recipes:
|
|
59
|
+
raise ValueError("The list of recipes cannot be empty.")
|
|
60
|
+
if len(recipes) == 1:
|
|
61
|
+
return recipes[0]
|
|
62
|
+
creation = FromMerger(recipes=recipes)
|
|
63
|
+
return DatasetRecipe(creation=creation)
|
|
64
|
+
|
|
65
|
+
@staticmethod
|
|
66
|
+
def from_json_str(json_str: str) -> "DatasetRecipe":
|
|
67
|
+
"""Deserialize from a JSON string."""
|
|
68
|
+
data = json.loads(json_str)
|
|
69
|
+
dataset_recipe = DatasetRecipe.from_dict(data)
|
|
70
|
+
if not isinstance(dataset_recipe, DatasetRecipe):
|
|
71
|
+
raise TypeError(f"Expected DatasetRecipe, got {type(dataset_recipe).__name__}.")
|
|
72
|
+
return dataset_recipe
|
|
73
|
+
|
|
74
|
+
@staticmethod
|
|
75
|
+
def from_json_file(path_json: Path) -> "DatasetRecipe":
|
|
76
|
+
json_str = path_json.read_text(encoding="utf-8")
|
|
77
|
+
return DatasetRecipe.from_json_str(json_str)
|
|
78
|
+
|
|
79
|
+
@staticmethod
|
|
80
|
+
def from_implicit_form(recipe: Any) -> DatasetRecipe:
|
|
81
|
+
"""
|
|
82
|
+
Recursively convert from implicit recipe to explicit form.
|
|
83
|
+
Handles mixed implicit/explicit recipes.
|
|
84
|
+
|
|
85
|
+
Conversion rules:
|
|
86
|
+
- str: Will get a dataset by name -> DatasetRecipeFromName
|
|
87
|
+
- Path: Will get a dataset from path -> DatasetRecipeFromPath
|
|
88
|
+
- tuple: Will merge datasets specified in the tuple -> RecipeMerger
|
|
89
|
+
- list: Will define a list of transformations -> RecipeTransforms
|
|
90
|
+
|
|
91
|
+
Example: DataRecipe from dataset name:
|
|
92
|
+
```python
|
|
93
|
+
recipe_implicit = "mnist"
|
|
94
|
+
recipe_explicit = DatasetRecipe.from_implicit_form(recipe_implicit)
|
|
95
|
+
>>> recipe_explicit
|
|
96
|
+
DatasetRecipeFromName(dataset_name='mnist', force_redownload=False)
|
|
97
|
+
```
|
|
98
|
+
|
|
99
|
+
Example: DataRecipe from tuple (merging multiple recipes):
|
|
100
|
+
```python
|
|
101
|
+
recipe_implicit = ("dataset1", "dataset2")
|
|
102
|
+
recipe_explicit = DatasetRecipe.from_implicit_form(recipe_implicit)
|
|
103
|
+
>>> recipe_explicit
|
|
104
|
+
RecipeMerger(
|
|
105
|
+
recipes=[
|
|
106
|
+
DatasetRecipeFromName(dataset_name='dataset1', force_redownload=False),
|
|
107
|
+
DatasetRecipeFromName(dataset_name='dataset2', force_redownload=False)
|
|
108
|
+
]
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
Example: DataRecipe from list (recipe and transformations):
|
|
112
|
+
```python
|
|
113
|
+
recipe_implicit = ["mnist", SelectSamples(n_samples=20), Shuffle(seed=123)]
|
|
114
|
+
recipe_explicit = DatasetRecipe.from_implicit_form(recipe_implicit)
|
|
115
|
+
>>> recipe_explicit
|
|
116
|
+
Transforms(
|
|
117
|
+
recipe=DatasetRecipeFromName(dataset_name='mnist', force_redownload=False),
|
|
118
|
+
transforms=[SelectSamples(n_samples=20), Shuffle(seed=123)]
|
|
119
|
+
)
|
|
120
|
+
```
|
|
121
|
+
|
|
122
|
+
"""
|
|
123
|
+
if isinstance(recipe, DatasetRecipe): # type: ignore
|
|
124
|
+
# It is possible to do an early return if recipe is a 'DataRecipe'-type even for nested and
|
|
125
|
+
# potentially mixed recipes. If you (really) think about it, this might surprise you,
|
|
126
|
+
# as this will bypass the conversion logic for nested recipes.
|
|
127
|
+
# However, this is not a problem as 'DataRecipe' classes are also pydantic models,
|
|
128
|
+
# so if a user introduces a 'DataRecipe'-class in the recipe (in potentially
|
|
129
|
+
# some nested and mixed implicit/explicit form) it will (due to pydantic validation) force
|
|
130
|
+
# the user to specify all nested recipes to be converted to explicit form.
|
|
131
|
+
return recipe
|
|
132
|
+
|
|
133
|
+
if isinstance(recipe, str): # str-type is convert to DatasetFromName
|
|
134
|
+
return DatasetRecipe.from_name(name=recipe)
|
|
135
|
+
|
|
136
|
+
if isinstance(recipe, Path): # Path-type is convert to DatasetFromPath
|
|
137
|
+
return DatasetRecipe.from_path(path_folder=recipe)
|
|
138
|
+
|
|
139
|
+
if isinstance(recipe, tuple): # tuple-type is convert to DatasetMerger
|
|
140
|
+
recipes = [DatasetRecipe.from_implicit_form(item) for item in recipe]
|
|
141
|
+
return DatasetRecipe.from_merger(recipes=recipes)
|
|
142
|
+
|
|
143
|
+
if isinstance(recipe, list): # list-type is convert to Transforms
|
|
144
|
+
if len(recipe) == 0:
|
|
145
|
+
raise ValueError("List of recipes cannot be empty")
|
|
146
|
+
|
|
147
|
+
dataset_recipe = recipe[0] # First element is the dataset recipe
|
|
148
|
+
loader = DatasetRecipe.from_implicit_form(dataset_recipe)
|
|
149
|
+
|
|
150
|
+
transforms = recipe[1:] # Remaining items are transformations
|
|
151
|
+
return DatasetRecipe(creation=loader.creation, operations=transforms)
|
|
152
|
+
|
|
153
|
+
raise ValueError(f"Unsupported recipe type: {type(recipe)}")
|
|
154
|
+
|
|
155
|
+
### Dataset Recipe Transformations ###
|
|
156
|
+
def shuffle(recipe: DatasetRecipe, seed: int = 42) -> DatasetRecipe:
|
|
157
|
+
operation = recipe_transforms.Shuffle(seed=seed)
|
|
158
|
+
recipe.append_operation(operation)
|
|
159
|
+
return recipe
|
|
160
|
+
|
|
161
|
+
def select_samples(
|
|
162
|
+
recipe: DatasetRecipe, n_samples: int, shuffle: bool = True, seed: int = 42, with_replacement: bool = False
|
|
163
|
+
) -> DatasetRecipe:
|
|
164
|
+
operation = recipe_transforms.SelectSamples(
|
|
165
|
+
n_samples=n_samples, shuffle=shuffle, seed=seed, with_replacement=with_replacement
|
|
166
|
+
)
|
|
167
|
+
recipe.append_operation(operation)
|
|
168
|
+
return recipe
|
|
169
|
+
|
|
170
|
+
def splits_by_ratios(recipe: DatasetRecipe, split_ratios: Dict[str, float], seed: int = 42) -> DatasetRecipe:
|
|
171
|
+
operation = recipe_transforms.SplitsByRatios(split_ratios=split_ratios, seed=seed)
|
|
172
|
+
recipe.append_operation(operation)
|
|
173
|
+
return recipe
|
|
174
|
+
|
|
175
|
+
def split_into_multiple_splits(
|
|
176
|
+
recipe: DatasetRecipe, split_name: str, split_ratios: Dict[str, float]
|
|
177
|
+
) -> DatasetRecipe:
|
|
178
|
+
operation = recipe_transforms.SplitIntoMultipleSplits(split_name=split_name, split_ratios=split_ratios)
|
|
179
|
+
recipe.append_operation(operation)
|
|
180
|
+
return recipe
|
|
181
|
+
|
|
182
|
+
def define_sample_set_by_size(recipe: DatasetRecipe, n_samples: int, seed: int = 42) -> DatasetRecipe:
|
|
183
|
+
operation = recipe_transforms.DefineSampleSetBySize(n_samples=n_samples, seed=seed)
|
|
184
|
+
recipe.append_operation(operation)
|
|
185
|
+
return recipe
|
|
186
|
+
|
|
187
|
+
### Conversions ###
|
|
188
|
+
def as_python_code(self, keep_default_fields: bool = False, as_kwargs: bool = True) -> str:
|
|
189
|
+
str_operations = [self.creation.as_python_code(keep_default_fields=keep_default_fields, as_kwargs=as_kwargs)]
|
|
190
|
+
if self.operations:
|
|
191
|
+
for op in self.operations:
|
|
192
|
+
str_operations.append(op.as_python_code(keep_default_fields=keep_default_fields, as_kwargs=as_kwargs))
|
|
193
|
+
operations_str = ".".join(str_operations)
|
|
194
|
+
return operations_str
|
|
195
|
+
|
|
196
|
+
def as_short_name(self) -> str:
|
|
197
|
+
"""Return a short name for the transforms."""
|
|
198
|
+
|
|
199
|
+
creation_name = self.creation.as_short_name()
|
|
200
|
+
if self.operations is None or len(self.operations) == 0:
|
|
201
|
+
return creation_name
|
|
202
|
+
short_names = [creation_name]
|
|
203
|
+
for operation in self.operations:
|
|
204
|
+
short_names.append(operation.as_short_name())
|
|
205
|
+
transforms_str = ",".join(short_names)
|
|
206
|
+
return f"Recipe({transforms_str})"
|
|
207
|
+
|
|
208
|
+
def as_json_str(self, indent: int = 2) -> str:
|
|
209
|
+
"""Serialize the dataset recipe to a JSON string."""
|
|
210
|
+
data = self.model_dump(mode="json")
|
|
211
|
+
# data = type_as_first_key(data)
|
|
212
|
+
return json.dumps(data, indent=indent, ensure_ascii=False)
|
|
213
|
+
|
|
214
|
+
def as_json_file(self, path_json: Path, indent: int = 2) -> None:
|
|
215
|
+
"""Serialize the dataset recipe to a JSON file."""
|
|
216
|
+
json_str = self.as_json_str(indent=indent)
|
|
217
|
+
path_json.write_text(json_str, encoding="utf-8")
|
|
218
|
+
|
|
219
|
+
### Validation and Serialization ###
|
|
220
|
+
@field_validator("creation", mode="plain")
|
|
221
|
+
@classmethod
|
|
222
|
+
def validate_creation(cls, creation: Union[Dict, RecipeCreation]) -> RecipeCreation:
|
|
223
|
+
if isinstance(creation, dict):
|
|
224
|
+
creation = Serializable.from_dict(creation) # type: ignore[assignment]
|
|
225
|
+
if not isinstance(creation, RecipeCreation):
|
|
226
|
+
raise TypeError(f"Operation must be an instance of RecipeCreation, got {type(creation).__name__}.")
|
|
227
|
+
return creation
|
|
228
|
+
|
|
229
|
+
@field_serializer("creation")
|
|
230
|
+
def serialize_creation(self, creation: RecipeCreation) -> dict:
|
|
231
|
+
return creation.model_dump()
|
|
232
|
+
|
|
233
|
+
@field_validator("operations", mode="plain")
|
|
234
|
+
@classmethod
|
|
235
|
+
def validate_operation(cls, operations: List[Union[Dict, RecipeTransform]]) -> List[RecipeTransform]:
|
|
236
|
+
if operations is None:
|
|
237
|
+
return None
|
|
238
|
+
validated_operations = []
|
|
239
|
+
for operation in operations:
|
|
240
|
+
if isinstance(operation, dict):
|
|
241
|
+
operation = Serializable.from_dict(operation) # type: ignore[assignment]
|
|
242
|
+
if not isinstance(operation, RecipeTransform):
|
|
243
|
+
raise TypeError(f"Operation must be an instance of RecipeTransform, got {type(operation).__name__}.")
|
|
244
|
+
validated_operations.append(operation)
|
|
245
|
+
return validated_operations
|
|
246
|
+
|
|
247
|
+
@field_serializer("operations")
|
|
248
|
+
def serialize_operations(self, operations: Optional[List[RecipeTransform]]) -> Optional[List[dict]]:
|
|
249
|
+
"""Serialize the operations to a list of dictionaries."""
|
|
250
|
+
if operations is None:
|
|
251
|
+
return None
|
|
252
|
+
return [operation.model_dump() for operation in operations]
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def unique_name_from_recipe(recipe: DatasetRecipe) -> str:
|
|
256
|
+
if isinstance(recipe.creation, FromName) and recipe.operations is None:
|
|
257
|
+
# If the dataset recipe is simply a DatasetFromName, we bypass the hashing logic
|
|
258
|
+
# and return the name directly. The dataset is already uniquely identified by its name.
|
|
259
|
+
# Add version if need... Optionally, you may also completely delete this exception
|
|
260
|
+
# and always return the unique name including the hash to support versioning.
|
|
261
|
+
return recipe.creation.name # Dataset name e.g 'mnist'
|
|
262
|
+
recipe_json_str = recipe.model_dump_json()
|
|
263
|
+
hash_recipe = utils.hash_from_string(recipe_json_str)
|
|
264
|
+
short_recipe_str = recipe.as_short_name()
|
|
265
|
+
unique_name = f"{short_recipe_str}_{hash_recipe}"
|
|
266
|
+
return unique_name
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def get_dataset_path_from_recipe(recipe: DatasetRecipe, path_datasets: Optional[Union[Path, str]] = None) -> Path:
|
|
270
|
+
path_datasets = path_datasets or utils.PATH_DATASETS
|
|
271
|
+
path_datasets = Path(path_datasets)
|
|
272
|
+
unique_dataset_name = unique_name_from_recipe(recipe)
|
|
273
|
+
return path_datasets / unique_dataset_name
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
class FromPath(RecipeCreation):
|
|
277
|
+
path_folder: Path
|
|
278
|
+
check_for_images: bool = True
|
|
279
|
+
|
|
280
|
+
@staticmethod
|
|
281
|
+
def get_function() -> Callable[..., "HafniaDataset"]:
|
|
282
|
+
return HafniaDataset.from_path
|
|
283
|
+
|
|
284
|
+
def as_short_name(self) -> str:
|
|
285
|
+
return f"'{self.path_folder}'".replace(os.sep, "|")
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
class FromName(RecipeCreation):
|
|
289
|
+
name: str
|
|
290
|
+
force_redownload: bool = False
|
|
291
|
+
download_files: bool = True
|
|
292
|
+
|
|
293
|
+
@staticmethod
|
|
294
|
+
def get_function() -> Callable[..., "HafniaDataset"]:
|
|
295
|
+
return HafniaDataset.from_name
|
|
296
|
+
|
|
297
|
+
def as_short_name(self) -> str:
|
|
298
|
+
return self.name
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
class FromMerge(RecipeCreation):
|
|
302
|
+
recipe0: DatasetRecipe
|
|
303
|
+
recipe1: DatasetRecipe
|
|
304
|
+
|
|
305
|
+
@staticmethod
|
|
306
|
+
def get_function():
|
|
307
|
+
return HafniaDataset.merge
|
|
308
|
+
|
|
309
|
+
def as_short_name(self) -> str:
|
|
310
|
+
merger = FromMerger(recipes=[self.recipe0, self.recipe1])
|
|
311
|
+
return merger.as_short_name()
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
class FromMerger(RecipeCreation):
|
|
315
|
+
recipes: List[DatasetRecipe]
|
|
316
|
+
|
|
317
|
+
def build(self) -> HafniaDataset:
|
|
318
|
+
"""Build the dataset from the merged recipes."""
|
|
319
|
+
datasets = [recipe.build() for recipe in self.recipes]
|
|
320
|
+
return self.get_function()(datasets=datasets)
|
|
321
|
+
|
|
322
|
+
@staticmethod
|
|
323
|
+
def get_function():
|
|
324
|
+
return HafniaDataset.from_merger
|
|
325
|
+
|
|
326
|
+
def as_short_name(self) -> str:
|
|
327
|
+
return f"Merger({','.join(recipe.as_short_name() for recipe in self.recipes)})"
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Callable, Dict
|
|
2
|
+
|
|
3
|
+
from hafnia.dataset.dataset_recipe.recipe_types import RecipeTransform
|
|
4
|
+
from hafnia.dataset.hafnia_dataset import HafniaDataset
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
pass
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Shuffle(RecipeTransform):
|
|
11
|
+
seed: int = 42
|
|
12
|
+
|
|
13
|
+
@staticmethod
|
|
14
|
+
def get_function() -> Callable[..., "HafniaDataset"]:
|
|
15
|
+
return HafniaDataset.shuffle
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class SelectSamples(RecipeTransform):
|
|
19
|
+
n_samples: int
|
|
20
|
+
shuffle: bool = True
|
|
21
|
+
seed: int = 42
|
|
22
|
+
with_replacement: bool = False
|
|
23
|
+
|
|
24
|
+
@staticmethod
|
|
25
|
+
def get_function() -> Callable[..., "HafniaDataset"]:
|
|
26
|
+
return HafniaDataset.select_samples
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class SplitsByRatios(RecipeTransform):
|
|
30
|
+
split_ratios: Dict[str, float]
|
|
31
|
+
seed: int = 42
|
|
32
|
+
|
|
33
|
+
@staticmethod
|
|
34
|
+
def get_function() -> Callable[..., "HafniaDataset"]:
|
|
35
|
+
return HafniaDataset.splits_by_ratios
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class SplitIntoMultipleSplits(RecipeTransform):
|
|
39
|
+
split_name: str
|
|
40
|
+
split_ratios: Dict[str, float]
|
|
41
|
+
|
|
42
|
+
@staticmethod
|
|
43
|
+
def get_function() -> Callable[..., "HafniaDataset"]:
|
|
44
|
+
return HafniaDataset.split_into_multiple_splits
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class DefineSampleSetBySize(RecipeTransform):
|
|
48
|
+
n_samples: int
|
|
49
|
+
seed: int = 42
|
|
50
|
+
|
|
51
|
+
@staticmethod
|
|
52
|
+
def get_function() -> Callable[..., "HafniaDataset"]:
|
|
53
|
+
return HafniaDataset.define_sample_set_by_size
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from collections.abc import Callable
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Dict, List
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel, computed_field
|
|
8
|
+
|
|
9
|
+
from hafnia import utils
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from hafnia.dataset.hafnia_dataset import HafniaDataset
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Serializable(BaseModel, ABC):
|
|
16
|
+
@computed_field # type: ignore[prop-decorator]
|
|
17
|
+
@property
|
|
18
|
+
def __type__(self) -> str:
|
|
19
|
+
return self.__class__.__name__
|
|
20
|
+
|
|
21
|
+
@classmethod
|
|
22
|
+
def get_nested_subclasses(cls) -> List[type["Serializable"]]:
|
|
23
|
+
"""Recursively get all subclasses of a class."""
|
|
24
|
+
all_subclasses = []
|
|
25
|
+
for subclass in cls.__subclasses__():
|
|
26
|
+
all_subclasses.append(subclass)
|
|
27
|
+
all_subclasses.extend(subclass.get_nested_subclasses())
|
|
28
|
+
return all_subclasses
|
|
29
|
+
|
|
30
|
+
@classmethod
|
|
31
|
+
def name_to_type_mapping(cls) -> Dict[str, type["Serializable"]]:
|
|
32
|
+
"""Create a mapping from class names to class types."""
|
|
33
|
+
return {subclass.__name__: subclass for subclass in cls.get_nested_subclasses()}
|
|
34
|
+
|
|
35
|
+
@staticmethod
|
|
36
|
+
def from_dict(data: Dict) -> "Serializable":
|
|
37
|
+
dataset_spec_args = data.copy()
|
|
38
|
+
dataset_type_name = dataset_spec_args.pop("__type__", None)
|
|
39
|
+
name_to_type_mapping = Serializable.name_to_type_mapping()
|
|
40
|
+
SerializableClass = name_to_type_mapping[dataset_type_name]
|
|
41
|
+
return SerializableClass(**dataset_spec_args)
|
|
42
|
+
|
|
43
|
+
def get_kwargs(self, keep_default_fields: bool) -> Dict:
|
|
44
|
+
"""Return a dictionary of fields that are not set to their default values."""
|
|
45
|
+
kwargs = dict(self)
|
|
46
|
+
kwargs.pop("__type__", None)
|
|
47
|
+
|
|
48
|
+
if keep_default_fields:
|
|
49
|
+
return kwargs
|
|
50
|
+
|
|
51
|
+
kwargs_no_defaults = {}
|
|
52
|
+
for key, value in kwargs.items():
|
|
53
|
+
default_value = self.model_fields[key].get_default()
|
|
54
|
+
if value != default_value:
|
|
55
|
+
kwargs_no_defaults[key] = value
|
|
56
|
+
|
|
57
|
+
return kwargs_no_defaults
|
|
58
|
+
|
|
59
|
+
@abstractmethod
|
|
60
|
+
def as_short_name(self) -> str:
|
|
61
|
+
pass
|
|
62
|
+
|
|
63
|
+
def as_python_code(self, keep_default_fields: bool = False, as_kwargs: bool = True) -> str:
|
|
64
|
+
"""Generate code representation of the operation."""
|
|
65
|
+
kwargs = self.get_kwargs(keep_default_fields=keep_default_fields)
|
|
66
|
+
|
|
67
|
+
args_as_strs = []
|
|
68
|
+
for argument_name, argument_value in kwargs.items():
|
|
69
|
+
# In case an argument is a Serializable, we want to keep its default fields
|
|
70
|
+
str_value = recursive_as_code(argument_value, keep_default_fields=keep_default_fields, as_kwargs=as_kwargs)
|
|
71
|
+
if as_kwargs:
|
|
72
|
+
args_as_strs.append(f"{argument_name}={str_value}")
|
|
73
|
+
else:
|
|
74
|
+
args_as_strs.append(str_value)
|
|
75
|
+
|
|
76
|
+
args_as_str = ", ".join(args_as_strs)
|
|
77
|
+
class_name = self.__class__.__name__
|
|
78
|
+
function_name = utils.pascal_to_snake_case(class_name)
|
|
79
|
+
return f"{function_name}({args_as_str})"
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def recursive_as_code(value: Any, keep_default_fields: bool = False, as_kwargs: bool = True) -> str:
|
|
83
|
+
if isinstance(value, Serializable):
|
|
84
|
+
return value.as_python_code(keep_default_fields=keep_default_fields, as_kwargs=as_kwargs)
|
|
85
|
+
|
|
86
|
+
elif isinstance(value, list):
|
|
87
|
+
as_strs = []
|
|
88
|
+
for item in value:
|
|
89
|
+
str_item = recursive_as_code(item, keep_default_fields=keep_default_fields, as_kwargs=as_kwargs)
|
|
90
|
+
as_strs.append(str_item)
|
|
91
|
+
as_str = ", ".join(as_strs)
|
|
92
|
+
return f"[{as_str}]"
|
|
93
|
+
|
|
94
|
+
elif isinstance(value, dict):
|
|
95
|
+
as_strs = []
|
|
96
|
+
for key, item in value.items():
|
|
97
|
+
str_item = recursive_as_code(item, keep_default_fields=keep_default_fields, as_kwargs=as_kwargs)
|
|
98
|
+
as_strs.append(f"{key!r}: {str_item}")
|
|
99
|
+
as_str = ", ".join(as_strs)
|
|
100
|
+
return "{" + as_str + "}"
|
|
101
|
+
|
|
102
|
+
return f"{value!r}"
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class RecipeCreation(Serializable):
|
|
106
|
+
@staticmethod
|
|
107
|
+
@abstractmethod
|
|
108
|
+
def get_function() -> Callable[..., "HafniaDataset"]:
|
|
109
|
+
pass
|
|
110
|
+
|
|
111
|
+
def build(self) -> "HafniaDataset":
|
|
112
|
+
from hafnia.dataset.dataset_recipe.dataset_recipe import DatasetRecipe
|
|
113
|
+
|
|
114
|
+
kwargs = dict(self)
|
|
115
|
+
kwargs_recipes_as_datasets = {}
|
|
116
|
+
for key, value in kwargs.items():
|
|
117
|
+
if isinstance(value, DatasetRecipe):
|
|
118
|
+
value = value.build()
|
|
119
|
+
key = key.replace("recipe", "dataset")
|
|
120
|
+
kwargs_recipes_as_datasets[key] = value
|
|
121
|
+
return self.get_function()(**kwargs_recipes_as_datasets)
|
|
122
|
+
|
|
123
|
+
def as_python_code(self, keep_default_fields: bool = False, as_kwargs: bool = True) -> str:
|
|
124
|
+
"""Generate code representation of the operation."""
|
|
125
|
+
as_python_code = Serializable.as_python_code(self, keep_default_fields=keep_default_fields, as_kwargs=as_kwargs)
|
|
126
|
+
return f"DatasetRecipe.{as_python_code}"
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class RecipeTransform(Serializable):
|
|
130
|
+
@staticmethod
|
|
131
|
+
@abstractmethod
|
|
132
|
+
def get_function() -> Callable[..., "HafniaDataset"]:
|
|
133
|
+
pass
|
|
134
|
+
|
|
135
|
+
def build(self, dataset: "HafniaDataset") -> "HafniaDataset":
|
|
136
|
+
kwargs = dict(self)
|
|
137
|
+
return self.get_function()(dataset=dataset, **kwargs)
|
|
138
|
+
|
|
139
|
+
def as_short_name(self) -> str:
|
|
140
|
+
return self.__class__.__name__
|