lightning-sdk 0.1.49__py3-none-any.whl → 0.1.51__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. lightning_sdk/__init__.py +1 -1
  2. lightning_sdk/api/job_api.py +18 -13
  3. lightning_sdk/api/lit_container_api.py +41 -11
  4. lightning_sdk/api/mmt_api.py +18 -13
  5. lightning_sdk/api/utils.py +52 -0
  6. lightning_sdk/cli/download.py +20 -1
  7. lightning_sdk/cli/run.py +71 -21
  8. lightning_sdk/cli/serve.py +1 -5
  9. lightning_sdk/cli/upload.py +33 -15
  10. lightning_sdk/helpers.py +1 -1
  11. lightning_sdk/job/base.py +16 -5
  12. lightning_sdk/job/job.py +30 -27
  13. lightning_sdk/job/v1.py +9 -5
  14. lightning_sdk/job/v2.py +14 -14
  15. lightning_sdk/job/work.py +2 -2
  16. lightning_sdk/lightning_cloud/login.py +4 -1
  17. lightning_sdk/lightning_cloud/openapi/__init__.py +3 -0
  18. lightning_sdk/lightning_cloud/openapi/api/jobs_service_api.py +5 -1
  19. lightning_sdk/lightning_cloud/openapi/api/lit_registry_service_api.py +113 -0
  20. lightning_sdk/lightning_cloud/openapi/models/__init__.py +3 -0
  21. lightning_sdk/lightning_cloud/openapi/models/deployments_id_body.py +27 -1
  22. lightning_sdk/lightning_cloud/openapi/models/litregistry_lit_repo_name_body.py +123 -0
  23. lightning_sdk/lightning_cloud/openapi/models/v1_cluster_spec.py +27 -1
  24. lightning_sdk/lightning_cloud/openapi/models/v1_deployment.py +27 -1
  25. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_api.py +27 -1
  26. lightning_sdk/lightning_cloud/openapi/models/v1_job_spec.py +27 -1
  27. lightning_sdk/lightning_cloud/openapi/models/v1_path_mapping.py +175 -0
  28. lightning_sdk/lightning_cloud/openapi/models/v1_update_lit_repository_response.py +97 -0
  29. lightning_sdk/lightning_cloud/openapi/models/v1_user_features.py +79 -79
  30. lightning_sdk/lit_container.py +19 -0
  31. lightning_sdk/mmt/base.py +40 -30
  32. lightning_sdk/mmt/mmt.py +30 -26
  33. lightning_sdk/mmt/v1.py +6 -3
  34. lightning_sdk/mmt/v2.py +16 -15
  35. lightning_sdk/models.py +5 -4
  36. lightning_sdk/utils/resolve.py +7 -0
  37. {lightning_sdk-0.1.49.dist-info → lightning_sdk-0.1.51.dist-info}/METADATA +2 -2
  38. {lightning_sdk-0.1.49.dist-info → lightning_sdk-0.1.51.dist-info}/RECORD +42 -39
  39. {lightning_sdk-0.1.49.dist-info → lightning_sdk-0.1.51.dist-info}/LICENSE +0 -0
  40. {lightning_sdk-0.1.49.dist-info → lightning_sdk-0.1.51.dist-info}/WHEEL +0 -0
  41. {lightning_sdk-0.1.49.dist-info → lightning_sdk-0.1.51.dist-info}/entry_points.txt +0 -0
  42. {lightning_sdk-0.1.49.dist-info → lightning_sdk-0.1.51.dist-info}/top_level.txt +0 -0
lightning_sdk/__init__.py CHANGED
@@ -29,5 +29,5 @@ __all__ = [
29
29
  "AIHub",
30
30
  ]
31
31
 
32
- __version__ = "0.1.49"
32
+ __version__ = "0.1.51"
33
33
  _check_version_and_prompt_upgrade(__version__)
@@ -1,16 +1,15 @@
1
1
  import time
2
- from typing import TYPE_CHECKING, Dict, List, Optional
2
+ from typing import TYPE_CHECKING, Dict, List, Optional, Union
3
3
  from urllib.request import urlopen
4
4
 
5
5
  from lightning_sdk.api.utils import (
6
6
  _COMPUTE_NAME_TO_MACHINE,
7
- _MACHINE_TO_COMPUTE_NAME,
8
7
  _create_app,
8
+ _machine_to_compute_name,
9
9
  remove_datetime_prefix,
10
+ resolve_path_mappings,
10
11
  )
11
- from lightning_sdk.api.utils import (
12
- _get_cloud_url as _cloud_url,
13
- )
12
+ from lightning_sdk.api.utils import _get_cloud_url as _cloud_url
14
13
  from lightning_sdk.constants import __GLOBAL_LIGHTNING_UNIQUE_IDS_STORE__
15
14
  from lightning_sdk.lightning_cloud.openapi import (
16
15
  AppinstancesIdBody,
@@ -120,7 +119,7 @@ class JobApiV1:
120
119
  studio_id: str,
121
120
  teamspace_id: str,
122
121
  cloud_account: str,
123
- machine: Machine,
122
+ machine: Union[Machine, str],
124
123
  interruptible: bool,
125
124
  ) -> Externalv1LightningappInstance:
126
125
  """Creates an arbitrary app."""
@@ -130,7 +129,7 @@ class JobApiV1:
130
129
  teamspace_id=teamspace_id,
131
130
  cloud_account=cloud_account,
132
131
  plugin_type="job",
133
- compute=_MACHINE_TO_COMPUTE_NAME[machine],
132
+ compute=_machine_to_compute_name(machine),
134
133
  name=name,
135
134
  entrypoint=command,
136
135
  interruptible=interruptible,
@@ -209,24 +208,31 @@ class JobApiV2:
209
208
  teamspace_id: str,
210
209
  studio_id: Optional[str],
211
210
  image: Optional[str],
212
- machine: Machine,
211
+ machine: Union[Machine, str],
213
212
  interruptible: bool,
214
213
  env: Optional[Dict[str, str]],
215
214
  image_credentials: Optional[str],
216
215
  cloud_account_auth: bool,
217
- artifacts_local: Optional[str],
218
- artifacts_remote: Optional[str],
219
216
  entrypoint: str,
217
+ path_mappings: Optional[Dict[str, str]],
218
+ artifacts_local: Optional[str], # deprecated in favor of path_mappings
219
+ artifacts_remote: Optional[str], # deprecated in favor of path_mappings
220
220
  ) -> V1Job:
221
221
  env_vars = []
222
222
  if env is not None:
223
223
  for k, v in env.items():
224
224
  env_vars.append(V1EnvVar(name=k, value=v))
225
225
 
226
- instance_name = _MACHINE_TO_COMPUTE_NAME[machine]
226
+ instance_name = _machine_to_compute_name(machine)
227
227
 
228
228
  run_id = __GLOBAL_LIGHTNING_UNIQUE_IDS_STORE__[studio_id] if studio_id is not None else ""
229
229
 
230
+ path_mappings_list = resolve_path_mappings(
231
+ mappings=path_mappings or {},
232
+ artifacts_local=artifacts_local,
233
+ artifacts_remote=artifacts_remote,
234
+ )
235
+
230
236
  spec = V1JobSpec(
231
237
  cloudspace_id=studio_id or "",
232
238
  cluster_id=cloud_account or "",
@@ -239,8 +245,7 @@ class JobApiV2:
239
245
  spot=interruptible,
240
246
  image_cluster_credentials=cloud_account_auth,
241
247
  image_secret_ref=image_credentials or "",
242
- artifacts_source=artifacts_local or "",
243
- artifacts_destination=artifacts_remote or "",
248
+ path_mappings=path_mappings_list,
244
249
  )
245
250
  body = ProjectIdJobsBody(name=name, spec=spec)
246
251
 
@@ -1,15 +1,36 @@
1
1
  from typing import Generator, List
2
2
 
3
+ import docker
4
+
3
5
  from lightning_sdk.api.utils import _get_registry_url
6
+ from lightning_sdk.lightning_cloud.env import LIGHTNING_CLOUD_URL
4
7
  from lightning_sdk.lightning_cloud.openapi.models import V1DeleteLitRepositoryResponse
5
8
  from lightning_sdk.lightning_cloud.rest_client import LightningClient
6
9
  from lightning_sdk.teamspace import Teamspace
7
10
 
8
11
 
12
+ class LCRAuthFailedError(Exception):
13
+ def __init__(self) -> None:
14
+ super().__init__("Failed to authenticate with Lightning Container Registry")
15
+
16
+
9
17
  class LitContainerApi:
10
18
  def __init__(self) -> None:
11
19
  self._client = LightningClient(max_tries=3)
12
20
 
21
+ try:
22
+ self._docker_client = docker.from_env()
23
+ self._docker_client.ping()
24
+ except docker.errors.DockerException as e:
25
+ raise RuntimeError(f"Failed to connect to Docker daemon: {e!s}. Is Docker running?") from None
26
+
27
+ def authenticate(self) -> bool:
28
+ authed_user = self._client.auth_service_get_user()
29
+ username = authed_user.username
30
+ api_key = authed_user.api_key
31
+ resp = self._docker_client.login(username, password=api_key, registry=_get_registry_url())
32
+ return resp["Status"] == "Login Succeeded"
33
+
13
34
  def list_containers(self, project_id: str) -> List:
14
35
  project = self._client.lit_registry_service_get_lit_project_registry(project_id)
15
36
  return project.repositories
@@ -21,22 +42,31 @@ class LitContainerApi:
21
42
  raise ValueError(f"Could not delete container {container} from project {project_id}") from ex
22
43
 
23
44
  def upload_container(self, container: str, teamspace: Teamspace, tag: str) -> Generator[str, None, None]:
24
- import docker
25
-
26
- try:
27
- client = docker.from_env()
28
- client.ping()
29
- except docker.errors.DockerException as e:
30
- raise RuntimeError(f"Failed to connect to Docker daemon: {e!s}. Is Docker running?") from None
31
-
32
45
  try:
33
- client.images.get(container)
46
+ self._docker_client.images.get(container)
34
47
  except docker.errors.ImageNotFound:
35
48
  raise ValueError(f"Container {container} does not exist") from None
36
49
 
37
50
  registry_url = _get_registry_url()
38
51
  repository = f"{registry_url}/lit-container/{teamspace.owner.name}/{teamspace.name}/{container}"
39
- tagged = client.api.tag(container, repository, tag)
52
+ tagged = self._docker_client.api.tag(container, repository, tag)
40
53
  if not tagged:
41
54
  raise ValueError(f"Could not tag container {container} with {repository}:{tag}")
42
- return client.api.push(repository, stream=True, decode=True)
55
+ lines = self._docker_client.api.push(repository, stream=True, decode=True)
56
+ for line in lines:
57
+ if "errorDetail" in line and "authorization failed" in line["error"]:
58
+ raise LCRAuthFailedError()
59
+ yield line
60
+ yield {
61
+ "finish": True,
62
+ "url": f"{LIGHTNING_CLOUD_URL}/{teamspace.owner.name}/{teamspace.name}/containers/{container}",
63
+ }
64
+
65
+ def download_container(self, container: str, teamspace: Teamspace, tag: str) -> Generator[str, None, None]:
66
+ registry_url = _get_registry_url()
67
+ repository = f"{registry_url}/lit-container/{teamspace.owner.name}/{teamspace.name}/{container}"
68
+ try:
69
+ self._docker_client.images.pull(repository, tag=tag)
70
+ except docker.errors.APIError as e:
71
+ raise ValueError(f"Could not pull container {container} from {repository}:{tag}") from e
72
+ return self._docker_client.api.tag(repository, container, tag)
@@ -1,16 +1,15 @@
1
1
  import json
2
2
  import time
3
- from typing import TYPE_CHECKING, Dict, List, Optional
3
+ from typing import TYPE_CHECKING, Dict, List, Optional, Union
4
4
 
5
5
  from lightning_sdk.api.job_api import JobApiV1
6
6
  from lightning_sdk.api.utils import (
7
7
  _COMPUTE_NAME_TO_MACHINE,
8
- _MACHINE_TO_COMPUTE_NAME,
9
8
  _create_app,
9
+ _machine_to_compute_name,
10
+ resolve_path_mappings,
10
11
  )
11
- from lightning_sdk.api.utils import (
12
- _get_cloud_url as _cloud_url,
13
- )
12
+ from lightning_sdk.api.utils import _get_cloud_url as _cloud_url
14
13
  from lightning_sdk.constants import __GLOBAL_LIGHTNING_UNIQUE_IDS_STORE__
15
14
  from lightning_sdk.lightning_cloud.openapi import (
16
15
  Externalv1LightningappInstance,
@@ -43,13 +42,13 @@ class MMTApiV1(JobApiV1):
43
42
  cloud_account: Optional[str],
44
43
  teamspace_id: str,
45
44
  studio_id: str,
46
- machine: Machine,
45
+ machine: Union[Machine, str],
47
46
  interruptible: bool,
48
47
  strategy: str,
49
48
  ) -> Externalv1LightningappInstance:
50
49
  """Creates a multi-machine job with given commands."""
51
50
  distributed_args = {
52
- "cloud_compute": _MACHINE_TO_COMPUTE_NAME[machine],
51
+ "cloud_compute": _machine_to_compute_name(machine),
53
52
  "num_instances": num_machines,
54
53
  "strategy": strategy,
55
54
  }
@@ -80,24 +79,31 @@ class MMTApiV2:
80
79
  teamspace_id: str,
81
80
  studio_id: Optional[str],
82
81
  image: Optional[str],
83
- machine: Machine,
82
+ machine: Union[Machine, str],
84
83
  interruptible: bool,
85
84
  env: Optional[Dict[str, str]],
86
85
  image_credentials: Optional[str],
87
86
  cloud_account_auth: bool,
88
- artifacts_local: Optional[str],
89
- artifacts_remote: Optional[str],
90
87
  entrypoint: str,
88
+ path_mappings: Optional[Dict[str, str]],
89
+ artifacts_local: Optional[str], # deprecated in favor of path_mappings
90
+ artifacts_remote: Optional[str], # deprecated in favor of path_mappings
91
91
  ) -> V1MultiMachineJob:
92
92
  env_vars = []
93
93
  if env is not None:
94
94
  for k, v in env.items():
95
95
  env_vars.append(V1EnvVar(name=k, value=v))
96
96
 
97
- instance_name = _MACHINE_TO_COMPUTE_NAME[machine]
97
+ instance_name = _machine_to_compute_name(machine)
98
98
 
99
99
  run_id = __GLOBAL_LIGHTNING_UNIQUE_IDS_STORE__[studio_id] if studio_id is not None else ""
100
100
 
101
+ path_mappings_list = resolve_path_mappings(
102
+ mappings=path_mappings or {},
103
+ artifacts_local=artifacts_local,
104
+ artifacts_remote=artifacts_remote,
105
+ )
106
+
101
107
  spec = V1JobSpec(
102
108
  cloudspace_id=studio_id or "",
103
109
  cluster_id=cloud_account or "",
@@ -110,8 +116,7 @@ class MMTApiV2:
110
116
  spot=interruptible,
111
117
  image_cluster_credentials=cloud_account_auth,
112
118
  image_secret_ref=image_credentials or "",
113
- artifacts_source=artifacts_local or "",
114
- artifacts_destination=artifacts_remote or "",
119
+ path_mappings=path_mappings_list,
115
120
  )
116
121
  body = ProjectIdMultimachinejobsBody(
117
122
  name=name, spec=spec, cluster_id=cloud_account or "", machines=num_machines
@@ -24,6 +24,7 @@ from lightning_sdk.lightning_cloud.openapi import (
24
24
  UploadsUploadIdBody,
25
25
  V1CompletedPart,
26
26
  V1CompleteUpload,
27
+ V1PathMapping,
27
28
  V1PresignedUrl,
28
29
  V1SignedUrl,
29
30
  V1UploadProjectArtifactPartsResponse,
@@ -614,3 +615,54 @@ def remove_datetime_prefix(text: str) -> str:
614
615
  # lines looks something like
615
616
  # '[2025-01-08T14:15:03.797142418Z] ⚡ ~ echo Hello\n[2025-01-08T14:15:03.803077717Z] Hello\n'
616
617
  return re.sub(r"^\[.*?\] ", "", text, flags=re.MULTILINE)
618
+
619
+
620
+ def resolve_path_mappings(
621
+ mappings: Dict[str, str],
622
+ artifacts_local: Optional[str],
623
+ artifacts_remote: Optional[str],
624
+ ) -> List[V1PathMapping]:
625
+ path_mappings_list = []
626
+ for k, v in mappings.items():
627
+ splitted = str(v).rsplit(":", 1)
628
+ connection_name: str
629
+ connection_path: str
630
+ if len(splitted) == 1:
631
+ connection_name = splitted[0]
632
+ connection_path = ""
633
+ else:
634
+ connection_name, connection_path = splitted
635
+
636
+ path_mappings_list.append(
637
+ V1PathMapping(
638
+ connection_name=connection_name,
639
+ connection_path=connection_path,
640
+ container_path=k,
641
+ )
642
+ )
643
+
644
+ if artifacts_remote:
645
+ splitted = str(artifacts_remote).rsplit(":", 2)
646
+ if len(splitted) not in (2, 3):
647
+ raise RuntimeError(
648
+ f"Artifacts remote need to be of format efs:connection_name[:path] but got {artifacts_remote}"
649
+ )
650
+ else:
651
+ if not artifacts_local:
652
+ raise RuntimeError("If Artifacts remote is specified, artifacts local should be specified as well")
653
+
654
+ if len(splitted) == 2:
655
+ _, connection_name = splitted
656
+ connection_path = ""
657
+ else:
658
+ _, connection_name, connection_path = splitted
659
+
660
+ path_mappings_list.append(
661
+ V1PathMapping(
662
+ connection_name=connection_name,
663
+ connection_path=connection_path,
664
+ container_path=artifacts_local,
665
+ )
666
+ )
667
+
668
+ return path_mappings_list
@@ -3,14 +3,18 @@ import re
3
3
  from pathlib import Path
4
4
  from typing import Optional
5
5
 
6
+ from rich.console import Console
7
+
8
+ from lightning_sdk.api.lit_container_api import LitContainerApi
6
9
  from lightning_sdk.cli.exceptions import StudioCliError
7
10
  from lightning_sdk.cli.studios_menu import _StudiosMenu
11
+ from lightning_sdk.cli.teamspace_menu import _TeamspacesMenu
8
12
  from lightning_sdk.models import download_model
9
13
  from lightning_sdk.studio import Studio
10
14
  from lightning_sdk.utils.resolve import _get_authed_user, skip_studio_init
11
15
 
12
16
 
13
- class _Downloads(_StudiosMenu):
17
+ class _Downloads(_StudiosMenu, _TeamspacesMenu):
14
18
  """Download files and folders from Lightning AI."""
15
19
 
16
20
  def model(self, name: str, download_dir: str = ".") -> None:
@@ -130,3 +134,18 @@ class _Downloads(_StudiosMenu):
130
134
  f"Could not download the file from the given Studio {studio}. "
131
135
  "Please contact Lightning AI directly to resolve this issue."
132
136
  ) from e
137
+
138
+ def container(self, container: str, teamspace: Optional[str] = None, tag: str = "latest") -> None:
139
+ """Download a docker container from a teamspace.
140
+
141
+ Args:
142
+ container: The name of the container to download.
143
+ teamspace: The name of the teamspace to download the container from.
144
+ tag: The tag of the container to download.
145
+ """
146
+ resolved_teamspace = self._resolve_teamspace(teamspace)
147
+ console = Console()
148
+ with console.status("Downloading container..."):
149
+ api = LitContainerApi()
150
+ api.download_container(container, resolved_teamspace, tag)
151
+ console.print("Container downloaded successfully", style="green")
lightning_sdk/cli/run.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import TYPE_CHECKING, Dict, Optional
1
+ from typing import TYPE_CHECKING, Dict, Optional, Union
2
2
 
3
3
  from lightning_sdk.job import Job
4
4
  from lightning_sdk.machine import Machine
@@ -43,20 +43,28 @@ class _Run:
43
43
  This should be the name of the respective credentials secret created on the Lightning AI platform.
44
44
  cloud_account_auth: Whether to authenticate with the cloud account to pull the image.
45
45
  Required if the registry is part of a cloud provider (e.g. ECR).
46
- artifacts_local: The path of inside the docker container, you want to persist images from.
46
+ entrypoint: The entrypoint of your docker container. Defaults to `sh -c` which
47
+ just runs the provided command in a standard shell.
48
+ To use the pre-defined entrypoint of the provided image, set this to an empty string.
49
+ Only applicable when submitting docker jobs.
50
+ path_mappings: Maps path inside of containers to paths inside data-connections.
51
+ Should be a comma separated list of form:
52
+ <MAPPING_1>,<MAPPING_2>,...
53
+ where each mapping is of the form
54
+ <CONTAINER_PATH_1>:<CONNECTION_NAME_1>:<PATH_WITHIN_CONNECTION_1> and
55
+ omitting the path inside the connection defaults to the connections root.
56
+ artifacts_local: Deprecated in favor of path_mappings.
57
+ The path of inside the docker container, you want to persist images from.
47
58
  CAUTION: When setting this to "/", it will effectively erase your container.
48
59
  Only supported for jobs with a docker image compute environment.
49
- artifacts_remote: The remote storage to persist your artifacts to.
60
+ artifacts_remote: Deprecated in favor of path_mappings.
61
+ The remote storage to persist your artifacts to.
50
62
  Should be of format <CONNECTION_TYPE>:<CONNECTION_NAME>:<PATH_WITHIN_CONNECTION>.
51
63
  PATH_WITHIN_CONNECTION hereby is a path relative to the connection's root.
52
64
  E.g. efs:data:some-path would result in an EFS connection named `data` and to the path `some-path`
53
65
  within it.
54
66
  Note that the connection needs to be added to the teamspace already in order for it to be found.
55
67
  Only supported for jobs with a docker image compute environment.
56
- entrypoint: The entrypoint of your docker container. Defaults to `sh -c` which
57
- just runs the provided command in a standard shell.
58
- To use the pre-defined entrypoint of the provided image, set this to an empty string.
59
- Only applicable when submitting docker jobs.
60
68
  """
61
69
  # TODO: the docstrings from artifacts_local and artifacts_remote don't show up completely,
62
70
  # might need to switch to explicit cli definition
@@ -87,20 +95,28 @@ class _Run:
87
95
  This should be the name of the respective credentials secret created on the Lightning AI platform.
88
96
  cloud_account_auth: Whether to authenticate with the cloud account to pull the image.
89
97
  Required if the registry is part of a cloud provider (e.g. ECR).
90
- artifacts_local: The path of inside the docker container, you want to persist images from.
98
+ entrypoint: The entrypoint of your docker container. Defaults to `sh -c` which
99
+ just runs the provided command in a standard shell.
100
+ To use the pre-defined entrypoint of the provided image, set this to an empty string.
101
+ Only applicable when submitting docker jobs.
102
+ path_mappings: Maps path inside of containers to paths inside data-connections.
103
+ Should be a comma separated list of form:
104
+ <MAPPING_1>,<MAPPING_2>,...
105
+ where each mapping is of the form
106
+ <CONTAINER_PATH_1>:<CONNECTION_NAME_1>:<PATH_WITHIN_CONNECTION_1> and
107
+ omitting the path inside the connection defaults to the connections root.
108
+ artifacts_local: Deprecated in favor of path_mappings.
109
+ The path of inside the docker container, you want to persist images from.
91
110
  CAUTION: When setting this to "/", it will effectively erase your container.
92
111
  Only supported for jobs with a docker image compute environment.
93
- artifacts_remote: The remote storage to persist your artifacts to.
112
+ artifacts_remote: Deprecated in favor of path_mappings.
113
+ The remote storage to persist your artifacts to.
94
114
  Should be of format <CONNECTION_TYPE>:<CONNECTION_NAME>:<PATH_WITHIN_CONNECTION>.
95
115
  PATH_WITHIN_CONNECTION hereby is a path relative to the connection's root.
96
116
  E.g. efs:data:some-path would result in an EFS connection named `data` and to the path `some-path`
97
117
  within it.
98
118
  Note that the connection needs to be added to the teamspace already in order for it to be found.
99
119
  Only supported for jobs with a docker image compute environment.
100
- entrypoint: The entrypoint of your docker container. Defaults to `sh -c` which
101
- just runs the provided command in a standard shell.
102
- To use the pre-defined entrypoint of the provided image, set this to an empty string.
103
- Only applicable when submitting docker jobs.
104
120
  """
105
121
  # TODO: the docstrings from artifacts_local and artifacts_remote don't show up completely,
106
122
  # might need to switch to explicit cli definition
@@ -124,20 +140,25 @@ class _Run:
124
140
  interruptible: bool = False,
125
141
  image_credentials: Optional[str] = None,
126
142
  cloud_account_auth: bool = False,
143
+ entrypoint: str = "sh -c",
144
+ path_mappings: str = "",
127
145
  artifacts_local: Optional[str] = None,
128
146
  artifacts_remote: Optional[str] = None,
129
- entrypoint: str = "sh -c",
130
147
  ) -> None:
131
148
  if not name:
132
149
  from datetime import datetime
133
150
 
134
- timestr = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
151
+ timestr = datetime.now().strftime("%b-%d-%H_%M")
135
152
  name = f"job-{timestr}"
136
153
 
137
154
  if machine is None:
138
155
  # TODO: infer from studio
139
156
  machine = "CPU"
140
- machine_enum = Machine[machine.upper()]
157
+ machine_enum: Union[str, Machine]
158
+ try:
159
+ machine_enum = Machine[machine.upper()]
160
+ except KeyError:
161
+ machine_enum = machine
141
162
 
142
163
  resolved_teamspace = Teamspace(name=teamspace, org=org, user=user)
143
164
 
@@ -145,6 +166,17 @@ class _Run:
145
166
  cloud_account = resolved_teamspace.default_cloud_account
146
167
  machine_enum = Machine(machine.upper())
147
168
 
169
+ path_mappings_dict = {}
170
+ for mapping in path_mappings.split(","):
171
+ splits = str(mapping).split(":", 1)
172
+ if len(splits) != 2:
173
+ raise RuntimeError(
174
+ "Mapping needs to be of form <CONTAINER_PATH>:<CONNECTION_NAME>[:<PATH_WITHIN_CONNECTION>], "
175
+ f"but got {mapping}"
176
+ )
177
+
178
+ path_mappings_dict[splits[0]] = splits[1]
179
+
148
180
  Job.run(
149
181
  name=name,
150
182
  machine=machine_enum,
@@ -159,9 +191,10 @@ class _Run:
159
191
  interruptible=interruptible,
160
192
  image_credentials=image_credentials,
161
193
  cloud_account_auth=cloud_account_auth,
194
+ entrypoint=entrypoint,
195
+ path_mappings=path_mappings_dict,
162
196
  artifacts_local=artifacts_local,
163
197
  artifacts_remote=artifacts_remote,
164
- entrypoint=entrypoint,
165
198
  )
166
199
 
167
200
  # TODO: sadly, fire displays both Optional[type] and Union[type, None] as Optional[Optional]
@@ -182,20 +215,25 @@ class _Run:
182
215
  interruptible: bool = False,
183
216
  image_credentials: Optional[str] = None,
184
217
  cloud_account_auth: bool = False,
218
+ entrypoint: str = "sh -c",
219
+ path_mappings: str = "",
185
220
  artifacts_local: Optional[str] = None,
186
221
  artifacts_remote: Optional[str] = None,
187
- entrypoint: str = "sh -c",
188
222
  ) -> None:
189
223
  if name is None:
190
224
  from datetime import datetime
191
225
 
192
- timestr = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
226
+ timestr = datetime.now().strftime("%b-%d-%H_%M")
193
227
  name = f"mmt-{timestr}"
194
228
 
195
229
  if machine is None:
196
230
  # TODO: infer from studio
197
231
  machine = "CPU"
198
- machine_enum = Machine[machine.upper()]
232
+ machine_enum: Union[str, Machine]
233
+ try:
234
+ machine_enum = Machine[machine.upper()]
235
+ except KeyError:
236
+ machine_enum = machine
199
237
 
200
238
  resolved_teamspace = Teamspace(name=teamspace, org=org, user=user)
201
239
  if cloud_account is None:
@@ -204,6 +242,17 @@ class _Run:
204
242
  if image is None:
205
243
  raise RuntimeError("Image needs to be specified to run a multi-machine job")
206
244
 
245
+ path_mappings_dict = {}
246
+ for mapping in path_mappings.split(","):
247
+ splits = str(mapping).split(":", 1)
248
+ if len(splits) != 2:
249
+ raise RuntimeError(
250
+ "Mapping needs to be of form <CONTAINER_PATH>:<CONNECTION_NAME>[:<PATH_WITHIN_CONNECTION>], "
251
+ f"but got {mapping}"
252
+ )
253
+
254
+ path_mappings_dict[splits[0]] = splits[1]
255
+
207
256
  MMT.run(
208
257
  name=name,
209
258
  num_machines=num_machines,
@@ -219,7 +268,8 @@ class _Run:
219
268
  interruptible=interruptible,
220
269
  image_credentials=image_credentials,
221
270
  cloud_account_auth=cloud_account_auth,
271
+ entrypoint=entrypoint,
272
+ path_mappings=path_mappings_dict,
222
273
  artifacts_local=artifacts_local,
223
274
  artifacts_remote=artifacts_remote,
224
- entrypoint=entrypoint,
225
275
  )
@@ -4,6 +4,7 @@ import warnings
4
4
  from pathlib import Path
5
5
  from typing import Optional, Union
6
6
 
7
+ import docker
7
8
  from rich.console import Console
8
9
  from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
9
10
  from rich.prompt import Confirm
@@ -89,11 +90,6 @@ class _LitServe:
89
90
  tag: str = "litserve-model",
90
91
  non_interactive: bool = False,
91
92
  ) -> None:
92
- try:
93
- import docker
94
- except ImportError:
95
- raise ImportError("docker-py is not installed. Please install it with `pip install docker`") from None
96
-
97
93
  try:
98
94
  client = docker.from_env()
99
95
  client.ping()
@@ -2,14 +2,15 @@ import concurrent.futures
2
2
  import json
3
3
  import os
4
4
  from pathlib import Path
5
- from typing import Dict, List, Optional
5
+ from typing import Dict, Generator, List, Optional
6
6
 
7
+ import rich
7
8
  from rich.console import Console
8
9
  from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
9
10
  from simple_term_menu import TerminalMenu
10
11
  from tqdm import tqdm
11
12
 
12
- from lightning_sdk.api.lit_container_api import LitContainerApi
13
+ from lightning_sdk.api.lit_container_api import LCRAuthFailedError, LitContainerApi
13
14
  from lightning_sdk.api.utils import _get_cloud_url
14
15
  from lightning_sdk.cli.exceptions import StudioCliError
15
16
  from lightning_sdk.cli.studios_menu import _StudiosMenu
@@ -162,21 +163,38 @@ class _Uploads(_StudiosMenu, _TeamspacesMenu):
162
163
  transient=False,
163
164
  ) as progress:
164
165
  push_task = progress.add_task("Pushing Docker image", total=None)
165
- resp = api.upload_container(container, teamspace, tag)
166
- for line in resp:
167
- if "status" in line:
168
- console.print(line["status"], style="bright_black")
169
- progress.update(push_task, description="Pushing Docker image")
170
- elif "aux" in line:
171
- console.print(line["aux"], style="bright_black")
172
- elif "error" in line:
173
- progress.stop()
174
- console.print(f"\n[red]{line}[/red]")
175
- return
176
- else:
177
- console.print(line, style="bright_black")
166
+ try:
167
+ lines = api.upload_container(container, teamspace, tag)
168
+ self._print_docker_push(lines, console, progress, push_task)
169
+ except LCRAuthFailedError:
170
+ console.print("Authenticating with Lightning Container Registry...")
171
+ if not api.authenticate():
172
+ raise StudioCliError("Failed to authenticate with Lightning Container Registry") from None
173
+ console.print("Authenticated with Lightning Container Registry", style="green")
174
+ lines = api.upload_container(container, teamspace, tag)
175
+ self._print_docker_push(lines, console, progress, push_task)
178
176
  progress.update(push_task, description="[green]Container pushed![/green]")
179
177
 
178
+ @staticmethod
179
+ def _print_docker_push(
180
+ lines: Generator, console: Console, progress: Progress, push_task: rich.progress.TaskID
181
+ ) -> None:
182
+ for line in lines:
183
+ if "status" in line:
184
+ console.print(line["status"], style="bright_black")
185
+ progress.update(push_task, description="Pushing Docker image")
186
+ elif "aux" in line:
187
+ console.print(line["aux"], style="bright_black")
188
+ elif "error" in line:
189
+ progress.stop()
190
+ console.print(f"\n[red]{line}[/red]")
191
+ return
192
+ elif "finish" in line:
193
+ console.print(f"Container available at [i]{line['url']}[/i]")
194
+ return
195
+ else:
196
+ console.print(line, style="bright_black")
197
+
180
198
  def _start_parallel_upload(
181
199
  self, executor: concurrent.futures.ThreadPoolExecutor, studio: Studio, upload_state: Dict[str, str]
182
200
  ) -> List[concurrent.futures.Future]:
lightning_sdk/helpers.py CHANGED
@@ -43,7 +43,7 @@ def _check_version_and_prompt_upgrade(curr_version: str) -> None:
43
43
  warnings.warn(
44
44
  f"A newer version of {__package_name__} is available ({new_version}). "
45
45
  f"Please consider upgrading with `pip install -U {__package_name__}`. "
46
- "Not all functionalities of the platform can be guaranteed to work with the current version.",
46
+ "Not all platform functionality can be guaranteed to work with the current version.",
47
47
  UserWarning,
48
48
  )
49
49
  return