lightning-sdk 0.1.37__py3-none-any.whl → 0.1.39__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. lightning_sdk/__init__.py +1 -1
  2. lightning_sdk/ai_hub.py +21 -23
  3. lightning_sdk/api/ai_hub_api.py +29 -4
  4. lightning_sdk/api/deployment_api.py +0 -2
  5. lightning_sdk/api/job_api.py +10 -2
  6. lightning_sdk/api/teamspace_api.py +22 -16
  7. lightning_sdk/api/utils.py +25 -3
  8. lightning_sdk/cli/ai_hub.py +1 -1
  9. lightning_sdk/cli/download.py +3 -5
  10. lightning_sdk/cli/run.py +24 -0
  11. lightning_sdk/cli/upload.py +3 -10
  12. lightning_sdk/job/base.py +35 -0
  13. lightning_sdk/job/job.py +18 -1
  14. lightning_sdk/job/v1.py +10 -1
  15. lightning_sdk/job/v2.py +16 -0
  16. lightning_sdk/lightning_cloud/openapi/__init__.py +13 -2
  17. lightning_sdk/lightning_cloud/openapi/api/cluster_service_api.py +5 -1
  18. lightning_sdk/lightning_cloud/openapi/api/data_connection_service_api.py +6 -1
  19. lightning_sdk/lightning_cloud/openapi/api/jobs_service_api.py +680 -62
  20. lightning_sdk/lightning_cloud/openapi/api/models_store_api.py +118 -1
  21. lightning_sdk/lightning_cloud/openapi/models/__init__.py +13 -2
  22. lightning_sdk/lightning_cloud/openapi/models/create.py +6 -32
  23. lightning_sdk/lightning_cloud/openapi/models/deploymenttemplates_id_body.py +32 -6
  24. lightning_sdk/lightning_cloud/openapi/models/externalv1_cloud_space_instance_status.py +27 -1
  25. lightning_sdk/lightning_cloud/openapi/models/id_start_body.py +29 -3
  26. lightning_sdk/lightning_cloud/openapi/models/multimachinejobs_id_body.py +123 -0
  27. lightning_sdk/lightning_cloud/openapi/models/project_id_agents_body.py +53 -1
  28. lightning_sdk/lightning_cloud/openapi/models/project_id_cloudspaces_body.py +53 -1
  29. lightning_sdk/lightning_cloud/openapi/models/project_id_multimachinejobs_body.py +201 -0
  30. lightning_sdk/lightning_cloud/openapi/models/update.py +6 -32
  31. lightning_sdk/lightning_cloud/openapi/models/v1_api_pricing_spec.py +149 -0
  32. lightning_sdk/lightning_cloud/openapi/models/v1_cluster_spec.py +27 -1
  33. lightning_sdk/lightning_cloud/openapi/models/v1_create_deployment_template_request.py +32 -6
  34. lightning_sdk/lightning_cloud/openapi/models/v1_data_connection.py +6 -32
  35. lightning_sdk/lightning_cloud/openapi/models/v1_data_path.py +29 -3
  36. lightning_sdk/lightning_cloud/openapi/models/v1_delete_multi_machine_job_response.py +97 -0
  37. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_metrics.py +43 -17
  38. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_performance.py +305 -0
  39. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_template.py +32 -6
  40. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_template_parameter.py +27 -1
  41. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_template_parameter_type.py +1 -0
  42. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_template_summary.py +27 -1
  43. lightning_sdk/lightning_cloud/openapi/models/{v1_efs_data_connection.py → v1_efs_config.py} +22 -22
  44. lightning_sdk/lightning_cloud/openapi/models/v1_get_model_files_response.py +27 -1
  45. lightning_sdk/lightning_cloud/openapi/models/v1_google_cloud_direct_v1.py +53 -1
  46. lightning_sdk/lightning_cloud/openapi/models/v1_job_spec.py +53 -53
  47. lightning_sdk/lightning_cloud/openapi/models/v1_lambda_labs_direct_v1.py +125 -0
  48. lightning_sdk/lightning_cloud/openapi/models/v1_list_multi_machine_jobs_response.py +123 -0
  49. lightning_sdk/lightning_cloud/openapi/models/v1_machines_selector.py +149 -0
  50. lightning_sdk/lightning_cloud/openapi/models/v1_message.py +6 -6
  51. lightning_sdk/lightning_cloud/openapi/models/v1_message_content.py +6 -6
  52. lightning_sdk/lightning_cloud/openapi/models/v1_message_content_type.py +103 -0
  53. lightning_sdk/lightning_cloud/openapi/models/v1_metrics_stream.py +53 -1
  54. lightning_sdk/lightning_cloud/openapi/models/v1_multi_machine_job.py +409 -0
  55. lightning_sdk/lightning_cloud/openapi/models/v1_multi_machine_job_state.py +106 -0
  56. lightning_sdk/lightning_cloud/openapi/models/v1_multi_machine_job_status.py +279 -0
  57. lightning_sdk/lightning_cloud/openapi/models/v1_rule_resource.py +2 -0
  58. lightning_sdk/lightning_cloud/openapi/models/v1_system_info.py +27 -1
  59. lightning_sdk/lightning_cloud/openapi/models/v1_user_features.py +53 -1
  60. lightning_sdk/lightning_cloud/openapi/models/v1_user_requested_compute_config.py +27 -1
  61. lightning_sdk/lightning_cloud/openapi/models/v1_validate_data_connection_response.py +6 -32
  62. lightning_sdk/lightning_cloud/openapi/models/validate.py +6 -32
  63. lightning_sdk/models.py +132 -0
  64. lightning_sdk/teamspace.py +8 -2
  65. {lightning_sdk-0.1.37.dist-info → lightning_sdk-0.1.39.dist-info}/METADATA +1 -1
  66. {lightning_sdk-0.1.37.dist-info → lightning_sdk-0.1.39.dist-info}/RECORD +70 -59
  67. lightning_sdk/cli/models.py +0 -68
  68. lightning_sdk/lightning_cloud/openapi/models/v1_efs_folder_data_connection.py +0 -201
  69. {lightning_sdk-0.1.37.dist-info → lightning_sdk-0.1.39.dist-info}/LICENSE +0 -0
  70. {lightning_sdk-0.1.37.dist-info → lightning_sdk-0.1.39.dist-info}/WHEEL +0 -0
  71. {lightning_sdk-0.1.37.dist-info → lightning_sdk-0.1.39.dist-info}/entry_points.txt +0 -0
  72. {lightning_sdk-0.1.37.dist-info → lightning_sdk-0.1.39.dist-info}/top_level.txt +0 -0
lightning_sdk/__init__.py CHANGED
@@ -27,5 +27,5 @@ __all__ = [
27
27
  "AIHub",
28
28
  ]
29
29
 
30
- __version__ = "0.1.37"
30
+ __version__ = "0.1.39"
31
31
  _check_version_and_prompt_upgrade(__version__)
lightning_sdk/ai_hub.py CHANGED
@@ -47,17 +47,7 @@ class AIHub:
47
47
  including its name, description, creation and update timestamps,
48
48
  parameters, tags, job specifications, and autoscaling settings.
49
49
  """
50
- template = self._api.api_info(api_id)
51
-
52
- api_arguments = [
53
- {
54
- "name": param.name,
55
- "short_description": param.short_description,
56
- "required": param.required,
57
- "default": param.input.default_value,
58
- }
59
- for param in template.parameter_spec.parameters
60
- ]
50
+ template, api_arguments = self._api.api_info(api_id)
61
51
 
62
52
  return {
63
53
  "name": template.name,
@@ -129,33 +119,33 @@ class AIHub:
129
119
  raise ValueError("You need to pass a teamspace or an org for your deployment.")
130
120
  return teamspace
131
121
 
132
- def deploy(
122
+ def run(
133
123
  self,
134
124
  api_id: str,
135
- cluster: Optional[str] = None,
125
+ api_arguments: Optional[Dict[str, Any]] = None,
136
126
  name: Optional[str] = None,
127
+ cluster: Optional[str] = None,
137
128
  teamspace: Optional[Union[str, "Teamspace"]] = None,
138
129
  org: Optional[Union[str, "Organization"]] = None,
139
- api_arguments: Optional[Dict[str, Any]] = None,
140
130
  ) -> Dict[str, Union[str, bool]]:
141
131
  """Deploy an API from the AI Hub.
142
132
 
143
133
  Example:
144
134
  from lightning_sdk import AIHub
145
135
  hub = AIHub()
146
- deployment = hub.deploy("temp_xxxx")
136
+ deployment = hub.run("temp_xxxx")
147
137
 
148
138
  # Using API arguments
149
- api_arugments = {"model": "unitary/toxic-bert", "batch_size" 10, "token": "lit_xxxx"}
150
- deployment = hub.deploy("temp_xxxx", api_arugments=api_arugments)
139
+ api_arugments = {"model": "unitary/toxic-bert", "batch_size": 10, "token": "lit_xxxx"}
140
+ deployment = hub.run("temp_xxxx", api_arugments=api_arugments)
151
141
 
152
142
  Args:
153
- api_id: The ID of the API you want to deploy.
154
- cluster: The cluster where you want to deploy the API, such as "lightning-public-prod". Defaults to None.
143
+ api_id: The ID of the AIHub template you want to run.
144
+ api_arguments: Additional API argument, such as model name, or batch size.
155
145
  name: Name for the deployed API. Defaults to None.
146
+ cluster: The cluster where you want to run the template, such as "lightning-public-prod". Defaults to None.
156
147
  teamspace: The team or group for deployment. Defaults to None.
157
148
  org: The organization for deployment. Defaults to None.
158
- api_arguments: Additional API argument, such as model name, or batch size.
159
149
 
160
150
  Returns:
161
151
  A dictionary containing the name of the deployed API,
@@ -169,14 +159,22 @@ class AIHub:
169
159
  teamspace_id = teamspace.id
170
160
 
171
161
  api_arguments = api_arguments or {}
172
- deployment = self._api.deploy_api(
162
+ deployment = self._api.run_api(
173
163
  template_id=api_id, cluster_id=cluster, project_id=teamspace_id, name=name, api_arguments=api_arguments
174
164
  )
175
- url = quote(f"{LIGHTNING_CLOUD_URL}/{teamspace._org.name}/{teamspace.name}/jobs/{deployment.name}", safe=":/()")
165
+ url = (
166
+ quote(
167
+ f"{LIGHTNING_CLOUD_URL}/{teamspace._org.name}/{teamspace.name}/jobs/{deployment.name}",
168
+ safe=":/()",
169
+ )
170
+ + "?app_id=deployment"
171
+ )
176
172
  print("Deployment available at:", url)
173
+
177
174
  return {
178
175
  "id": deployment.id,
179
176
  "name": deployment.name,
180
- "base_url": deployment.status.urls[0],
177
+ "deployment_url": url,
178
+ "api_endpoint": deployment.status.urls[0],
181
179
  "interruptible": deployment.spec.spot,
182
180
  }
@@ -1,5 +1,5 @@
1
1
  import traceback
2
- from typing import Dict, List, Optional
2
+ from typing import Dict, List, Optional, Tuple
3
3
 
4
4
  import backoff
5
5
 
@@ -22,15 +22,40 @@ class AIHubApi:
22
22
  def __init__(self) -> None:
23
23
  self._client = LightningClient(max_tries=3)
24
24
 
25
- def api_info(self, api_id: str) -> "V1DeploymentTemplate":
25
+ def api_info(self, api_id: str) -> Tuple[V1DeploymentTemplate, List[Dict[str, str]]]:
26
26
  try:
27
- return self._client.deployment_templates_service_get_deployment_template(api_id)
27
+ template = self._client.deployment_templates_service_get_deployment_template(api_id)
28
28
  except Exception as e:
29
29
  stack_trace = traceback.format_exc()
30
30
  if "record not found" in stack_trace:
31
31
  raise ValueError(f"api_id={api_id} not found.") from None
32
32
  raise e
33
33
 
34
+ api_arguments = []
35
+ for param in template.parameter_spec.parameters:
36
+ default = None
37
+ if param.type == V1DeploymentTemplateParameterType.INPUT and param.input:
38
+ default = param.input.default_value
39
+ if param.type == V1DeploymentTemplateParameterType.SELECT and param.select:
40
+ default = param.select.options[0]
41
+ if param.type == V1DeploymentTemplateParameterType.CHECKBOX and param.checkbox:
42
+ default = (
43
+ (param.checkbox.true_value or "True")
44
+ if param.checkbox.is_checked
45
+ else (param.checkbox.false_value or "False")
46
+ )
47
+
48
+ api_arguments.append(
49
+ {
50
+ "name": param.name,
51
+ "short_description": param.short_description,
52
+ "required": param.required,
53
+ "type": param.type,
54
+ "default": default,
55
+ }
56
+ )
57
+ return template, api_arguments
58
+
34
59
  @backoff.on_predicate(backoff.expo, lambda x: not x, max_tries=5)
35
60
  def list_apis(self, search_query: str) -> List[V1DeploymentTemplateGalleryResponse]:
36
61
  kwargs = {"show_globally_visible": True}
@@ -84,7 +109,7 @@ class AIHubApi:
84
109
 
85
110
  return job
86
111
 
87
- def deploy_api(
112
+ def run_api(
88
113
  self, template_id: str, project_id: str, cluster_id: str, name: Optional[str], api_arguments: Dict[str, str]
89
114
  ) -> V1Deployment:
90
115
  template = self._client.deployment_templates_service_get_deployment_template(template_id)
@@ -548,8 +548,6 @@ def to_spec(
548
548
  spot=spot,
549
549
  instance_name=_MACHINE_TO_COMPUTE_NAME[machine],
550
550
  readiness_probe=to_health_check(health_check),
551
- skip_data_connections_setup=True,
552
- skip_filesystem_setup=True,
553
551
  )
554
552
 
555
553
 
@@ -141,13 +141,17 @@ class JobApiV2:
141
141
  self,
142
142
  name: str,
143
143
  command: Optional[str],
144
- cluster_id: str,
144
+ cluster_id: Optional[str],
145
145
  teamspace_id: str,
146
146
  studio_id: Optional[str],
147
147
  image: Optional[str],
148
148
  machine: Machine,
149
149
  interruptible: bool,
150
150
  env: Optional[Dict[str, str]],
151
+ image_credentials: Optional[str],
152
+ cluster_auth: bool,
153
+ artifacts_local: Optional[str],
154
+ artifacts_remote: Optional[str],
151
155
  ) -> V1Job:
152
156
  env_vars = []
153
157
  if env is not None:
@@ -160,13 +164,17 @@ class JobApiV2:
160
164
 
161
165
  spec = V1JobSpec(
162
166
  cloudspace_id=studio_id or "",
163
- cluster_id=cluster_id,
167
+ cluster_id=cluster_id or "",
164
168
  command=command or "",
165
169
  env=env_vars,
166
170
  image=image or "",
167
171
  instance_name=instance_name,
168
172
  run_id=run_id,
169
173
  spot=interruptible,
174
+ image_cluster_credentials=cluster_auth,
175
+ image_secret_ref=image_credentials or "",
176
+ artifacts_source=artifacts_local or "",
177
+ artifacts_destination=artifacts_remote or "",
170
178
  )
171
179
  body = ProjectIdJobsBody(name=name, spec=spec)
172
180
 
@@ -3,7 +3,7 @@ from dataclasses import dataclass
3
3
  from pathlib import Path
4
4
  from typing import Dict, List, Optional
5
5
 
6
- from lightning_sdk.api.utils import _download_model_files, _DummyBody, _ModelFileUploader
6
+ from lightning_sdk.api.utils import _download_model_files, _DummyBody, _get_model_version, _ModelFileUploader
7
7
  from lightning_sdk.lightning_cloud.login import Auth
8
8
  from lightning_sdk.lightning_cloud.openapi import (
9
9
  ModelIdVersionsBody,
@@ -159,6 +159,16 @@ class TeamspaceApi:
159
159
 
160
160
  return self._client.assistants_service_create_assistant(body=body, project_id=teamspace_id)
161
161
 
162
+ # lazy property which is only created when needed
163
+ @property
164
+ def models(self) -> ModelsStoreApi:
165
+ if not self._models:
166
+ self._models = ModelsStoreApi(self._client.api_client)
167
+ return self._models
168
+
169
+ def get_model_version(self, name: str, version: str, teamspace_id: str) -> V1ModelVersionArchive:
170
+ return _get_model_version(client=self._client, name=name, version=version, teamspace_id=teamspace_id)
171
+
162
172
  def create_model(
163
173
  self,
164
174
  name: str,
@@ -167,17 +177,15 @@ class TeamspaceApi:
167
177
  teamspace_id: str,
168
178
  cluster_id: str,
169
179
  ) -> V1ModelVersionArchive:
170
- if not self._models:
171
- self._models = ModelsStoreApi(self._client.api_client)
172
180
  # ask if such model already exists by listing models with specific name
173
- models = self._models.models_store_list_models(project_id=teamspace_id, name=name).models
181
+ models = self.models.models_store_list_models(project_id=teamspace_id, name=name).models
174
182
  if len(models) == 0:
175
- return self._models.models_store_create_model(
183
+ return self.models.models_store_create_model(
176
184
  body=ProjectIdModelsBody(cluster_id=cluster_id, metadata=metadata, name=name, private=private),
177
185
  project_id=teamspace_id,
178
186
  )
179
187
  assert len(models) == 1, "Multiple models with the same name found"
180
- return self._models.models_store_create_model_version(
188
+ return self.models.models_store_create_model_version(
181
189
  body=ModelIdVersionsBody(cluster_id=cluster_id),
182
190
  project_id=teamspace_id,
183
191
  model_id=models[0].id,
@@ -185,18 +193,16 @@ class TeamspaceApi:
185
193
 
186
194
  def delete_model(self, name: str, version: Optional[str], teamspace_id: str) -> None:
187
195
  """Delete a model or a version from the model store."""
188
- if not self._models:
189
- self._models = ModelsStoreApi(self._client.api_client)
190
- models = self._models.models_store_list_models(project_id=teamspace_id, name=name).models
196
+ models = self.models.models_store_list_models(project_id=teamspace_id, name=name).models
191
197
  assert len(models) == 1, "Multiple models with the same name found"
192
198
  model_id = models[0].id
193
199
  # decide if delete only version of whole model
194
200
  if version:
195
201
  if version == "latest":
196
202
  version = models[0].latest_version
197
- self._models.models_store_delete_model_version(project_id=teamspace_id, model_id=model_id, version=version)
203
+ self.models.models_store_delete_model_version(project_id=teamspace_id, model_id=model_id, version=version)
198
204
  else:
199
- self._models.models_store_delete_model(project_id=teamspace_id, model_id=model_id)
205
+ self.models.models_store_delete_model(project_id=teamspace_id, model_id=model_id)
200
206
 
201
207
  def upload_model_file(
202
208
  self,
@@ -242,9 +248,7 @@ class TeamspaceApi:
242
248
  )
243
249
 
244
250
  def complete_model_upload(self, model_id: str, version: str, teamspace_id: str) -> None:
245
- if not self._models:
246
- self._models = ModelsStoreApi(self._client.api_client)
247
- self._models.models_store_complete_model_upload(
251
+ self.models.models_store_complete_model_upload(
248
252
  body=_DummyBody(),
249
253
  project_id=teamspace_id,
250
254
  model_id=model_id,
@@ -256,12 +260,14 @@ class TeamspaceApi:
256
260
  name: str,
257
261
  version: str,
258
262
  download_dir: Path,
259
- teamspace_id: str,
263
+ teamspace_name: str,
264
+ teamspace_owner_name: str,
260
265
  progress_bar: bool = True,
261
266
  ) -> List[str]:
262
267
  return _download_model_files(
263
268
  client=self._client,
264
- teamspace_id=teamspace_id,
269
+ teamspace_name=teamspace_name,
270
+ teamspace_owner_name=teamspace_owner_name,
265
271
  name=name,
266
272
  version=version,
267
273
  download_dir=download_dir,
@@ -28,6 +28,7 @@ from lightning_sdk.lightning_cloud.openapi import (
28
28
  V1UploadProjectArtifactResponse,
29
29
  VersionUploadsBody,
30
30
  )
31
+ from lightning_sdk.lightning_cloud.openapi.models.v1_model_version_archive import V1ModelVersionArchive
31
32
 
32
33
  try:
33
34
  from lightning_sdk.lightning_cloud.openapi import AppsIdBody1 as AppsIdBody
@@ -492,9 +493,28 @@ class _FileDownloader:
492
493
  os.rename(tmp_filename, self.local_path)
493
494
 
494
495
 
496
+ def _get_model_version(client: LightningClient, teamspace_id: str, name: str, version: str) -> V1ModelVersionArchive:
497
+ api = ModelsStoreApi(client.api_client)
498
+ models = api.models_store_list_models(project_id=teamspace_id, name=name).models
499
+ if not models:
500
+ raise ValueError(f"Model `{name}` does not exist")
501
+ elif len(models) > 1:
502
+ raise ValueError("Multiple models with the same name found")
503
+ if version == "latest":
504
+ return models[0].latest_version
505
+ versions = api.models_store_list_model_versions(project_id=teamspace_id, model_id=models[0].id).versions
506
+ if not versions:
507
+ raise ValueError(f"Model `{name}` does not have any versions")
508
+ for ver in versions:
509
+ if ver.version == version:
510
+ return ver
511
+ raise ValueError(f"Model `{name}` does not have version `{version}`")
512
+
513
+
495
514
  def _download_model_files(
496
515
  client: LightningClient,
497
- teamspace_id: str,
516
+ teamspace_name: str,
517
+ teamspace_owner_name: str,
498
518
  name: str,
499
519
  version: str,
500
520
  download_dir: Path,
@@ -502,7 +522,9 @@ def _download_model_files(
502
522
  num_workers: int = 20,
503
523
  ) -> List[str]:
504
524
  api = ModelsStoreApi(client.api_client)
505
- response = api.models_store_get_model_files(project_id=teamspace_id, name=name, version=version)
525
+ response = api.models_store_get_model_files(
526
+ project_name=teamspace_name, project_owner_name=teamspace_owner_name, name=name, version=version
527
+ )
506
528
 
507
529
  pbar = None
508
530
  if progress_bar:
@@ -522,7 +544,7 @@ def _download_model_files(
522
544
  client=client,
523
545
  model_id=response.model_id,
524
546
  version=response.version,
525
- teamspace_id=teamspace_id,
547
+ teamspace_id=response.project_id,
526
548
  remote_path=filepath,
527
549
  file_path=str(local_file),
528
550
  num_workers=num_workers,
@@ -46,4 +46,4 @@ class _AIHub(_StudiosMenu):
46
46
  teamspace: Teamspace to deploy the API to. Defaults to user's default teamspace.
47
47
  org: Organization to deploy the API to. Defaults to user's default organization.
48
48
  """
49
- return self._hub.deploy(api_id, cluster=cluster, name=name, teamspace=teamspace, org=org)
49
+ return self._hub.run(api_id, cluster=cluster, name=name, teamspace=teamspace, org=org)
@@ -4,8 +4,8 @@ from pathlib import Path
4
4
  from typing import Optional
5
5
 
6
6
  from lightning_sdk.cli.exceptions import StudioCliError
7
- from lightning_sdk.cli.models import _get_teamspace, _parse_model_name
8
7
  from lightning_sdk.cli.studios_menu import _StudiosMenu
8
+ from lightning_sdk.models import download_model
9
9
  from lightning_sdk.studio import Studio
10
10
  from lightning_sdk.utils.resolve import _get_authed_user, skip_studio_init
11
11
 
@@ -21,10 +21,8 @@ class _Downloads(_StudiosMenu):
21
21
  This should have the format <ORGANIZATION-NAME>/<TEAMSPACE-NAME>/<MODEL-NAME>.
22
22
  download_dir: The directory where the Model should be downloaded.
23
23
  """
24
- org_name, teamspace_name, model_name = _parse_model_name(name)
25
- teamspace = _get_teamspace(name=teamspace_name, organization=org_name)
26
- teamspace.download_model(
27
- name=model_name,
24
+ download_model(
25
+ name=name,
28
26
  download_dir=download_dir,
29
27
  progress_bar=True,
30
28
  )
lightning_sdk/cli/run.py CHANGED
@@ -36,7 +36,23 @@ class _Run:
36
36
  If not provided will fall back to the teamspaces default cluster.
37
37
  env: Environment variables to set inside the job.
38
38
  interruptible: Whether the job should run on interruptible instances. They are cheaper but can be preempted.
39
+ image_credentials: The credentials used to pull the image. Required if the image is private.
40
+ This should be the name of the respective credentials secret created on the Lightning AI platform.
41
+ cluster_auth: Whether to authenticate with the cluster to pull the image.
42
+ Required if the registry is part of a cluster provider (e.g. ECR).
43
+ artifacts_local: The path of inside the docker container, you want to persist images from.
44
+ CAUTION: When setting this to "/", it will effectively erase your container.
45
+ Only supported for jobs with a docker image compute environment.
46
+ artifacts_remote: The remote storage to persist your artifacts to.
47
+ Should be of format <CONNECTION_TYPE>:<CONNECTION_NAME>:<PATH_WITHIN_CONNECTION>.
48
+ PATH_WITHIN_CONNECTION hereby is a path relative to the connection's root.
49
+ E.g. efs:data:some-path would result in an EFS connection named `data` and to the path `some-path`
50
+ within it.
51
+ Note that the connection needs to be added to the teamspace already in order for it to be found.
52
+ Only supported for jobs with a docker image compute environment.
39
53
  """
54
+ # TODO: the docstrings from artifacts_local and artifacts_remote don't show up completely,
55
+ # might need to switch to explicit cli definition
40
56
  self.job.__func__.__doc__ = docstr
41
57
 
42
58
  # TODO: sadly, fire displays both Optional[type] and Union[type, None] as Optional[Optional]
@@ -55,6 +71,10 @@ class _Run:
55
71
  cluster: Optional[str] = None,
56
72
  env: Optional[Dict[str, str]] = None,
57
73
  interruptible: bool = False,
74
+ image_credentials: Optional[str] = None,
75
+ cluster_auth: bool = False,
76
+ artifacts_local: Optional[str] = None,
77
+ artifacts_remote: Optional[str] = None,
58
78
  ) -> None:
59
79
  machine_enum = Machine(machine.upper())
60
80
  Job.run(
@@ -69,4 +89,8 @@ class _Run:
69
89
  cluster=cluster,
70
90
  env=env,
71
91
  interruptible=interruptible,
92
+ image_credentials=image_credentials,
93
+ cluster_auth=cluster_auth,
94
+ artifacts_local=artifacts_local,
95
+ artifacts_remote=artifacts_remote,
72
96
  )
@@ -9,8 +9,8 @@ from tqdm import tqdm
9
9
 
10
10
  from lightning_sdk.api.utils import _get_cloud_url
11
11
  from lightning_sdk.cli.exceptions import StudioCliError
12
- from lightning_sdk.cli.models import _get_teamspace, _parse_model_name
13
12
  from lightning_sdk.cli.studios_menu import _StudiosMenu
13
+ from lightning_sdk.models import upload_model
14
14
  from lightning_sdk.studio import Studio
15
15
  from lightning_sdk.utils.resolve import _get_authed_user, skip_studio_init
16
16
 
@@ -20,7 +20,7 @@ class _Uploads(_StudiosMenu):
20
20
 
21
21
  _studio_upload_status_path = "~/.lightning/studios/uploads"
22
22
 
23
- def model(self, name: str, path: Optional[str] = None, cloud_account: Optional[str] = None) -> None:
23
+ def model(self, name: str, path: str = ".", cloud_account: Optional[str] = None) -> None:
24
24
  """Upload a Model.
25
25
 
26
26
  Args:
@@ -29,14 +29,7 @@ class _Uploads(_StudiosMenu):
29
29
  path: The path to the file or directory you want to upload. Defaults to the current directory.
30
30
  cloud_account: The name of the cloud account to store the Model in.
31
31
  """
32
- org_name, teamspace_name, model_name = _parse_model_name(name)
33
- teamspace = _get_teamspace(name=teamspace_name, organization=org_name)
34
- teamspace.upload_model(
35
- path=path or ".",
36
- name=model_name,
37
- progress_bar=True,
38
- cluster_id=cloud_account,
39
- )
32
+ upload_model(name, path, cloud_account=cloud_account)
40
33
 
41
34
  def _resolve_studio(self, studio: Optional[str]) -> Studio:
42
35
  user = _get_authed_user()
lightning_sdk/job/base.py CHANGED
@@ -50,6 +50,10 @@ class _BaseJob(ABC):
50
50
  cluster: Optional[str] = None,
51
51
  env: Optional[Dict[str, str]] = None,
52
52
  interruptible: bool = False,
53
+ image_credentials: Optional[str] = None,
54
+ cluster_auth: bool = False,
55
+ artifacts_local: Optional[str] = None,
56
+ artifacts_remote: Optional[str] = None,
53
57
  ) -> "_BaseJob":
54
58
  from lightning_sdk.studio import Studio
55
59
 
@@ -80,12 +84,35 @@ class _BaseJob(ABC):
80
84
  "Studio cluster does not match provided cluster. "
81
85
  "Can only run jobs with Studio envs in the same cluster."
82
86
  )
87
+
88
+ if image_credentials is not None:
89
+ raise ValueError("image_credentials is only supported when using a custom image")
90
+
91
+ if cluster_auth:
92
+ raise ValueError("cluster_auth is only supported when using a custom image")
93
+
94
+ if artifacts_local is not None or artifacts_remote is not None:
95
+ raise ValueError(
96
+ "Specifying artifacts persistence is supported for docker images only. "
97
+ "Other jobs will automatically persist artifacts to the teamspace distributed filesystem."
98
+ )
99
+
83
100
  else:
84
101
  if studio is not None:
85
102
  raise RuntimeError(
86
103
  "image and studio are mutually exclusive as both define the environment to run the job in"
87
104
  )
88
105
 
106
+ # they either need to specified both or none of them
107
+ if bool(artifacts_local) != bool(artifacts_remote):
108
+ raise ValueError("Artifact persistence requires both artifacts_local and artifacts_remote to be set")
109
+
110
+ if artifacts_remote and len(artifacts_remote.split(":")) != 3:
111
+ raise ValueError(
112
+ "Artifact persistence requires exactly three arguments separated by colon of kind "
113
+ f"<CONNECTION_TYPE>:<CONNECTION_NAME>:<PATH_WITHIN_CONNECTION>, got {artifacts_local}"
114
+ )
115
+
89
116
  inst = cls(name=name, teamspace=teamspace, org=org, user=user, _fetch_job=False)
90
117
  inst._submit(
91
118
  machine=machine,
@@ -95,6 +122,10 @@ class _BaseJob(ABC):
95
122
  image=image,
96
123
  env=env,
97
124
  interruptible=interruptible,
125
+ image_credentials=image_credentials,
126
+ cluster_auth=cluster_auth,
127
+ artifacts_local=artifacts_local,
128
+ artifacts_remote=artifacts_remote,
98
129
  )
99
130
  return inst
100
131
 
@@ -108,6 +139,10 @@ class _BaseJob(ABC):
108
139
  env: Optional[Dict[str, str]] = None,
109
140
  interruptible: bool = False,
110
141
  cluster: Optional[str] = None,
142
+ image_credentials: Optional[str] = None,
143
+ cluster_auth: bool = False,
144
+ artifacts_local: Optional[str] = None,
145
+ artifacts_remote: Optional[str] = None,
111
146
  ) -> None:
112
147
  """Submits a job and updates the internal _job attribute as well as the _name attribute."""
113
148
 
lightning_sdk/job/job.py CHANGED
@@ -58,6 +58,10 @@ class Job(_BaseJob):
58
58
  cluster: Optional[str] = None,
59
59
  env: Optional[Dict[str, str]] = None,
60
60
  interruptible: bool = False,
61
+ image_credentials: Optional[str] = None,
62
+ cluster_auth: bool = False,
63
+ artifacts_local: Optional[str] = None,
64
+ artifacts_remote: Optional[str] = None,
61
65
  ) -> "Job":
62
66
  ret_val = super().run(
63
67
  name=name,
@@ -71,6 +75,10 @@ class Job(_BaseJob):
71
75
  cluster=cluster,
72
76
  env=env,
73
77
  interruptible=interruptible,
78
+ image_credentials=image_credentials,
79
+ cluster_auth=cluster_auth,
80
+ artifacts_local=artifacts_local,
81
+ artifacts_remote=artifacts_remote,
74
82
  )
75
83
  # required for typing with "Job"
76
84
  assert isinstance(ret_val, cls)
@@ -85,8 +93,12 @@ class Job(_BaseJob):
85
93
  env: Optional[Dict[str, str]] = None,
86
94
  interruptible: bool = False,
87
95
  cluster: Optional[str] = None,
96
+ image_credentials: Optional[str] = None,
97
+ cluster_auth: bool = False,
98
+ artifacts_local: Optional[str] = None,
99
+ artifacts_remote: Optional[str] = None,
88
100
  ) -> None:
89
- return self._internal_job._submit(
101
+ self._job = self._internal_job._submit(
90
102
  machine=machine,
91
103
  cluster=cluster,
92
104
  command=command,
@@ -94,7 +106,12 @@ class Job(_BaseJob):
94
106
  image=image,
95
107
  env=env,
96
108
  interruptible=interruptible,
109
+ image_credentials=image_credentials,
110
+ cluster_auth=cluster_auth,
111
+ artifacts_local=artifacts_local,
112
+ artifacts_remote=artifacts_remote,
97
113
  )
114
+ return self
98
115
 
99
116
  def stop(self) -> None:
100
117
  return self._internal_job.stop()
lightning_sdk/job/v1.py CHANGED
@@ -54,6 +54,8 @@ class _JobV1(_BaseJob):
54
54
  cluster=cluster,
55
55
  env=None,
56
56
  interruptible=interruptible,
57
+ image_credentials=None,
58
+ cluster_auth=False,
57
59
  )
58
60
 
59
61
  def _submit(
@@ -65,13 +67,20 @@ class _JobV1(_BaseJob):
65
67
  env: Optional[Dict[str, str]] = None,
66
68
  interruptible: bool = False,
67
69
  cluster: Optional[str] = None,
70
+ image_credentials: Optional[str] = None,
71
+ cluster_auth: bool = False,
72
+ artifacts_local: Optional[str] = None,
73
+ artifacts_remote: Optional[str] = None,
68
74
  ) -> None:
69
75
  if studio is None:
70
76
  raise ValueError("Studio is required for submitting jobs")
71
77
 
72
- if image is not None:
78
+ if image is not None or image_credentials is not None or cluster_auth:
73
79
  raise ValueError("Image is not supported for submitting jobs")
74
80
 
81
+ if artifacts_local is not None or artifacts_remote is not None:
82
+ raise ValueError("Specifying how to persist artifacts is not yet supported with jobs")
83
+
75
84
  if env is not None:
76
85
  raise ValueError("Environment variables are not supported for submitting jobs")
77
86
 
lightning_sdk/job/v2.py CHANGED
@@ -34,6 +34,10 @@ class _JobV2(_BaseJob):
34
34
  env: Optional[Dict[str, str]] = None,
35
35
  interruptible: bool = False,
36
36
  cluster: Optional[str] = None,
37
+ image_credentials: Optional[str] = None,
38
+ cluster_auth: bool = False,
39
+ artifacts_local: Optional[str] = None,
40
+ artifacts_remote: Optional[str] = None,
37
41
  ) -> None:
38
42
  # Command is required if Studio is provided to know what to run
39
43
  # Image is mutually exclusive with Studio
@@ -62,6 +66,10 @@ class _JobV2(_BaseJob):
62
66
  machine=machine,
63
67
  interruptible=interruptible,
64
68
  env=env,
69
+ image_credentials=image_credentials,
70
+ cluster_auth=cluster_auth,
71
+ artifacts_local=artifacts_local,
72
+ artifacts_remote=artifacts_remote,
65
73
  )
66
74
  self._job = submitted
67
75
  self._name = submitted.name
@@ -108,10 +116,18 @@ class _JobV2(_BaseJob):
108
116
 
109
117
  @property
110
118
  def artifact_path(self) -> Optional[str]:
119
+ if self._guaranteed_job.spec.image != "":
120
+ if self._guaranteed_job.spec.artifacts_destination != "":
121
+ splits = self._guaranteed_job.spec.artifacts_destination.split(":")
122
+ return f"/teamspace/{splits[0]}_connections/{splits[1]}/{splits[2]}"
123
+ return None
124
+
111
125
  return f"/teamspace/jobs/{self._guaranteed_job.name}/artifacts"
112
126
 
113
127
  @property
114
128
  def snapshot_path(self) -> Optional[str]:
129
+ if self._guaranteed_job.spec.image != "":
130
+ return None
115
131
  return f"/teamspace/jobs/{self._guaranteed_job.name}/snapshot"
116
132
 
117
133
  @property