hafnia 0.1.24__py3-none-any.whl → 0.1.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cli/__main__.py CHANGED
@@ -1,7 +1,7 @@
1
1
  #!/usr/bin/env python
2
2
  import click
3
3
 
4
- from cli import consts, data_cmds, experiment_cmds, profile_cmds, runc_cmds
4
+ from cli import consts, data_cmds, experiment_cmds, profile_cmds, recipe_cmds, runc_cmds
5
5
  from cli.config import Config, ConfigSchema
6
6
 
7
7
 
@@ -10,6 +10,7 @@ from cli.config import Config, ConfigSchema
10
10
  def main(ctx: click.Context) -> None:
11
11
  """Hafnia CLI."""
12
12
  ctx.obj = Config()
13
+ ctx.max_content_width = 120
13
14
 
14
15
 
15
16
  @main.command("configure")
@@ -17,9 +18,7 @@ def main(ctx: click.Context) -> None:
17
18
  def configure(cfg: Config) -> None:
18
19
  """Configure Hafnia CLI settings."""
19
20
 
20
- from hafnia.platform.api import get_organization_id
21
-
22
- profile_name = click.prompt("Profile Name", type=str, default="default")
21
+ profile_name = click.prompt("Profile Name", type=str, default=consts.DEFAULT_PROFILE_NAME)
23
22
  profile_name = profile_name.strip()
24
23
  try:
25
24
  cfg.add_profile(profile_name, ConfigSchema(), set_active=True)
@@ -32,12 +31,8 @@ def configure(cfg: Config) -> None:
32
31
  except ValueError as e:
33
32
  click.echo(f"Error: {str(e)}", err=True)
34
33
  return
35
- platform_url = click.prompt("Hafnia Platform URL", type=str, default="https://api.mdi.milestonesys.com")
34
+ platform_url = click.prompt("Hafnia Platform URL", type=str, default=consts.DEFAULT_API_URL)
36
35
  cfg.platform_url = platform_url.strip()
37
- try:
38
- cfg.organization_id = get_organization_id(cfg.get_platform_endpoint("organizations"), cfg.api_key)
39
- except Exception:
40
- raise click.ClickException(consts.ERROR_ORG_ID)
41
36
  cfg.save_config()
42
37
  profile_cmds.profile_show(cfg)
43
38
 
@@ -54,6 +49,7 @@ main.add_command(profile_cmds.profile)
54
49
  main.add_command(data_cmds.data)
55
50
  main.add_command(runc_cmds.runc)
56
51
  main.add_command(experiment_cmds.experiment)
52
+ main.add_command(recipe_cmds.recipe)
57
53
 
58
54
  if __name__ == "__main__":
59
- main()
55
+ main(max_content_width=120)
cli/config.py CHANGED
@@ -6,14 +6,21 @@ from typing import Dict, List, Optional
6
6
  from pydantic import BaseModel, field_validator
7
7
 
8
8
  import cli.consts as consts
9
- from hafnia.log import logger
9
+ from hafnia.log import user_logger
10
+
11
+ PLATFORM_API_MAPPING = {
12
+ "recipes": "/api/v1/recipes",
13
+ "experiments": "/api/v1/experiments",
14
+ "experiment_environments": "/api/v1/experiment-environments",
15
+ "experiment_runs": "/api/v1/experiment-runs",
16
+ "runs": "/api/v1/experiments-runs",
17
+ "datasets": "/api/v1/datasets",
18
+ }
10
19
 
11
20
 
12
21
  class ConfigSchema(BaseModel):
13
- organization_id: str = ""
14
22
  platform_url: str = ""
15
23
  api_key: Optional[str] = None
16
- api_mapping: Optional[Dict[str, str]] = None
17
24
 
18
25
  @field_validator("api_key")
19
26
  def validate_api_key(cls, value: str) -> str:
@@ -61,14 +68,6 @@ class Config:
61
68
  def api_key(self, value: str) -> None:
62
69
  self.config.api_key = value
63
70
 
64
- @property
65
- def organization_id(self) -> str:
66
- return self.config.organization_id
67
-
68
- @organization_id.setter
69
- def organization_id(self, value: str) -> None:
70
- self.config.organization_id = value
71
-
72
71
  @property
73
72
  def platform_url(self) -> str:
74
73
  return self.config.platform_url
@@ -77,7 +76,6 @@ class Config:
77
76
  def platform_url(self, value: str) -> None:
78
77
  base_url = value.rstrip("/")
79
78
  self.config.platform_url = base_url
80
- self.config.api_mapping = self.get_api_mapping(base_url)
81
79
 
82
80
  def __init__(self, config_path: Optional[Path] = None) -> None:
83
81
  self.config_path = self.resolve_config_path(config_path)
@@ -96,27 +94,22 @@ class Config:
96
94
 
97
95
  def add_profile(self, profile_name: str, profile: ConfigSchema, set_active: bool = False) -> None:
98
96
  profile_name = profile_name.strip()
97
+ if profile_name in self.config_data.profiles:
98
+ user_logger.warning(
99
+ f"Profile with name '{profile_name}' already exists, it will be overwritten by the new one."
100
+ )
101
+
99
102
  self.config_data.profiles[profile_name] = profile
100
103
  if set_active:
101
104
  self.config_data.active_profile = profile_name
102
105
  self.save_config()
103
106
 
104
- def get_api_mapping(self, base_url: str) -> Dict:
105
- return {
106
- "organizations": f"{base_url}/api/v1/organizations",
107
- "recipes": f"{base_url}/api/v1/recipes",
108
- "experiments": f"{base_url}/api/v1/experiments",
109
- "experiment_environments": f"{base_url}/api/v1/experiment-environments",
110
- "experiment_runs": f"{base_url}/api/v1/experiment-runs",
111
- "runs": f"{base_url}/api/v1/experiments-runs",
112
- "datasets": f"{base_url}/api/v1/datasets",
113
- }
114
-
115
107
  def get_platform_endpoint(self, method: str) -> str:
116
108
  """Get specific API endpoint"""
117
- if not self.config.api_mapping or method not in self.config.api_mapping:
118
- raise ValueError(f"{method} is not supported.")
119
- return self.config.api_mapping[method]
109
+ if method not in PLATFORM_API_MAPPING:
110
+ raise ValueError(f"'{method}' is not supported.")
111
+ endpoint = self.config.platform_url + PLATFORM_API_MAPPING[method]
112
+ return endpoint
120
113
 
121
114
  def load_config(self) -> ConfigFileSchema:
122
115
  """Load configuration from file."""
@@ -127,7 +120,7 @@ class Config:
127
120
  data = json.load(f)
128
121
  return ConfigFileSchema(**data)
129
122
  except json.JSONDecodeError:
130
- logger.error("Error decoding JSON file.")
123
+ user_logger.error("Error decoding JSON file.")
131
124
  raise ValueError("Failed to parse configuration file")
132
125
 
133
126
  def save_config(self) -> None:
cli/consts.py CHANGED
@@ -1,13 +1,16 @@
1
+ DEFAULT_API_URL = "https://api.mdi.milestonesys.com"
2
+ DEFAULT_PROFILE_NAME = "default"
3
+
1
4
  ERROR_CONFIGURE: str = "Please configure the CLI with `hafnia configure`"
2
5
  ERROR_PROFILE_NOT_EXIST: str = "No active profile configured. Please configure the CLI with `hafnia configure`"
3
6
  ERROR_PROFILE_REMOVE_ACTIVE: str = "Cannot remove active profile. Please switch to another profile first."
4
7
  ERROR_API_KEY_NOT_SET: str = "API key not set. Please configure the CLI with `hafnia configure`."
5
- ERROR_ORG_ID: str = "Failed to fetch organization ID. Verify platform URL and API key."
6
8
  ERROR_CREATE_PROFILE: str = "Failed to create profile. Profile name must be unique and not empty."
7
9
 
8
10
  ERROR_GET_RESOURCE: str = "Failed to get the data from platform. Verify url or api key."
9
11
 
10
12
  ERROR_EXPERIMENT_DIR: str = "Source directory does not exist"
13
+ ERROR_RECIPE_FILE_FORMAT: str = "Recipe filename must be a '.zip' file"
11
14
 
12
15
  PROFILE_SWITCHED_SUCCESS: str = "Switched to profile:"
13
16
  PROFILE_REMOVED_SUCCESS: str = "Removed profile:"
cli/data_cmds.py CHANGED
@@ -1,3 +1,4 @@
1
+ from pathlib import Path
1
2
  from typing import Optional
2
3
 
3
4
  import click
@@ -35,20 +36,18 @@ def data_get(cfg: Config, url: str, destination: click.Path) -> None:
35
36
  @click.argument("destination", default=None, required=False)
36
37
  @click.option("--force", is_flag=True, default=False, help="Force download")
37
38
  @click.pass_obj
38
- def data_download(cfg: Config, dataset_name: str, destination: Optional[click.Path], force: bool) -> None:
39
+ def data_download(cfg: Config, dataset_name: str, destination: Optional[click.Path], force: bool) -> Path:
39
40
  """Download dataset from Hafnia platform"""
40
41
 
41
42
  from hafnia.data.factory import download_or_get_dataset_path
42
43
 
43
44
  try:
44
- endpoint_dataset = cfg.get_platform_endpoint("datasets")
45
- api_key = cfg.api_key
46
- download_or_get_dataset_path(
45
+ path_dataset = download_or_get_dataset_path(
47
46
  dataset_name=dataset_name,
48
- endpoint=endpoint_dataset,
49
- api_key=api_key,
47
+ cfg=cfg,
50
48
  output_dir=destination,
51
49
  force_redownload=force,
52
50
  )
53
51
  except Exception:
54
52
  raise click.ClickException(consts.ERROR_GET_RESOURCE)
53
+ return path_dataset
cli/experiment_cmds.py CHANGED
@@ -13,32 +13,6 @@ def experiment() -> None:
13
13
  pass
14
14
 
15
15
 
16
- @experiment.command(name="create_recipe")
17
- @click.option("--source_folder", default=".", type=Path, help="Path to the source folder", show_default=True)
18
- @click.option(
19
- "--recipe_filename",
20
- default="recipe.zip",
21
- type=Path,
22
- help="Recipe filename. Should have a '.zip' suffix",
23
- show_default=True,
24
- )
25
- def create_recipe(source_folder: str, recipe_filename: str) -> None:
26
- """Build recipe from local path as image with prefix - localhost"""
27
-
28
- from hafnia.platform.builder import validate_recipe
29
- from hafnia.utils import archive_dir
30
-
31
- path_output_zip = Path(recipe_filename)
32
-
33
- if path_output_zip.suffix != ".zip":
34
- raise click.ClickException("Recipe filename must be a '.zip' file")
35
-
36
- path_source = Path(source_folder)
37
-
38
- path_output_zip = archive_dir(path_source, path_output_zip)
39
- validate_recipe(path_output_zip)
40
-
41
-
42
16
  @experiment.command(name="create")
43
17
  @click.argument("name")
44
18
  @click.argument("source_dir", type=Path)
@@ -48,25 +22,18 @@ def create_recipe(source_folder: str, recipe_filename: str) -> None:
48
22
  @click.pass_obj
49
23
  def create(cfg: Config, name: str, source_dir: Path, exec_cmd: str, dataset_name: str, env_name: str) -> None:
50
24
  """Create a new experiment run"""
51
- from hafnia.platform import (
52
- create_experiment,
53
- create_recipe,
54
- get_dataset_id,
55
- get_exp_environment_id,
56
- )
25
+ from hafnia.platform import create_experiment, create_recipe, get_dataset_id, get_exp_environment_id
57
26
 
58
27
  if not source_dir.exists():
59
28
  raise click.ClickException(consts.ERROR_EXPERIMENT_DIR)
60
29
 
61
30
  try:
62
31
  dataset_id = get_dataset_id(dataset_name, cfg.get_platform_endpoint("datasets"), cfg.api_key)
63
- except (IndexError, KeyError):
64
- raise click.ClickException(f"Dataset '{dataset_name}' not found.")
65
32
  except Exception:
66
33
  raise click.ClickException(f"Error retrieving dataset '{dataset_name}'.")
67
34
 
68
35
  try:
69
- recipe_id = create_recipe(source_dir, cfg.get_platform_endpoint("recipes"), cfg.api_key, cfg.organization_id)
36
+ recipe_id = create_recipe(source_dir, cfg.get_platform_endpoint("recipes"), cfg.api_key)
70
37
  except Exception:
71
38
  raise click.ClickException(f"Failed to create recipe from '{source_dir}'")
72
39
 
@@ -77,14 +44,7 @@ def create(cfg: Config, name: str, source_dir: Path, exec_cmd: str, dataset_name
77
44
 
78
45
  try:
79
46
  experiment_id = create_experiment(
80
- name,
81
- dataset_id,
82
- recipe_id,
83
- exec_cmd,
84
- env_id,
85
- cfg.get_platform_endpoint("experiments"),
86
- cfg.api_key,
87
- cfg.organization_id,
47
+ name, dataset_id, recipe_id, exec_cmd, env_id, cfg.get_platform_endpoint("experiments"), cfg.api_key
88
48
  )
89
49
  except Exception:
90
50
  raise click.ClickException(f"Failed to create experiment '{name}'")
cli/profile_cmds.py CHANGED
@@ -3,7 +3,7 @@ from rich.console import Console
3
3
  from rich.table import Table
4
4
 
5
5
  import cli.consts as consts
6
- from cli.config import Config
6
+ from cli.config import Config, ConfigSchema
7
7
 
8
8
 
9
9
  @click.group()
@@ -43,6 +43,21 @@ def profile_use(cfg: Config, profile_name: str) -> None:
43
43
  click.echo(f"{consts.PROFILE_SWITCHED_SUCCESS} {profile_name}")
44
44
 
45
45
 
46
+ @profile.command("create")
47
+ @click.argument("api-key", required=True)
48
+ @click.option("--name", help="Specify profile name", default=consts.DEFAULT_PROFILE_NAME, show_default=True)
49
+ @click.option("--api-url", help="API URL", default=consts.DEFAULT_API_URL, show_default=True)
50
+ @click.option(
51
+ "--activate/--no-activate", help="Activate the created profile after creation", default=True, show_default=True
52
+ )
53
+ @click.pass_obj
54
+ def profile_create(cfg: Config, name: str, api_url: str, api_key: str, activate: bool) -> None:
55
+ """Create a new profile."""
56
+ cfg_profile = ConfigSchema(platform_url=api_url, api_key=api_key)
57
+
58
+ cfg.add_profile(profile_name=name, profile=cfg_profile, set_active=activate)
59
+
60
+
46
61
  @profile.command("rm")
47
62
  @click.argument("profile_name", required=True)
48
63
  @click.pass_obj
@@ -80,7 +95,6 @@ def profile_show(cfg: Config) -> None:
80
95
  table.add_column("Value")
81
96
 
82
97
  table.add_row("API Key", masked_key)
83
- table.add_row("Organization", cfg.organization_id)
84
98
  table.add_row("Platform URL", cfg.platform_url)
85
99
  table.add_row("Config File", cfg.config_path.as_posix())
86
100
  console.print(table)
cli/recipe_cmds.py ADDED
@@ -0,0 +1,45 @@
1
+ from pathlib import Path
2
+
3
+ import click
4
+
5
+ import cli.consts as consts
6
+
7
+
8
+ @click.group(name="recipe")
9
+ def recipe() -> None:
10
+ """Hafnia Recipe management commands"""
11
+ pass
12
+
13
+
14
+ @recipe.command(name="create")
15
+ @click.argument("source")
16
+ @click.option(
17
+ "--output", type=click.Path(writable=True), default="./recipe.zip", show_default=True, help="Output recipe path."
18
+ )
19
+ def create(source: str, output: str) -> None:
20
+ """Create HRF from local path"""
21
+
22
+ from hafnia.utils import archive_dir
23
+
24
+ path_output_zip = Path(output)
25
+ if path_output_zip.suffix != ".zip":
26
+ raise click.ClickException(consts.ERROR_RECIPE_FILE_FORMAT)
27
+
28
+ path_source = Path(source)
29
+ path_output_zip = archive_dir(path_source, path_output_zip)
30
+
31
+
32
+ @recipe.command(name="view")
33
+ @click.option("--path", type=str, default="./recipe.zip", show_default=True, help="Path of recipe.zip.")
34
+ @click.option("--depth-limit", type=int, default=3, help="Limit the depth of the tree view.", show_default=True)
35
+ def view(path: str, depth_limit: int) -> None:
36
+ """View the content of a recipe zip file."""
37
+ from hafnia.utils import show_recipe_content
38
+
39
+ path_recipe = Path(path)
40
+ if not path_recipe.exists():
41
+ raise click.ClickException(
42
+ f"Recipe file '{path_recipe}' does not exist. Please provide a valid path. "
43
+ f"To create a recipe, use the 'hafnia recipe create' command."
44
+ )
45
+ show_recipe_content(path_recipe, depth_limit=depth_limit)
cli/runc_cmds.py CHANGED
@@ -1,10 +1,14 @@
1
- from hashlib import sha256
1
+ import json
2
+ import subprocess
3
+ import zipfile
2
4
  from pathlib import Path
3
5
  from tempfile import TemporaryDirectory
6
+ from typing import Optional
4
7
 
5
8
  import click
6
9
 
7
10
  from cli.config import Config
11
+ from hafnia.log import sys_logger, user_logger
8
12
 
9
13
 
10
14
  @click.group(name="runc")
@@ -13,56 +17,128 @@ def runc():
13
17
  pass
14
18
 
15
19
 
16
- @runc.command(name="launch")
17
- @click.argument("task", required=True)
18
- def launch(task: str) -> None:
20
+ @runc.command(name="launch-local")
21
+ @click.argument("exec_cmd", type=str)
22
+ @click.option(
23
+ "--dataset",
24
+ type=str,
25
+ help="Hafnia dataset name e.g. mnist, midwest-vehicle-detection or a path to a local dataset",
26
+ required=True,
27
+ )
28
+ @click.option(
29
+ "--image_name",
30
+ type=Optional[str],
31
+ default=None,
32
+ help=(
33
+ "Docker image name to use for the launch. "
34
+ "By default, it will use image name from '.state.json' "
35
+ "file generated by the 'hafnia runc build-local' command"
36
+ ),
37
+ )
38
+ @click.pass_obj
39
+ def launch_local(cfg: Config, exec_cmd: str, dataset: str, image_name: str) -> None:
19
40
  """Launch a job within the image."""
20
- from hafnia.platform.executor import handle_launch
41
+ from hafnia.data.factory import download_or_get_dataset_path
42
+
43
+ is_local_dataset = "/" in dataset
44
+ if is_local_dataset:
45
+ click.echo(f"Using local dataset: {dataset}")
46
+ path_dataset = Path(dataset)
47
+ if not path_dataset.exists():
48
+ raise click.ClickException(f"Dataset path does not exist: {path_dataset}")
49
+ else:
50
+ click.echo(f"Using Hafnia dataset: {dataset}")
51
+ path_dataset = download_or_get_dataset_path(dataset_name=dataset, cfg=cfg, force_redownload=False)
52
+
53
+ if image_name is None:
54
+ # Load image name from state.json
55
+ path_state_file = Path("state.json")
56
+ if not path_state_file.exists():
57
+ raise click.ClickException("State file does not exist. Please build the image first.")
58
+ state_dict = json.loads(path_state_file.read_text())
59
+ if "image_tag" not in state_dict:
60
+ raise click.ClickException("'image_tag' not found in state file. Please build the image first.")
61
+ image_name = state_dict["image_tag"]
21
62
 
22
- handle_launch(task)
63
+ docker_cmds = [
64
+ "docker",
65
+ "run",
66
+ "--rm",
67
+ "-v",
68
+ f"{path_dataset.absolute()}:/opt/ml/input/data/training",
69
+ "-e",
70
+ "HAFNIA_CLOUD=true",
71
+ "-e",
72
+ "PYTHONPATH=src",
73
+ "--runtime",
74
+ "nvidia",
75
+ image_name,
76
+ ] + exec_cmd.split(" ")
77
+
78
+ # Use the "hafnia runc launch" cmd when we have moved to the new folder structure and
79
+ # direct commands.
80
+ # Replace '+ exec_cmd.split(" ")' with '["hafnia", "runc", "launch"] + exec_cmd.split(" ")'
81
+
82
+ click.echo(f"Running command: \n\t{' '.join(docker_cmds)}")
83
+ subprocess.run(docker_cmds, check=True)
23
84
 
24
85
 
25
86
  @runc.command(name="build")
26
87
  @click.argument("recipe_url")
27
- @click.argument("state_file", default="state.json")
28
- @click.argument("ecr_repository", default="localhost")
29
- @click.argument("image_name", default="recipe")
88
+ @click.option("--state_file", "--st", type=str, default="state.json")
89
+ @click.option("--repo", type=str, default="localhost", help="Docker repository")
30
90
  @click.pass_obj
31
- def build(cfg: Config, recipe_url: str, state_file: str, ecr_repository: str, image_name: str) -> None:
91
+ def build(cfg: Config, recipe_url: str, state_file: str, repo: str) -> None:
32
92
  """Build docker image with a given recipe."""
33
93
  from hafnia.platform.builder import build_image, prepare_recipe
34
94
 
35
95
  with TemporaryDirectory() as temp_dir:
36
- image_info = prepare_recipe(recipe_url, Path(temp_dir), cfg.api_key)
37
- image_info["name"] = image_name
38
- build_image(image_info, ecr_repository, state_file)
96
+ metadata = prepare_recipe(recipe_url, Path(temp_dir), cfg.api_key)
97
+ build_image(metadata, repo, state_file=state_file)
39
98
 
40
99
 
41
100
  @runc.command(name="build-local")
42
101
  @click.argument("recipe")
43
- @click.argument("state_file", default="state.json")
44
- @click.argument("image_name", default="recipe")
45
- def build_local(recipe: str, state_file: str, image_name: str) -> None:
102
+ @click.option("--state_file", "--st", type=str, default="state.json")
103
+ @click.option("--repo", type=str, default="localhost", help="Docker repository")
104
+ def build_local(recipe: Path, state_file: str, repo: str) -> None:
46
105
  """Build recipe from local path as image with prefix - localhost"""
106
+ import shutil
107
+ import uuid
108
+
109
+ import seedir
110
+
111
+ from hafnia.platform.builder import build_image
112
+ from hafnia.utils import filter_recipe_files
113
+
114
+ recipe = Path(recipe)
115
+
116
+ with TemporaryDirectory() as d:
117
+ tmp_dir = Path(d)
118
+ recipe_dir = tmp_dir / "recipe"
119
+ recipe_dir.mkdir(parents=True, exist_ok=True)
120
+
121
+ if recipe.suffix == ".zip":
122
+ user_logger.info("Extracting recipe for processing.")
123
+ with zipfile.ZipFile(recipe.as_posix(), "r") as zip_ref:
124
+ zip_ref.extractall(recipe_dir)
125
+ elif recipe.is_dir():
126
+ for rf in filter_recipe_files(recipe):
127
+ src_path = (recipe / rf).absolute()
128
+ target_path = recipe_dir / rf
129
+ target_path.parent.mkdir(parents=True, exist_ok=True)
130
+ shutil.copyfile(src_path, target_path)
131
+
132
+ user_logger.info(
133
+ seedir.seedir(recipe_dir, sort=True, first="folders", style="emoji", printout=False, depthlimit=2)
134
+ )
135
+
136
+ metadata = {
137
+ "dockerfile": (recipe_dir / "Dockerfile").as_posix(),
138
+ "docker_context": recipe_dir.as_posix(),
139
+ "digest": uuid.uuid4().hex[:8],
140
+ }
47
141
 
48
- from hafnia.platform.builder import build_image, validate_recipe
49
- from hafnia.utils import archive_dir
50
-
51
- recipe_zip = Path(recipe)
52
- recipe_created = False
53
- if not recipe_zip.suffix == ".zip" and recipe_zip.is_dir():
54
- recipe_zip = archive_dir(recipe_zip)
55
- recipe_created = True
56
-
57
- validate_recipe(recipe_zip)
58
- click.echo("Recipe successfully validated")
59
- image_info = {
60
- "name": image_name,
61
- "dockerfile": f"{recipe_zip.parent}/Dockerfile",
62
- "docker_context": f"{recipe_zip.parent}",
63
- "hash": sha256(recipe_zip.read_bytes()).hexdigest()[:8],
64
- }
65
- click.echo("Start building image")
66
- build_image(image_info, "localhost", state_file=state_file)
67
- if recipe_created:
68
- recipe_zip.unlink()
142
+ user_logger.info("Start building image.")
143
+ sys_logger.debug(metadata)
144
+ build_image(metadata, repo, state_file=state_file)
hafnia/data/factory.py CHANGED
@@ -7,7 +7,7 @@ from datasets import Dataset, DatasetDict, load_from_disk
7
7
 
8
8
  from cli.config import Config
9
9
  from hafnia import utils
10
- from hafnia.log import logger
10
+ from hafnia.log import user_logger
11
11
  from hafnia.platform import download_resource, get_dataset_id
12
12
 
13
13
 
@@ -15,29 +15,33 @@ def load_local(dataset_path: Path) -> Union[Dataset, DatasetDict]:
15
15
  """Load a Hugging Face dataset from a local directory path."""
16
16
  if not dataset_path.exists():
17
17
  raise ValueError(f"Can not load dataset, directory does not exist -- {dataset_path}")
18
- logger.info(f"Loading data from {dataset_path.as_posix()}")
18
+ user_logger.info(f"Loading data from {dataset_path.as_posix()}")
19
19
  return load_from_disk(dataset_path.as_posix())
20
20
 
21
21
 
22
22
  def download_or_get_dataset_path(
23
23
  dataset_name: str,
24
- endpoint: str,
25
- api_key: str,
24
+ cfg: Optional[Config] = None,
26
25
  output_dir: Optional[str] = None,
27
26
  force_redownload: bool = False,
28
27
  ) -> Path:
29
28
  """Download or get the path of the dataset."""
29
+
30
+ cfg = cfg or Config()
31
+ endpoint_dataset = cfg.get_platform_endpoint("datasets")
32
+ api_key = cfg.api_key
33
+
30
34
  output_dir = output_dir or str(utils.PATH_DATASET)
31
35
  dataset_path_base = Path(output_dir).absolute() / dataset_name
32
36
  dataset_path_base.mkdir(exist_ok=True, parents=True)
33
37
  dataset_path_sample = dataset_path_base / "sample"
34
38
 
35
39
  if dataset_path_sample.exists() and not force_redownload:
36
- logger.info("Dataset found locally. Set 'force=True' or add `--force` flag with cli to re-download")
40
+ user_logger.info("Dataset found locally. Set 'force=True' or add `--force` flag with cli to re-download")
37
41
  return dataset_path_sample
38
42
 
39
- dataset_id = get_dataset_id(dataset_name, endpoint, api_key)
40
- dataset_access_info_url = f"{endpoint}/{dataset_id}/temporary-credentials"
43
+ dataset_id = get_dataset_id(dataset_name, endpoint_dataset, api_key)
44
+ dataset_access_info_url = f"{endpoint_dataset}/{dataset_id}/temporary-credentials"
41
45
 
42
46
  if force_redownload and dataset_path_sample.exists():
43
47
  # Remove old files to avoid old files conflicting with new files
@@ -48,23 +52,6 @@ def download_or_get_dataset_path(
48
52
  raise RuntimeError("Failed to download dataset")
49
53
 
50
54
 
51
- def load_from_platform(
52
- dataset_name: str,
53
- endpoint: str,
54
- api_key: str,
55
- output_dir: Optional[str] = None,
56
- force_redownload: bool = False,
57
- ) -> Union[Dataset, DatasetDict]:
58
- path_dataset = download_or_get_dataset_path(
59
- dataset_name=dataset_name,
60
- endpoint=endpoint,
61
- api_key=api_key,
62
- output_dir=output_dir,
63
- force_redownload=force_redownload,
64
- )
65
- return load_local(path_dataset)
66
-
67
-
68
55
  def load_dataset(dataset_name: str, force_redownload: bool = False) -> Union[Dataset, DatasetDict]:
69
56
  """Load a dataset either from a local path or from the Hafnia platform."""
70
57
 
@@ -72,15 +59,9 @@ def load_dataset(dataset_name: str, force_redownload: bool = False) -> Union[Dat
72
59
  path_dataset = Path(os.getenv("MDI_DATASET_DIR", "/opt/ml/input/data/training"))
73
60
  return load_local(path_dataset)
74
61
 
75
- cfg = Config()
76
- endpoint_dataset = cfg.get_platform_endpoint("datasets")
77
- api_key = cfg.api_key
78
- dataset = load_from_platform(
62
+ path_dataset = download_or_get_dataset_path(
79
63
  dataset_name=dataset_name,
80
- endpoint=endpoint_dataset,
81
- api_key=api_key,
82
- output_dir=None,
83
64
  force_redownload=force_redownload,
84
65
  )
85
-
66
+ dataset = load_local(path_dataset)
86
67
  return dataset
@@ -13,7 +13,7 @@ from datasets import DatasetDict
13
13
  from pydantic import BaseModel, field_validator
14
14
 
15
15
  from hafnia.data.factory import load_dataset
16
- from hafnia.log import logger
16
+ from hafnia.log import sys_logger, user_logger
17
17
  from hafnia.utils import is_remote_job, now_as_str
18
18
 
19
19
 
@@ -49,7 +49,7 @@ class Entity(BaseModel):
49
49
  try:
50
50
  return float(v)
51
51
  except (ValueError, TypeError) as e:
52
- logger.warning(f"Invalid value '{v}' provided, defaulting to -1.0: {e}")
52
+ user_logger.warning(f"Invalid value '{v}' provided, defaulting to -1.0: {e}")
53
53
  return -1.0
54
54
 
55
55
  @field_validator("ent_type", mode="before")
@@ -159,11 +159,15 @@ class HafniaLogger:
159
159
  def log_hparams(self, params: Dict, fname: str = "hparams.json"):
160
160
  file_path = self._path_artifacts() / fname
161
161
  try:
162
- with open(file_path, "w") as f:
163
- json.dump(params, f, indent=2)
164
- logger.info(f"Saved parameters to {file_path}")
162
+ if file_path.exists(): # New params are appended to existing params
163
+ existing_params = json.loads(file_path.read_text())
164
+ else:
165
+ existing_params = {}
166
+ existing_params.update(params)
167
+ file_path.write_text(json.dumps(existing_params, indent=2))
168
+ user_logger.info(f"Saved parameters to {file_path}")
165
169
  except Exception as e:
166
- logger.error(f"Failed to save parameters to {file_path}: {e}")
170
+ user_logger.error(f"Failed to save parameters to {file_path}: {e}")
167
171
 
168
172
  def log_environment(self):
169
173
  environment_info = {
@@ -197,4 +201,4 @@ class HafniaLogger:
197
201
  next_table = pa.concat_tables([prev, log_batch])
198
202
  pq.write_table(next_table, self.log_file)
199
203
  except Exception as e:
200
- logger.error(f"Failed to flush logs: {e}")
204
+ sys_logger.error(f"Failed to flush logs: {e}")