lightning-sdk 0.1.38__py3-none-any.whl → 0.1.40__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. lightning_sdk/__init__.py +1 -1
  2. lightning_sdk/_mmt/__init__.py +3 -0
  3. lightning_sdk/_mmt/base.py +180 -0
  4. lightning_sdk/_mmt/mmt.py +161 -0
  5. lightning_sdk/_mmt/v1.py +69 -0
  6. lightning_sdk/_mmt/v2.py +141 -0
  7. lightning_sdk/api/deployment_api.py +0 -2
  8. lightning_sdk/api/job_api.py +4 -0
  9. lightning_sdk/api/mmt_api.py +147 -0
  10. lightning_sdk/api/teamspace_api.py +4 -11
  11. lightning_sdk/api/utils.py +6 -3
  12. lightning_sdk/cli/download.py +3 -5
  13. lightning_sdk/cli/mmt.py +137 -0
  14. lightning_sdk/cli/run.py +16 -0
  15. lightning_sdk/cli/upload.py +3 -10
  16. lightning_sdk/job/base.py +24 -3
  17. lightning_sdk/job/job.py +10 -1
  18. lightning_sdk/job/v1.py +7 -1
  19. lightning_sdk/job/v2.py +18 -9
  20. lightning_sdk/lightning_cloud/openapi/__init__.py +7 -3
  21. lightning_sdk/lightning_cloud/openapi/api/assistants_service_api.py +90 -284
  22. lightning_sdk/lightning_cloud/openapi/api/data_connection_service_api.py +6 -1
  23. lightning_sdk/lightning_cloud/openapi/api/models_store_api.py +235 -1
  24. lightning_sdk/lightning_cloud/openapi/models/__init__.py +7 -3
  25. lightning_sdk/lightning_cloud/openapi/models/deployments_id_body.py +27 -1
  26. lightning_sdk/lightning_cloud/openapi/models/id_start_body.py +29 -3
  27. lightning_sdk/lightning_cloud/openapi/models/model_id_visibility_body.py +123 -0
  28. lightning_sdk/lightning_cloud/openapi/models/project_id_cloudspaces_body.py +27 -1
  29. lightning_sdk/lightning_cloud/openapi/models/v1_aws_direct_v1.py +27 -1
  30. lightning_sdk/lightning_cloud/openapi/models/{project_id_agentmanagedmodels_body.py → v1_body.py} +21 -47
  31. lightning_sdk/lightning_cloud/openapi/models/v1_data_path.py +29 -3
  32. lightning_sdk/lightning_cloud/openapi/models/v1_deployment.py +27 -1
  33. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_api.py +201 -0
  34. lightning_sdk/lightning_cloud/openapi/models/v1_get_model_files_response.py +27 -1
  35. lightning_sdk/lightning_cloud/openapi/models/v1_job_spec.py +53 -53
  36. lightning_sdk/lightning_cloud/openapi/models/v1_managed_model.py +27 -1
  37. lightning_sdk/lightning_cloud/openapi/models/v1_managed_model_abilities.py +175 -0
  38. lightning_sdk/lightning_cloud/openapi/models/v1_model.py +29 -3
  39. lightning_sdk/lightning_cloud/openapi/models/v1_multi_machine_job.py +53 -1
  40. lightning_sdk/lightning_cloud/openapi/models/v1_multi_machine_job_state.py +1 -2
  41. lightning_sdk/lightning_cloud/openapi/models/v1_query_param.py +175 -0
  42. lightning_sdk/lightning_cloud/openapi/models/{v1_list_managed_models_response.py → v1_resource_visibility.py} +23 -23
  43. lightning_sdk/lightning_cloud/openapi/models/{v1_delete_managed_model_response.py → v1_update_model_visibility_response.py} +6 -6
  44. lightning_sdk/lightning_cloud/openapi/models/v1_user_features.py +27 -1
  45. lightning_sdk/models.py +153 -0
  46. lightning_sdk/teamspace.py +15 -11
  47. {lightning_sdk-0.1.38.dist-info → lightning_sdk-0.1.40.dist-info}/METADATA +1 -1
  48. {lightning_sdk-0.1.38.dist-info → lightning_sdk-0.1.40.dist-info}/RECORD +52 -41
  49. {lightning_sdk-0.1.38.dist-info → lightning_sdk-0.1.40.dist-info}/entry_points.txt +1 -0
  50. lightning_sdk/cli/models.py +0 -68
  51. {lightning_sdk-0.1.38.dist-info → lightning_sdk-0.1.40.dist-info}/LICENSE +0 -0
  52. {lightning_sdk-0.1.38.dist-info → lightning_sdk-0.1.40.dist-info}/WHEEL +0 -0
  53. {lightning_sdk-0.1.38.dist-info → lightning_sdk-0.1.40.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,147 @@
1
+ import time
2
+ from typing import TYPE_CHECKING, Dict, Optional
3
+
4
+ from lightning_sdk.api.utils import (
5
+ _COMPUTE_NAME_TO_MACHINE,
6
+ _MACHINE_TO_COMPUTE_NAME,
7
+ )
8
+ from lightning_sdk.api.utils import (
9
+ _get_cloud_url as _cloud_url,
10
+ )
11
+ from lightning_sdk.constants import __GLOBAL_LIGHTNING_UNIQUE_IDS_STORE__
12
+ from lightning_sdk.lightning_cloud.openapi import (
13
+ MultimachinejobsIdBody,
14
+ ProjectIdMultimachinejobsBody,
15
+ V1EnvVar,
16
+ V1JobSpec,
17
+ V1MultiMachineJob,
18
+ V1MultiMachineJobState,
19
+ )
20
+ from lightning_sdk.lightning_cloud.rest_client import LightningClient
21
+ from lightning_sdk.machine import Machine
22
+
23
+ if TYPE_CHECKING:
24
+ from lightning_sdk.status import Status
25
+
26
+
27
+ class MMTApi:
28
+ mmt_state_unspecified = "MultiMachineJob_STATE_UNSPECIFIED"
29
+ mmt_state_running = "MultiMachineJob_STATE_RUNNING"
30
+ mmt_state_stopped = "MultiMachineJob_STATE_STOPPED"
31
+ mmt_state_deleted = "MultiMachineJob_STATE_DELETED"
32
+ mmt_state_failed = "MultiMachineJob_STATE_FAILED"
33
+ mmt_state_completed = "MultiMachineJob_STATE_COMPLETED"
34
+
35
+ def __init__(self) -> None:
36
+ self._cloud_url = _cloud_url()
37
+ self._client = LightningClient(max_tries=7)
38
+
39
+ def submit_job(
40
+ self,
41
+ name: str,
42
+ num_machines: int,
43
+ command: Optional[str],
44
+ cluster_id: Optional[str],
45
+ teamspace_id: str,
46
+ studio_id: Optional[str],
47
+ image: Optional[str],
48
+ machine: Machine,
49
+ interruptible: bool,
50
+ env: Optional[Dict[str, str]],
51
+ image_credentials: Optional[str],
52
+ cluster_auth: bool,
53
+ artifacts_local: Optional[str],
54
+ artifacts_remote: Optional[str],
55
+ ) -> V1MultiMachineJob:
56
+ env_vars = []
57
+ if env is not None:
58
+ for k, v in env.items():
59
+ env_vars.append(V1EnvVar(name=k, value=v))
60
+
61
+ instance_name = _MACHINE_TO_COMPUTE_NAME[machine]
62
+
63
+ run_id = __GLOBAL_LIGHTNING_UNIQUE_IDS_STORE__[studio_id] if studio_id is not None else ""
64
+
65
+ spec = V1JobSpec(
66
+ cloudspace_id=studio_id or "",
67
+ cluster_id=cluster_id or "",
68
+ command=command or "",
69
+ env=env_vars,
70
+ image=image or "",
71
+ instance_name=instance_name,
72
+ run_id=run_id,
73
+ spot=interruptible,
74
+ image_cluster_credentials=cluster_auth,
75
+ image_secret_ref=image_credentials or "",
76
+ artifacts_source=artifacts_local or "",
77
+ artifacts_destination=artifacts_remote or "",
78
+ )
79
+ body = ProjectIdMultimachinejobsBody(name=name, spec=spec, cluster_id=cluster_id or "", machines=num_machines)
80
+
81
+ job: V1MultiMachineJob = self._client.jobs_service_create_multi_machine_job(project_id=teamspace_id, body=body)
82
+ return job
83
+
84
+ def get_job_by_name(self, name: str, teamspace_id: str) -> V1MultiMachineJob:
85
+ job: V1MultiMachineJob = self._client.jobs_service_get_multi_machine_job_by_name(
86
+ project_id=teamspace_id, name=name
87
+ )
88
+ return job
89
+
90
+ def get_job(self, job_id: str, teamspace_id: str) -> V1MultiMachineJob:
91
+ job: V1MultiMachineJob = self._client.jobs_service_get_multi_machine_job(project_id=teamspace_id, id=job_id)
92
+ return job
93
+
94
+ def stop_job(self, job_id: str, teamspace_id: str) -> None:
95
+ from lightning_sdk.status import Status
96
+
97
+ current_job = self.get_job(job_id=job_id, teamspace_id=teamspace_id)
98
+
99
+ current_state = self._job_state_to_external(current_job.desired_state)
100
+
101
+ if current_state in (
102
+ Status.Stopped,
103
+ Status.Completed,
104
+ Status.Failed,
105
+ ):
106
+ return
107
+
108
+ if current_state != Status.Stopped:
109
+ update_body = MultimachinejobsIdBody(desired_state=self.mmt_state_stopped)
110
+ self._client.jobs_service_update_multi_machine_job(body=update_body, project_id=teamspace_id, id=job_id)
111
+
112
+ while True:
113
+ current_job = self.get_job(job_id=job_id, teamspace_id=teamspace_id)
114
+ if self._job_state_to_external(current_job.desired_state) in (
115
+ Status.Stopped,
116
+ Status.Completed,
117
+ Status.Stopped,
118
+ Status.Failed,
119
+ ):
120
+ break
121
+ time.sleep(1)
122
+
123
+ def delete_job(self, job_id: str, teamspace_id: str) -> None:
124
+ self._client.jobs_service_delete_multi_machine_job(project_id=teamspace_id, id=job_id)
125
+
126
+ def _job_state_to_external(self, state: V1MultiMachineJobState) -> "Status":
127
+ from lightning_sdk.status import Status
128
+
129
+ if str(state) == self.mmt_state_unspecified:
130
+ return Status.Pending
131
+ if str(state) == self.mmt_state_running:
132
+ return Status.Running
133
+ if str(state) == self.mmt_state_stopped:
134
+ return Status.Stopped
135
+ if str(state) == self.mmt_state_completed:
136
+ return Status.Completed
137
+ if str(state) == self.mmt_state_failed:
138
+ return Status.Failed
139
+ return Status.Pending
140
+
141
+ def _get_job_machine_from_spec(self, spec: V1JobSpec) -> "Machine":
142
+ instance_name = spec.instance_name
143
+ instance_type = spec.instance_type
144
+
145
+ return _COMPUTE_NAME_TO_MACHINE.get(
146
+ instance_type, _COMPUTE_NAME_TO_MACHINE.get(instance_name, instance_type or instance_name)
147
+ )
@@ -1,5 +1,4 @@
1
1
  import os
2
- from dataclasses import dataclass
3
2
  from pathlib import Path
4
3
  from typing import Dict, List, Optional
5
4
 
@@ -260,22 +259,16 @@ class TeamspaceApi:
260
259
  name: str,
261
260
  version: str,
262
261
  download_dir: Path,
263
- teamspace_id: str,
262
+ teamspace_name: str,
263
+ teamspace_owner_name: str,
264
264
  progress_bar: bool = True,
265
265
  ) -> List[str]:
266
266
  return _download_model_files(
267
267
  client=self._client,
268
- teamspace_id=teamspace_id,
268
+ teamspace_name=teamspace_name,
269
+ teamspace_owner_name=teamspace_owner_name,
269
270
  name=name,
270
271
  version=version,
271
272
  download_dir=download_dir,
272
273
  progress_bar=progress_bar,
273
274
  )
274
-
275
-
276
- @dataclass
277
- class UploadedModelInfo:
278
- name: str
279
- version: str
280
- teamspace: str
281
- cluster: str
@@ -513,7 +513,8 @@ def _get_model_version(client: LightningClient, teamspace_id: str, name: str, ve
513
513
 
514
514
  def _download_model_files(
515
515
  client: LightningClient,
516
- teamspace_id: str,
516
+ teamspace_name: str,
517
+ teamspace_owner_name: str,
517
518
  name: str,
518
519
  version: str,
519
520
  download_dir: Path,
@@ -521,7 +522,9 @@ def _download_model_files(
521
522
  num_workers: int = 20,
522
523
  ) -> List[str]:
523
524
  api = ModelsStoreApi(client.api_client)
524
- response = api.models_store_get_model_files(project_id=teamspace_id, name=name, version=version)
525
+ response = api.models_store_get_model_files(
526
+ project_name=teamspace_name, project_owner_name=teamspace_owner_name, name=name, version=version
527
+ )
525
528
 
526
529
  pbar = None
527
530
  if progress_bar:
@@ -541,7 +544,7 @@ def _download_model_files(
541
544
  client=client,
542
545
  model_id=response.model_id,
543
546
  version=response.version,
544
- teamspace_id=teamspace_id,
547
+ teamspace_id=response.project_id,
545
548
  remote_path=filepath,
546
549
  file_path=str(local_file),
547
550
  num_workers=num_workers,
@@ -4,8 +4,8 @@ from pathlib import Path
4
4
  from typing import Optional
5
5
 
6
6
  from lightning_sdk.cli.exceptions import StudioCliError
7
- from lightning_sdk.cli.models import _get_teamspace, _parse_model_name
8
7
  from lightning_sdk.cli.studios_menu import _StudiosMenu
8
+ from lightning_sdk.models import download_model
9
9
  from lightning_sdk.studio import Studio
10
10
  from lightning_sdk.utils.resolve import _get_authed_user, skip_studio_init
11
11
 
@@ -21,10 +21,8 @@ class _Downloads(_StudiosMenu):
21
21
  This should have the format <ORGANIZATION-NAME>/<TEAMSPACE-NAME>/<MODEL-NAME>.
22
22
  download_dir: The directory where the Model should be downloaded.
23
23
  """
24
- org_name, teamspace_name, model_name = _parse_model_name(name)
25
- teamspace = _get_teamspace(name=teamspace_name, organization=org_name)
26
- teamspace.download_model(
27
- name=model_name,
24
+ download_model(
25
+ name=name,
28
26
  download_dir=download_dir,
29
27
  progress_bar=True,
30
28
  )
@@ -0,0 +1,137 @@
1
+ from typing import Dict, Optional
2
+
3
+ from fire import Fire
4
+
5
+ from lightning_sdk._mmt import MMT
6
+ from lightning_sdk.api.studio_api import _cloud_url
7
+ from lightning_sdk.lightning_cloud.login import Auth
8
+ from lightning_sdk.machine import Machine
9
+ from lightning_sdk.teamspace import Teamspace
10
+
11
+ _MACHINE_VALUES = tuple([machine.value for machine in Machine])
12
+
13
+
14
+ class MMTCLI:
15
+ """Command line interface (CLI) to interact with/manage Lightning AI MMT."""
16
+
17
+ def __init__(self) -> None:
18
+ # Need to set the docstring here for f-strings to work.
19
+ # Sadly this is the only way to really show options as f-strings are not allowed as docstrings directly
20
+ # and fire does not show values for literals, just that it is a literal.
21
+ docstr = f"""Run async workloads on multiple machines using a docker image.
22
+
23
+ Args:
24
+ name: The name of the job. Needs to be unique within the teamspace.
25
+ num_machines: The number of Machines to run on. Defaults to 2 Machines
26
+ machine: The machine type to run the job on. One of {", ".join(_MACHINE_VALUES)}. Defaults to CPU
27
+ command: The command to run inside your job. Required if using a studio. Optional if using an image.
28
+ If not provided for images, will run the container entrypoint and default command.
29
+ studio: The studio env to run the job with. Mutually exclusive with image.
30
+ image: The docker image to run the job with. Mutually exclusive with studio.
31
+ teamspace: The teamspace the job should be associated with. Defaults to the current teamspace.
32
+ org: The organization owning the teamspace (if any). Defaults to the current organization.
33
+ user: The user owning the teamspace (if any). Defaults to the current user.
34
+ cluster: The cluster to run the job on. Defaults to the studio cluster if running with studio compute env.
35
+ If not provided will fall back to the teamspaces default cluster.
36
+ env: Environment variables to set inside the job.
37
+ interruptible: Whether the job should run on interruptible instances. They are cheaper but can be preempted.
38
+ image_credentials: The credentials used to pull the image. Required if the image is private.
39
+ This should be the name of the respective credentials secret created on the Lightning AI platform.
40
+ cluster_auth: Whether to authenticate with the cluster to pull the image.
41
+ Required if the registry is part of a cluster provider (e.g. ECR).
42
+ artifacts_local: The path of inside the docker container, you want to persist images from.
43
+ CAUTION: When setting this to "/", it will effectively erase your container.
44
+ Only supported for jobs with a docker image compute environment.
45
+ artifacts_remote: The remote storage to persist your artifacts to.
46
+ Should be of format <CONNECTION_TYPE>:<CONNECTION_NAME>:<PATH_WITHIN_CONNECTION>.
47
+ PATH_WITHIN_CONNECTION hereby is a path relative to the connection's root.
48
+ E.g. efs:data:some-path would result in an EFS connection named `data` and to the path `some-path`
49
+ within it.
50
+ Note that the connection needs to be added to the teamspace already in order for it to be found.
51
+ Only supported for jobs with a docker image compute environment.
52
+ """
53
+ # TODO: the docstrings from artifacts_local and artifacts_remote don't show up completely,
54
+ # might need to switch to explicit cli definition
55
+ self.run.__func__.__doc__ = docstr
56
+
57
+ def login(self) -> None:
58
+ """Login to Lightning AI Studios."""
59
+ auth = Auth()
60
+ auth.clear()
61
+
62
+ try:
63
+ auth.authenticate()
64
+ except ConnectionError:
65
+ raise RuntimeError(f"Unable to connect to {_cloud_url()}. Please check your internet connection.") from None
66
+
67
+ def logout(self) -> None:
68
+ """Logout from Lightning AI Studios."""
69
+ auth = Auth()
70
+ auth.clear()
71
+
72
+ # TODO: sadly, fire displays both Optional[type] and Union[type, None] as Optional[Optional]
73
+ # see https://github.com/google/python-fire/pull/513
74
+ # might need to move to different cli library
75
+ def run(
76
+ self,
77
+ name: Optional[str] = None,
78
+ num_machines: int = 2,
79
+ machine: Optional[str] = None,
80
+ command: Optional[str] = None,
81
+ studio: Optional[str] = None,
82
+ image: Optional[str] = None,
83
+ teamspace: Optional[str] = None,
84
+ org: Optional[str] = None,
85
+ user: Optional[str] = None,
86
+ cluster: Optional[str] = None,
87
+ env: Optional[Dict[str, str]] = None,
88
+ interruptible: bool = False,
89
+ image_credentials: Optional[str] = None,
90
+ cluster_auth: bool = False,
91
+ artifacts_local: Optional[str] = None,
92
+ artifacts_remote: Optional[str] = None,
93
+ ) -> None:
94
+ if name is None:
95
+ from datetime import datetime
96
+
97
+ timestr = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
98
+ name = f"mmt-{timestr}"
99
+
100
+ if machine is None:
101
+ # TODO: infer from studio
102
+ machine = "CPU"
103
+ machine_enum = Machine(machine.upper())
104
+
105
+ teamspace = Teamspace(name=teamspace, org=org, user=user)
106
+ if cluster is None:
107
+ cluster = teamspace.default_cluster
108
+
109
+ if image is None:
110
+ raise RuntimeError("Currently only docker images are specified")
111
+ MMT.run(
112
+ name=name,
113
+ num_machines=num_machines,
114
+ machine=machine_enum,
115
+ command=command,
116
+ studio=studio,
117
+ image=image,
118
+ teamspace=teamspace,
119
+ org=org,
120
+ user=user,
121
+ cluster=cluster,
122
+ env=env,
123
+ interruptible=interruptible,
124
+ image_credentials=image_credentials,
125
+ cluster_auth=cluster_auth,
126
+ artifacts_local=artifacts_local,
127
+ artifacts_remote=artifacts_remote,
128
+ )
129
+
130
+
131
+ def main_cli() -> None:
132
+ """CLI entrypoint."""
133
+ Fire(MMTCLI(), name="_mmt")
134
+
135
+
136
+ if __name__ == "__main__":
137
+ main_cli()
lightning_sdk/cli/run.py CHANGED
@@ -40,7 +40,19 @@ class _Run:
40
40
  This should be the name of the respective credentials secret created on the Lightning AI platform.
41
41
  cluster_auth: Whether to authenticate with the cluster to pull the image.
42
42
  Required if the registry is part of a cluster provider (e.g. ECR).
43
+ artifacts_local: The path of inside the docker container, you want to persist images from.
44
+ CAUTION: When setting this to "/", it will effectively erase your container.
45
+ Only supported for jobs with a docker image compute environment.
46
+ artifacts_remote: The remote storage to persist your artifacts to.
47
+ Should be of format <CONNECTION_TYPE>:<CONNECTION_NAME>:<PATH_WITHIN_CONNECTION>.
48
+ PATH_WITHIN_CONNECTION hereby is a path relative to the connection's root.
49
+ E.g. efs:data:some-path would result in an EFS connection named `data` and to the path `some-path`
50
+ within it.
51
+ Note that the connection needs to be added to the teamspace already in order for it to be found.
52
+ Only supported for jobs with a docker image compute environment.
43
53
  """
54
+ # TODO: the docstrings from artifacts_local and artifacts_remote don't show up completely,
55
+ # might need to switch to explicit cli definition
44
56
  self.job.__func__.__doc__ = docstr
45
57
 
46
58
  # TODO: sadly, fire displays both Optional[type] and Union[type, None] as Optional[Optional]
@@ -61,6 +73,8 @@ class _Run:
61
73
  interruptible: bool = False,
62
74
  image_credentials: Optional[str] = None,
63
75
  cluster_auth: bool = False,
76
+ artifacts_local: Optional[str] = None,
77
+ artifacts_remote: Optional[str] = None,
64
78
  ) -> None:
65
79
  machine_enum = Machine(machine.upper())
66
80
  Job.run(
@@ -77,4 +91,6 @@ class _Run:
77
91
  interruptible=interruptible,
78
92
  image_credentials=image_credentials,
79
93
  cluster_auth=cluster_auth,
94
+ artifacts_local=artifacts_local,
95
+ artifacts_remote=artifacts_remote,
80
96
  )
@@ -9,8 +9,8 @@ from tqdm import tqdm
9
9
 
10
10
  from lightning_sdk.api.utils import _get_cloud_url
11
11
  from lightning_sdk.cli.exceptions import StudioCliError
12
- from lightning_sdk.cli.models import _get_teamspace, _parse_model_name
13
12
  from lightning_sdk.cli.studios_menu import _StudiosMenu
13
+ from lightning_sdk.models import upload_model
14
14
  from lightning_sdk.studio import Studio
15
15
  from lightning_sdk.utils.resolve import _get_authed_user, skip_studio_init
16
16
 
@@ -20,7 +20,7 @@ class _Uploads(_StudiosMenu):
20
20
 
21
21
  _studio_upload_status_path = "~/.lightning/studios/uploads"
22
22
 
23
- def model(self, name: str, path: Optional[str] = None, cloud_account: Optional[str] = None) -> None:
23
+ def model(self, name: str, path: str = ".", cloud_account: Optional[str] = None) -> None:
24
24
  """Upload a Model.
25
25
 
26
26
  Args:
@@ -29,14 +29,7 @@ class _Uploads(_StudiosMenu):
29
29
  path: The path to the file or directory you want to upload. Defaults to the current directory.
30
30
  cloud_account: The name of the cloud account to store the Model in.
31
31
  """
32
- org_name, teamspace_name, model_name = _parse_model_name(name)
33
- teamspace = _get_teamspace(name=teamspace_name, organization=org_name)
34
- teamspace.upload_model(
35
- path=path or ".",
36
- name=model_name,
37
- progress_bar=True,
38
- cluster_id=cloud_account,
39
- )
32
+ upload_model(name, path, cloud_account=cloud_account)
40
33
 
41
34
  def _resolve_studio(self, studio: Optional[str]) -> Studio:
42
35
  user = _get_authed_user()
lightning_sdk/job/base.py CHANGED
@@ -52,6 +52,8 @@ class _BaseJob(ABC):
52
52
  interruptible: bool = False,
53
53
  image_credentials: Optional[str] = None,
54
54
  cluster_auth: bool = False,
55
+ artifacts_local: Optional[str] = None,
56
+ artifacts_remote: Optional[str] = None,
55
57
  ) -> "_BaseJob":
56
58
  from lightning_sdk.studio import Studio
57
59
 
@@ -89,14 +91,30 @@ class _BaseJob(ABC):
89
91
  if cluster_auth:
90
92
  raise ValueError("cluster_auth is only supported when using a custom image")
91
93
 
94
+ if artifacts_local is not None or artifacts_remote is not None:
95
+ raise ValueError(
96
+ "Specifying artifacts persistence is supported for docker images only. "
97
+ "Other jobs will automatically persist artifacts to the teamspace distributed filesystem."
98
+ )
99
+
92
100
  else:
93
101
  if studio is not None:
94
102
  raise RuntimeError(
95
103
  "image and studio are mutually exclusive as both define the environment to run the job in"
96
104
  )
97
105
 
106
+ # they either need to specified both or none of them
107
+ if bool(artifacts_local) != bool(artifacts_remote):
108
+ raise ValueError("Artifact persistence requires both artifacts_local and artifacts_remote to be set")
109
+
110
+ if artifacts_remote and len(artifacts_remote.split(":")) != 3:
111
+ raise ValueError(
112
+ "Artifact persistence requires exactly three arguments separated by colon of kind "
113
+ f"<CONNECTION_TYPE>:<CONNECTION_NAME>:<PATH_WITHIN_CONNECTION>, got {artifacts_local}"
114
+ )
115
+
98
116
  inst = cls(name=name, teamspace=teamspace, org=org, user=user, _fetch_job=False)
99
- inst._submit(
117
+ return inst._submit(
100
118
  machine=machine,
101
119
  cluster=cluster,
102
120
  command=command,
@@ -106,8 +124,9 @@ class _BaseJob(ABC):
106
124
  interruptible=interruptible,
107
125
  image_credentials=image_credentials,
108
126
  cluster_auth=cluster_auth,
127
+ artifacts_local=artifacts_local,
128
+ artifacts_remote=artifacts_remote,
109
129
  )
110
- return inst
111
130
 
112
131
  @abstractmethod
113
132
  def _submit(
@@ -121,7 +140,9 @@ class _BaseJob(ABC):
121
140
  cluster: Optional[str] = None,
122
141
  image_credentials: Optional[str] = None,
123
142
  cluster_auth: bool = False,
124
- ) -> None:
143
+ artifacts_local: Optional[str] = None,
144
+ artifacts_remote: Optional[str] = None,
145
+ ) -> "_BaseJob":
125
146
  """Submits a job and updates the internal _job attribute as well as the _name attribute."""
126
147
 
127
148
  @abstractmethod
lightning_sdk/job/job.py CHANGED
@@ -60,6 +60,8 @@ class Job(_BaseJob):
60
60
  interruptible: bool = False,
61
61
  image_credentials: Optional[str] = None,
62
62
  cluster_auth: bool = False,
63
+ artifacts_local: Optional[str] = None,
64
+ artifacts_remote: Optional[str] = None,
63
65
  ) -> "Job":
64
66
  ret_val = super().run(
65
67
  name=name,
@@ -75,6 +77,8 @@ class Job(_BaseJob):
75
77
  interruptible=interruptible,
76
78
  image_credentials=image_credentials,
77
79
  cluster_auth=cluster_auth,
80
+ artifacts_local=artifacts_local,
81
+ artifacts_remote=artifacts_remote,
78
82
  )
79
83
  # required for typing with "Job"
80
84
  assert isinstance(ret_val, cls)
@@ -91,8 +95,10 @@ class Job(_BaseJob):
91
95
  cluster: Optional[str] = None,
92
96
  image_credentials: Optional[str] = None,
93
97
  cluster_auth: bool = False,
98
+ artifacts_local: Optional[str] = None,
99
+ artifacts_remote: Optional[str] = None,
94
100
  ) -> None:
95
- return self._internal_job._submit(
101
+ self._job = self._internal_job._submit(
96
102
  machine=machine,
97
103
  cluster=cluster,
98
104
  command=command,
@@ -102,7 +108,10 @@ class Job(_BaseJob):
102
108
  interruptible=interruptible,
103
109
  image_credentials=image_credentials,
104
110
  cluster_auth=cluster_auth,
111
+ artifacts_local=artifacts_local,
112
+ artifacts_remote=artifacts_remote,
105
113
  )
114
+ return self
106
115
 
107
116
  def stop(self) -> None:
108
117
  return self._internal_job.stop()
lightning_sdk/job/v1.py CHANGED
@@ -69,13 +69,18 @@ class _JobV1(_BaseJob):
69
69
  cluster: Optional[str] = None,
70
70
  image_credentials: Optional[str] = None,
71
71
  cluster_auth: bool = False,
72
- ) -> None:
72
+ artifacts_local: Optional[str] = None,
73
+ artifacts_remote: Optional[str] = None,
74
+ ) -> "_JobV1":
73
75
  if studio is None:
74
76
  raise ValueError("Studio is required for submitting jobs")
75
77
 
76
78
  if image is not None or image_credentials is not None or cluster_auth:
77
79
  raise ValueError("Image is not supported for submitting jobs")
78
80
 
81
+ if artifacts_local is not None or artifacts_remote is not None:
82
+ raise ValueError("Specifying how to persist artifacts is not yet supported with jobs")
83
+
79
84
  if env is not None:
80
85
  raise ValueError("Environment variables are not supported for submitting jobs")
81
86
 
@@ -95,6 +100,7 @@ class _JobV1(_BaseJob):
95
100
  )
96
101
  self._name = _submitted.name
97
102
  self._job = _submitted
103
+ return self
98
104
 
99
105
  def _update_internal_job(self) -> None:
100
106
  try:
lightning_sdk/job/v2.py CHANGED
@@ -36,7 +36,9 @@ class _JobV2(_BaseJob):
36
36
  cluster: Optional[str] = None,
37
37
  image_credentials: Optional[str] = None,
38
38
  cluster_auth: bool = False,
39
- ) -> None:
39
+ artifacts_local: Optional[str] = None,
40
+ artifacts_remote: Optional[str] = None,
41
+ ) -> "_JobV2":
40
42
  # Command is required if Studio is provided to know what to run
41
43
  # Image is mutually exclusive with Studio
42
44
  # Command is optional for Image
@@ -66,22 +68,21 @@ class _JobV2(_BaseJob):
66
68
  env=env,
67
69
  image_credentials=image_credentials,
68
70
  cluster_auth=cluster_auth,
71
+ artifacts_local=artifacts_local,
72
+ artifacts_remote=artifacts_remote,
69
73
  )
70
74
  self._job = submitted
71
75
  self._name = submitted.name
76
+ return self
72
77
 
73
78
  def stop(self) -> None:
74
- if self._job is None:
75
- self._update_internal_job()
76
-
77
- self._job_api.stop_job(job_id=self._job.id, teamspace_id=self._teamspace.id)
79
+ self._job_api.stop_job(job_id=self._guaranteed_job.id, teamspace_id=self._teamspace.id)
78
80
 
79
81
  def delete(self) -> None:
80
- if self._job is None:
81
- self._update_internal_job()
82
-
83
82
  self._job_api.delete_job(
84
- job_id=self._job.id, teamspace_id=self._teamspace.id, cloudspace_id=self._job.spec.cloudspace_id
83
+ job_id=self._guaranteed_job.id,
84
+ teamspace_id=self._teamspace.id,
85
+ cloudspace_id=self._guaranteed_job.spec.cloudspace_id,
85
86
  )
86
87
 
87
88
  @property
@@ -112,10 +113,18 @@ class _JobV2(_BaseJob):
112
113
 
113
114
  @property
114
115
  def artifact_path(self) -> Optional[str]:
116
+ if self._guaranteed_job.spec.image != "":
117
+ if self._guaranteed_job.spec.artifacts_destination != "":
118
+ splits = self._guaranteed_job.spec.artifacts_destination.split(":")
119
+ return f"/teamspace/{splits[0]}_connections/{splits[1]}/{splits[2]}"
120
+ return None
121
+
115
122
  return f"/teamspace/jobs/{self._guaranteed_job.name}/artifacts"
116
123
 
117
124
  @property
118
125
  def snapshot_path(self) -> Optional[str]:
126
+ if self._guaranteed_job.spec.image != "":
127
+ return None
119
128
  return f"/teamspace/jobs/{self._guaranteed_job.name}/snapshot"
120
129
 
121
130
  @property