lightning-sdk 0.1.48__py3-none-any.whl → 0.1.50__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lightning_sdk/__init__.py +3 -1
- lightning_sdk/api/job_api.py +13 -6
- lightning_sdk/api/lit_container_api.py +37 -1
- lightning_sdk/api/mmt_api.py +12 -6
- lightning_sdk/api/utils.py +7 -0
- lightning_sdk/cli/download.py +20 -1
- lightning_sdk/cli/entrypoint.py +11 -0
- lightning_sdk/cli/list.py +60 -2
- lightning_sdk/cli/run.py +19 -4
- lightning_sdk/cli/upload.py +32 -1
- lightning_sdk/job/base.py +23 -4
- lightning_sdk/job/job.py +4 -3
- lightning_sdk/job/v1.py +4 -4
- lightning_sdk/job/v2.py +7 -10
- lightning_sdk/job/work.py +2 -2
- lightning_sdk/lightning_cloud/openapi/models/v1_cluster_spec.py +1 -29
- lightning_sdk/lightning_cloud/openapi/models/v1_lambda_labs_direct_v1.py +31 -3
- lightning_sdk/lightning_cloud/openapi/models/v1_user_features.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_vultr_direct_v1.py +27 -1
- lightning_sdk/lit_container.py +40 -0
- lightning_sdk/mmt/base.py +22 -5
- lightning_sdk/mmt/mmt.py +5 -3
- lightning_sdk/mmt/v1.py +5 -3
- lightning_sdk/mmt/v2.py +11 -10
- {lightning_sdk-0.1.48.dist-info → lightning_sdk-0.1.50.dist-info}/METADATA +1 -1
- {lightning_sdk-0.1.48.dist-info → lightning_sdk-0.1.50.dist-info}/RECORD +30 -30
- {lightning_sdk-0.1.48.dist-info → lightning_sdk-0.1.50.dist-info}/LICENSE +0 -0
- {lightning_sdk-0.1.48.dist-info → lightning_sdk-0.1.50.dist-info}/WHEEL +0 -0
- {lightning_sdk-0.1.48.dist-info → lightning_sdk-0.1.50.dist-info}/entry_points.txt +0 -0
- {lightning_sdk-0.1.48.dist-info → lightning_sdk-0.1.50.dist-info}/top_level.txt +0 -0
lightning_sdk/__init__.py
CHANGED
|
@@ -4,6 +4,7 @@ from lightning_sdk.constants import __GLOBAL_LIGHTNING_UNIQUE_IDS_STORE__ # noq
|
|
|
4
4
|
from lightning_sdk.helpers import _check_version_and_prompt_upgrade
|
|
5
5
|
from lightning_sdk.job import Job
|
|
6
6
|
from lightning_sdk.machine import Machine
|
|
7
|
+
from lightning_sdk.mmt import MMT
|
|
7
8
|
from lightning_sdk.organization import Organization
|
|
8
9
|
from lightning_sdk.plugin import JobsPlugin, MultiMachineTrainingPlugin, Plugin, SlurmJobsPlugin
|
|
9
10
|
from lightning_sdk.status import Status
|
|
@@ -15,6 +16,7 @@ __all__ = [
|
|
|
15
16
|
"Job",
|
|
16
17
|
"JobsPlugin",
|
|
17
18
|
"Machine",
|
|
19
|
+
"MMT",
|
|
18
20
|
"MultiMachineTrainingPlugin",
|
|
19
21
|
"Organization",
|
|
20
22
|
"Plugin",
|
|
@@ -27,5 +29,5 @@ __all__ = [
|
|
|
27
29
|
"AIHub",
|
|
28
30
|
]
|
|
29
31
|
|
|
30
|
-
__version__ = "0.1.
|
|
32
|
+
__version__ = "0.1.50"
|
|
31
33
|
_check_version_and_prompt_upgrade(__version__)
|
lightning_sdk/api/job_api.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
import time
|
|
2
|
-
from typing import TYPE_CHECKING, Dict, List, Optional
|
|
2
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
|
3
3
|
from urllib.request import urlopen
|
|
4
4
|
|
|
5
5
|
from lightning_sdk.api.utils import (
|
|
6
6
|
_COMPUTE_NAME_TO_MACHINE,
|
|
7
|
-
_MACHINE_TO_COMPUTE_NAME,
|
|
8
7
|
_create_app,
|
|
8
|
+
_machine_to_compute_name,
|
|
9
9
|
remove_datetime_prefix,
|
|
10
10
|
)
|
|
11
11
|
from lightning_sdk.api.utils import (
|
|
@@ -120,7 +120,7 @@ class JobApiV1:
|
|
|
120
120
|
studio_id: str,
|
|
121
121
|
teamspace_id: str,
|
|
122
122
|
cloud_account: str,
|
|
123
|
-
machine: Machine,
|
|
123
|
+
machine: Union[Machine, str],
|
|
124
124
|
interruptible: bool,
|
|
125
125
|
) -> Externalv1LightningappInstance:
|
|
126
126
|
"""Creates an arbitrary app."""
|
|
@@ -130,7 +130,7 @@ class JobApiV1:
|
|
|
130
130
|
teamspace_id=teamspace_id,
|
|
131
131
|
cloud_account=cloud_account,
|
|
132
132
|
plugin_type="job",
|
|
133
|
-
compute=
|
|
133
|
+
compute=_machine_to_compute_name(machine),
|
|
134
134
|
name=name,
|
|
135
135
|
entrypoint=command,
|
|
136
136
|
interruptible=interruptible,
|
|
@@ -180,6 +180,10 @@ class JobApiV1:
|
|
|
180
180
|
|
|
181
181
|
raise RuntimeError("Could not extract command from app")
|
|
182
182
|
|
|
183
|
+
def get_total_cost(self, job: Externalv1LightningappInstance) -> float:
|
|
184
|
+
status: V1LightningappInstanceStatus = job.status
|
|
185
|
+
return status.total_cost
|
|
186
|
+
|
|
183
187
|
|
|
184
188
|
class JobApiV2:
|
|
185
189
|
# these are stages the job can be in.
|
|
@@ -205,7 +209,7 @@ class JobApiV2:
|
|
|
205
209
|
teamspace_id: str,
|
|
206
210
|
studio_id: Optional[str],
|
|
207
211
|
image: Optional[str],
|
|
208
|
-
machine: Machine,
|
|
212
|
+
machine: Union[Machine, str],
|
|
209
213
|
interruptible: bool,
|
|
210
214
|
env: Optional[Dict[str, str]],
|
|
211
215
|
image_credentials: Optional[str],
|
|
@@ -219,7 +223,7 @@ class JobApiV2:
|
|
|
219
223
|
for k, v in env.items():
|
|
220
224
|
env_vars.append(V1EnvVar(name=k, value=v))
|
|
221
225
|
|
|
222
|
-
instance_name =
|
|
226
|
+
instance_name = _machine_to_compute_name(machine)
|
|
223
227
|
|
|
224
228
|
run_id = __GLOBAL_LIGHTNING_UNIQUE_IDS_STORE__[studio_id] if studio_id is not None else ""
|
|
225
229
|
|
|
@@ -337,3 +341,6 @@ class JobApiV2:
|
|
|
337
341
|
return _COMPUTE_NAME_TO_MACHINE.get(
|
|
338
342
|
instance_type, _COMPUTE_NAME_TO_MACHINE.get(instance_name, instance_type or instance_name)
|
|
339
343
|
)
|
|
344
|
+
|
|
345
|
+
def get_total_cost(self, job: V1Job) -> float:
|
|
346
|
+
return job.total_cost
|
|
@@ -1,13 +1,23 @@
|
|
|
1
|
-
from typing import List
|
|
1
|
+
from typing import Generator, List
|
|
2
2
|
|
|
3
|
+
from lightning_sdk.api.utils import _get_registry_url
|
|
3
4
|
from lightning_sdk.lightning_cloud.openapi.models import V1DeleteLitRepositoryResponse
|
|
4
5
|
from lightning_sdk.lightning_cloud.rest_client import LightningClient
|
|
6
|
+
from lightning_sdk.teamspace import Teamspace
|
|
5
7
|
|
|
6
8
|
|
|
7
9
|
class LitContainerApi:
|
|
8
10
|
def __init__(self) -> None:
|
|
9
11
|
self._client = LightningClient(max_tries=3)
|
|
10
12
|
|
|
13
|
+
import docker
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
self._docker_client = docker.from_env()
|
|
17
|
+
self._docker_client.ping()
|
|
18
|
+
except docker.errors.DockerException as e:
|
|
19
|
+
raise RuntimeError(f"Failed to connect to Docker daemon: {e!s}. Is Docker running?") from None
|
|
20
|
+
|
|
11
21
|
def list_containers(self, project_id: str) -> List:
|
|
12
22
|
project = self._client.lit_registry_service_get_lit_project_registry(project_id)
|
|
13
23
|
return project.repositories
|
|
@@ -17,3 +27,29 @@ class LitContainerApi:
|
|
|
17
27
|
return self._client.lit_registry_service_delete_lit_repository(project_id, container)
|
|
18
28
|
except Exception as ex:
|
|
19
29
|
raise ValueError(f"Could not delete container {container} from project {project_id}") from ex
|
|
30
|
+
|
|
31
|
+
def upload_container(self, container: str, teamspace: Teamspace, tag: str) -> Generator[str, None, None]:
|
|
32
|
+
import docker
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
self._docker_client.images.get(container)
|
|
36
|
+
except docker.errors.ImageNotFound:
|
|
37
|
+
raise ValueError(f"Container {container} does not exist") from None
|
|
38
|
+
|
|
39
|
+
registry_url = _get_registry_url()
|
|
40
|
+
repository = f"{registry_url}/lit-container/{teamspace.owner.name}/{teamspace.name}/{container}"
|
|
41
|
+
tagged = self._docker_client.api.tag(container, repository, tag)
|
|
42
|
+
if not tagged:
|
|
43
|
+
raise ValueError(f"Could not tag container {container} with {repository}:{tag}")
|
|
44
|
+
return self._docker_client.api.push(repository, stream=True, decode=True)
|
|
45
|
+
|
|
46
|
+
def download_container(self, container: str, teamspace: Teamspace, tag: str) -> Generator[str, None, None]:
|
|
47
|
+
import docker
|
|
48
|
+
|
|
49
|
+
registry_url = _get_registry_url()
|
|
50
|
+
repository = f"{registry_url}/lit-container/{teamspace.owner.name}/{teamspace.name}/{container}"
|
|
51
|
+
try:
|
|
52
|
+
self._docker_client.images.pull(repository, tag=tag)
|
|
53
|
+
except docker.errors.APIError as e:
|
|
54
|
+
raise ValueError(f"Could not pull container {container} from {repository}:{tag}") from e
|
|
55
|
+
return self._docker_client.api.tag(repository, container, tag)
|
lightning_sdk/api/mmt_api.py
CHANGED
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import time
|
|
3
|
-
from typing import TYPE_CHECKING, Dict, List, Optional
|
|
3
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
|
4
4
|
|
|
5
5
|
from lightning_sdk.api.job_api import JobApiV1
|
|
6
6
|
from lightning_sdk.api.utils import (
|
|
7
7
|
_COMPUTE_NAME_TO_MACHINE,
|
|
8
|
-
_MACHINE_TO_COMPUTE_NAME,
|
|
9
8
|
_create_app,
|
|
9
|
+
_machine_to_compute_name,
|
|
10
10
|
)
|
|
11
11
|
from lightning_sdk.api.utils import (
|
|
12
12
|
_get_cloud_url as _cloud_url,
|
|
@@ -43,13 +43,13 @@ class MMTApiV1(JobApiV1):
|
|
|
43
43
|
cloud_account: Optional[str],
|
|
44
44
|
teamspace_id: str,
|
|
45
45
|
studio_id: str,
|
|
46
|
-
machine: Machine,
|
|
46
|
+
machine: Union[Machine, str],
|
|
47
47
|
interruptible: bool,
|
|
48
48
|
strategy: str,
|
|
49
49
|
) -> Externalv1LightningappInstance:
|
|
50
50
|
"""Creates a multi-machine job with given commands."""
|
|
51
51
|
distributed_args = {
|
|
52
|
-
"cloud_compute":
|
|
52
|
+
"cloud_compute": _machine_to_compute_name(machine),
|
|
53
53
|
"num_instances": num_machines,
|
|
54
54
|
"strategy": strategy,
|
|
55
55
|
}
|
|
@@ -80,7 +80,7 @@ class MMTApiV2:
|
|
|
80
80
|
teamspace_id: str,
|
|
81
81
|
studio_id: Optional[str],
|
|
82
82
|
image: Optional[str],
|
|
83
|
-
machine: Machine,
|
|
83
|
+
machine: Union[Machine, str],
|
|
84
84
|
interruptible: bool,
|
|
85
85
|
env: Optional[Dict[str, str]],
|
|
86
86
|
image_credentials: Optional[str],
|
|
@@ -94,7 +94,7 @@ class MMTApiV2:
|
|
|
94
94
|
for k, v in env.items():
|
|
95
95
|
env_vars.append(V1EnvVar(name=k, value=v))
|
|
96
96
|
|
|
97
|
-
instance_name =
|
|
97
|
+
instance_name = _machine_to_compute_name(machine)
|
|
98
98
|
|
|
99
99
|
run_id = __GLOBAL_LIGHTNING_UNIQUE_IDS_STORE__[studio_id] if studio_id is not None else ""
|
|
100
100
|
|
|
@@ -203,3 +203,9 @@ class MMTApiV2:
|
|
|
203
203
|
return _COMPUTE_NAME_TO_MACHINE.get(
|
|
204
204
|
instance_type, _COMPUTE_NAME_TO_MACHINE.get(instance_name, instance_type or instance_name)
|
|
205
205
|
)
|
|
206
|
+
|
|
207
|
+
def get_total_cost(self, job: V1MultiMachineJob) -> float:
|
|
208
|
+
return job.total_cost
|
|
209
|
+
|
|
210
|
+
def get_num_machines(self, job: V1MultiMachineJob) -> int:
|
|
211
|
+
return job.machines
|
lightning_sdk/api/utils.py
CHANGED
|
@@ -354,6 +354,7 @@ def _machine_to_compute_name(machine: Union[Machine, str]) -> str:
|
|
|
354
354
|
_COMPUTE_NAME_TO_MACHINE: Dict[str, Machine] = {v: k for k, v in _MACHINE_TO_COMPUTE_NAME.items()}
|
|
355
355
|
|
|
356
356
|
_DEFAULT_CLOUD_URL = "https://lightning.ai"
|
|
357
|
+
_DEFAULT_REGISTRY_URL = "litcr.io"
|
|
357
358
|
|
|
358
359
|
|
|
359
360
|
def _get_cloud_url() -> str:
|
|
@@ -362,6 +363,12 @@ def _get_cloud_url() -> str:
|
|
|
362
363
|
return cloud_url
|
|
363
364
|
|
|
364
365
|
|
|
366
|
+
def _get_registry_url() -> str:
|
|
367
|
+
registry_url = os.environ.get("LIGHTNING_REGISTRY_URL", _DEFAULT_REGISTRY_URL)
|
|
368
|
+
os.environ["LIGHTNING_REGISTRY_URL"] = registry_url
|
|
369
|
+
return registry_url
|
|
370
|
+
|
|
371
|
+
|
|
365
372
|
def _sanitize_studio_remote_path(path: str, studio_id: str) -> str:
|
|
366
373
|
return f"/cloudspaces/{studio_id}/code/content/{path.replace('/teamspace/studios/this_studio/', '')}"
|
|
367
374
|
|
lightning_sdk/cli/download.py
CHANGED
|
@@ -3,14 +3,18 @@ import re
|
|
|
3
3
|
from pathlib import Path
|
|
4
4
|
from typing import Optional
|
|
5
5
|
|
|
6
|
+
from rich.console import Console
|
|
7
|
+
|
|
8
|
+
from lightning_sdk.api.lit_container_api import LitContainerApi
|
|
6
9
|
from lightning_sdk.cli.exceptions import StudioCliError
|
|
7
10
|
from lightning_sdk.cli.studios_menu import _StudiosMenu
|
|
11
|
+
from lightning_sdk.cli.teamspace_menu import _TeamspacesMenu
|
|
8
12
|
from lightning_sdk.models import download_model
|
|
9
13
|
from lightning_sdk.studio import Studio
|
|
10
14
|
from lightning_sdk.utils.resolve import _get_authed_user, skip_studio_init
|
|
11
15
|
|
|
12
16
|
|
|
13
|
-
class _Downloads(_StudiosMenu):
|
|
17
|
+
class _Downloads(_StudiosMenu, _TeamspacesMenu):
|
|
14
18
|
"""Download files and folders from Lightning AI."""
|
|
15
19
|
|
|
16
20
|
def model(self, name: str, download_dir: str = ".") -> None:
|
|
@@ -130,3 +134,18 @@ class _Downloads(_StudiosMenu):
|
|
|
130
134
|
f"Could not download the file from the given Studio {studio}. "
|
|
131
135
|
"Please contact Lightning AI directly to resolve this issue."
|
|
132
136
|
) from e
|
|
137
|
+
|
|
138
|
+
def container(self, container: str, teamspace: Optional[str] = None, tag: str = "latest") -> None:
|
|
139
|
+
"""Download a docker container from a teamspace.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
container: The name of the container to download.
|
|
143
|
+
teamspace: The name of the teamspace to download the container from.
|
|
144
|
+
tag: The tag of the container to download.
|
|
145
|
+
"""
|
|
146
|
+
resolved_teamspace = self._resolve_teamspace(teamspace)
|
|
147
|
+
console = Console()
|
|
148
|
+
with console.status("Downloading container..."):
|
|
149
|
+
api = LitContainerApi()
|
|
150
|
+
api.download_container(container, resolved_teamspace, tag)
|
|
151
|
+
console.print("Container downloaded successfully", style="green")
|
lightning_sdk/cli/entrypoint.py
CHANGED
|
@@ -1,3 +1,7 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
from types import TracebackType
|
|
3
|
+
from typing import Type
|
|
4
|
+
|
|
1
5
|
from fire import Fire
|
|
2
6
|
from lightning_utilities.core.imports import RequirementCache
|
|
3
7
|
|
|
@@ -32,6 +36,8 @@ class StudioCLI:
|
|
|
32
36
|
self.inspect = _Inspect()
|
|
33
37
|
self.stop = _Stop()
|
|
34
38
|
|
|
39
|
+
sys.excepthook = _notify_exception
|
|
40
|
+
|
|
35
41
|
def login(self) -> None:
|
|
36
42
|
"""Login to Lightning AI Studios."""
|
|
37
43
|
auth = Auth()
|
|
@@ -48,6 +54,11 @@ class StudioCLI:
|
|
|
48
54
|
auth.clear()
|
|
49
55
|
|
|
50
56
|
|
|
57
|
+
def _notify_exception(exception_type: Type[BaseException], value: BaseException, tb: TracebackType) -> None: # No
|
|
58
|
+
"""CLI won't show tracebacks, just print the exception message."""
|
|
59
|
+
print(value)
|
|
60
|
+
|
|
61
|
+
|
|
51
62
|
def main_cli() -> None:
|
|
52
63
|
"""CLI entrypoint."""
|
|
53
64
|
Fire(StudioCLI(), name="lightning")
|
lightning_sdk/cli/list.py
CHANGED
|
@@ -20,7 +20,36 @@ class _List(_TeamspacesMenu):
|
|
|
20
20
|
"""
|
|
21
21
|
resolved_teamspace = self._resolve_teamspace(teamspace=teamspace)
|
|
22
22
|
|
|
23
|
-
|
|
23
|
+
jobs = resolved_teamspace.jobs
|
|
24
|
+
|
|
25
|
+
table = Table(
|
|
26
|
+
pad_edge=True,
|
|
27
|
+
)
|
|
28
|
+
table.add_column("Name")
|
|
29
|
+
table.add_column("Teamspace")
|
|
30
|
+
table.add_column("Studio")
|
|
31
|
+
table.add_column("Image")
|
|
32
|
+
table.add_column("Status")
|
|
33
|
+
table.add_column("Machine")
|
|
34
|
+
table.add_column("Total Cost")
|
|
35
|
+
for j in jobs:
|
|
36
|
+
# we know we just fetched these, so no need to refetch
|
|
37
|
+
j._prevent_refetch_latest = True
|
|
38
|
+
j._internal_job._prevent_refetch_latest = True
|
|
39
|
+
|
|
40
|
+
studio = j.studio
|
|
41
|
+
table.add_row(
|
|
42
|
+
j.name,
|
|
43
|
+
f"{j.teamspace.owner.name}/{j.teamspace.name}",
|
|
44
|
+
studio.name if studio else None,
|
|
45
|
+
j.image,
|
|
46
|
+
str(j.status),
|
|
47
|
+
str(j.machine),
|
|
48
|
+
f"{j.total_cost:.3f}",
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
console = Console()
|
|
52
|
+
console.print(table)
|
|
24
53
|
|
|
25
54
|
def mmts(self, teamspace: Optional[str] = None) -> None:
|
|
26
55
|
"""List multi-machine jobs for a given teamspace.
|
|
@@ -32,7 +61,36 @@ class _List(_TeamspacesMenu):
|
|
|
32
61
|
"""
|
|
33
62
|
resolved_teamspace = self._resolve_teamspace(teamspace=teamspace)
|
|
34
63
|
|
|
35
|
-
|
|
64
|
+
jobs = resolved_teamspace.multi_machine_jobs
|
|
65
|
+
|
|
66
|
+
table = Table(pad_edge=True)
|
|
67
|
+
table.add_column("Name")
|
|
68
|
+
table.add_column("Teamspace")
|
|
69
|
+
table.add_column("Studio")
|
|
70
|
+
table.add_column("Image")
|
|
71
|
+
table.add_column("Status")
|
|
72
|
+
table.add_column("Machine")
|
|
73
|
+
table.add_column("Num Machines")
|
|
74
|
+
table.add_column("Total Cost")
|
|
75
|
+
for j in jobs:
|
|
76
|
+
# we know we just fetched these, so no need to refetch
|
|
77
|
+
j._prevent_refetch_latest = True
|
|
78
|
+
j._internal_job._prevent_refetch_latest = True
|
|
79
|
+
|
|
80
|
+
studio = j.studio
|
|
81
|
+
table.add_row(
|
|
82
|
+
j.name,
|
|
83
|
+
f"{j.teamspace.owner.name}/{j.teamspace.name}",
|
|
84
|
+
studio.name if studio else None,
|
|
85
|
+
j.image,
|
|
86
|
+
str(j.status),
|
|
87
|
+
str(j.machine),
|
|
88
|
+
str(j.num_machines),
|
|
89
|
+
str(j.total_cost),
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
console = Console()
|
|
93
|
+
console.print(table)
|
|
36
94
|
|
|
37
95
|
def containers(self, teamspace: Optional[str] = None) -> None:
|
|
38
96
|
"""Display the list of available containers.
|
lightning_sdk/cli/run.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import TYPE_CHECKING, Dict, Optional
|
|
1
|
+
from typing import TYPE_CHECKING, Dict, Optional, Union
|
|
2
2
|
|
|
3
3
|
from lightning_sdk.job import Job
|
|
4
4
|
from lightning_sdk.machine import Machine
|
|
@@ -111,7 +111,7 @@ class _Run:
|
|
|
111
111
|
# might need to move to different cli library
|
|
112
112
|
def job(
|
|
113
113
|
self,
|
|
114
|
-
name: str,
|
|
114
|
+
name: Optional[str] = None,
|
|
115
115
|
machine: Optional[str] = None,
|
|
116
116
|
command: Optional[str] = None,
|
|
117
117
|
studio: Optional[str] = None,
|
|
@@ -128,16 +128,27 @@ class _Run:
|
|
|
128
128
|
artifacts_remote: Optional[str] = None,
|
|
129
129
|
entrypoint: str = "sh -c",
|
|
130
130
|
) -> None:
|
|
131
|
+
if not name:
|
|
132
|
+
from datetime import datetime
|
|
133
|
+
|
|
134
|
+
timestr = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
|
135
|
+
name = f"job-{timestr}"
|
|
136
|
+
|
|
131
137
|
if machine is None:
|
|
132
138
|
# TODO: infer from studio
|
|
133
139
|
machine = "CPU"
|
|
134
|
-
machine_enum
|
|
140
|
+
machine_enum: Union[str, Machine]
|
|
141
|
+
try:
|
|
142
|
+
machine_enum = Machine[machine.upper()]
|
|
143
|
+
except KeyError:
|
|
144
|
+
machine_enum = machine
|
|
135
145
|
|
|
136
146
|
resolved_teamspace = Teamspace(name=teamspace, org=org, user=user)
|
|
137
147
|
|
|
138
148
|
if cloud_account is None:
|
|
139
149
|
cloud_account = resolved_teamspace.default_cloud_account
|
|
140
150
|
machine_enum = Machine(machine.upper())
|
|
151
|
+
|
|
141
152
|
Job.run(
|
|
142
153
|
name=name,
|
|
143
154
|
machine=machine_enum,
|
|
@@ -188,7 +199,11 @@ class _Run:
|
|
|
188
199
|
if machine is None:
|
|
189
200
|
# TODO: infer from studio
|
|
190
201
|
machine = "CPU"
|
|
191
|
-
machine_enum
|
|
202
|
+
machine_enum: Union[str, Machine]
|
|
203
|
+
try:
|
|
204
|
+
machine_enum = Machine[machine.upper()]
|
|
205
|
+
except KeyError:
|
|
206
|
+
machine_enum = machine
|
|
192
207
|
|
|
193
208
|
resolved_teamspace = Teamspace(name=teamspace, org=org, user=user)
|
|
194
209
|
if cloud_account is None:
|
lightning_sdk/cli/upload.py
CHANGED
|
@@ -4,18 +4,22 @@ import os
|
|
|
4
4
|
from pathlib import Path
|
|
5
5
|
from typing import Dict, List, Optional
|
|
6
6
|
|
|
7
|
+
from rich.console import Console
|
|
8
|
+
from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
|
|
7
9
|
from simple_term_menu import TerminalMenu
|
|
8
10
|
from tqdm import tqdm
|
|
9
11
|
|
|
12
|
+
from lightning_sdk.api.lit_container_api import LitContainerApi
|
|
10
13
|
from lightning_sdk.api.utils import _get_cloud_url
|
|
11
14
|
from lightning_sdk.cli.exceptions import StudioCliError
|
|
12
15
|
from lightning_sdk.cli.studios_menu import _StudiosMenu
|
|
16
|
+
from lightning_sdk.cli.teamspace_menu import _TeamspacesMenu
|
|
13
17
|
from lightning_sdk.models import upload_model
|
|
14
18
|
from lightning_sdk.studio import Studio
|
|
15
19
|
from lightning_sdk.utils.resolve import _get_authed_user, skip_studio_init
|
|
16
20
|
|
|
17
21
|
|
|
18
|
-
class _Uploads(_StudiosMenu):
|
|
22
|
+
class _Uploads(_StudiosMenu, _TeamspacesMenu):
|
|
19
23
|
"""Upload files and folders to Lightning AI."""
|
|
20
24
|
|
|
21
25
|
_studio_upload_status_path = "~/.lightning/studios/uploads"
|
|
@@ -146,6 +150,33 @@ class _Uploads(_StudiosMenu):
|
|
|
146
150
|
)
|
|
147
151
|
print(f"See your file at {studio_url}")
|
|
148
152
|
|
|
153
|
+
def container(self, container: str, tag: str = "latest", teamspace: Optional[str] = None) -> None:
|
|
154
|
+
teamspace = self._resolve_teamspace(teamspace)
|
|
155
|
+
api = LitContainerApi()
|
|
156
|
+
console = Console()
|
|
157
|
+
with Progress(
|
|
158
|
+
SpinnerColumn(),
|
|
159
|
+
TextColumn("[progress.description]{task.description}"),
|
|
160
|
+
TimeElapsedColumn(),
|
|
161
|
+
console=console,
|
|
162
|
+
transient=False,
|
|
163
|
+
) as progress:
|
|
164
|
+
push_task = progress.add_task("Pushing Docker image", total=None)
|
|
165
|
+
resp = api.upload_container(container, teamspace, tag)
|
|
166
|
+
for line in resp:
|
|
167
|
+
if "status" in line:
|
|
168
|
+
console.print(line["status"], style="bright_black")
|
|
169
|
+
progress.update(push_task, description="Pushing Docker image")
|
|
170
|
+
elif "aux" in line:
|
|
171
|
+
console.print(line["aux"], style="bright_black")
|
|
172
|
+
elif "error" in line:
|
|
173
|
+
progress.stop()
|
|
174
|
+
console.print(f"\n[red]{line}[/red]")
|
|
175
|
+
return
|
|
176
|
+
else:
|
|
177
|
+
console.print(line, style="bright_black")
|
|
178
|
+
progress.update(push_task, description="[green]Container pushed![/green]")
|
|
179
|
+
|
|
149
180
|
def _start_parallel_upload(
|
|
150
181
|
self, executor: concurrent.futures.ThreadPoolExecutor, studio: Studio, upload_state: Dict[str, str]
|
|
151
182
|
) -> List[concurrent.futures.Future]:
|
lightning_sdk/job/base.py
CHANGED
|
@@ -16,7 +16,7 @@ if TYPE_CHECKING:
|
|
|
16
16
|
class MachineDict(TypedDict):
|
|
17
17
|
name: str
|
|
18
18
|
status: "Status"
|
|
19
|
-
machine: "Machine"
|
|
19
|
+
machine: Union["Machine", str]
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
class JobDict(MachineDict):
|
|
@@ -24,6 +24,7 @@ class JobDict(MachineDict):
|
|
|
24
24
|
teamspace: str
|
|
25
25
|
studio: Optional[str]
|
|
26
26
|
image: Optional[str]
|
|
27
|
+
total_cost: float
|
|
27
28
|
|
|
28
29
|
|
|
29
30
|
class _BaseJob(ABC):
|
|
@@ -61,11 +62,13 @@ class _BaseJob(ABC):
|
|
|
61
62
|
if _fetch_job:
|
|
62
63
|
self._update_internal_job()
|
|
63
64
|
|
|
65
|
+
self._prevent_refetch_latest = False
|
|
66
|
+
|
|
64
67
|
@classmethod
|
|
65
68
|
def run(
|
|
66
69
|
cls,
|
|
67
70
|
name: str,
|
|
68
|
-
machine: "Machine",
|
|
71
|
+
machine: Union["Machine", str],
|
|
69
72
|
command: Optional[str] = None,
|
|
70
73
|
studio: Union["Studio", str, None] = None,
|
|
71
74
|
image: Optional[str] = None,
|
|
@@ -202,7 +205,7 @@ class _BaseJob(ABC):
|
|
|
202
205
|
@abstractmethod
|
|
203
206
|
def _submit(
|
|
204
207
|
self,
|
|
205
|
-
machine: "Machine",
|
|
208
|
+
machine: Union["Machine", str],
|
|
206
209
|
command: Optional[str] = None,
|
|
207
210
|
studio: Optional["Studio"] = None,
|
|
208
211
|
image: Optional[str] = None,
|
|
@@ -268,7 +271,7 @@ class _BaseJob(ABC):
|
|
|
268
271
|
|
|
269
272
|
@property
|
|
270
273
|
@abstractmethod
|
|
271
|
-
def machine(self) -> "Machine":
|
|
274
|
+
def machine(self) -> Union["Machine", str]:
|
|
272
275
|
"""The machine type the job is running on."""
|
|
273
276
|
|
|
274
277
|
@property
|
|
@@ -332,6 +335,7 @@ class _BaseJob(ABC):
|
|
|
332
335
|
"command": self.command,
|
|
333
336
|
"status": self.status,
|
|
334
337
|
"machine": self.machine,
|
|
338
|
+
"total_cost": self.total_cost,
|
|
335
339
|
}
|
|
336
340
|
|
|
337
341
|
def json(self) -> str:
|
|
@@ -361,3 +365,18 @@ class _BaseJob(ABC):
|
|
|
361
365
|
self._update_internal_job()
|
|
362
366
|
|
|
363
367
|
return self._job
|
|
368
|
+
|
|
369
|
+
@property
|
|
370
|
+
def total_cost(self) -> float:
|
|
371
|
+
"""The number of credits the job was consuming so far."""
|
|
372
|
+
return self._job_api.get_total_cost(self._latest_job)
|
|
373
|
+
|
|
374
|
+
@property
|
|
375
|
+
def _latest_job(self) -> Any:
|
|
376
|
+
"""Guarantees to fetch the latest version of a job before returning it."""
|
|
377
|
+
# in some cases we know we just refetched the latest state, no need to refetch again
|
|
378
|
+
if self._prevent_refetch_latest:
|
|
379
|
+
return self._guaranteed_job
|
|
380
|
+
|
|
381
|
+
self._update_internal_job()
|
|
382
|
+
return self._job
|
lightning_sdk/job/job.py
CHANGED
|
@@ -89,7 +89,7 @@ class Job(_BaseJob):
|
|
|
89
89
|
def run(
|
|
90
90
|
cls,
|
|
91
91
|
name: str,
|
|
92
|
-
machine: "Machine",
|
|
92
|
+
machine: Union["Machine", str],
|
|
93
93
|
command: Optional[str] = None,
|
|
94
94
|
studio: Union["Studio", str, None] = None,
|
|
95
95
|
image: Union[str, None] = None,
|
|
@@ -169,7 +169,7 @@ class Job(_BaseJob):
|
|
|
169
169
|
|
|
170
170
|
def _submit(
|
|
171
171
|
self,
|
|
172
|
-
machine: "Machine",
|
|
172
|
+
machine: Union["Machine", str],
|
|
173
173
|
command: Optional[str] = None,
|
|
174
174
|
studio: Optional["Studio"] = None,
|
|
175
175
|
image: Optional[str] = None,
|
|
@@ -225,6 +225,7 @@ class Job(_BaseJob):
|
|
|
225
225
|
cloud_account_auth=cloud_account_auth,
|
|
226
226
|
artifacts_local=artifacts_local,
|
|
227
227
|
artifacts_remote=artifacts_remote,
|
|
228
|
+
entrypoint=entrypoint,
|
|
228
229
|
)
|
|
229
230
|
return self
|
|
230
231
|
|
|
@@ -248,7 +249,7 @@ class Job(_BaseJob):
|
|
|
248
249
|
return self._internal_job.status
|
|
249
250
|
|
|
250
251
|
@property
|
|
251
|
-
def machine(self) -> "Machine":
|
|
252
|
+
def machine(self) -> Union["Machine", str]:
|
|
252
253
|
"""The machine type the job is running on."""
|
|
253
254
|
return self._internal_job.machine
|
|
254
255
|
|
lightning_sdk/job/v1.py
CHANGED
|
@@ -44,7 +44,7 @@ class _JobV1(_BaseJob):
|
|
|
44
44
|
def run(
|
|
45
45
|
cls,
|
|
46
46
|
name: str,
|
|
47
|
-
machine: "Machine",
|
|
47
|
+
machine: Union["Machine", str],
|
|
48
48
|
command: str,
|
|
49
49
|
studio: "Studio",
|
|
50
50
|
teamspace: Union[str, "Teamspace", None] = None,
|
|
@@ -89,7 +89,7 @@ class _JobV1(_BaseJob):
|
|
|
89
89
|
|
|
90
90
|
def _submit(
|
|
91
91
|
self,
|
|
92
|
-
machine: "Machine",
|
|
92
|
+
machine: Union["Machine", str],
|
|
93
93
|
command: Optional[str] = None,
|
|
94
94
|
studio: Optional["Studio"] = None,
|
|
95
95
|
image: Optional[str] = None,
|
|
@@ -174,7 +174,7 @@ class _JobV1(_BaseJob):
|
|
|
174
174
|
|
|
175
175
|
def stop(self) -> None:
|
|
176
176
|
"""Stops the job. is blocking until the ob is stopped."""
|
|
177
|
-
if self.status in (Status.Stopped, Status.Failed):
|
|
177
|
+
if self.status in (Status.Stopped, Status.Completed, Status.Failed):
|
|
178
178
|
return None
|
|
179
179
|
|
|
180
180
|
return self._job_api.stop_job(self._job.id, self.teamspace.id)
|
|
@@ -195,7 +195,7 @@ class _JobV1(_BaseJob):
|
|
|
195
195
|
return Work(_work[0].id, self, self.teamspace)
|
|
196
196
|
|
|
197
197
|
@property
|
|
198
|
-
def machine(self) -> "Machine":
|
|
198
|
+
def machine(self) -> Union["Machine", str]:
|
|
199
199
|
"""Get the machine the job is running on."""
|
|
200
200
|
return self.work.machine
|
|
201
201
|
|
lightning_sdk/job/v2.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
|
1
|
-
from typing import TYPE_CHECKING,
|
|
1
|
+
from typing import TYPE_CHECKING, Dict, Optional, Union
|
|
2
2
|
|
|
3
3
|
from lightning_sdk.api.job_api import JobApiV2
|
|
4
4
|
from lightning_sdk.api.utils import _get_cloud_url
|
|
5
5
|
from lightning_sdk.job.base import _BaseJob
|
|
6
|
+
from lightning_sdk.status import Status
|
|
6
7
|
|
|
7
8
|
if TYPE_CHECKING:
|
|
8
9
|
from lightning_sdk.machine import Machine
|
|
9
10
|
from lightning_sdk.organization import Organization
|
|
10
|
-
from lightning_sdk.status import Status
|
|
11
11
|
from lightning_sdk.studio import Studio
|
|
12
12
|
from lightning_sdk.teamspace import Teamspace
|
|
13
13
|
from lightning_sdk.user import User
|
|
@@ -37,7 +37,7 @@ class _JobV2(_BaseJob):
|
|
|
37
37
|
|
|
38
38
|
def _submit(
|
|
39
39
|
self,
|
|
40
|
-
machine: "Machine",
|
|
40
|
+
machine: Union["Machine", str],
|
|
41
41
|
command: Optional[str] = None,
|
|
42
42
|
studio: Optional["Studio"] = None,
|
|
43
43
|
image: Optional[str] = None,
|
|
@@ -121,6 +121,9 @@ class _JobV2(_BaseJob):
|
|
|
121
121
|
|
|
122
122
|
def stop(self) -> None:
|
|
123
123
|
"""Stop the job. If the job is already stopped, this is a no-op. This is blocking until the job is stopped."""
|
|
124
|
+
if self.status in (Status.Stopped, Status.Completed, Status.Failed):
|
|
125
|
+
return
|
|
126
|
+
|
|
124
127
|
self._job_api.stop_job(job_id=self._guaranteed_job.id, teamspace_id=self._teamspace.id)
|
|
125
128
|
|
|
126
129
|
def delete(self) -> None:
|
|
@@ -134,19 +137,13 @@ class _JobV2(_BaseJob):
|
|
|
134
137
|
cloudspace_id=self._guaranteed_job.spec.cloudspace_id,
|
|
135
138
|
)
|
|
136
139
|
|
|
137
|
-
@property
|
|
138
|
-
def _latest_job(self) -> Any:
|
|
139
|
-
"""Guarantees to fetch the latest version of a job before returning it."""
|
|
140
|
-
self._update_internal_job()
|
|
141
|
-
return self._job
|
|
142
|
-
|
|
143
140
|
@property
|
|
144
141
|
def status(self) -> "Status":
|
|
145
142
|
"""The current status of the job."""
|
|
146
143
|
return self._job_api._job_state_to_external(self._latest_job.state)
|
|
147
144
|
|
|
148
145
|
@property
|
|
149
|
-
def machine(self) -> "Machine":
|
|
146
|
+
def machine(self) -> Union["Machine", str]:
|
|
150
147
|
"""The machine type the job is running on."""
|
|
151
148
|
# only fetch the job it it hasn't been fetched yet as machine cannot change over time
|
|
152
149
|
return self._job_api._get_job_machine_from_spec(self._guaranteed_job.spec)
|