hafnia 0.2.4__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. cli/__main__.py +16 -3
  2. cli/config.py +45 -4
  3. cli/consts.py +1 -1
  4. cli/dataset_cmds.py +6 -14
  5. cli/dataset_recipe_cmds.py +78 -0
  6. cli/experiment_cmds.py +226 -43
  7. cli/keychain.py +88 -0
  8. cli/profile_cmds.py +10 -6
  9. cli/runc_cmds.py +5 -5
  10. cli/trainer_package_cmds.py +65 -0
  11. hafnia/__init__.py +2 -0
  12. hafnia/data/factory.py +1 -2
  13. hafnia/dataset/dataset_helpers.py +9 -14
  14. hafnia/dataset/dataset_names.py +10 -5
  15. hafnia/dataset/dataset_recipe/dataset_recipe.py +165 -67
  16. hafnia/dataset/dataset_recipe/recipe_transforms.py +48 -4
  17. hafnia/dataset/dataset_recipe/recipe_types.py +1 -1
  18. hafnia/dataset/dataset_upload_helper.py +265 -56
  19. hafnia/dataset/format_conversions/image_classification_from_directory.py +106 -0
  20. hafnia/dataset/format_conversions/torchvision_datasets.py +281 -0
  21. hafnia/dataset/hafnia_dataset.py +577 -213
  22. hafnia/dataset/license_types.py +63 -0
  23. hafnia/dataset/operations/dataset_stats.py +259 -3
  24. hafnia/dataset/operations/dataset_transformations.py +332 -7
  25. hafnia/dataset/operations/table_transformations.py +43 -5
  26. hafnia/dataset/primitives/__init__.py +8 -0
  27. hafnia/dataset/primitives/bbox.py +25 -12
  28. hafnia/dataset/primitives/bitmask.py +26 -14
  29. hafnia/dataset/primitives/classification.py +16 -8
  30. hafnia/dataset/primitives/point.py +7 -3
  31. hafnia/dataset/primitives/polygon.py +16 -9
  32. hafnia/dataset/primitives/segmentation.py +10 -7
  33. hafnia/experiment/hafnia_logger.py +111 -8
  34. hafnia/http.py +16 -2
  35. hafnia/platform/__init__.py +9 -3
  36. hafnia/platform/builder.py +12 -10
  37. hafnia/platform/dataset_recipe.py +104 -0
  38. hafnia/platform/datasets.py +47 -9
  39. hafnia/platform/download.py +25 -19
  40. hafnia/platform/experiment.py +51 -56
  41. hafnia/platform/trainer_package.py +57 -0
  42. hafnia/utils.py +81 -13
  43. hafnia/visualizations/image_visualizations.py +4 -4
  44. {hafnia-0.2.4.dist-info → hafnia-0.4.0.dist-info}/METADATA +40 -34
  45. hafnia-0.4.0.dist-info/RECORD +56 -0
  46. cli/recipe_cmds.py +0 -45
  47. hafnia-0.2.4.dist-info/RECORD +0 -49
  48. {hafnia-0.2.4.dist-info → hafnia-0.4.0.dist-info}/WHEEL +0 -0
  49. {hafnia-0.2.4.dist-info → hafnia-0.4.0.dist-info}/entry_points.txt +0 -0
  50. {hafnia-0.2.4.dist-info → hafnia-0.4.0.dist-info}/licenses/LICENSE +0 -0
cli/profile_cmds.py CHANGED
@@ -14,7 +14,7 @@ def profile():
14
14
 
15
15
  @profile.command("ls")
16
16
  @click.pass_obj
17
- def profile_ls(cfg: Config) -> None:
17
+ def cmd_profile_ls(cfg: Config) -> None:
18
18
  """List all available profiles."""
19
19
  profiles = cfg.available_profiles
20
20
  if not profiles:
@@ -31,7 +31,7 @@ def profile_ls(cfg: Config) -> None:
31
31
  @profile.command("use")
32
32
  @click.argument("profile_name", required=True)
33
33
  @click.pass_obj
34
- def profile_use(cfg: Config, profile_name: str) -> None:
34
+ def cmd_profile_use(cfg: Config, profile_name: str) -> None:
35
35
  """Switch to a different profile."""
36
36
  if len(cfg.available_profiles) == 0:
37
37
  raise click.ClickException(consts.ERROR_CONFIGURE)
@@ -50,10 +50,13 @@ def profile_use(cfg: Config, profile_name: str) -> None:
50
50
  @click.option(
51
51
  "--activate/--no-activate", help="Activate the created profile after creation", default=True, show_default=True
52
52
  )
53
+ @click.option(
54
+ "--use-keychain", is_flag=True, help="Store API key in system keychain instead of config file", default=False
55
+ )
53
56
  @click.pass_obj
54
- def profile_create(cfg: Config, name: str, api_url: str, api_key: str, activate: bool) -> None:
57
+ def cmd_profile_create(cfg: Config, name: str, api_url: str, api_key: str, activate: bool, use_keychain: bool) -> None:
55
58
  """Create a new profile."""
56
- cfg_profile = ConfigSchema(platform_url=api_url, api_key=api_key)
59
+ cfg_profile = ConfigSchema(platform_url=api_url, api_key=api_key, use_keychain=use_keychain)
57
60
 
58
61
  cfg.add_profile(profile_name=name, profile=cfg_profile, set_active=activate)
59
62
  profile_show(cfg)
@@ -62,7 +65,7 @@ def profile_create(cfg: Config, name: str, api_url: str, api_key: str, activate:
62
65
  @profile.command("rm")
63
66
  @click.argument("profile_name", required=True)
64
67
  @click.pass_obj
65
- def profile_rm(cfg: Config, profile_name: str) -> None:
68
+ def cmd_profile_rm(cfg: Config, profile_name: str) -> None:
66
69
  """Remove a profile."""
67
70
  if len(cfg.available_profiles) == 0:
68
71
  raise click.ClickException(consts.ERROR_CONFIGURE)
@@ -80,7 +83,8 @@ def profile_rm(cfg: Config, profile_name: str) -> None:
80
83
 
81
84
  @profile.command("active")
82
85
  @click.pass_obj
83
- def profile_active(cfg: Config) -> None:
86
+ def cmd_profile_active(cfg: Config) -> None:
87
+ """Show the currently active profile."""
84
88
  try:
85
89
  profile_show(cfg)
86
90
  except Exception as e:
cli/runc_cmds.py CHANGED
@@ -13,7 +13,7 @@ from hafnia.log import sys_logger, user_logger
13
13
 
14
14
  @click.group(name="runc")
15
15
  def runc():
16
- """Experiment management commands"""
16
+ """Creating and running trainer packages locally"""
17
17
  pass
18
18
 
19
19
 
@@ -90,10 +90,10 @@ def launch_local(cfg: Config, exec_cmd: str, dataset: str, image_name: str) -> N
90
90
  @click.pass_obj
91
91
  def build(cfg: Config, recipe_url: str, state_file: str, repo: str) -> None:
92
92
  """Build docker image with a given recipe."""
93
- from hafnia.platform.builder import build_image, prepare_recipe
93
+ from hafnia.platform.builder import build_image, prepare_trainer_package
94
94
 
95
95
  with TemporaryDirectory() as temp_dir:
96
- metadata = prepare_recipe(recipe_url, Path(temp_dir), cfg.api_key)
96
+ metadata = prepare_trainer_package(recipe_url, Path(temp_dir), cfg.api_key)
97
97
  build_image(metadata, repo, state_file=state_file)
98
98
 
99
99
 
@@ -109,7 +109,7 @@ def build_local(recipe: Path, state_file: str, repo: str) -> None:
109
109
  import seedir
110
110
 
111
111
  from hafnia.platform.builder import build_image
112
- from hafnia.utils import filter_recipe_files
112
+ from hafnia.utils import filter_trainer_package_files
113
113
 
114
114
  recipe = Path(recipe)
115
115
 
@@ -123,7 +123,7 @@ def build_local(recipe: Path, state_file: str, repo: str) -> None:
123
123
  with zipfile.ZipFile(recipe.as_posix(), "r") as zip_ref:
124
124
  zip_ref.extractall(recipe_dir)
125
125
  elif recipe.is_dir():
126
- for rf in filter_recipe_files(recipe):
126
+ for rf in filter_trainer_package_files(recipe):
127
127
  src_path = (recipe / rf).absolute()
128
128
  target_path = recipe_dir / rf
129
129
  target_path.parent.mkdir(parents=True, exist_ok=True)
@@ -0,0 +1,65 @@
1
+ from pathlib import Path
2
+ from typing import Optional
3
+
4
+ import click
5
+
6
+ import cli.consts as consts
7
+ from cli.config import Config
8
+
9
+
10
+ @click.group(name="trainer")
11
+ def trainer_package() -> None:
12
+ """Trainer package commands"""
13
+ pass
14
+
15
+
16
+ @trainer_package.command(name="ls")
17
+ @click.pass_obj
18
+ @click.option("-l", "--limit", type=int, default=None, help="Limit number of listed trainer packages.")
19
+ def cmd_list_trainer_packages(cfg: Config, limit: Optional[int]) -> None:
20
+ """List available trainer packages on the platform"""
21
+
22
+ from hafnia.platform.trainer_package import get_trainer_packages, pretty_print_trainer_packages
23
+
24
+ endpoint = cfg.get_platform_endpoint("trainers")
25
+ trainers = get_trainer_packages(endpoint, cfg.api_key)
26
+
27
+ pretty_print_trainer_packages(trainers, limit=limit)
28
+
29
+
30
+ @trainer_package.command(name="create-zip")
31
+ @click.argument("source")
32
+ @click.option(
33
+ "--output",
34
+ type=click.Path(writable=True),
35
+ default="./trainer.zip",
36
+ show_default=True,
37
+ help="Output trainer package path.",
38
+ )
39
+ def cmd_create_trainer_package_zip(source: str, output: str) -> None:
40
+ """Create Hafnia trainer package as zip-file from local path"""
41
+
42
+ from hafnia.utils import archive_dir
43
+
44
+ path_output_zip = Path(output)
45
+ if path_output_zip.suffix != ".zip":
46
+ raise click.ClickException(consts.ERROR_TRAINER_PACKAGE_FILE_FORMAT)
47
+
48
+ path_source = Path(source)
49
+ path_output_zip = archive_dir(path_source, path_output_zip)
50
+
51
+
52
+ @trainer_package.command(name="view-zip")
53
+ @click.option("--path", type=str, default="./trainer.zip", show_default=True, help="Path of trainer.zip.")
54
+ @click.option("--depth-limit", type=int, default=3, help="Limit the depth of the tree view.", show_default=True)
55
+ def cmd_view_trainer_package_zip(path: str, depth_limit: int) -> None:
56
+ """View the content of a trainer package zip file."""
57
+ from hafnia.utils import show_trainer_package_content
58
+
59
+ path_trainer_package = Path(path)
60
+ if not path_trainer_package.exists():
61
+ raise click.ClickException(
62
+ f"Trainer package file '{path_trainer_package}' does not exist. Please provide a valid path. "
63
+ f"To create a trainer package, use the 'hafnia trainer create-zip' command."
64
+ )
65
+ show_trainer_package_content(path_trainer_package, depth_limit=depth_limit)
hafnia/__init__.py CHANGED
@@ -2,3 +2,5 @@ from importlib.metadata import version
2
2
 
3
3
  __package_name__ = "hafnia"
4
4
  __version__ = version(__package_name__)
5
+
6
+ __dataset_format_version__ = "0.1.0" # Hafnia dataset format version
hafnia/data/factory.py CHANGED
@@ -1,4 +1,3 @@
1
- import os
2
1
  from pathlib import Path
3
2
  from typing import Any
4
3
 
@@ -16,7 +15,7 @@ def load_dataset(recipe: Any, force_redownload: bool = False) -> HafniaDataset:
16
15
 
17
16
  def get_dataset_path(recipe: Any, force_redownload: bool = False) -> Path:
18
17
  if utils.is_hafnia_cloud_job():
19
- return Path(os.getenv("MDI_DATASET_DIR", "/opt/ml/input/data/training"))
18
+ return utils.get_dataset_path_in_hafnia_cloud()
20
19
 
21
20
  path_dataset = get_or_create_dataset_path_from_recipe(recipe, force_redownload=force_redownload)
22
21
 
@@ -38,12 +38,19 @@ def hash_from_bytes(data: bytes) -> str:
38
38
 
39
39
  def save_image_with_hash_name(image: np.ndarray, path_folder: Path) -> Path:
40
40
  pil_image = Image.fromarray(image)
41
+ path_image = save_pil_image_with_hash_name(pil_image, path_folder)
42
+ return path_image
43
+
44
+
45
+ def save_pil_image_with_hash_name(image: Image.Image, path_folder: Path, allow_skip: bool = True) -> Path:
41
46
  buffer = io.BytesIO()
42
- pil_image.save(buffer, format="PNG")
47
+ image.save(buffer, format="PNG")
43
48
  hash_value = hash_from_bytes(buffer.getvalue())
44
49
  path_image = Path(path_folder) / relative_path_from_hash(hash=hash_value, suffix=".png")
50
+ if allow_skip and path_image.exists():
51
+ return path_image
45
52
  path_image.parent.mkdir(parents=True, exist_ok=True)
46
- pil_image.save(path_image)
53
+ image.save(path_image)
47
54
  return path_image
48
55
 
49
56
 
@@ -110,15 +117,3 @@ def split_sizes_from_ratios(n_items: int, split_ratios: Dict[str, float]) -> Dic
110
117
  raise ValueError("Something is wrong. The split sizes do not match the number of items.")
111
118
 
112
119
  return split_sizes
113
-
114
-
115
- def select_evenly_across_list(lst: list, num_samples: int):
116
- if num_samples >= len(lst):
117
- return lst # No need to sample
118
- step = (len(lst) - 1) / (num_samples - 1)
119
- indices = [int(round(step * i)) for i in range(num_samples)] # noqa: RUF046
120
- return [lst[index] for index in indices]
121
-
122
-
123
- def prefix_dict(d: dict, prefix: str) -> dict:
124
- return {f"{prefix}.{k}": v for k, v in d.items()}
@@ -18,11 +18,14 @@ class DeploymentStage(Enum):
18
18
  PRODUCTION = "production"
19
19
 
20
20
 
21
+ TAG_IS_SAMPLE = "sample"
22
+
23
+ OPS_REMOVE_CLASS = "__REMOVE__"
24
+
25
+
21
26
  class FieldName:
22
27
  CLASS_NAME: str = "class_name" # Name of the class this primitive is associated with, e.g. "car" for Bbox
23
- CLASS_IDX: str = (
24
- "class_idx" # Index of the class this primitive is associated with, e.g. 0 for "car" if it is the first class
25
- )
28
+ CLASS_IDX: str = "class_idx" # Index of the class this primitive is associated with, e.g. 0 for "car" if it is the first class # noqa: E501
26
29
  OBJECT_ID: str = "object_id" # Unique identifier for the object, e.g. "12345123"
27
30
  CONFIDENCE: str = "confidence" # Confidence score (0-1.0) for the primitive, e.g. 0.95 for Bbox
28
31
 
@@ -46,13 +49,15 @@ class FieldName:
46
49
 
47
50
  class ColumnName:
48
51
  SAMPLE_INDEX: str = "sample_index"
49
- FILE_NAME: str = "file_name"
52
+ FILE_PATH: str = "file_path"
50
53
  HEIGHT: str = "height"
51
54
  WIDTH: str = "width"
52
55
  SPLIT: str = "split"
53
- IS_SAMPLE: str = "is_sample"
54
56
  REMOTE_PATH: str = "remote_path" # Path to the file in remote storage, e.g. S3
57
+ ATTRIBUTION: str = "attribution" # Attribution for the sample (image/video), e.g. creator, license, source, etc.
58
+ TAGS: str = "tags"
55
59
  META: str = "meta"
60
+ DATASET_NAME: str = "dataset_name"
56
61
 
57
62
 
58
63
  class SplitName:
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import json
4
4
  import os
5
5
  from pathlib import Path
6
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
6
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
7
7
 
8
8
  from pydantic import (
9
9
  field_serializer,
@@ -12,11 +12,13 @@ from pydantic import (
12
12
 
13
13
  from hafnia import utils
14
14
  from hafnia.dataset.dataset_recipe import recipe_transforms
15
- from hafnia.dataset.dataset_recipe.recipe_types import RecipeCreation, RecipeTransform, Serializable
15
+ from hafnia.dataset.dataset_recipe.recipe_types import (
16
+ RecipeCreation,
17
+ RecipeTransform,
18
+ Serializable,
19
+ )
16
20
  from hafnia.dataset.hafnia_dataset import HafniaDataset
17
-
18
- if TYPE_CHECKING:
19
- from hafnia.dataset.hafnia_dataset import HafniaDataset
21
+ from hafnia.dataset.primitives.primitive import Primitive
20
22
 
21
23
 
22
24
  class DatasetRecipe(Serializable):
@@ -43,6 +45,17 @@ class DatasetRecipe(Serializable):
43
45
  creation = FromName(name=name, force_redownload=force_redownload, download_files=download_files)
44
46
  return DatasetRecipe(creation=creation)
45
47
 
48
+ @staticmethod
49
+ def from_name_public_dataset(
50
+ name: str, force_redownload: bool = False, n_samples: Optional[int] = None
51
+ ) -> DatasetRecipe:
52
+ creation = FromNamePublicDataset(
53
+ name=name,
54
+ force_redownload=force_redownload,
55
+ n_samples=n_samples,
56
+ )
57
+ return DatasetRecipe(creation=creation)
58
+
46
59
  @staticmethod
47
60
  def from_path(path_folder: Path, check_for_images: bool = True) -> DatasetRecipe:
48
61
  creation = FromPath(path_folder=path_folder, check_for_images=check_for_images)
@@ -76,6 +89,42 @@ class DatasetRecipe(Serializable):
76
89
  json_str = path_json.read_text(encoding="utf-8")
77
90
  return DatasetRecipe.from_json_str(json_str)
78
91
 
92
+ @staticmethod
93
+ def from_dict(data: Dict[str, Any]) -> "DatasetRecipe":
94
+ """Deserialize from a dictionary."""
95
+ dataset_recipe = Serializable.from_dict(data)
96
+ return dataset_recipe
97
+
98
+ @staticmethod
99
+ def from_recipe_id(recipe_id: str) -> "DatasetRecipe":
100
+ """Loads a dataset recipe by id from the hafnia platform."""
101
+ from cli.config import Config
102
+ from hafnia.platform.dataset_recipe import get_dataset_recipe_by_id
103
+
104
+ cfg = Config()
105
+ endpoint_dataset = cfg.get_platform_endpoint("dataset_recipes")
106
+ recipe_dict = get_dataset_recipe_by_id(recipe_id, endpoint=endpoint_dataset, api_key=cfg.api_key)
107
+ recipe_dict = recipe_dict["template"]["body"]
108
+ if isinstance(recipe_dict, str):
109
+ return DatasetRecipe.from_implicit_form(recipe_dict)
110
+
111
+ recipe = DatasetRecipe.from_dict(recipe_dict)
112
+ return recipe
113
+
114
+ @staticmethod
115
+ def from_recipe_name(name: str) -> "DatasetRecipe":
116
+ """Loads a dataset recipe by name from the hafnia platform"""
117
+ from cli.config import Config
118
+ from hafnia.platform.dataset_recipe import get_dataset_recipe_by_name
119
+
120
+ cfg = Config()
121
+ endpoint_dataset = cfg.get_platform_endpoint("dataset_recipes")
122
+ recipe = get_dataset_recipe_by_name(name=name, endpoint=endpoint_dataset, api_key=cfg.api_key)
123
+ if not recipe:
124
+ raise ValueError(f"Dataset recipe '{name}' not found.")
125
+ recipe_id = recipe["id"]
126
+ return DatasetRecipe.from_recipe_id(recipe_id)
127
+
79
128
  @staticmethod
80
129
  def from_implicit_form(recipe: Any) -> DatasetRecipe:
81
130
  """
@@ -152,6 +201,60 @@ class DatasetRecipe(Serializable):
152
201
 
153
202
  raise ValueError(f"Unsupported recipe type: {type(recipe)}")
154
203
 
204
+ ### Upload, store and recipe conversions ###
205
+ def as_python_code(self, keep_default_fields: bool = False, as_kwargs: bool = True) -> str:
206
+ str_operations = [self.creation.as_python_code(keep_default_fields=keep_default_fields, as_kwargs=as_kwargs)]
207
+ if self.operations:
208
+ for op in self.operations:
209
+ str_operations.append(op.as_python_code(keep_default_fields=keep_default_fields, as_kwargs=as_kwargs))
210
+ operations_str = ".".join(str_operations)
211
+ return operations_str
212
+
213
+ def as_short_name(self) -> str:
214
+ """Return a short name for the transforms."""
215
+
216
+ creation_name = self.creation.as_short_name()
217
+ if self.operations is None or len(self.operations) == 0:
218
+ return creation_name
219
+ short_names = [creation_name]
220
+ for operation in self.operations:
221
+ short_names.append(operation.as_short_name())
222
+ transforms_str = ",".join(short_names)
223
+ return f"Recipe({transforms_str})"
224
+
225
+ def as_json_str(self, indent: int = 2) -> str:
226
+ """Serialize the dataset recipe to a JSON string."""
227
+ dict_data = self.as_dict()
228
+ return json.dumps(dict_data, indent=indent, ensure_ascii=False)
229
+
230
+ def as_json_file(self, path_json: Path, indent: int = 2) -> None:
231
+ """Serialize the dataset recipe to a JSON file."""
232
+ path_json.parent.mkdir(parents=True, exist_ok=True)
233
+ json_str = self.as_json_str(indent=indent)
234
+ path_json.write_text(json_str, encoding="utf-8")
235
+
236
+ def as_dict(self) -> dict:
237
+ """Serialize the dataset recipe to a dictionary."""
238
+ return self.model_dump(mode="json")
239
+
240
+ def as_platform_recipe(self, recipe_name: Optional[str], overwrite: bool = False) -> Dict:
241
+ """Uploads dataset recipe to the hafnia platform."""
242
+ from cli.config import Config
243
+ from hafnia.platform.dataset_recipe import get_or_create_dataset_recipe
244
+
245
+ recipe = self.as_dict()
246
+ cfg = Config()
247
+ endpoint_dataset = cfg.get_platform_endpoint("dataset_recipes")
248
+ recipe_dict = get_or_create_dataset_recipe(
249
+ recipe=recipe,
250
+ endpoint=endpoint_dataset,
251
+ api_key=cfg.api_key,
252
+ name=recipe_name,
253
+ overwrite=overwrite,
254
+ )
255
+
256
+ return recipe_dict
257
+
155
258
  ### Dataset Recipe Transformations ###
156
259
  def shuffle(recipe: DatasetRecipe, seed: int = 42) -> DatasetRecipe:
157
260
  operation = recipe_transforms.Shuffle(seed=seed)
@@ -159,10 +262,17 @@ class DatasetRecipe(Serializable):
159
262
  return recipe
160
263
 
161
264
  def select_samples(
162
- recipe: DatasetRecipe, n_samples: int, shuffle: bool = True, seed: int = 42, with_replacement: bool = False
265
+ recipe: DatasetRecipe,
266
+ n_samples: int,
267
+ shuffle: bool = True,
268
+ seed: int = 42,
269
+ with_replacement: bool = False,
163
270
  ) -> DatasetRecipe:
164
271
  operation = recipe_transforms.SelectSamples(
165
- n_samples=n_samples, shuffle=shuffle, seed=seed, with_replacement=with_replacement
272
+ n_samples=n_samples,
273
+ shuffle=shuffle,
274
+ seed=seed,
275
+ with_replacement=with_replacement,
166
276
  )
167
277
  recipe.append_operation(operation)
168
278
  return recipe
@@ -184,37 +294,36 @@ class DatasetRecipe(Serializable):
184
294
  recipe.append_operation(operation)
185
295
  return recipe
186
296
 
187
- ### Conversions ###
188
- def as_python_code(self, keep_default_fields: bool = False, as_kwargs: bool = True) -> str:
189
- str_operations = [self.creation.as_python_code(keep_default_fields=keep_default_fields, as_kwargs=as_kwargs)]
190
- if self.operations:
191
- for op in self.operations:
192
- str_operations.append(op.as_python_code(keep_default_fields=keep_default_fields, as_kwargs=as_kwargs))
193
- operations_str = ".".join(str_operations)
194
- return operations_str
195
-
196
- def as_short_name(self) -> str:
197
- """Return a short name for the transforms."""
198
-
199
- creation_name = self.creation.as_short_name()
200
- if self.operations is None or len(self.operations) == 0:
201
- return creation_name
202
- short_names = [creation_name]
203
- for operation in self.operations:
204
- short_names.append(operation.as_short_name())
205
- transforms_str = ",".join(short_names)
206
- return f"Recipe({transforms_str})"
297
+ def class_mapper(
298
+ recipe: DatasetRecipe,
299
+ class_mapping: Union[Dict[str, str], List[Tuple[str, str]]],
300
+ method: str = "strict",
301
+ primitive: Optional[Type[Primitive]] = None,
302
+ task_name: Optional[str] = None,
303
+ ) -> DatasetRecipe:
304
+ operation = recipe_transforms.ClassMapper(
305
+ class_mapping=class_mapping,
306
+ method=method,
307
+ primitive=primitive,
308
+ task_name=task_name,
309
+ )
310
+ recipe.append_operation(operation)
311
+ return recipe
207
312
 
208
- def as_json_str(self, indent: int = 2) -> str:
209
- """Serialize the dataset recipe to a JSON string."""
210
- data = self.model_dump(mode="json")
211
- # data = type_as_first_key(data)
212
- return json.dumps(data, indent=indent, ensure_ascii=False)
313
+ def rename_task(recipe: DatasetRecipe, old_task_name: str, new_task_name: str) -> DatasetRecipe:
314
+ operation = recipe_transforms.RenameTask(old_task_name=old_task_name, new_task_name=new_task_name)
315
+ recipe.append_operation(operation)
316
+ return recipe
213
317
 
214
- def as_json_file(self, path_json: Path, indent: int = 2) -> None:
215
- """Serialize the dataset recipe to a JSON file."""
216
- json_str = self.as_json_str(indent=indent)
217
- path_json.write_text(json_str, encoding="utf-8")
318
+ def select_samples_by_class_name(
319
+ recipe: DatasetRecipe,
320
+ name: Union[List[str], str],
321
+ task_name: Optional[str] = None,
322
+ primitive: Optional[Type[Primitive]] = None,
323
+ ) -> DatasetRecipe:
324
+ operation = recipe_transforms.SelectSamplesByClassName(name=name, task_name=task_name, primitive=primitive)
325
+ recipe.append_operation(operation)
326
+ return recipe
218
327
 
219
328
  ### Helper methods ###
220
329
  def get_dataset_names(self) -> List[str]:
@@ -314,6 +423,22 @@ class FromName(RecipeCreation):
314
423
  return [self.name]
315
424
 
316
425
 
426
+ class FromNamePublicDataset(RecipeCreation):
427
+ name: str
428
+ force_redownload: bool = False
429
+ n_samples: Optional[int] = None
430
+
431
+ @staticmethod
432
+ def get_function() -> Callable[..., "HafniaDataset"]:
433
+ return HafniaDataset.from_name_public_dataset
434
+
435
+ def as_short_name(self) -> str:
436
+ return f"Torchvision('{self.name}')"
437
+
438
+ def get_dataset_names(self) -> List[str]:
439
+ return []
440
+
441
+
317
442
  class FromMerge(RecipeCreation):
318
443
  recipe0: DatasetRecipe
319
444
  recipe1: DatasetRecipe
@@ -328,7 +453,10 @@ class FromMerge(RecipeCreation):
328
453
 
329
454
  def get_dataset_names(self) -> List[str]:
330
455
  """Get the dataset names from the merged recipes."""
331
- names = [*self.recipe0.creation.get_dataset_names(), *self.recipe1.creation.get_dataset_names()]
456
+ names = [
457
+ *self.recipe0.creation.get_dataset_names(),
458
+ *self.recipe1.creation.get_dataset_names(),
459
+ ]
332
460
  return names
333
461
 
334
462
 
@@ -353,33 +481,3 @@ class FromMerger(RecipeCreation):
353
481
  for recipe in self.recipes:
354
482
  names.extend(recipe.creation.get_dataset_names())
355
483
  return names
356
-
357
-
358
- def extract_dataset_names_from_json_dict(data: dict) -> list[str]:
359
- """
360
- Extract dataset names recursively from a JSON dictionary added with 'from_name'.
361
-
362
- Even if the same functionality is achieved with `DatasetRecipe.get_dataset_names()`,
363
- we want to keep this function in 'dipdatalib' to extract dataset names from json dictionaries
364
- directly.
365
- """
366
- creation_field = data.get("creation")
367
- if creation_field is None:
368
- return []
369
- if creation_field.get("__type__") == "FromName":
370
- return [creation_field["name"]]
371
- elif creation_field.get("__type__") == "FromMerge":
372
- recipe_names = ["recipe0", "recipe1"]
373
- dataset_name = []
374
- for recipe_name in recipe_names:
375
- recipe = creation_field.get(recipe_name)
376
- if recipe is None:
377
- continue
378
- dataset_name.extend(extract_dataset_names_from_json_dict(recipe))
379
- return dataset_name
380
- elif creation_field.get("__type__") == "FromMerger":
381
- dataset_name = []
382
- for recipe in creation_field.get("recipes", []):
383
- dataset_name.extend(extract_dataset_names_from_json_dict(recipe))
384
- return dataset_name
385
- return []
@@ -1,10 +1,10 @@
1
- from typing import TYPE_CHECKING, Callable, Dict
1
+ from typing import Callable, Dict, List, Optional, Tuple, Type, Union
2
+
3
+ from pydantic import field_validator
2
4
 
3
5
  from hafnia.dataset.dataset_recipe.recipe_types import RecipeTransform
4
6
  from hafnia.dataset.hafnia_dataset import HafniaDataset
5
-
6
- if TYPE_CHECKING:
7
- pass
7
+ from hafnia.dataset.primitives.primitive import Primitive
8
8
 
9
9
 
10
10
  class Shuffle(RecipeTransform):
@@ -51,3 +51,47 @@ class DefineSampleSetBySize(RecipeTransform):
51
51
  @staticmethod
52
52
  def get_function() -> Callable[..., "HafniaDataset"]:
53
53
  return HafniaDataset.define_sample_set_by_size
54
+
55
+
56
+ class ClassMapper(RecipeTransform):
57
+ class_mapping: Union[Dict[str, str], List[Tuple[str, str]]]
58
+ method: str = "strict"
59
+ primitive: Optional[Type[Primitive]] = None
60
+ task_name: Optional[str] = None
61
+
62
+ @field_validator("class_mapping", mode="after")
63
+ @classmethod
64
+ def serialize_class_mapping(cls, value: Union[Dict[str, str], List[Tuple[str, str]]]) -> List[Tuple[str, str]]:
65
+ # Converts the dictionary class mapping to a list of tuples
66
+ # e.g. {"old_class": "new_class", } --> [("old_class", "new_class")]
67
+ # The reason is that storing class mappings as a dictionary does not preserve order of json fields
68
+ # when stored in a database as a jsonb field (postgres).
69
+ # Preserving order of class mapping fields is important as it defines the indices of the classes.
70
+ # So to ensure that class indices are maintained, we preserve order of json fields, by converting the
71
+ # dictionary to a list of tuples.
72
+ if isinstance(value, dict):
73
+ value = list(value.items())
74
+ return value
75
+
76
+ @staticmethod
77
+ def get_function() -> Callable[..., "HafniaDataset"]:
78
+ return HafniaDataset.class_mapper
79
+
80
+
81
+ class RenameTask(RecipeTransform):
82
+ old_task_name: str
83
+ new_task_name: str
84
+
85
+ @staticmethod
86
+ def get_function() -> Callable[..., "HafniaDataset"]:
87
+ return HafniaDataset.rename_task
88
+
89
+
90
+ class SelectSamplesByClassName(RecipeTransform):
91
+ name: Union[List[str], str]
92
+ task_name: Optional[str] = None
93
+ primitive: Optional[Type[Primitive]] = None
94
+
95
+ @staticmethod
96
+ def get_function() -> Callable[..., "HafniaDataset"]:
97
+ return HafniaDataset.select_samples_by_class_name
@@ -8,7 +8,7 @@ from pydantic import BaseModel, computed_field
8
8
 
9
9
  from hafnia import utils
10
10
 
11
- if TYPE_CHECKING:
11
+ if TYPE_CHECKING: # Using 'TYPE_CHECKING' to avoid circular imports during type checking
12
12
  from hafnia.dataset.hafnia_dataset import HafniaDataset
13
13
 
14
14