lightning-sdk 0.1.50__py3-none-any.whl → 0.1.53__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. lightning_sdk/__init__.py +1 -1
  2. lightning_sdk/ai_hub.py +16 -27
  3. lightning_sdk/api/ai_hub_api.py +7 -1
  4. lightning_sdk/api/job_api.py +12 -7
  5. lightning_sdk/api/lit_container_api.py +24 -7
  6. lightning_sdk/api/mmt_api.py +12 -7
  7. lightning_sdk/api/utils.py +52 -0
  8. lightning_sdk/cli/run.py +65 -18
  9. lightning_sdk/cli/serve.py +1 -5
  10. lightning_sdk/cli/upload.py +33 -15
  11. lightning_sdk/helpers.py +1 -1
  12. lightning_sdk/job/base.py +28 -1
  13. lightning_sdk/job/job.py +27 -25
  14. lightning_sdk/job/v1.py +6 -2
  15. lightning_sdk/job/v2.py +12 -12
  16. lightning_sdk/lightning_cloud/login.py +4 -1
  17. lightning_sdk/lightning_cloud/openapi/__init__.py +17 -0
  18. lightning_sdk/lightning_cloud/openapi/api/assistants_service_api.py +105 -0
  19. lightning_sdk/lightning_cloud/openapi/api/cluster_service_api.py +417 -1
  20. lightning_sdk/lightning_cloud/openapi/api/file_system_service_api.py +105 -0
  21. lightning_sdk/lightning_cloud/openapi/api/jobs_service_api.py +5 -1
  22. lightning_sdk/lightning_cloud/openapi/api/lit_registry_service_api.py +113 -0
  23. lightning_sdk/lightning_cloud/openapi/api/storage_service_api.py +101 -0
  24. lightning_sdk/lightning_cloud/openapi/api/user_service_api.py +5 -1
  25. lightning_sdk/lightning_cloud/openapi/models/__init__.py +17 -0
  26. lightning_sdk/lightning_cloud/openapi/models/cluster_id_usagerestrictions_body.py +175 -0
  27. lightning_sdk/lightning_cloud/openapi/models/deployments_id_body.py +27 -1
  28. lightning_sdk/lightning_cloud/openapi/models/id_contactowner_body.py +149 -0
  29. lightning_sdk/lightning_cloud/openapi/models/litregistry_lit_repo_name_body.py +123 -0
  30. lightning_sdk/lightning_cloud/openapi/models/metricsstream_create_body.py +27 -1
  31. lightning_sdk/lightning_cloud/openapi/models/usagerestrictions_id_body.py +175 -0
  32. lightning_sdk/lightning_cloud/openapi/models/v1_assistant_model_status.py +4 -0
  33. lightning_sdk/lightning_cloud/openapi/models/v1_cloud_provider.py +104 -0
  34. lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space_artifact_event.py +149 -0
  35. lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space_artifact_event_type.py +103 -0
  36. lightning_sdk/lightning_cloud/openapi/models/v1_cluster_accelerator.py +81 -3
  37. lightning_sdk/lightning_cloud/openapi/models/v1_cluster_spec.py +27 -1
  38. lightning_sdk/lightning_cloud/openapi/models/v1_cluster_tagging_options.py +29 -3
  39. lightning_sdk/lightning_cloud/openapi/models/v1_cluster_usage_restriction.py +227 -0
  40. lightning_sdk/lightning_cloud/openapi/models/v1_contact_assistant_owner_reason.py +102 -0
  41. lightning_sdk/lightning_cloud/openapi/models/v1_contact_assistant_owner_response.py +97 -0
  42. lightning_sdk/lightning_cloud/openapi/models/v1_delete_cluster_usage_restriction_response.py +97 -0
  43. lightning_sdk/lightning_cloud/openapi/models/v1_deployment.py +27 -1
  44. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_api.py +53 -1
  45. lightning_sdk/lightning_cloud/openapi/models/v1_filesystem_mmt.py +175 -0
  46. lightning_sdk/lightning_cloud/openapi/models/v1_job_spec.py +27 -1
  47. lightning_sdk/lightning_cloud/openapi/models/v1_list_cluster_usage_restrictions_response.py +123 -0
  48. lightning_sdk/lightning_cloud/openapi/models/v1_list_filesystem_mm_ts_response.py +123 -0
  49. lightning_sdk/lightning_cloud/openapi/models/v1_metrics_stream.py +27 -1
  50. lightning_sdk/lightning_cloud/openapi/models/v1_model.py +27 -1
  51. lightning_sdk/lightning_cloud/openapi/models/v1_path_mapping.py +175 -0
  52. lightning_sdk/lightning_cloud/openapi/models/v1_post_cloud_space_artifact_events_response.py +97 -0
  53. lightning_sdk/lightning_cloud/openapi/models/v1_resource_visibility.py +27 -1
  54. lightning_sdk/lightning_cloud/openapi/models/v1_update_lit_repository_response.py +97 -0
  55. lightning_sdk/lightning_cloud/openapi/models/v1_user_features.py +128 -76
  56. lightning_sdk/lightning_cloud/utils/data_connection.py +75 -7
  57. lightning_sdk/mmt/base.py +36 -26
  58. lightning_sdk/mmt/mmt.py +28 -26
  59. lightning_sdk/mmt/v1.py +4 -1
  60. lightning_sdk/mmt/v2.py +14 -13
  61. lightning_sdk/models.py +5 -4
  62. lightning_sdk/studio.py +68 -1
  63. lightning_sdk/utils/resolve.py +7 -0
  64. {lightning_sdk-0.1.50.dist-info → lightning_sdk-0.1.53.dist-info}/METADATA +2 -2
  65. {lightning_sdk-0.1.50.dist-info → lightning_sdk-0.1.53.dist-info}/RECORD +69 -52
  66. {lightning_sdk-0.1.50.dist-info → lightning_sdk-0.1.53.dist-info}/LICENSE +0 -0
  67. {lightning_sdk-0.1.50.dist-info → lightning_sdk-0.1.53.dist-info}/WHEEL +0 -0
  68. {lightning_sdk-0.1.50.dist-info → lightning_sdk-0.1.53.dist-info}/entry_points.txt +0 -0
  69. {lightning_sdk-0.1.50.dist-info → lightning_sdk-0.1.53.dist-info}/top_level.txt +0 -0
lightning_sdk/__init__.py CHANGED
@@ -29,5 +29,5 @@ __all__ = [
29
29
  "AIHub",
30
30
  ]
31
31
 
32
- __version__ = "0.1.50"
32
+ __version__ = "0.1.53"
33
33
  _check_version_and_prompt_upgrade(__version__)
lightning_sdk/ai_hub.py CHANGED
@@ -1,11 +1,10 @@
1
1
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
2
2
  from urllib.parse import quote
3
3
 
4
- from lightning_sdk.api import AIHubApi, UserApi
5
- from lightning_sdk.lightning_cloud import login
6
- from lightning_sdk.lightning_cloud.env import LIGHTNING_CLOUD_URL
4
+ from lightning_sdk.api import AIHubApi
5
+ from lightning_sdk.api.utils import _get_cloud_url
7
6
  from lightning_sdk.user import User
8
- from lightning_sdk.utils.resolve import _resolve_org, _resolve_teamspace
7
+ from lightning_sdk.utils.resolve import _resolve_teamspace
9
8
 
10
9
  if TYPE_CHECKING:
11
10
  from lightning_sdk import Organization, Teamspace
@@ -99,26 +98,6 @@ class AIHub:
99
98
  results.append(result)
100
99
  return results
101
100
 
102
- def _authenticate(
103
- self,
104
- teamspace: Optional[Union[str, "Teamspace"]] = None,
105
- org: Optional[Union[str, "Organization"]] = None,
106
- user: Optional[Union[str, "User"]] = None,
107
- ) -> "Teamspace":
108
- if self._auth is None:
109
- self._auth = login.Auth()
110
- try:
111
- self._auth.authenticate()
112
- user = User(name=UserApi()._get_user_by_id(self._auth.user_id).username)
113
- except ConnectionError as e:
114
- raise e
115
-
116
- org = _resolve_org(org)
117
- teamspace = _resolve_teamspace(teamspace=teamspace, org=org, user=user if org is None else None)
118
- if teamspace is None:
119
- raise ValueError("You need to pass a teamspace or an org for your deployment.")
120
- return teamspace
121
-
122
101
  def run(
123
102
  self,
124
103
  api_id: str,
@@ -127,6 +106,7 @@ class AIHub:
127
106
  cloud_account: Optional[str] = None,
128
107
  teamspace: Optional[Union[str, "Teamspace"]] = None,
129
108
  org: Optional[Union[str, "Organization"]] = None,
109
+ user: Optional[Union[str, "User"]] = None,
130
110
  ) -> Dict[str, Union[str, bool]]:
131
111
  """Deploy an API from the AI Hub.
132
112
 
@@ -146,7 +126,8 @@ class AIHub:
146
126
  cloud_account: The cloud account where you want to run the template, such as "lightning-public-prod".
147
127
  Defaults to None.
148
128
  teamspace: The team or group for deployment. Defaults to None.
149
- org: The organization for deployment. Defaults to None.
129
+ org: The organization for deployment. Don't pass user with this. Defaults to None.
130
+ user: The user for deployment. Don't pass org with this. Defaults to None.
150
131
 
151
132
  Returns:
152
133
  A dictionary containing the name of the deployed API,
@@ -156,7 +137,13 @@ class AIHub:
156
137
  ValueError: If a teamspace or organization is not provided.
157
138
  ConnectionError: If there is an issue with logging in.
158
139
  """
159
- teamspace = self._authenticate(teamspace, org)
140
+ if user is not None and org is not None:
141
+ raise ValueError("User and org are mutually exclusive. Please only specify the one owns the teamspace.")
142
+
143
+ teamspace = _resolve_teamspace(teamspace=teamspace, org=org, user=user)
144
+ if teamspace is None:
145
+ raise ValueError("You need to pass a teamspace or an org for your deployment.")
146
+
160
147
  teamspace_id = teamspace.id
161
148
 
162
149
  api_arguments = api_arguments or {}
@@ -167,13 +154,15 @@ class AIHub:
167
154
  name=name,
168
155
  api_arguments=api_arguments,
169
156
  )
157
+
170
158
  url = (
171
159
  quote(
172
- f"{LIGHTNING_CLOUD_URL}/{teamspace._org.name}/{teamspace.name}/jobs/{deployment.name}",
160
+ f"{_get_cloud_url}/{teamspace.owner.name}/{teamspace.name}/jobs/{deployment.name}",
173
161
  safe=":/()",
174
162
  )
175
163
  + "?app_id=deployment"
176
164
  )
165
+
177
166
  print("Deployment available at:", url)
178
167
 
179
168
  return {
@@ -116,6 +116,12 @@ class AIHubApi:
116
116
  name = name or template.name
117
117
  template.spec_v2.endpoint.id = None
118
118
 
119
+ # These are needed to ensure templates with a max replicas of 0 will start on creation
120
+ if template.spec_v2.autoscaling.max_replicas == "0":
121
+ template.spec_v2.autoscaling.max_replicas = "1"
122
+ if not template.spec_v2.autoscaling.enabled:
123
+ template.spec_v2.autoscaling.enabled = True
124
+
119
125
  AIHubApi._set_parameters(template.spec_v2.job, template.parameter_spec.parameters, api_arguments)
120
126
  return self._client.jobs_service_create_deployment(
121
127
  project_id=project_id,
@@ -124,7 +130,7 @@ class AIHubApi:
124
130
  cluster_id=cloud_account,
125
131
  endpoint=template.spec_v2.endpoint,
126
132
  name=name,
127
- replicas=0,
133
+ replicas=1,
128
134
  spec=template.spec_v2.job,
129
135
  ),
130
136
  )
@@ -7,10 +7,9 @@ from lightning_sdk.api.utils import (
7
7
  _create_app,
8
8
  _machine_to_compute_name,
9
9
  remove_datetime_prefix,
10
+ resolve_path_mappings,
10
11
  )
11
- from lightning_sdk.api.utils import (
12
- _get_cloud_url as _cloud_url,
13
- )
12
+ from lightning_sdk.api.utils import _get_cloud_url as _cloud_url
14
13
  from lightning_sdk.constants import __GLOBAL_LIGHTNING_UNIQUE_IDS_STORE__
15
14
  from lightning_sdk.lightning_cloud.openapi import (
16
15
  AppinstancesIdBody,
@@ -214,9 +213,10 @@ class JobApiV2:
214
213
  env: Optional[Dict[str, str]],
215
214
  image_credentials: Optional[str],
216
215
  cloud_account_auth: bool,
217
- artifacts_local: Optional[str],
218
- artifacts_remote: Optional[str],
219
216
  entrypoint: str,
217
+ path_mappings: Optional[Dict[str, str]],
218
+ artifacts_local: Optional[str], # deprecated in favor of path_mappings
219
+ artifacts_remote: Optional[str], # deprecated in favor of path_mappings
220
220
  ) -> V1Job:
221
221
  env_vars = []
222
222
  if env is not None:
@@ -227,6 +227,12 @@ class JobApiV2:
227
227
 
228
228
  run_id = __GLOBAL_LIGHTNING_UNIQUE_IDS_STORE__[studio_id] if studio_id is not None else ""
229
229
 
230
+ path_mappings_list = resolve_path_mappings(
231
+ mappings=path_mappings or {},
232
+ artifacts_local=artifacts_local,
233
+ artifacts_remote=artifacts_remote,
234
+ )
235
+
230
236
  spec = V1JobSpec(
231
237
  cloudspace_id=studio_id or "",
232
238
  cluster_id=cloud_account or "",
@@ -239,8 +245,7 @@ class JobApiV2:
239
245
  spot=interruptible,
240
246
  image_cluster_credentials=cloud_account_auth,
241
247
  image_secret_ref=image_credentials or "",
242
- artifacts_source=artifacts_local or "",
243
- artifacts_destination=artifacts_remote or "",
248
+ path_mappings=path_mappings_list,
244
249
  )
245
250
  body = ProjectIdJobsBody(name=name, spec=spec)
246
251
 
@@ -1,23 +1,36 @@
1
1
  from typing import Generator, List
2
2
 
3
+ import docker
4
+
3
5
  from lightning_sdk.api.utils import _get_registry_url
6
+ from lightning_sdk.lightning_cloud.env import LIGHTNING_CLOUD_URL
4
7
  from lightning_sdk.lightning_cloud.openapi.models import V1DeleteLitRepositoryResponse
5
8
  from lightning_sdk.lightning_cloud.rest_client import LightningClient
6
9
  from lightning_sdk.teamspace import Teamspace
7
10
 
8
11
 
12
+ class LCRAuthFailedError(Exception):
13
+ def __init__(self) -> None:
14
+ super().__init__("Failed to authenticate with Lightning Container Registry")
15
+
16
+
9
17
  class LitContainerApi:
10
18
  def __init__(self) -> None:
11
19
  self._client = LightningClient(max_tries=3)
12
20
 
13
- import docker
14
-
15
21
  try:
16
22
  self._docker_client = docker.from_env()
17
23
  self._docker_client.ping()
18
24
  except docker.errors.DockerException as e:
19
25
  raise RuntimeError(f"Failed to connect to Docker daemon: {e!s}. Is Docker running?") from None
20
26
 
27
+ def authenticate(self) -> bool:
28
+ authed_user = self._client.auth_service_get_user()
29
+ username = authed_user.username
30
+ api_key = authed_user.api_key
31
+ resp = self._docker_client.login(username, password=api_key, registry=_get_registry_url())
32
+ return resp["Status"] == "Login Succeeded"
33
+
21
34
  def list_containers(self, project_id: str) -> List:
22
35
  project = self._client.lit_registry_service_get_lit_project_registry(project_id)
23
36
  return project.repositories
@@ -29,8 +42,6 @@ class LitContainerApi:
29
42
  raise ValueError(f"Could not delete container {container} from project {project_id}") from ex
30
43
 
31
44
  def upload_container(self, container: str, teamspace: Teamspace, tag: str) -> Generator[str, None, None]:
32
- import docker
33
-
34
45
  try:
35
46
  self._docker_client.images.get(container)
36
47
  except docker.errors.ImageNotFound:
@@ -41,11 +52,17 @@ class LitContainerApi:
41
52
  tagged = self._docker_client.api.tag(container, repository, tag)
42
53
  if not tagged:
43
54
  raise ValueError(f"Could not tag container {container} with {repository}:{tag}")
44
- return self._docker_client.api.push(repository, stream=True, decode=True)
55
+ lines = self._docker_client.api.push(repository, stream=True, decode=True)
56
+ for line in lines:
57
+ if "errorDetail" in line and "authorization failed" in line["error"]:
58
+ raise LCRAuthFailedError()
59
+ yield line
60
+ yield {
61
+ "finish": True,
62
+ "url": f"{LIGHTNING_CLOUD_URL}/{teamspace.owner.name}/{teamspace.name}/containers/{container}",
63
+ }
45
64
 
46
65
  def download_container(self, container: str, teamspace: Teamspace, tag: str) -> Generator[str, None, None]:
47
- import docker
48
-
49
66
  registry_url = _get_registry_url()
50
67
  repository = f"{registry_url}/lit-container/{teamspace.owner.name}/{teamspace.name}/{container}"
51
68
  try:
@@ -7,10 +7,9 @@ from lightning_sdk.api.utils import (
7
7
  _COMPUTE_NAME_TO_MACHINE,
8
8
  _create_app,
9
9
  _machine_to_compute_name,
10
+ resolve_path_mappings,
10
11
  )
11
- from lightning_sdk.api.utils import (
12
- _get_cloud_url as _cloud_url,
13
- )
12
+ from lightning_sdk.api.utils import _get_cloud_url as _cloud_url
14
13
  from lightning_sdk.constants import __GLOBAL_LIGHTNING_UNIQUE_IDS_STORE__
15
14
  from lightning_sdk.lightning_cloud.openapi import (
16
15
  Externalv1LightningappInstance,
@@ -85,9 +84,10 @@ class MMTApiV2:
85
84
  env: Optional[Dict[str, str]],
86
85
  image_credentials: Optional[str],
87
86
  cloud_account_auth: bool,
88
- artifacts_local: Optional[str],
89
- artifacts_remote: Optional[str],
90
87
  entrypoint: str,
88
+ path_mappings: Optional[Dict[str, str]],
89
+ artifacts_local: Optional[str], # deprecated in favor of path_mappings
90
+ artifacts_remote: Optional[str], # deprecated in favor of path_mappings
91
91
  ) -> V1MultiMachineJob:
92
92
  env_vars = []
93
93
  if env is not None:
@@ -98,6 +98,12 @@ class MMTApiV2:
98
98
 
99
99
  run_id = __GLOBAL_LIGHTNING_UNIQUE_IDS_STORE__[studio_id] if studio_id is not None else ""
100
100
 
101
+ path_mappings_list = resolve_path_mappings(
102
+ mappings=path_mappings or {},
103
+ artifacts_local=artifacts_local,
104
+ artifacts_remote=artifacts_remote,
105
+ )
106
+
101
107
  spec = V1JobSpec(
102
108
  cloudspace_id=studio_id or "",
103
109
  cluster_id=cloud_account or "",
@@ -110,8 +116,7 @@ class MMTApiV2:
110
116
  spot=interruptible,
111
117
  image_cluster_credentials=cloud_account_auth,
112
118
  image_secret_ref=image_credentials or "",
113
- artifacts_source=artifacts_local or "",
114
- artifacts_destination=artifacts_remote or "",
119
+ path_mappings=path_mappings_list,
115
120
  )
116
121
  body = ProjectIdMultimachinejobsBody(
117
122
  name=name, spec=spec, cluster_id=cloud_account or "", machines=num_machines
@@ -24,6 +24,7 @@ from lightning_sdk.lightning_cloud.openapi import (
24
24
  UploadsUploadIdBody,
25
25
  V1CompletedPart,
26
26
  V1CompleteUpload,
27
+ V1PathMapping,
27
28
  V1PresignedUrl,
28
29
  V1SignedUrl,
29
30
  V1UploadProjectArtifactPartsResponse,
@@ -614,3 +615,54 @@ def remove_datetime_prefix(text: str) -> str:
614
615
  # lines looks something like
615
616
  # '[2025-01-08T14:15:03.797142418Z] ⚡ ~ echo Hello\n[2025-01-08T14:15:03.803077717Z] Hello\n'
616
617
  return re.sub(r"^\[.*?\] ", "", text, flags=re.MULTILINE)
618
+
619
+
620
+ def resolve_path_mappings(
621
+ mappings: Dict[str, str],
622
+ artifacts_local: Optional[str],
623
+ artifacts_remote: Optional[str],
624
+ ) -> List[V1PathMapping]:
625
+ path_mappings_list = []
626
+ for k, v in mappings.items():
627
+ splitted = str(v).rsplit(":", 1)
628
+ connection_name: str
629
+ connection_path: str
630
+ if len(splitted) == 1:
631
+ connection_name = splitted[0]
632
+ connection_path = ""
633
+ else:
634
+ connection_name, connection_path = splitted
635
+
636
+ path_mappings_list.append(
637
+ V1PathMapping(
638
+ connection_name=connection_name,
639
+ connection_path=connection_path,
640
+ container_path=k,
641
+ )
642
+ )
643
+
644
+ if artifacts_remote:
645
+ splitted = str(artifacts_remote).rsplit(":", 2)
646
+ if len(splitted) not in (2, 3):
647
+ raise RuntimeError(
648
+ f"Artifacts remote need to be of format efs:connection_name[:path] but got {artifacts_remote}"
649
+ )
650
+ else:
651
+ if not artifacts_local:
652
+ raise RuntimeError("If Artifacts remote is specified, artifacts local should be specified as well")
653
+
654
+ if len(splitted) == 2:
655
+ _, connection_name = splitted
656
+ connection_path = ""
657
+ else:
658
+ _, connection_name, connection_path = splitted
659
+
660
+ path_mappings_list.append(
661
+ V1PathMapping(
662
+ connection_name=connection_name,
663
+ connection_path=connection_path,
664
+ container_path=artifacts_local,
665
+ )
666
+ )
667
+
668
+ return path_mappings_list
lightning_sdk/cli/run.py CHANGED
@@ -43,20 +43,28 @@ class _Run:
43
43
  This should be the name of the respective credentials secret created on the Lightning AI platform.
44
44
  cloud_account_auth: Whether to authenticate with the cloud account to pull the image.
45
45
  Required if the registry is part of a cloud provider (e.g. ECR).
46
- artifacts_local: The path of inside the docker container, you want to persist images from.
46
+ entrypoint: The entrypoint of your docker container. Defaults to `sh -c` which
47
+ just runs the provided command in a standard shell.
48
+ To use the pre-defined entrypoint of the provided image, set this to an empty string.
49
+ Only applicable when submitting docker jobs.
50
+ path_mappings: Maps path inside of containers to paths inside data-connections.
51
+ Should be a comma separated list of form:
52
+ <MAPPING_1>,<MAPPING_2>,...
53
+ where each mapping is of the form
54
+ <CONTAINER_PATH_1>:<CONNECTION_NAME_1>:<PATH_WITHIN_CONNECTION_1> and
55
+ omitting the path inside the connection defaults to the connections root.
56
+ artifacts_local: Deprecated in favor of path_mappings.
57
+ The path of inside the docker container, you want to persist images from.
47
58
  CAUTION: When setting this to "/", it will effectively erase your container.
48
59
  Only supported for jobs with a docker image compute environment.
49
- artifacts_remote: The remote storage to persist your artifacts to.
60
+ artifacts_remote: Deprecated in favor of path_mappings.
61
+ The remote storage to persist your artifacts to.
50
62
  Should be of format <CONNECTION_TYPE>:<CONNECTION_NAME>:<PATH_WITHIN_CONNECTION>.
51
63
  PATH_WITHIN_CONNECTION hereby is a path relative to the connection's root.
52
64
  E.g. efs:data:some-path would result in an EFS connection named `data` and to the path `some-path`
53
65
  within it.
54
66
  Note that the connection needs to be added to the teamspace already in order for it to be found.
55
67
  Only supported for jobs with a docker image compute environment.
56
- entrypoint: The entrypoint of your docker container. Defaults to `sh -c` which
57
- just runs the provided command in a standard shell.
58
- To use the pre-defined entrypoint of the provided image, set this to an empty string.
59
- Only applicable when submitting docker jobs.
60
68
  """
61
69
  # TODO: the docstrings from artifacts_local and artifacts_remote don't show up completely,
62
70
  # might need to switch to explicit cli definition
@@ -87,20 +95,28 @@ class _Run:
87
95
  This should be the name of the respective credentials secret created on the Lightning AI platform.
88
96
  cloud_account_auth: Whether to authenticate with the cloud account to pull the image.
89
97
  Required if the registry is part of a cloud provider (e.g. ECR).
90
- artifacts_local: The path of inside the docker container, you want to persist images from.
98
+ entrypoint: The entrypoint of your docker container. Defaults to `sh -c` which
99
+ just runs the provided command in a standard shell.
100
+ To use the pre-defined entrypoint of the provided image, set this to an empty string.
101
+ Only applicable when submitting docker jobs.
102
+ path_mappings: Maps path inside of containers to paths inside data-connections.
103
+ Should be a comma separated list of form:
104
+ <MAPPING_1>,<MAPPING_2>,...
105
+ where each mapping is of the form
106
+ <CONTAINER_PATH_1>:<CONNECTION_NAME_1>:<PATH_WITHIN_CONNECTION_1> and
107
+ omitting the path inside the connection defaults to the connections root.
108
+ artifacts_local: Deprecated in favor of path_mappings.
109
+ The path of inside the docker container, you want to persist images from.
91
110
  CAUTION: When setting this to "/", it will effectively erase your container.
92
111
  Only supported for jobs with a docker image compute environment.
93
- artifacts_remote: The remote storage to persist your artifacts to.
112
+ artifacts_remote: Deprecated in favor of path_mappings.
113
+ The remote storage to persist your artifacts to.
94
114
  Should be of format <CONNECTION_TYPE>:<CONNECTION_NAME>:<PATH_WITHIN_CONNECTION>.
95
115
  PATH_WITHIN_CONNECTION hereby is a path relative to the connection's root.
96
116
  E.g. efs:data:some-path would result in an EFS connection named `data` and to the path `some-path`
97
117
  within it.
98
118
  Note that the connection needs to be added to the teamspace already in order for it to be found.
99
119
  Only supported for jobs with a docker image compute environment.
100
- entrypoint: The entrypoint of your docker container. Defaults to `sh -c` which
101
- just runs the provided command in a standard shell.
102
- To use the pre-defined entrypoint of the provided image, set this to an empty string.
103
- Only applicable when submitting docker jobs.
104
120
  """
105
121
  # TODO: the docstrings from artifacts_local and artifacts_remote don't show up completely,
106
122
  # might need to switch to explicit cli definition
@@ -124,14 +140,15 @@ class _Run:
124
140
  interruptible: bool = False,
125
141
  image_credentials: Optional[str] = None,
126
142
  cloud_account_auth: bool = False,
143
+ entrypoint: str = "sh -c",
144
+ path_mappings: str = "",
127
145
  artifacts_local: Optional[str] = None,
128
146
  artifacts_remote: Optional[str] = None,
129
- entrypoint: str = "sh -c",
130
147
  ) -> None:
131
148
  if not name:
132
149
  from datetime import datetime
133
150
 
134
- timestr = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
151
+ timestr = datetime.now().strftime("%b-%d-%H_%M")
135
152
  name = f"job-{timestr}"
136
153
 
137
154
  if machine is None:
@@ -149,6 +166,8 @@ class _Run:
149
166
  cloud_account = resolved_teamspace.default_cloud_account
150
167
  machine_enum = Machine(machine.upper())
151
168
 
169
+ path_mappings_dict = self._resolve_path_mapping(path_mappings=path_mappings)
170
+
152
171
  Job.run(
153
172
  name=name,
154
173
  machine=machine_enum,
@@ -163,9 +182,10 @@ class _Run:
163
182
  interruptible=interruptible,
164
183
  image_credentials=image_credentials,
165
184
  cloud_account_auth=cloud_account_auth,
185
+ entrypoint=entrypoint,
186
+ path_mappings=path_mappings_dict,
166
187
  artifacts_local=artifacts_local,
167
188
  artifacts_remote=artifacts_remote,
168
- entrypoint=entrypoint,
169
189
  )
170
190
 
171
191
  # TODO: sadly, fire displays both Optional[type] and Union[type, None] as Optional[Optional]
@@ -186,14 +206,15 @@ class _Run:
186
206
  interruptible: bool = False,
187
207
  image_credentials: Optional[str] = None,
188
208
  cloud_account_auth: bool = False,
209
+ entrypoint: str = "sh -c",
210
+ path_mappings: str = "",
189
211
  artifacts_local: Optional[str] = None,
190
212
  artifacts_remote: Optional[str] = None,
191
- entrypoint: str = "sh -c",
192
213
  ) -> None:
193
214
  if name is None:
194
215
  from datetime import datetime
195
216
 
196
- timestr = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
217
+ timestr = datetime.now().strftime("%b-%d-%H_%M")
197
218
  name = f"mmt-{timestr}"
198
219
 
199
220
  if machine is None:
@@ -212,6 +233,8 @@ class _Run:
212
233
  if image is None:
213
234
  raise RuntimeError("Image needs to be specified to run a multi-machine job")
214
235
 
236
+ path_mappings_dict = self._resolve_path_mapping(path_mappings=path_mappings)
237
+
215
238
  MMT.run(
216
239
  name=name,
217
240
  num_machines=num_machines,
@@ -227,7 +250,31 @@ class _Run:
227
250
  interruptible=interruptible,
228
251
  image_credentials=image_credentials,
229
252
  cloud_account_auth=cloud_account_auth,
253
+ entrypoint=entrypoint,
254
+ path_mappings=path_mappings_dict,
230
255
  artifacts_local=artifacts_local,
231
256
  artifacts_remote=artifacts_remote,
232
- entrypoint=entrypoint,
233
257
  )
258
+
259
+ @staticmethod
260
+ def _resolve_path_mapping(path_mappings: str) -> Dict[str, str]:
261
+ path_mappings = path_mappings.strip()
262
+
263
+ if not path_mappings:
264
+ return {}
265
+
266
+ path_mappings_dict = {}
267
+ for mapping in path_mappings.split(","):
268
+ if not mapping.strip():
269
+ continue
270
+
271
+ splits = str(mapping).split(":", 1)
272
+ if len(splits) != 2:
273
+ raise RuntimeError(
274
+ "Mapping needs to be of form <CONTAINER_PATH>:<CONNECTION_NAME>[:<PATH_WITHIN_CONNECTION>], "
275
+ f"but got {mapping}"
276
+ )
277
+
278
+ path_mappings_dict[splits[0].strip()] = splits[1].strip()
279
+
280
+ return path_mappings_dict
@@ -4,6 +4,7 @@ import warnings
4
4
  from pathlib import Path
5
5
  from typing import Optional, Union
6
6
 
7
+ import docker
7
8
  from rich.console import Console
8
9
  from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
9
10
  from rich.prompt import Confirm
@@ -89,11 +90,6 @@ class _LitServe:
89
90
  tag: str = "litserve-model",
90
91
  non_interactive: bool = False,
91
92
  ) -> None:
92
- try:
93
- import docker
94
- except ImportError:
95
- raise ImportError("docker-py is not installed. Please install it with `pip install docker`") from None
96
-
97
93
  try:
98
94
  client = docker.from_env()
99
95
  client.ping()
@@ -2,14 +2,15 @@ import concurrent.futures
2
2
  import json
3
3
  import os
4
4
  from pathlib import Path
5
- from typing import Dict, List, Optional
5
+ from typing import Dict, Generator, List, Optional
6
6
 
7
+ import rich
7
8
  from rich.console import Console
8
9
  from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
9
10
  from simple_term_menu import TerminalMenu
10
11
  from tqdm import tqdm
11
12
 
12
- from lightning_sdk.api.lit_container_api import LitContainerApi
13
+ from lightning_sdk.api.lit_container_api import LCRAuthFailedError, LitContainerApi
13
14
  from lightning_sdk.api.utils import _get_cloud_url
14
15
  from lightning_sdk.cli.exceptions import StudioCliError
15
16
  from lightning_sdk.cli.studios_menu import _StudiosMenu
@@ -162,21 +163,38 @@ class _Uploads(_StudiosMenu, _TeamspacesMenu):
162
163
  transient=False,
163
164
  ) as progress:
164
165
  push_task = progress.add_task("Pushing Docker image", total=None)
165
- resp = api.upload_container(container, teamspace, tag)
166
- for line in resp:
167
- if "status" in line:
168
- console.print(line["status"], style="bright_black")
169
- progress.update(push_task, description="Pushing Docker image")
170
- elif "aux" in line:
171
- console.print(line["aux"], style="bright_black")
172
- elif "error" in line:
173
- progress.stop()
174
- console.print(f"\n[red]{line}[/red]")
175
- return
176
- else:
177
- console.print(line, style="bright_black")
166
+ try:
167
+ lines = api.upload_container(container, teamspace, tag)
168
+ self._print_docker_push(lines, console, progress, push_task)
169
+ except LCRAuthFailedError:
170
+ console.print("Authenticating with Lightning Container Registry...")
171
+ if not api.authenticate():
172
+ raise StudioCliError("Failed to authenticate with Lightning Container Registry") from None
173
+ console.print("Authenticated with Lightning Container Registry", style="green")
174
+ lines = api.upload_container(container, teamspace, tag)
175
+ self._print_docker_push(lines, console, progress, push_task)
178
176
  progress.update(push_task, description="[green]Container pushed![/green]")
179
177
 
178
+ @staticmethod
179
+ def _print_docker_push(
180
+ lines: Generator, console: Console, progress: Progress, push_task: rich.progress.TaskID
181
+ ) -> None:
182
+ for line in lines:
183
+ if "status" in line:
184
+ console.print(line["status"], style="bright_black")
185
+ progress.update(push_task, description="Pushing Docker image")
186
+ elif "aux" in line:
187
+ console.print(line["aux"], style="bright_black")
188
+ elif "error" in line:
189
+ progress.stop()
190
+ console.print(f"\n[red]{line}[/red]")
191
+ return
192
+ elif "finish" in line:
193
+ console.print(f"Container available at [i]{line['url']}[/i]")
194
+ return
195
+ else:
196
+ console.print(line, style="bright_black")
197
+
180
198
  def _start_parallel_upload(
181
199
  self, executor: concurrent.futures.ThreadPoolExecutor, studio: Studio, upload_state: Dict[str, str]
182
200
  ) -> List[concurrent.futures.Future]:
lightning_sdk/helpers.py CHANGED
@@ -43,7 +43,7 @@ def _check_version_and_prompt_upgrade(curr_version: str) -> None:
43
43
  warnings.warn(
44
44
  f"A newer version of {__package_name__} is available ({new_version}). "
45
45
  f"Please consider upgrading with `pip install -U {__package_name__}`. "
46
- "Not all functionalities of the platform can be guaranteed to work with the current version.",
46
+ "Not all platform functionality can be guaranteed to work with the current version.",
47
47
  UserWarning,
48
48
  )
49
49
  return