hafnia 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/__init__.py +0 -0
- cli/__main__.py +63 -0
- cli/config.py +148 -0
- cli/consts.py +18 -0
- cli/data_cmds.py +58 -0
- cli/experiment_cmds.py +82 -0
- cli/profile_cmds.py +86 -0
- cli/runc_cmds.py +70 -0
- hafnia/__init__.py +4 -0
- hafnia/data/__init__.py +3 -0
- hafnia/data/factory.py +88 -0
- hafnia/experiment/__init__.py +3 -0
- hafnia/experiment/mdi_logger.py +200 -0
- hafnia/http.py +85 -0
- hafnia/log.py +32 -0
- hafnia/platform/__init__.py +23 -0
- hafnia/platform/api.py +12 -0
- hafnia/platform/builder.py +198 -0
- hafnia/platform/download.py +126 -0
- hafnia/platform/executor.py +115 -0
- hafnia/platform/experiment.py +68 -0
- hafnia/torch_helpers.py +185 -0
- hafnia/utils.py +83 -0
- hafnia-0.1.8.dist-info/METADATA +43 -0
- hafnia-0.1.8.dist-info/RECORD +27 -0
- hafnia-0.1.8.dist-info/WHEEL +4 -0
- hafnia-0.1.8.dist-info/entry_points.txt +2 -0
cli/__init__.py
ADDED
|
File without changes
|
cli/__main__.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
import click
|
|
3
|
+
|
|
4
|
+
from cli import consts, data_cmds, experiment_cmds, profile_cmds, runc_cmds
|
|
5
|
+
from cli.config import Config, ConfigSchema
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@click.group()
|
|
9
|
+
@click.pass_context
|
|
10
|
+
def main(ctx: click.Context) -> None:
|
|
11
|
+
"""MDI CLI."""
|
|
12
|
+
ctx.obj = Config()
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@main.command("configure")
|
|
16
|
+
@click.pass_obj
|
|
17
|
+
def configure(cfg: Config) -> None:
|
|
18
|
+
"""Configure MDI CLI settings."""
|
|
19
|
+
|
|
20
|
+
from hafnia.platform.api import get_organization_id
|
|
21
|
+
|
|
22
|
+
profile_name = click.prompt("Profile Name", type=str, default="default")
|
|
23
|
+
profile_name = profile_name.strip()
|
|
24
|
+
try:
|
|
25
|
+
cfg.add_profile(profile_name, ConfigSchema(), set_active=True)
|
|
26
|
+
except ValueError:
|
|
27
|
+
raise click.ClickException(consts.ERROR_CREATE_PROFILE)
|
|
28
|
+
|
|
29
|
+
api_key = click.prompt("MDI API Key", type=str, hide_input=True)
|
|
30
|
+
try:
|
|
31
|
+
cfg.api_key = api_key.strip()
|
|
32
|
+
except ValueError as e:
|
|
33
|
+
click.echo(f"Error: {str(e)}", err=True)
|
|
34
|
+
return
|
|
35
|
+
platform_url = click.prompt(
|
|
36
|
+
"MDI Platform URL", type=str, default="https://api.mdi.milestonesys.com"
|
|
37
|
+
)
|
|
38
|
+
cfg.platform_url = platform_url.strip()
|
|
39
|
+
try:
|
|
40
|
+
cfg.organization_id = get_organization_id(
|
|
41
|
+
cfg.get_platform_endpoint("organizations"), cfg.api_key
|
|
42
|
+
)
|
|
43
|
+
except Exception:
|
|
44
|
+
raise click.ClickException(consts.ERROR_ORG_ID)
|
|
45
|
+
cfg.save_config()
|
|
46
|
+
profile_cmds.profile_show(cfg)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@main.command("clear")
|
|
50
|
+
@click.pass_obj
|
|
51
|
+
def clear(cfg: Config) -> None:
|
|
52
|
+
"""Remove stored configuration."""
|
|
53
|
+
cfg.clear()
|
|
54
|
+
click.echo("Successfully cleared MDI configuration.")
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
main.add_command(profile_cmds.profile)
|
|
58
|
+
main.add_command(data_cmds.data)
|
|
59
|
+
main.add_command(runc_cmds.runc)
|
|
60
|
+
main.add_command(experiment_cmds.experiment)
|
|
61
|
+
|
|
62
|
+
if __name__ == "__main__":
|
|
63
|
+
main()
|
cli/config.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Dict, List, Optional
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, field_validator
|
|
7
|
+
|
|
8
|
+
import cli.consts as consts
|
|
9
|
+
from hafnia.log import logger
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ConfigSchema(BaseModel):
|
|
13
|
+
organization_id: str = ""
|
|
14
|
+
platform_url: str = ""
|
|
15
|
+
api_key: Optional[str] = None
|
|
16
|
+
api_mapping: Optional[Dict[str, str]] = None
|
|
17
|
+
|
|
18
|
+
@field_validator("api_key")
|
|
19
|
+
def validate_api_key(cls, value: str) -> str:
|
|
20
|
+
if value is not None and len(value) < 10:
|
|
21
|
+
raise ValueError("API key is too short.")
|
|
22
|
+
return value
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ConfigFileSchema(BaseModel):
|
|
26
|
+
active_profile: Optional[str] = None
|
|
27
|
+
profiles: Dict[str, ConfigSchema] = {}
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class Config:
|
|
31
|
+
@property
|
|
32
|
+
def available_profiles(self) -> List[str]:
|
|
33
|
+
return list(self.config_data.profiles.keys())
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def active_profile(self) -> str:
|
|
37
|
+
if self.config_data.active_profile is None:
|
|
38
|
+
raise ValueError(consts.ERROR_PROFILE_NOT_EXIST)
|
|
39
|
+
return self.config_data.active_profile
|
|
40
|
+
|
|
41
|
+
@active_profile.setter
|
|
42
|
+
def active_profile(self, value: str) -> None:
|
|
43
|
+
profile_name = value.strip()
|
|
44
|
+
if profile_name not in self.config_data.profiles:
|
|
45
|
+
raise ValueError(f"Profile '{profile_name}' does not exist.")
|
|
46
|
+
self.config_data.active_profile = profile_name
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def config(self) -> ConfigSchema:
|
|
50
|
+
if not self.config_data.active_profile:
|
|
51
|
+
raise ValueError(consts.ERROR_PROFILE_NOT_EXIST)
|
|
52
|
+
return self.config_data.profiles[self.config_data.active_profile]
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def api_key(self) -> str:
|
|
56
|
+
if self.config.api_key is not None:
|
|
57
|
+
return self.config.api_key
|
|
58
|
+
raise ValueError(consts.ERROR_API_KEY_NOT_SET)
|
|
59
|
+
|
|
60
|
+
@api_key.setter
|
|
61
|
+
def api_key(self, value: str) -> None:
|
|
62
|
+
self.config.api_key = value
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def organization_id(self) -> str:
|
|
66
|
+
return self.config.organization_id
|
|
67
|
+
|
|
68
|
+
@organization_id.setter
|
|
69
|
+
def organization_id(self, value: str) -> None:
|
|
70
|
+
self.config.organization_id = value
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def platform_url(self) -> str:
|
|
74
|
+
return self.config.platform_url
|
|
75
|
+
|
|
76
|
+
@platform_url.setter
|
|
77
|
+
def platform_url(self, value: str) -> None:
|
|
78
|
+
base_url = value.rstrip("/")
|
|
79
|
+
self.config.platform_url = base_url
|
|
80
|
+
self.config.api_mapping = self.get_api_mapping(base_url)
|
|
81
|
+
|
|
82
|
+
def __init__(self, config_path: Optional[Path] = None) -> None:
|
|
83
|
+
self.config_path = self.resolve_config_path(config_path)
|
|
84
|
+
self.config_path.parent.mkdir(parents=True, exist_ok=True)
|
|
85
|
+
self.config_data = self.load_config()
|
|
86
|
+
|
|
87
|
+
def resolve_config_path(self, path: Optional[Path] = None) -> Path:
|
|
88
|
+
if path:
|
|
89
|
+
return Path(path).expanduser()
|
|
90
|
+
|
|
91
|
+
config_env_path = os.getenv("MDI_CONFIG_PATH")
|
|
92
|
+
if config_env_path:
|
|
93
|
+
return Path(config_env_path).expanduser()
|
|
94
|
+
|
|
95
|
+
return Path.home() / ".mdi" / "config.json"
|
|
96
|
+
|
|
97
|
+
def add_profile(
|
|
98
|
+
self, profile_name: str, profile: ConfigSchema, set_active: bool = False
|
|
99
|
+
) -> None:
|
|
100
|
+
profile_name = profile_name.strip()
|
|
101
|
+
self.config_data.profiles[profile_name] = profile
|
|
102
|
+
if set_active:
|
|
103
|
+
self.config_data.active_profile = profile_name
|
|
104
|
+
self.save_config()
|
|
105
|
+
|
|
106
|
+
def get_api_mapping(self, base_url: str) -> Dict:
|
|
107
|
+
return {
|
|
108
|
+
"organizations": f"{base_url}/api/v1/organizations",
|
|
109
|
+
"recipes": f"{base_url}/api/v1/recipes",
|
|
110
|
+
"experiments": f"{base_url}/api/v1/experiments",
|
|
111
|
+
"experiment_environments": f"{base_url}/api/v1/experiment-environments",
|
|
112
|
+
"experiment_runs": f"{base_url}/api/v1/experiment-runs",
|
|
113
|
+
"runs": f"{base_url}/api/v1/experiments-runs",
|
|
114
|
+
"datasets": f"{base_url}/api/v1/datasets",
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
def get_platform_endpoint(self, method: str) -> str:
|
|
118
|
+
"""Get specific API endpoint"""
|
|
119
|
+
if not self.config.api_mapping or method not in self.config.api_mapping:
|
|
120
|
+
raise ValueError(f"{method} is not supported.")
|
|
121
|
+
return self.config.api_mapping[method]
|
|
122
|
+
|
|
123
|
+
def load_config(self) -> ConfigFileSchema:
|
|
124
|
+
"""Load configuration from file."""
|
|
125
|
+
if not self.config_path.exists():
|
|
126
|
+
return ConfigFileSchema()
|
|
127
|
+
try:
|
|
128
|
+
with open(self.config_path.as_posix(), "r") as f:
|
|
129
|
+
data = json.load(f)
|
|
130
|
+
return ConfigFileSchema(**data)
|
|
131
|
+
except json.JSONDecodeError:
|
|
132
|
+
logger.error("Error decoding JSON file.")
|
|
133
|
+
raise ValueError("Failed to parse configuration file")
|
|
134
|
+
|
|
135
|
+
def save_config(self) -> None:
|
|
136
|
+
with open(self.config_path, "w") as f:
|
|
137
|
+
json.dump(self.config_data.model_dump(), f, indent=4)
|
|
138
|
+
|
|
139
|
+
def remove_profile(self, profile_name: str) -> None:
|
|
140
|
+
if profile_name not in self.config_data.profiles:
|
|
141
|
+
raise ValueError(f"Profile '{profile_name}' does not exist.")
|
|
142
|
+
del self.config_data.profiles[profile_name]
|
|
143
|
+
self.save_config()
|
|
144
|
+
|
|
145
|
+
def clear(self) -> None:
|
|
146
|
+
self.config_data = ConfigFileSchema(active_profile=None, profiles={})
|
|
147
|
+
if self.config_path.exists():
|
|
148
|
+
self.config_path.unlink()
|
cli/consts.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
ERROR_CONFIGURE: str = "Please configure the CLI with `mdi configure`"
|
|
2
|
+
ERROR_PROFILE_NOT_EXIST: str = (
|
|
3
|
+
"No active profile configured. Please configure the CLI with `mdi configure`"
|
|
4
|
+
)
|
|
5
|
+
ERROR_PROFILE_REMOVE_ACTIVE: str = (
|
|
6
|
+
"Cannot remove active profile. Please switch to another profile first."
|
|
7
|
+
)
|
|
8
|
+
ERROR_API_KEY_NOT_SET: str = "API key not set. Please configure the CLI with `mdi configure`."
|
|
9
|
+
ERROR_ORG_ID: str = "Failed to fetch organization ID. Verify platform URL and API key."
|
|
10
|
+
ERROR_CREATE_PROFILE: str = "Failed to create profile. Profile name must be unique and not empty."
|
|
11
|
+
|
|
12
|
+
ERROR_GET_RESOURCE: str = "Failed to get the data from platform. Verify url or api key."
|
|
13
|
+
|
|
14
|
+
ERROR_EXPERIMENT_DIR: str = "Source directory does not exist"
|
|
15
|
+
|
|
16
|
+
PROFILE_SWITCHED_SUCCESS: str = "Switched to profile:"
|
|
17
|
+
PROFILE_REMOVED_SUCCESS: str = "Removed profile:"
|
|
18
|
+
PROFILE_TABLE_HEADER: str = "MDI Platform Profile:"
|
cli/data_cmds.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import click
|
|
4
|
+
from rich import print as rprint
|
|
5
|
+
|
|
6
|
+
import cli.consts as consts
|
|
7
|
+
from cli.config import Config
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@click.group()
|
|
11
|
+
def data():
|
|
12
|
+
"""Manage data interaction"""
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@data.command("get")
|
|
17
|
+
@click.argument("url")
|
|
18
|
+
@click.argument("destination")
|
|
19
|
+
@click.pass_obj
|
|
20
|
+
def data_get(cfg: Config, url: str, destination: click.Path) -> None:
|
|
21
|
+
"""Download resource from MDI platform"""
|
|
22
|
+
|
|
23
|
+
from hafnia.platform import download_resource
|
|
24
|
+
|
|
25
|
+
try:
|
|
26
|
+
result = download_resource(
|
|
27
|
+
resource_url=url, destination=str(destination), api_key=cfg.api_key
|
|
28
|
+
)
|
|
29
|
+
except Exception:
|
|
30
|
+
raise click.ClickException(consts.ERROR_GET_RESOURCE)
|
|
31
|
+
|
|
32
|
+
rprint(result)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@data.command("download")
|
|
36
|
+
@click.argument("dataset_name")
|
|
37
|
+
@click.argument("destination", default=None, required=False)
|
|
38
|
+
@click.option("--force", is_flag=True, default=False, help="Force download")
|
|
39
|
+
@click.pass_obj
|
|
40
|
+
def data_download(
|
|
41
|
+
cfg: Config, dataset_name: str, destination: Optional[click.Path], force: bool
|
|
42
|
+
) -> None:
|
|
43
|
+
"""Download dataset from MDI platform"""
|
|
44
|
+
|
|
45
|
+
from hafnia.data.factory import download_or_get_dataset_path
|
|
46
|
+
|
|
47
|
+
try:
|
|
48
|
+
endpoint_dataset = cfg.get_platform_endpoint("datasets")
|
|
49
|
+
api_key = cfg.api_key
|
|
50
|
+
download_or_get_dataset_path(
|
|
51
|
+
dataset_name=dataset_name,
|
|
52
|
+
endpoint=endpoint_dataset,
|
|
53
|
+
api_key=api_key,
|
|
54
|
+
output_dir=destination,
|
|
55
|
+
force=force,
|
|
56
|
+
)
|
|
57
|
+
except Exception:
|
|
58
|
+
raise click.ClickException(consts.ERROR_GET_RESOURCE)
|
cli/experiment_cmds.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
import click
|
|
4
|
+
from rich import print as rprint
|
|
5
|
+
|
|
6
|
+
import cli.consts as consts
|
|
7
|
+
from cli.config import Config
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@click.group(name="experiment")
|
|
11
|
+
def experiment() -> None:
|
|
12
|
+
"""Experiment management commands"""
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@experiment.command(name="create")
|
|
17
|
+
@click.argument("name")
|
|
18
|
+
@click.argument("source_dir", type=Path)
|
|
19
|
+
@click.argument("exec_cmd", type=str)
|
|
20
|
+
@click.argument("dataset_name")
|
|
21
|
+
@click.argument("env_name")
|
|
22
|
+
@click.pass_obj
|
|
23
|
+
def create(
|
|
24
|
+
cfg: Config, name: str, source_dir: Path, exec_cmd: str, dataset_name: str, env_name: str
|
|
25
|
+
) -> None:
|
|
26
|
+
"""Create a new experiment run"""
|
|
27
|
+
from hafnia.platform import (
|
|
28
|
+
create_experiment,
|
|
29
|
+
create_recipe,
|
|
30
|
+
get_dataset_id,
|
|
31
|
+
get_exp_environment_id,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
if not source_dir.exists():
|
|
35
|
+
raise click.ClickException(consts.ERROR_EXPERIMENT_DIR)
|
|
36
|
+
|
|
37
|
+
try:
|
|
38
|
+
dataset_id = get_dataset_id(
|
|
39
|
+
dataset_name, cfg.get_platform_endpoint("datasets"), cfg.api_key
|
|
40
|
+
)
|
|
41
|
+
except (IndexError, KeyError):
|
|
42
|
+
raise click.ClickException(f"Dataset '{dataset_name}' not found.")
|
|
43
|
+
except Exception:
|
|
44
|
+
raise click.ClickException(f"Error retrieving dataset '{dataset_name}'.")
|
|
45
|
+
|
|
46
|
+
try:
|
|
47
|
+
recipe_id = create_recipe(
|
|
48
|
+
source_dir, cfg.get_platform_endpoint("recipes"), cfg.api_key, cfg.organization_id
|
|
49
|
+
)
|
|
50
|
+
except Exception:
|
|
51
|
+
raise click.ClickException(f"Failed to create recipe from '{source_dir}'")
|
|
52
|
+
|
|
53
|
+
try:
|
|
54
|
+
env_id = get_exp_environment_id(
|
|
55
|
+
env_name, cfg.get_platform_endpoint("experiment_environments"), cfg.api_key
|
|
56
|
+
)
|
|
57
|
+
except Exception:
|
|
58
|
+
raise click.ClickException(f"Environment '{env_name}' not found")
|
|
59
|
+
|
|
60
|
+
try:
|
|
61
|
+
experiment_id = create_experiment(
|
|
62
|
+
name,
|
|
63
|
+
dataset_id,
|
|
64
|
+
recipe_id,
|
|
65
|
+
exec_cmd,
|
|
66
|
+
env_id,
|
|
67
|
+
cfg.get_platform_endpoint("experiments"),
|
|
68
|
+
cfg.api_key,
|
|
69
|
+
cfg.organization_id,
|
|
70
|
+
)
|
|
71
|
+
except Exception:
|
|
72
|
+
raise click.ClickException(f"Failed to create experiment '{name}'")
|
|
73
|
+
|
|
74
|
+
rprint(
|
|
75
|
+
{
|
|
76
|
+
"dataset_id": dataset_id,
|
|
77
|
+
"recipe_id": recipe_id,
|
|
78
|
+
"environment_id": env_id,
|
|
79
|
+
"experiment_id": experiment_id,
|
|
80
|
+
"status": "CREATED",
|
|
81
|
+
}
|
|
82
|
+
)
|
cli/profile_cmds.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
import click
|
|
2
|
+
from rich.console import Console
|
|
3
|
+
from rich.table import Table
|
|
4
|
+
|
|
5
|
+
import cli.consts as consts
|
|
6
|
+
from cli.config import Config
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@click.group()
|
|
10
|
+
def profile():
|
|
11
|
+
"""Manage profile."""
|
|
12
|
+
pass
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@profile.command("ls")
|
|
16
|
+
@click.pass_obj
|
|
17
|
+
def profile_ls(cfg: Config) -> None:
|
|
18
|
+
"""List all available profiles."""
|
|
19
|
+
profiles = cfg.available_profiles
|
|
20
|
+
if not profiles:
|
|
21
|
+
raise click.ClickException(consts.ERROR_CONFIGURE)
|
|
22
|
+
active = cfg.active_profile
|
|
23
|
+
|
|
24
|
+
for profile in profiles:
|
|
25
|
+
status = "* " if profile == active else " "
|
|
26
|
+
print(f"{status}{profile}")
|
|
27
|
+
|
|
28
|
+
print(f"\nActive profile: {active}")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@profile.command("use")
|
|
32
|
+
@click.argument("profile_name", required=True)
|
|
33
|
+
@click.pass_obj
|
|
34
|
+
def profile_use(cfg: Config, profile_name: str) -> None:
|
|
35
|
+
"""Switch to a different profile."""
|
|
36
|
+
if len(cfg.available_profiles) == 0:
|
|
37
|
+
raise click.ClickException(consts.ERROR_CONFIGURE)
|
|
38
|
+
try:
|
|
39
|
+
cfg.active_profile = profile_name
|
|
40
|
+
cfg.save_config()
|
|
41
|
+
except ValueError:
|
|
42
|
+
raise click.ClickException(consts.ERROR_PROFILE_NOT_EXIST)
|
|
43
|
+
click.echo(f"{consts.PROFILE_SWITCHED_SUCCESS} {profile_name}")
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@profile.command("rm")
|
|
47
|
+
@click.argument("profile_name", required=True)
|
|
48
|
+
@click.pass_obj
|
|
49
|
+
def profile_rm(cfg: Config, profile_name: str) -> None:
|
|
50
|
+
"""Remove a profile."""
|
|
51
|
+
if len(cfg.available_profiles) == 0:
|
|
52
|
+
raise click.ClickException(consts.ERROR_CONFIGURE)
|
|
53
|
+
|
|
54
|
+
if profile_name == cfg.active_profile:
|
|
55
|
+
raise click.ClickException(consts.ERROR_PROFILE_REMOVE_ACTIVE)
|
|
56
|
+
|
|
57
|
+
try:
|
|
58
|
+
cfg.remove_profile(profile_name)
|
|
59
|
+
cfg.save_config()
|
|
60
|
+
except ValueError:
|
|
61
|
+
raise click.ClickException(consts.ERROR_PROFILE_NOT_EXIST)
|
|
62
|
+
click.echo(f"{consts.PROFILE_REMOVED_SUCCESS} {profile_name}")
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@profile.command("active")
|
|
66
|
+
@click.pass_obj
|
|
67
|
+
def profile_active(cfg: Config) -> None:
|
|
68
|
+
try:
|
|
69
|
+
profile_show(cfg)
|
|
70
|
+
except Exception as e:
|
|
71
|
+
raise click.ClickException(str(e))
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def profile_show(cfg: Config) -> None:
|
|
75
|
+
masked_key = f"{cfg.api_key[:4]}...{cfg.api_key[-4:]}" if len(cfg.api_key) > 8 else "****"
|
|
76
|
+
console = Console()
|
|
77
|
+
|
|
78
|
+
table = Table(title=f"{consts.PROFILE_TABLE_HEADER} {cfg.active_profile}", show_header=False)
|
|
79
|
+
table.add_column("Property", style="cyan")
|
|
80
|
+
table.add_column("Value")
|
|
81
|
+
|
|
82
|
+
table.add_row("API Key", masked_key)
|
|
83
|
+
table.add_row("Organization", cfg.organization_id)
|
|
84
|
+
table.add_row("Platform URL", cfg.platform_url)
|
|
85
|
+
table.add_row("Config File", cfg.config_path.as_posix())
|
|
86
|
+
console.print(table)
|
cli/runc_cmds.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
from hashlib import sha256
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from tempfile import TemporaryDirectory
|
|
4
|
+
|
|
5
|
+
import click
|
|
6
|
+
|
|
7
|
+
from cli.config import Config
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@click.group(name="runc")
|
|
11
|
+
def runc():
|
|
12
|
+
"""Experiment management commands"""
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@runc.command(name="launch")
|
|
17
|
+
@click.argument("task", required=True)
|
|
18
|
+
def launch(task: str) -> None:
|
|
19
|
+
"""Launch a job within the image."""
|
|
20
|
+
from hafnia.platform.executor import handle_launch
|
|
21
|
+
|
|
22
|
+
handle_launch(task)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@runc.command(name="build")
|
|
26
|
+
@click.argument("recipe_url")
|
|
27
|
+
@click.argument("state_file", default="state.json")
|
|
28
|
+
@click.argument("ecr_repository", default="localhost")
|
|
29
|
+
@click.argument("image_name", default="recipe")
|
|
30
|
+
@click.pass_obj
|
|
31
|
+
def build(
|
|
32
|
+
cfg: Config, recipe_url: str, state_file: str, ecr_repository: str, image_name: str
|
|
33
|
+
) -> None:
|
|
34
|
+
"""Build docker image with a given recipe."""
|
|
35
|
+
from hafnia.platform.builder import build_image, prepare_recipe
|
|
36
|
+
|
|
37
|
+
with TemporaryDirectory() as temp_dir:
|
|
38
|
+
image_info = prepare_recipe(recipe_url, Path(temp_dir), cfg.api_key)
|
|
39
|
+
image_info["name"] = image_name
|
|
40
|
+
build_image(image_info, ecr_repository, state_file)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@runc.command(name="build-local")
|
|
44
|
+
@click.argument("recipe")
|
|
45
|
+
@click.argument("state_file", default="state.json")
|
|
46
|
+
@click.argument("image_name", default="recipe")
|
|
47
|
+
def build_local(recipe: str, state_file: str, image_name: str) -> None:
|
|
48
|
+
"""Build recipe from local path as image with prefix - localhost"""
|
|
49
|
+
|
|
50
|
+
from hafnia.platform.builder import build_image, validate_recipe
|
|
51
|
+
from hafnia.utils import archive_dir
|
|
52
|
+
|
|
53
|
+
recipe_zip = Path(recipe)
|
|
54
|
+
recipe_created = False
|
|
55
|
+
if not recipe_zip.suffix == ".zip" and recipe_zip.is_dir():
|
|
56
|
+
recipe_zip = archive_dir(recipe_zip)
|
|
57
|
+
recipe_created = True
|
|
58
|
+
|
|
59
|
+
validate_recipe(recipe_zip)
|
|
60
|
+
click.echo("Recipe successfully validated")
|
|
61
|
+
image_info = {
|
|
62
|
+
"name": image_name,
|
|
63
|
+
"dockerfile": f"{recipe_zip.parent}/Dockerfile",
|
|
64
|
+
"docker_context": f"{recipe_zip.parent}",
|
|
65
|
+
"hash": sha256(recipe_zip.read_bytes()).hexdigest()[:8],
|
|
66
|
+
}
|
|
67
|
+
click.echo("Start building image")
|
|
68
|
+
build_image(image_info, "localhost", state_file=state_file)
|
|
69
|
+
if recipe_created:
|
|
70
|
+
recipe_zip.unlink()
|
hafnia/__init__.py
ADDED
hafnia/data/__init__.py
ADDED
hafnia/data/factory.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import shutil
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Optional, Union
|
|
5
|
+
|
|
6
|
+
from datasets import Dataset, DatasetDict, load_from_disk
|
|
7
|
+
|
|
8
|
+
from cli.config import Config
|
|
9
|
+
from hafnia import utils
|
|
10
|
+
from hafnia.log import logger
|
|
11
|
+
from hafnia.platform import download_resource, get_dataset_id
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def load_local(dataset_path: Path) -> Union[Dataset, DatasetDict]:
|
|
15
|
+
"""Load a Hugging Face dataset from a local directory path."""
|
|
16
|
+
if not dataset_path.exists():
|
|
17
|
+
raise ValueError(f"Can not load dataset, directory does not exist -- {dataset_path}")
|
|
18
|
+
logger.info(f"Loading data from {dataset_path.as_posix()}")
|
|
19
|
+
return load_from_disk(dataset_path.as_posix())
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def download_or_get_dataset_path(
|
|
23
|
+
dataset_name: str,
|
|
24
|
+
endpoint: str,
|
|
25
|
+
api_key: str,
|
|
26
|
+
output_dir: Optional[str] = None,
|
|
27
|
+
force_redownload: bool = False,
|
|
28
|
+
) -> Path:
|
|
29
|
+
"""Download or get the path of the dataset."""
|
|
30
|
+
output_dir = output_dir or str(utils.PATH_DATASET)
|
|
31
|
+
dataset_path_base = Path(output_dir).absolute() / dataset_name
|
|
32
|
+
dataset_path_base.mkdir(exist_ok=True, parents=True)
|
|
33
|
+
dataset_path_sample = dataset_path_base / "sample"
|
|
34
|
+
|
|
35
|
+
if dataset_path_sample.exists() and not force_redownload:
|
|
36
|
+
logger.info(
|
|
37
|
+
"Dataset found locally. Set 'force=True' or add `--force` flag with cli to re-download"
|
|
38
|
+
)
|
|
39
|
+
return dataset_path_sample
|
|
40
|
+
|
|
41
|
+
dataset_id = get_dataset_id(dataset_name, endpoint, api_key)
|
|
42
|
+
dataset_access_info_url = f"{endpoint}/{dataset_id}/temporary-credentials"
|
|
43
|
+
|
|
44
|
+
if force_redownload and dataset_path_sample.exists():
|
|
45
|
+
# Remove old files to avoid old files conflicting with new files
|
|
46
|
+
shutil.rmtree(dataset_path_sample, ignore_errors=True)
|
|
47
|
+
status = download_resource(dataset_access_info_url, dataset_path_base, api_key)
|
|
48
|
+
if status:
|
|
49
|
+
return dataset_path_sample
|
|
50
|
+
raise RuntimeError("Failed to download dataset")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def load_from_platform(
|
|
54
|
+
dataset_name: str,
|
|
55
|
+
endpoint: str,
|
|
56
|
+
api_key: str,
|
|
57
|
+
output_dir: Optional[str] = None,
|
|
58
|
+
force_redownload: bool = False,
|
|
59
|
+
) -> Union[Dataset, DatasetDict]:
|
|
60
|
+
path_dataset = download_or_get_dataset_path(
|
|
61
|
+
dataset_name=dataset_name,
|
|
62
|
+
endpoint=endpoint,
|
|
63
|
+
api_key=api_key,
|
|
64
|
+
output_dir=output_dir,
|
|
65
|
+
force_redownload=force_redownload,
|
|
66
|
+
)
|
|
67
|
+
return load_local(path_dataset)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def load_dataset(dataset_name: str, force_redownload: bool = False) -> Union[Dataset, DatasetDict]:
|
|
71
|
+
"""Load a dataset either from a local path or from the MDI platform."""
|
|
72
|
+
|
|
73
|
+
if utils.is_remote_job():
|
|
74
|
+
path_dataset = Path(os.getenv("MDI_DATASET_DIR", "/opt/ml/input/data/training"))
|
|
75
|
+
return load_local(path_dataset)
|
|
76
|
+
|
|
77
|
+
cfg = Config()
|
|
78
|
+
endpoint_dataset = cfg.get_platform_endpoint("datasets")
|
|
79
|
+
api_key = cfg.api_key
|
|
80
|
+
dataset = load_from_platform(
|
|
81
|
+
dataset_name=dataset_name,
|
|
82
|
+
endpoint=endpoint_dataset,
|
|
83
|
+
api_key=api_key,
|
|
84
|
+
output_dir=None,
|
|
85
|
+
force_redownload=force_redownload,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
return dataset
|