lightning-sdk 0.1.38__py3-none-any.whl → 0.1.40__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lightning_sdk/__init__.py +1 -1
- lightning_sdk/_mmt/__init__.py +3 -0
- lightning_sdk/_mmt/base.py +180 -0
- lightning_sdk/_mmt/mmt.py +161 -0
- lightning_sdk/_mmt/v1.py +69 -0
- lightning_sdk/_mmt/v2.py +141 -0
- lightning_sdk/api/deployment_api.py +0 -2
- lightning_sdk/api/job_api.py +4 -0
- lightning_sdk/api/mmt_api.py +147 -0
- lightning_sdk/api/teamspace_api.py +4 -11
- lightning_sdk/api/utils.py +6 -3
- lightning_sdk/cli/download.py +3 -5
- lightning_sdk/cli/mmt.py +137 -0
- lightning_sdk/cli/run.py +16 -0
- lightning_sdk/cli/upload.py +3 -10
- lightning_sdk/job/base.py +24 -3
- lightning_sdk/job/job.py +10 -1
- lightning_sdk/job/v1.py +7 -1
- lightning_sdk/job/v2.py +18 -9
- lightning_sdk/lightning_cloud/openapi/__init__.py +7 -3
- lightning_sdk/lightning_cloud/openapi/api/assistants_service_api.py +90 -284
- lightning_sdk/lightning_cloud/openapi/api/data_connection_service_api.py +6 -1
- lightning_sdk/lightning_cloud/openapi/api/models_store_api.py +235 -1
- lightning_sdk/lightning_cloud/openapi/models/__init__.py +7 -3
- lightning_sdk/lightning_cloud/openapi/models/deployments_id_body.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/id_start_body.py +29 -3
- lightning_sdk/lightning_cloud/openapi/models/model_id_visibility_body.py +123 -0
- lightning_sdk/lightning_cloud/openapi/models/project_id_cloudspaces_body.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_aws_direct_v1.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/{project_id_agentmanagedmodels_body.py → v1_body.py} +21 -47
- lightning_sdk/lightning_cloud/openapi/models/v1_data_path.py +29 -3
- lightning_sdk/lightning_cloud/openapi/models/v1_deployment.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_deployment_api.py +201 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_get_model_files_response.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_job_spec.py +53 -53
- lightning_sdk/lightning_cloud/openapi/models/v1_managed_model.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_managed_model_abilities.py +175 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_model.py +29 -3
- lightning_sdk/lightning_cloud/openapi/models/v1_multi_machine_job.py +53 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_multi_machine_job_state.py +1 -2
- lightning_sdk/lightning_cloud/openapi/models/v1_query_param.py +175 -0
- lightning_sdk/lightning_cloud/openapi/models/{v1_list_managed_models_response.py → v1_resource_visibility.py} +23 -23
- lightning_sdk/lightning_cloud/openapi/models/{v1_delete_managed_model_response.py → v1_update_model_visibility_response.py} +6 -6
- lightning_sdk/lightning_cloud/openapi/models/v1_user_features.py +27 -1
- lightning_sdk/models.py +153 -0
- lightning_sdk/teamspace.py +15 -11
- {lightning_sdk-0.1.38.dist-info → lightning_sdk-0.1.40.dist-info}/METADATA +1 -1
- {lightning_sdk-0.1.38.dist-info → lightning_sdk-0.1.40.dist-info}/RECORD +52 -41
- {lightning_sdk-0.1.38.dist-info → lightning_sdk-0.1.40.dist-info}/entry_points.txt +1 -0
- lightning_sdk/cli/models.py +0 -68
- {lightning_sdk-0.1.38.dist-info → lightning_sdk-0.1.40.dist-info}/LICENSE +0 -0
- {lightning_sdk-0.1.38.dist-info → lightning_sdk-0.1.40.dist-info}/WHEEL +0 -0
- {lightning_sdk-0.1.38.dist-info → lightning_sdk-0.1.40.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
import time
|
|
2
|
+
from typing import TYPE_CHECKING, Dict, Optional
|
|
3
|
+
|
|
4
|
+
from lightning_sdk.api.utils import (
|
|
5
|
+
_COMPUTE_NAME_TO_MACHINE,
|
|
6
|
+
_MACHINE_TO_COMPUTE_NAME,
|
|
7
|
+
)
|
|
8
|
+
from lightning_sdk.api.utils import (
|
|
9
|
+
_get_cloud_url as _cloud_url,
|
|
10
|
+
)
|
|
11
|
+
from lightning_sdk.constants import __GLOBAL_LIGHTNING_UNIQUE_IDS_STORE__
|
|
12
|
+
from lightning_sdk.lightning_cloud.openapi import (
|
|
13
|
+
MultimachinejobsIdBody,
|
|
14
|
+
ProjectIdMultimachinejobsBody,
|
|
15
|
+
V1EnvVar,
|
|
16
|
+
V1JobSpec,
|
|
17
|
+
V1MultiMachineJob,
|
|
18
|
+
V1MultiMachineJobState,
|
|
19
|
+
)
|
|
20
|
+
from lightning_sdk.lightning_cloud.rest_client import LightningClient
|
|
21
|
+
from lightning_sdk.machine import Machine
|
|
22
|
+
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
from lightning_sdk.status import Status
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class MMTApi:
|
|
28
|
+
mmt_state_unspecified = "MultiMachineJob_STATE_UNSPECIFIED"
|
|
29
|
+
mmt_state_running = "MultiMachineJob_STATE_RUNNING"
|
|
30
|
+
mmt_state_stopped = "MultiMachineJob_STATE_STOPPED"
|
|
31
|
+
mmt_state_deleted = "MultiMachineJob_STATE_DELETED"
|
|
32
|
+
mmt_state_failed = "MultiMachineJob_STATE_FAILED"
|
|
33
|
+
mmt_state_completed = "MultiMachineJob_STATE_COMPLETED"
|
|
34
|
+
|
|
35
|
+
def __init__(self) -> None:
|
|
36
|
+
self._cloud_url = _cloud_url()
|
|
37
|
+
self._client = LightningClient(max_tries=7)
|
|
38
|
+
|
|
39
|
+
def submit_job(
|
|
40
|
+
self,
|
|
41
|
+
name: str,
|
|
42
|
+
num_machines: int,
|
|
43
|
+
command: Optional[str],
|
|
44
|
+
cluster_id: Optional[str],
|
|
45
|
+
teamspace_id: str,
|
|
46
|
+
studio_id: Optional[str],
|
|
47
|
+
image: Optional[str],
|
|
48
|
+
machine: Machine,
|
|
49
|
+
interruptible: bool,
|
|
50
|
+
env: Optional[Dict[str, str]],
|
|
51
|
+
image_credentials: Optional[str],
|
|
52
|
+
cluster_auth: bool,
|
|
53
|
+
artifacts_local: Optional[str],
|
|
54
|
+
artifacts_remote: Optional[str],
|
|
55
|
+
) -> V1MultiMachineJob:
|
|
56
|
+
env_vars = []
|
|
57
|
+
if env is not None:
|
|
58
|
+
for k, v in env.items():
|
|
59
|
+
env_vars.append(V1EnvVar(name=k, value=v))
|
|
60
|
+
|
|
61
|
+
instance_name = _MACHINE_TO_COMPUTE_NAME[machine]
|
|
62
|
+
|
|
63
|
+
run_id = __GLOBAL_LIGHTNING_UNIQUE_IDS_STORE__[studio_id] if studio_id is not None else ""
|
|
64
|
+
|
|
65
|
+
spec = V1JobSpec(
|
|
66
|
+
cloudspace_id=studio_id or "",
|
|
67
|
+
cluster_id=cluster_id or "",
|
|
68
|
+
command=command or "",
|
|
69
|
+
env=env_vars,
|
|
70
|
+
image=image or "",
|
|
71
|
+
instance_name=instance_name,
|
|
72
|
+
run_id=run_id,
|
|
73
|
+
spot=interruptible,
|
|
74
|
+
image_cluster_credentials=cluster_auth,
|
|
75
|
+
image_secret_ref=image_credentials or "",
|
|
76
|
+
artifacts_source=artifacts_local or "",
|
|
77
|
+
artifacts_destination=artifacts_remote or "",
|
|
78
|
+
)
|
|
79
|
+
body = ProjectIdMultimachinejobsBody(name=name, spec=spec, cluster_id=cluster_id or "", machines=num_machines)
|
|
80
|
+
|
|
81
|
+
job: V1MultiMachineJob = self._client.jobs_service_create_multi_machine_job(project_id=teamspace_id, body=body)
|
|
82
|
+
return job
|
|
83
|
+
|
|
84
|
+
def get_job_by_name(self, name: str, teamspace_id: str) -> V1MultiMachineJob:
|
|
85
|
+
job: V1MultiMachineJob = self._client.jobs_service_get_multi_machine_job_by_name(
|
|
86
|
+
project_id=teamspace_id, name=name
|
|
87
|
+
)
|
|
88
|
+
return job
|
|
89
|
+
|
|
90
|
+
def get_job(self, job_id: str, teamspace_id: str) -> V1MultiMachineJob:
|
|
91
|
+
job: V1MultiMachineJob = self._client.jobs_service_get_multi_machine_job(project_id=teamspace_id, id=job_id)
|
|
92
|
+
return job
|
|
93
|
+
|
|
94
|
+
def stop_job(self, job_id: str, teamspace_id: str) -> None:
|
|
95
|
+
from lightning_sdk.status import Status
|
|
96
|
+
|
|
97
|
+
current_job = self.get_job(job_id=job_id, teamspace_id=teamspace_id)
|
|
98
|
+
|
|
99
|
+
current_state = self._job_state_to_external(current_job.desired_state)
|
|
100
|
+
|
|
101
|
+
if current_state in (
|
|
102
|
+
Status.Stopped,
|
|
103
|
+
Status.Completed,
|
|
104
|
+
Status.Failed,
|
|
105
|
+
):
|
|
106
|
+
return
|
|
107
|
+
|
|
108
|
+
if current_state != Status.Stopped:
|
|
109
|
+
update_body = MultimachinejobsIdBody(desired_state=self.mmt_state_stopped)
|
|
110
|
+
self._client.jobs_service_update_multi_machine_job(body=update_body, project_id=teamspace_id, id=job_id)
|
|
111
|
+
|
|
112
|
+
while True:
|
|
113
|
+
current_job = self.get_job(job_id=job_id, teamspace_id=teamspace_id)
|
|
114
|
+
if self._job_state_to_external(current_job.desired_state) in (
|
|
115
|
+
Status.Stopped,
|
|
116
|
+
Status.Completed,
|
|
117
|
+
Status.Stopped,
|
|
118
|
+
Status.Failed,
|
|
119
|
+
):
|
|
120
|
+
break
|
|
121
|
+
time.sleep(1)
|
|
122
|
+
|
|
123
|
+
def delete_job(self, job_id: str, teamspace_id: str) -> None:
|
|
124
|
+
self._client.jobs_service_delete_multi_machine_job(project_id=teamspace_id, id=job_id)
|
|
125
|
+
|
|
126
|
+
def _job_state_to_external(self, state: V1MultiMachineJobState) -> "Status":
|
|
127
|
+
from lightning_sdk.status import Status
|
|
128
|
+
|
|
129
|
+
if str(state) == self.mmt_state_unspecified:
|
|
130
|
+
return Status.Pending
|
|
131
|
+
if str(state) == self.mmt_state_running:
|
|
132
|
+
return Status.Running
|
|
133
|
+
if str(state) == self.mmt_state_stopped:
|
|
134
|
+
return Status.Stopped
|
|
135
|
+
if str(state) == self.mmt_state_completed:
|
|
136
|
+
return Status.Completed
|
|
137
|
+
if str(state) == self.mmt_state_failed:
|
|
138
|
+
return Status.Failed
|
|
139
|
+
return Status.Pending
|
|
140
|
+
|
|
141
|
+
def _get_job_machine_from_spec(self, spec: V1JobSpec) -> "Machine":
|
|
142
|
+
instance_name = spec.instance_name
|
|
143
|
+
instance_type = spec.instance_type
|
|
144
|
+
|
|
145
|
+
return _COMPUTE_NAME_TO_MACHINE.get(
|
|
146
|
+
instance_type, _COMPUTE_NAME_TO_MACHINE.get(instance_name, instance_type or instance_name)
|
|
147
|
+
)
|
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import os
|
|
2
|
-
from dataclasses import dataclass
|
|
3
2
|
from pathlib import Path
|
|
4
3
|
from typing import Dict, List, Optional
|
|
5
4
|
|
|
@@ -260,22 +259,16 @@ class TeamspaceApi:
|
|
|
260
259
|
name: str,
|
|
261
260
|
version: str,
|
|
262
261
|
download_dir: Path,
|
|
263
|
-
|
|
262
|
+
teamspace_name: str,
|
|
263
|
+
teamspace_owner_name: str,
|
|
264
264
|
progress_bar: bool = True,
|
|
265
265
|
) -> List[str]:
|
|
266
266
|
return _download_model_files(
|
|
267
267
|
client=self._client,
|
|
268
|
-
|
|
268
|
+
teamspace_name=teamspace_name,
|
|
269
|
+
teamspace_owner_name=teamspace_owner_name,
|
|
269
270
|
name=name,
|
|
270
271
|
version=version,
|
|
271
272
|
download_dir=download_dir,
|
|
272
273
|
progress_bar=progress_bar,
|
|
273
274
|
)
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
@dataclass
|
|
277
|
-
class UploadedModelInfo:
|
|
278
|
-
name: str
|
|
279
|
-
version: str
|
|
280
|
-
teamspace: str
|
|
281
|
-
cluster: str
|
lightning_sdk/api/utils.py
CHANGED
|
@@ -513,7 +513,8 @@ def _get_model_version(client: LightningClient, teamspace_id: str, name: str, ve
|
|
|
513
513
|
|
|
514
514
|
def _download_model_files(
|
|
515
515
|
client: LightningClient,
|
|
516
|
-
|
|
516
|
+
teamspace_name: str,
|
|
517
|
+
teamspace_owner_name: str,
|
|
517
518
|
name: str,
|
|
518
519
|
version: str,
|
|
519
520
|
download_dir: Path,
|
|
@@ -521,7 +522,9 @@ def _download_model_files(
|
|
|
521
522
|
num_workers: int = 20,
|
|
522
523
|
) -> List[str]:
|
|
523
524
|
api = ModelsStoreApi(client.api_client)
|
|
524
|
-
response = api.models_store_get_model_files(
|
|
525
|
+
response = api.models_store_get_model_files(
|
|
526
|
+
project_name=teamspace_name, project_owner_name=teamspace_owner_name, name=name, version=version
|
|
527
|
+
)
|
|
525
528
|
|
|
526
529
|
pbar = None
|
|
527
530
|
if progress_bar:
|
|
@@ -541,7 +544,7 @@ def _download_model_files(
|
|
|
541
544
|
client=client,
|
|
542
545
|
model_id=response.model_id,
|
|
543
546
|
version=response.version,
|
|
544
|
-
teamspace_id=
|
|
547
|
+
teamspace_id=response.project_id,
|
|
545
548
|
remote_path=filepath,
|
|
546
549
|
file_path=str(local_file),
|
|
547
550
|
num_workers=num_workers,
|
lightning_sdk/cli/download.py
CHANGED
|
@@ -4,8 +4,8 @@ from pathlib import Path
|
|
|
4
4
|
from typing import Optional
|
|
5
5
|
|
|
6
6
|
from lightning_sdk.cli.exceptions import StudioCliError
|
|
7
|
-
from lightning_sdk.cli.models import _get_teamspace, _parse_model_name
|
|
8
7
|
from lightning_sdk.cli.studios_menu import _StudiosMenu
|
|
8
|
+
from lightning_sdk.models import download_model
|
|
9
9
|
from lightning_sdk.studio import Studio
|
|
10
10
|
from lightning_sdk.utils.resolve import _get_authed_user, skip_studio_init
|
|
11
11
|
|
|
@@ -21,10 +21,8 @@ class _Downloads(_StudiosMenu):
|
|
|
21
21
|
This should have the format <ORGANIZATION-NAME>/<TEAMSPACE-NAME>/<MODEL-NAME>.
|
|
22
22
|
download_dir: The directory where the Model should be downloaded.
|
|
23
23
|
"""
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
teamspace.download_model(
|
|
27
|
-
name=model_name,
|
|
24
|
+
download_model(
|
|
25
|
+
name=name,
|
|
28
26
|
download_dir=download_dir,
|
|
29
27
|
progress_bar=True,
|
|
30
28
|
)
|
lightning_sdk/cli/mmt.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
from typing import Dict, Optional
|
|
2
|
+
|
|
3
|
+
from fire import Fire
|
|
4
|
+
|
|
5
|
+
from lightning_sdk._mmt import MMT
|
|
6
|
+
from lightning_sdk.api.studio_api import _cloud_url
|
|
7
|
+
from lightning_sdk.lightning_cloud.login import Auth
|
|
8
|
+
from lightning_sdk.machine import Machine
|
|
9
|
+
from lightning_sdk.teamspace import Teamspace
|
|
10
|
+
|
|
11
|
+
_MACHINE_VALUES = tuple([machine.value for machine in Machine])
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class MMTCLI:
|
|
15
|
+
"""Command line interface (CLI) to interact with/manage Lightning AI MMT."""
|
|
16
|
+
|
|
17
|
+
def __init__(self) -> None:
|
|
18
|
+
# Need to set the docstring here for f-strings to work.
|
|
19
|
+
# Sadly this is the only way to really show options as f-strings are not allowed as docstrings directly
|
|
20
|
+
# and fire does not show values for literals, just that it is a literal.
|
|
21
|
+
docstr = f"""Run async workloads on multiple machines using a docker image.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
name: The name of the job. Needs to be unique within the teamspace.
|
|
25
|
+
num_machines: The number of Machines to run on. Defaults to 2 Machines
|
|
26
|
+
machine: The machine type to run the job on. One of {", ".join(_MACHINE_VALUES)}. Defaults to CPU
|
|
27
|
+
command: The command to run inside your job. Required if using a studio. Optional if using an image.
|
|
28
|
+
If not provided for images, will run the container entrypoint and default command.
|
|
29
|
+
studio: The studio env to run the job with. Mutually exclusive with image.
|
|
30
|
+
image: The docker image to run the job with. Mutually exclusive with studio.
|
|
31
|
+
teamspace: The teamspace the job should be associated with. Defaults to the current teamspace.
|
|
32
|
+
org: The organization owning the teamspace (if any). Defaults to the current organization.
|
|
33
|
+
user: The user owning the teamspace (if any). Defaults to the current user.
|
|
34
|
+
cluster: The cluster to run the job on. Defaults to the studio cluster if running with studio compute env.
|
|
35
|
+
If not provided will fall back to the teamspaces default cluster.
|
|
36
|
+
env: Environment variables to set inside the job.
|
|
37
|
+
interruptible: Whether the job should run on interruptible instances. They are cheaper but can be preempted.
|
|
38
|
+
image_credentials: The credentials used to pull the image. Required if the image is private.
|
|
39
|
+
This should be the name of the respective credentials secret created on the Lightning AI platform.
|
|
40
|
+
cluster_auth: Whether to authenticate with the cluster to pull the image.
|
|
41
|
+
Required if the registry is part of a cluster provider (e.g. ECR).
|
|
42
|
+
artifacts_local: The path of inside the docker container, you want to persist images from.
|
|
43
|
+
CAUTION: When setting this to "/", it will effectively erase your container.
|
|
44
|
+
Only supported for jobs with a docker image compute environment.
|
|
45
|
+
artifacts_remote: The remote storage to persist your artifacts to.
|
|
46
|
+
Should be of format <CONNECTION_TYPE>:<CONNECTION_NAME>:<PATH_WITHIN_CONNECTION>.
|
|
47
|
+
PATH_WITHIN_CONNECTION hereby is a path relative to the connection's root.
|
|
48
|
+
E.g. efs:data:some-path would result in an EFS connection named `data` and to the path `some-path`
|
|
49
|
+
within it.
|
|
50
|
+
Note that the connection needs to be added to the teamspace already in order for it to be found.
|
|
51
|
+
Only supported for jobs with a docker image compute environment.
|
|
52
|
+
"""
|
|
53
|
+
# TODO: the docstrings from artifacts_local and artifacts_remote don't show up completely,
|
|
54
|
+
# might need to switch to explicit cli definition
|
|
55
|
+
self.run.__func__.__doc__ = docstr
|
|
56
|
+
|
|
57
|
+
def login(self) -> None:
|
|
58
|
+
"""Login to Lightning AI Studios."""
|
|
59
|
+
auth = Auth()
|
|
60
|
+
auth.clear()
|
|
61
|
+
|
|
62
|
+
try:
|
|
63
|
+
auth.authenticate()
|
|
64
|
+
except ConnectionError:
|
|
65
|
+
raise RuntimeError(f"Unable to connect to {_cloud_url()}. Please check your internet connection.") from None
|
|
66
|
+
|
|
67
|
+
def logout(self) -> None:
|
|
68
|
+
"""Logout from Lightning AI Studios."""
|
|
69
|
+
auth = Auth()
|
|
70
|
+
auth.clear()
|
|
71
|
+
|
|
72
|
+
# TODO: sadly, fire displays both Optional[type] and Union[type, None] as Optional[Optional]
|
|
73
|
+
# see https://github.com/google/python-fire/pull/513
|
|
74
|
+
# might need to move to different cli library
|
|
75
|
+
def run(
|
|
76
|
+
self,
|
|
77
|
+
name: Optional[str] = None,
|
|
78
|
+
num_machines: int = 2,
|
|
79
|
+
machine: Optional[str] = None,
|
|
80
|
+
command: Optional[str] = None,
|
|
81
|
+
studio: Optional[str] = None,
|
|
82
|
+
image: Optional[str] = None,
|
|
83
|
+
teamspace: Optional[str] = None,
|
|
84
|
+
org: Optional[str] = None,
|
|
85
|
+
user: Optional[str] = None,
|
|
86
|
+
cluster: Optional[str] = None,
|
|
87
|
+
env: Optional[Dict[str, str]] = None,
|
|
88
|
+
interruptible: bool = False,
|
|
89
|
+
image_credentials: Optional[str] = None,
|
|
90
|
+
cluster_auth: bool = False,
|
|
91
|
+
artifacts_local: Optional[str] = None,
|
|
92
|
+
artifacts_remote: Optional[str] = None,
|
|
93
|
+
) -> None:
|
|
94
|
+
if name is None:
|
|
95
|
+
from datetime import datetime
|
|
96
|
+
|
|
97
|
+
timestr = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
|
98
|
+
name = f"mmt-{timestr}"
|
|
99
|
+
|
|
100
|
+
if machine is None:
|
|
101
|
+
# TODO: infer from studio
|
|
102
|
+
machine = "CPU"
|
|
103
|
+
machine_enum = Machine(machine.upper())
|
|
104
|
+
|
|
105
|
+
teamspace = Teamspace(name=teamspace, org=org, user=user)
|
|
106
|
+
if cluster is None:
|
|
107
|
+
cluster = teamspace.default_cluster
|
|
108
|
+
|
|
109
|
+
if image is None:
|
|
110
|
+
raise RuntimeError("Currently only docker images are specified")
|
|
111
|
+
MMT.run(
|
|
112
|
+
name=name,
|
|
113
|
+
num_machines=num_machines,
|
|
114
|
+
machine=machine_enum,
|
|
115
|
+
command=command,
|
|
116
|
+
studio=studio,
|
|
117
|
+
image=image,
|
|
118
|
+
teamspace=teamspace,
|
|
119
|
+
org=org,
|
|
120
|
+
user=user,
|
|
121
|
+
cluster=cluster,
|
|
122
|
+
env=env,
|
|
123
|
+
interruptible=interruptible,
|
|
124
|
+
image_credentials=image_credentials,
|
|
125
|
+
cluster_auth=cluster_auth,
|
|
126
|
+
artifacts_local=artifacts_local,
|
|
127
|
+
artifacts_remote=artifacts_remote,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def main_cli() -> None:
|
|
132
|
+
"""CLI entrypoint."""
|
|
133
|
+
Fire(MMTCLI(), name="_mmt")
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
if __name__ == "__main__":
|
|
137
|
+
main_cli()
|
lightning_sdk/cli/run.py
CHANGED
|
@@ -40,7 +40,19 @@ class _Run:
|
|
|
40
40
|
This should be the name of the respective credentials secret created on the Lightning AI platform.
|
|
41
41
|
cluster_auth: Whether to authenticate with the cluster to pull the image.
|
|
42
42
|
Required if the registry is part of a cluster provider (e.g. ECR).
|
|
43
|
+
artifacts_local: The path of inside the docker container, you want to persist images from.
|
|
44
|
+
CAUTION: When setting this to "/", it will effectively erase your container.
|
|
45
|
+
Only supported for jobs with a docker image compute environment.
|
|
46
|
+
artifacts_remote: The remote storage to persist your artifacts to.
|
|
47
|
+
Should be of format <CONNECTION_TYPE>:<CONNECTION_NAME>:<PATH_WITHIN_CONNECTION>.
|
|
48
|
+
PATH_WITHIN_CONNECTION hereby is a path relative to the connection's root.
|
|
49
|
+
E.g. efs:data:some-path would result in an EFS connection named `data` and to the path `some-path`
|
|
50
|
+
within it.
|
|
51
|
+
Note that the connection needs to be added to the teamspace already in order for it to be found.
|
|
52
|
+
Only supported for jobs with a docker image compute environment.
|
|
43
53
|
"""
|
|
54
|
+
# TODO: the docstrings from artifacts_local and artifacts_remote don't show up completely,
|
|
55
|
+
# might need to switch to explicit cli definition
|
|
44
56
|
self.job.__func__.__doc__ = docstr
|
|
45
57
|
|
|
46
58
|
# TODO: sadly, fire displays both Optional[type] and Union[type, None] as Optional[Optional]
|
|
@@ -61,6 +73,8 @@ class _Run:
|
|
|
61
73
|
interruptible: bool = False,
|
|
62
74
|
image_credentials: Optional[str] = None,
|
|
63
75
|
cluster_auth: bool = False,
|
|
76
|
+
artifacts_local: Optional[str] = None,
|
|
77
|
+
artifacts_remote: Optional[str] = None,
|
|
64
78
|
) -> None:
|
|
65
79
|
machine_enum = Machine(machine.upper())
|
|
66
80
|
Job.run(
|
|
@@ -77,4 +91,6 @@ class _Run:
|
|
|
77
91
|
interruptible=interruptible,
|
|
78
92
|
image_credentials=image_credentials,
|
|
79
93
|
cluster_auth=cluster_auth,
|
|
94
|
+
artifacts_local=artifacts_local,
|
|
95
|
+
artifacts_remote=artifacts_remote,
|
|
80
96
|
)
|
lightning_sdk/cli/upload.py
CHANGED
|
@@ -9,8 +9,8 @@ from tqdm import tqdm
|
|
|
9
9
|
|
|
10
10
|
from lightning_sdk.api.utils import _get_cloud_url
|
|
11
11
|
from lightning_sdk.cli.exceptions import StudioCliError
|
|
12
|
-
from lightning_sdk.cli.models import _get_teamspace, _parse_model_name
|
|
13
12
|
from lightning_sdk.cli.studios_menu import _StudiosMenu
|
|
13
|
+
from lightning_sdk.models import upload_model
|
|
14
14
|
from lightning_sdk.studio import Studio
|
|
15
15
|
from lightning_sdk.utils.resolve import _get_authed_user, skip_studio_init
|
|
16
16
|
|
|
@@ -20,7 +20,7 @@ class _Uploads(_StudiosMenu):
|
|
|
20
20
|
|
|
21
21
|
_studio_upload_status_path = "~/.lightning/studios/uploads"
|
|
22
22
|
|
|
23
|
-
def model(self, name: str, path:
|
|
23
|
+
def model(self, name: str, path: str = ".", cloud_account: Optional[str] = None) -> None:
|
|
24
24
|
"""Upload a Model.
|
|
25
25
|
|
|
26
26
|
Args:
|
|
@@ -29,14 +29,7 @@ class _Uploads(_StudiosMenu):
|
|
|
29
29
|
path: The path to the file or directory you want to upload. Defaults to the current directory.
|
|
30
30
|
cloud_account: The name of the cloud account to store the Model in.
|
|
31
31
|
"""
|
|
32
|
-
|
|
33
|
-
teamspace = _get_teamspace(name=teamspace_name, organization=org_name)
|
|
34
|
-
teamspace.upload_model(
|
|
35
|
-
path=path or ".",
|
|
36
|
-
name=model_name,
|
|
37
|
-
progress_bar=True,
|
|
38
|
-
cluster_id=cloud_account,
|
|
39
|
-
)
|
|
32
|
+
upload_model(name, path, cloud_account=cloud_account)
|
|
40
33
|
|
|
41
34
|
def _resolve_studio(self, studio: Optional[str]) -> Studio:
|
|
42
35
|
user = _get_authed_user()
|
lightning_sdk/job/base.py
CHANGED
|
@@ -52,6 +52,8 @@ class _BaseJob(ABC):
|
|
|
52
52
|
interruptible: bool = False,
|
|
53
53
|
image_credentials: Optional[str] = None,
|
|
54
54
|
cluster_auth: bool = False,
|
|
55
|
+
artifacts_local: Optional[str] = None,
|
|
56
|
+
artifacts_remote: Optional[str] = None,
|
|
55
57
|
) -> "_BaseJob":
|
|
56
58
|
from lightning_sdk.studio import Studio
|
|
57
59
|
|
|
@@ -89,14 +91,30 @@ class _BaseJob(ABC):
|
|
|
89
91
|
if cluster_auth:
|
|
90
92
|
raise ValueError("cluster_auth is only supported when using a custom image")
|
|
91
93
|
|
|
94
|
+
if artifacts_local is not None or artifacts_remote is not None:
|
|
95
|
+
raise ValueError(
|
|
96
|
+
"Specifying artifacts persistence is supported for docker images only. "
|
|
97
|
+
"Other jobs will automatically persist artifacts to the teamspace distributed filesystem."
|
|
98
|
+
)
|
|
99
|
+
|
|
92
100
|
else:
|
|
93
101
|
if studio is not None:
|
|
94
102
|
raise RuntimeError(
|
|
95
103
|
"image and studio are mutually exclusive as both define the environment to run the job in"
|
|
96
104
|
)
|
|
97
105
|
|
|
106
|
+
# they either need to specified both or none of them
|
|
107
|
+
if bool(artifacts_local) != bool(artifacts_remote):
|
|
108
|
+
raise ValueError("Artifact persistence requires both artifacts_local and artifacts_remote to be set")
|
|
109
|
+
|
|
110
|
+
if artifacts_remote and len(artifacts_remote.split(":")) != 3:
|
|
111
|
+
raise ValueError(
|
|
112
|
+
"Artifact persistence requires exactly three arguments separated by colon of kind "
|
|
113
|
+
f"<CONNECTION_TYPE>:<CONNECTION_NAME>:<PATH_WITHIN_CONNECTION>, got {artifacts_local}"
|
|
114
|
+
)
|
|
115
|
+
|
|
98
116
|
inst = cls(name=name, teamspace=teamspace, org=org, user=user, _fetch_job=False)
|
|
99
|
-
inst._submit(
|
|
117
|
+
return inst._submit(
|
|
100
118
|
machine=machine,
|
|
101
119
|
cluster=cluster,
|
|
102
120
|
command=command,
|
|
@@ -106,8 +124,9 @@ class _BaseJob(ABC):
|
|
|
106
124
|
interruptible=interruptible,
|
|
107
125
|
image_credentials=image_credentials,
|
|
108
126
|
cluster_auth=cluster_auth,
|
|
127
|
+
artifacts_local=artifacts_local,
|
|
128
|
+
artifacts_remote=artifacts_remote,
|
|
109
129
|
)
|
|
110
|
-
return inst
|
|
111
130
|
|
|
112
131
|
@abstractmethod
|
|
113
132
|
def _submit(
|
|
@@ -121,7 +140,9 @@ class _BaseJob(ABC):
|
|
|
121
140
|
cluster: Optional[str] = None,
|
|
122
141
|
image_credentials: Optional[str] = None,
|
|
123
142
|
cluster_auth: bool = False,
|
|
124
|
-
|
|
143
|
+
artifacts_local: Optional[str] = None,
|
|
144
|
+
artifacts_remote: Optional[str] = None,
|
|
145
|
+
) -> "_BaseJob":
|
|
125
146
|
"""Submits a job and updates the internal _job attribute as well as the _name attribute."""
|
|
126
147
|
|
|
127
148
|
@abstractmethod
|
lightning_sdk/job/job.py
CHANGED
|
@@ -60,6 +60,8 @@ class Job(_BaseJob):
|
|
|
60
60
|
interruptible: bool = False,
|
|
61
61
|
image_credentials: Optional[str] = None,
|
|
62
62
|
cluster_auth: bool = False,
|
|
63
|
+
artifacts_local: Optional[str] = None,
|
|
64
|
+
artifacts_remote: Optional[str] = None,
|
|
63
65
|
) -> "Job":
|
|
64
66
|
ret_val = super().run(
|
|
65
67
|
name=name,
|
|
@@ -75,6 +77,8 @@ class Job(_BaseJob):
|
|
|
75
77
|
interruptible=interruptible,
|
|
76
78
|
image_credentials=image_credentials,
|
|
77
79
|
cluster_auth=cluster_auth,
|
|
80
|
+
artifacts_local=artifacts_local,
|
|
81
|
+
artifacts_remote=artifacts_remote,
|
|
78
82
|
)
|
|
79
83
|
# required for typing with "Job"
|
|
80
84
|
assert isinstance(ret_val, cls)
|
|
@@ -91,8 +95,10 @@ class Job(_BaseJob):
|
|
|
91
95
|
cluster: Optional[str] = None,
|
|
92
96
|
image_credentials: Optional[str] = None,
|
|
93
97
|
cluster_auth: bool = False,
|
|
98
|
+
artifacts_local: Optional[str] = None,
|
|
99
|
+
artifacts_remote: Optional[str] = None,
|
|
94
100
|
) -> None:
|
|
95
|
-
|
|
101
|
+
self._job = self._internal_job._submit(
|
|
96
102
|
machine=machine,
|
|
97
103
|
cluster=cluster,
|
|
98
104
|
command=command,
|
|
@@ -102,7 +108,10 @@ class Job(_BaseJob):
|
|
|
102
108
|
interruptible=interruptible,
|
|
103
109
|
image_credentials=image_credentials,
|
|
104
110
|
cluster_auth=cluster_auth,
|
|
111
|
+
artifacts_local=artifacts_local,
|
|
112
|
+
artifacts_remote=artifacts_remote,
|
|
105
113
|
)
|
|
114
|
+
return self
|
|
106
115
|
|
|
107
116
|
def stop(self) -> None:
|
|
108
117
|
return self._internal_job.stop()
|
lightning_sdk/job/v1.py
CHANGED
|
@@ -69,13 +69,18 @@ class _JobV1(_BaseJob):
|
|
|
69
69
|
cluster: Optional[str] = None,
|
|
70
70
|
image_credentials: Optional[str] = None,
|
|
71
71
|
cluster_auth: bool = False,
|
|
72
|
-
|
|
72
|
+
artifacts_local: Optional[str] = None,
|
|
73
|
+
artifacts_remote: Optional[str] = None,
|
|
74
|
+
) -> "_JobV1":
|
|
73
75
|
if studio is None:
|
|
74
76
|
raise ValueError("Studio is required for submitting jobs")
|
|
75
77
|
|
|
76
78
|
if image is not None or image_credentials is not None or cluster_auth:
|
|
77
79
|
raise ValueError("Image is not supported for submitting jobs")
|
|
78
80
|
|
|
81
|
+
if artifacts_local is not None or artifacts_remote is not None:
|
|
82
|
+
raise ValueError("Specifying how to persist artifacts is not yet supported with jobs")
|
|
83
|
+
|
|
79
84
|
if env is not None:
|
|
80
85
|
raise ValueError("Environment variables are not supported for submitting jobs")
|
|
81
86
|
|
|
@@ -95,6 +100,7 @@ class _JobV1(_BaseJob):
|
|
|
95
100
|
)
|
|
96
101
|
self._name = _submitted.name
|
|
97
102
|
self._job = _submitted
|
|
103
|
+
return self
|
|
98
104
|
|
|
99
105
|
def _update_internal_job(self) -> None:
|
|
100
106
|
try:
|
lightning_sdk/job/v2.py
CHANGED
|
@@ -36,7 +36,9 @@ class _JobV2(_BaseJob):
|
|
|
36
36
|
cluster: Optional[str] = None,
|
|
37
37
|
image_credentials: Optional[str] = None,
|
|
38
38
|
cluster_auth: bool = False,
|
|
39
|
-
|
|
39
|
+
artifacts_local: Optional[str] = None,
|
|
40
|
+
artifacts_remote: Optional[str] = None,
|
|
41
|
+
) -> "_JobV2":
|
|
40
42
|
# Command is required if Studio is provided to know what to run
|
|
41
43
|
# Image is mutually exclusive with Studio
|
|
42
44
|
# Command is optional for Image
|
|
@@ -66,22 +68,21 @@ class _JobV2(_BaseJob):
|
|
|
66
68
|
env=env,
|
|
67
69
|
image_credentials=image_credentials,
|
|
68
70
|
cluster_auth=cluster_auth,
|
|
71
|
+
artifacts_local=artifacts_local,
|
|
72
|
+
artifacts_remote=artifacts_remote,
|
|
69
73
|
)
|
|
70
74
|
self._job = submitted
|
|
71
75
|
self._name = submitted.name
|
|
76
|
+
return self
|
|
72
77
|
|
|
73
78
|
def stop(self) -> None:
|
|
74
|
-
|
|
75
|
-
self._update_internal_job()
|
|
76
|
-
|
|
77
|
-
self._job_api.stop_job(job_id=self._job.id, teamspace_id=self._teamspace.id)
|
|
79
|
+
self._job_api.stop_job(job_id=self._guaranteed_job.id, teamspace_id=self._teamspace.id)
|
|
78
80
|
|
|
79
81
|
def delete(self) -> None:
|
|
80
|
-
if self._job is None:
|
|
81
|
-
self._update_internal_job()
|
|
82
|
-
|
|
83
82
|
self._job_api.delete_job(
|
|
84
|
-
job_id=self.
|
|
83
|
+
job_id=self._guaranteed_job.id,
|
|
84
|
+
teamspace_id=self._teamspace.id,
|
|
85
|
+
cloudspace_id=self._guaranteed_job.spec.cloudspace_id,
|
|
85
86
|
)
|
|
86
87
|
|
|
87
88
|
@property
|
|
@@ -112,10 +113,18 @@ class _JobV2(_BaseJob):
|
|
|
112
113
|
|
|
113
114
|
@property
|
|
114
115
|
def artifact_path(self) -> Optional[str]:
|
|
116
|
+
if self._guaranteed_job.spec.image != "":
|
|
117
|
+
if self._guaranteed_job.spec.artifacts_destination != "":
|
|
118
|
+
splits = self._guaranteed_job.spec.artifacts_destination.split(":")
|
|
119
|
+
return f"/teamspace/{splits[0]}_connections/{splits[1]}/{splits[2]}"
|
|
120
|
+
return None
|
|
121
|
+
|
|
115
122
|
return f"/teamspace/jobs/{self._guaranteed_job.name}/artifacts"
|
|
116
123
|
|
|
117
124
|
@property
|
|
118
125
|
def snapshot_path(self) -> Optional[str]:
|
|
126
|
+
if self._guaranteed_job.spec.image != "":
|
|
127
|
+
return None
|
|
119
128
|
return f"/teamspace/jobs/{self._guaranteed_job.name}/snapshot"
|
|
120
129
|
|
|
121
130
|
@property
|