lightning-sdk 0.1.36__py3-none-any.whl → 0.1.38__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. lightning_sdk/__init__.py +1 -1
  2. lightning_sdk/ai_hub.py +38 -31
  3. lightning_sdk/api/ai_hub_api.py +29 -4
  4. lightning_sdk/api/job_api.py +6 -2
  5. lightning_sdk/api/teamspace_api.py +18 -14
  6. lightning_sdk/api/utils.py +19 -0
  7. lightning_sdk/cli/ai_hub.py +1 -1
  8. lightning_sdk/cli/entrypoint.py +2 -2
  9. lightning_sdk/cli/models.py +45 -15
  10. lightning_sdk/cli/run.py +80 -0
  11. lightning_sdk/job/base.py +61 -20
  12. lightning_sdk/job/job.py +29 -11
  13. lightning_sdk/job/v1.py +14 -10
  14. lightning_sdk/job/v2.py +10 -6
  15. lightning_sdk/lightning_cloud/openapi/__init__.py +13 -2
  16. lightning_sdk/lightning_cloud/openapi/api/cluster_service_api.py +5 -1
  17. lightning_sdk/lightning_cloud/openapi/api/jobs_service_api.py +680 -62
  18. lightning_sdk/lightning_cloud/openapi/models/__init__.py +13 -2
  19. lightning_sdk/lightning_cloud/openapi/models/create.py +6 -32
  20. lightning_sdk/lightning_cloud/openapi/models/create_deployment_request_defines_a_spec_for_the_job_that_allows_for_autoscaling_jobs.py +27 -1
  21. lightning_sdk/lightning_cloud/openapi/models/deployments_id_body.py +27 -1
  22. lightning_sdk/lightning_cloud/openapi/models/deploymenttemplates_id_body.py +32 -6
  23. lightning_sdk/lightning_cloud/openapi/models/externalv1_cloud_space_instance_status.py +27 -1
  24. lightning_sdk/lightning_cloud/openapi/models/multimachinejobs_id_body.py +123 -0
  25. lightning_sdk/lightning_cloud/openapi/models/project_id_agents_body.py +53 -1
  26. lightning_sdk/lightning_cloud/openapi/models/project_id_cloudspaces_body.py +27 -1
  27. lightning_sdk/lightning_cloud/openapi/models/project_id_multimachinejobs_body.py +201 -0
  28. lightning_sdk/lightning_cloud/openapi/models/update.py +6 -32
  29. lightning_sdk/lightning_cloud/openapi/models/v1_api_pricing_spec.py +149 -0
  30. lightning_sdk/lightning_cloud/openapi/models/v1_checkbox.py +29 -3
  31. lightning_sdk/lightning_cloud/openapi/models/v1_cluster_spec.py +27 -1
  32. lightning_sdk/lightning_cloud/openapi/models/v1_create_deployment_template_request.py +32 -6
  33. lightning_sdk/lightning_cloud/openapi/models/v1_data_connection.py +6 -32
  34. lightning_sdk/lightning_cloud/openapi/models/v1_delete_multi_machine_job_response.py +97 -0
  35. lightning_sdk/lightning_cloud/openapi/models/v1_deployment.py +27 -1
  36. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_metrics.py +43 -17
  37. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_performance.py +305 -0
  38. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_template.py +32 -6
  39. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_template_parameter.py +27 -1
  40. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_template_parameter_type.py +1 -0
  41. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_template_summary.py +27 -1
  42. lightning_sdk/lightning_cloud/openapi/models/{v1_efs_data_connection.py → v1_efs_config.py} +22 -22
  43. lightning_sdk/lightning_cloud/openapi/models/v1_get_model_files_response.py +27 -1
  44. lightning_sdk/lightning_cloud/openapi/models/v1_google_cloud_direct_v1.py +53 -27
  45. lightning_sdk/lightning_cloud/openapi/models/v1_input.py +29 -3
  46. lightning_sdk/lightning_cloud/openapi/models/v1_job.py +27 -1
  47. lightning_sdk/lightning_cloud/openapi/models/v1_lambda_labs_direct_v1.py +125 -0
  48. lightning_sdk/lightning_cloud/openapi/models/v1_list_multi_machine_jobs_response.py +123 -0
  49. lightning_sdk/lightning_cloud/openapi/models/v1_machines_selector.py +149 -0
  50. lightning_sdk/lightning_cloud/openapi/models/v1_message.py +6 -6
  51. lightning_sdk/lightning_cloud/openapi/models/v1_message_content.py +6 -6
  52. lightning_sdk/lightning_cloud/openapi/models/v1_message_content_type.py +103 -0
  53. lightning_sdk/lightning_cloud/openapi/models/v1_metrics_stream.py +53 -1
  54. lightning_sdk/lightning_cloud/openapi/models/v1_multi_machine_job.py +383 -0
  55. lightning_sdk/lightning_cloud/openapi/models/v1_multi_machine_job_state.py +108 -0
  56. lightning_sdk/lightning_cloud/openapi/models/v1_multi_machine_job_status.py +279 -0
  57. lightning_sdk/lightning_cloud/openapi/models/v1_rule_resource.py +2 -0
  58. lightning_sdk/lightning_cloud/openapi/models/v1_select.py +29 -3
  59. lightning_sdk/lightning_cloud/openapi/models/v1_system_info.py +27 -1
  60. lightning_sdk/lightning_cloud/openapi/models/v1_user_features.py +131 -1
  61. lightning_sdk/lightning_cloud/openapi/models/v1_user_requested_compute_config.py +27 -1
  62. lightning_sdk/lightning_cloud/openapi/models/v1_validate_data_connection_response.py +6 -32
  63. lightning_sdk/lightning_cloud/openapi/models/validate.py +6 -32
  64. lightning_sdk/teamspace.py +5 -0
  65. {lightning_sdk-0.1.36.dist-info → lightning_sdk-0.1.38.dist-info}/METADATA +1 -1
  66. {lightning_sdk-0.1.36.dist-info → lightning_sdk-0.1.38.dist-info}/RECORD +70 -58
  67. lightning_sdk/lightning_cloud/openapi/models/v1_efs_folder_data_connection.py +0 -201
  68. {lightning_sdk-0.1.36.dist-info → lightning_sdk-0.1.38.dist-info}/LICENSE +0 -0
  69. {lightning_sdk-0.1.36.dist-info → lightning_sdk-0.1.38.dist-info}/WHEEL +0 -0
  70. {lightning_sdk-0.1.36.dist-info → lightning_sdk-0.1.38.dist-info}/entry_points.txt +0 -0
  71. {lightning_sdk-0.1.36.dist-info → lightning_sdk-0.1.38.dist-info}/top_level.txt +0 -0
lightning_sdk/__init__.py CHANGED
@@ -27,5 +27,5 @@ __all__ = [
27
27
  "AIHub",
28
28
  ]
29
29
 
30
- __version__ = "0.1.36"
30
+ __version__ = "0.1.38"
31
31
  _check_version_and_prompt_upgrade(__version__)
lightning_sdk/ai_hub.py CHANGED
@@ -15,8 +15,17 @@ class AIHub:
15
15
  """An interface to interact with the AI Hub.
16
16
 
17
17
  Example:
18
- ai_hub = AIHub()
19
- api_list = ai_hub.list_apis()
18
+ from lightning_sdk import AIHub
19
+ hub = AIHub()
20
+
21
+ # List public API templates
22
+ api_list = hub.list_apis()
23
+
24
+ # Get detailed information about an API template
25
+ api_info = hub.api_info("temp_xxxx")
26
+
27
+ # Deploy an API template
28
+ deployment = hub.deploy("temp_xxxx")
20
29
  """
21
30
 
22
31
  def __init__(self) -> None:
@@ -28,7 +37,7 @@ class AIHub:
28
37
 
29
38
  Example:
30
39
  ai_hub = AIHub()
31
- api_info = ai_hub.api_info("api_12345")
40
+ api_info = ai_hub.api_info("temp_xxxx")
32
41
 
33
42
  Args:
34
43
  api_id: The ID of the API for which information is requested.
@@ -38,17 +47,7 @@ class AIHub:
38
47
  including its name, description, creation and update timestamps,
39
48
  parameters, tags, job specifications, and autoscaling settings.
40
49
  """
41
- template = self._api.api_info(api_id)
42
-
43
- api_arguments = [
44
- {
45
- "name": param.name,
46
- "short_description": param.short_description,
47
- "required": param.required,
48
- "default": param.input.default_value,
49
- }
50
- for param in template.parameter_spec.parameters
51
- ]
50
+ template, api_arguments = self._api.api_info(api_id)
52
51
 
53
52
  return {
54
53
  "name": template.name,
@@ -70,8 +69,11 @@ class AIHub:
70
69
  },
71
70
  }
72
71
 
73
- def list_apis(self, search: Optional[str] = None) -> List[Dict[str, str]]:
74
- """Get a list of AI Hub API templates.
72
+ def list_apis(
73
+ self,
74
+ search: Optional[str] = None,
75
+ ) -> List[Dict[str, str]]:
76
+ """Get a list of public AI Hub API templates.
75
77
 
76
78
  Example:
77
79
  ai_hub = AIHub()
@@ -93,9 +95,6 @@ class AIHub:
93
95
  "name": template.name,
94
96
  "description": template.description,
95
97
  "creator_username": template.creator_username,
96
- "created_on": template.creation_timestamp.strftime("%Y-%m-%d %H:%M:%S")
97
- if template.creation_timestamp
98
- else None,
99
98
  }
100
99
  results.append(result)
101
100
  return results
@@ -120,33 +119,33 @@ class AIHub:
120
119
  raise ValueError("You need to pass a teamspace or an org for your deployment.")
121
120
  return teamspace
122
121
 
123
- def deploy(
122
+ def run(
124
123
  self,
125
124
  api_id: str,
126
- cluster: Optional[str] = None,
125
+ api_arguments: Optional[Dict[str, Any]] = None,
127
126
  name: Optional[str] = None,
127
+ cluster: Optional[str] = None,
128
128
  teamspace: Optional[Union[str, "Teamspace"]] = None,
129
129
  org: Optional[Union[str, "Organization"]] = None,
130
- api_arguments: Optional[Dict[str, Any]] = None,
131
130
  ) -> Dict[str, Union[str, bool]]:
132
131
  """Deploy an API from the AI Hub.
133
132
 
134
133
  Example:
135
134
  from lightning_sdk import AIHub
136
135
  hub = AIHub()
137
- deployment = hub.deploy("temp_01jc37n6qpqkdptjpyep0z06hy")
136
+ deployment = hub.run("temp_xxxx")
138
137
 
139
138
  # Using API arguments
140
- api_arugments = {"batch_size" 10, "batch_timeout": 0.001, "env_token": "lit_xxxx"}
141
- deployment = hub.deploy("temp_01jc37n6qpqkdptjpyep0z06hy", api_arugments=api_arugments)
139
+ api_arugments = {"model": "unitary/toxic-bert", "batch_size": 10, "token": "lit_xxxx"}
140
+ deployment = hub.run("temp_xxxx", api_arugments=api_arugments)
142
141
 
143
142
  Args:
144
- api_id: The ID of the API you want to deploy.
145
- cluster: The cluster where you want to deploy the API, such as "lightning-public-prod". Defaults to None.
143
+ api_id: The ID of the AIHub template you want to run.
144
+ api_arguments: Additional API argument, such as model name, or batch size.
146
145
  name: Name for the deployed API. Defaults to None.
146
+ cluster: The cluster where you want to run the template, such as "lightning-public-prod". Defaults to None.
147
147
  teamspace: The team or group for deployment. Defaults to None.
148
148
  org: The organization for deployment. Defaults to None.
149
- api_arguments: Additional API argument, such as model name, or batch size.
150
149
 
151
150
  Returns:
152
151
  A dictionary containing the name of the deployed API,
@@ -160,14 +159,22 @@ class AIHub:
160
159
  teamspace_id = teamspace.id
161
160
 
162
161
  api_arguments = api_arguments or {}
163
- deployment = self._api.deploy_api(
162
+ deployment = self._api.run_api(
164
163
  template_id=api_id, cluster_id=cluster, project_id=teamspace_id, name=name, api_arguments=api_arguments
165
164
  )
166
- url = quote(f"{LIGHTNING_CLOUD_URL}/{teamspace._org.name}/{teamspace.name}/jobs/{deployment.name}", safe=":/()")
165
+ url = (
166
+ quote(
167
+ f"{LIGHTNING_CLOUD_URL}/{teamspace._org.name}/{teamspace.name}/jobs/{deployment.name}",
168
+ safe=":/()",
169
+ )
170
+ + "?app_id=deployment"
171
+ )
167
172
  print("Deployment available at:", url)
173
+
168
174
  return {
169
175
  "id": deployment.id,
170
176
  "name": deployment.name,
171
- "base_url": deployment.status.urls[0],
177
+ "deployment_url": url,
178
+ "api_endpoint": deployment.status.urls[0],
172
179
  "interruptible": deployment.spec.spot,
173
180
  }
@@ -1,5 +1,5 @@
1
1
  import traceback
2
- from typing import Dict, List, Optional
2
+ from typing import Dict, List, Optional, Tuple
3
3
 
4
4
  import backoff
5
5
 
@@ -22,15 +22,40 @@ class AIHubApi:
22
22
  def __init__(self) -> None:
23
23
  self._client = LightningClient(max_tries=3)
24
24
 
25
- def api_info(self, api_id: str) -> "V1DeploymentTemplate":
25
+ def api_info(self, api_id: str) -> Tuple[V1DeploymentTemplate, List[Dict[str, str]]]:
26
26
  try:
27
- return self._client.deployment_templates_service_get_deployment_template(api_id)
27
+ template = self._client.deployment_templates_service_get_deployment_template(api_id)
28
28
  except Exception as e:
29
29
  stack_trace = traceback.format_exc()
30
30
  if "record not found" in stack_trace:
31
31
  raise ValueError(f"api_id={api_id} not found.") from None
32
32
  raise e
33
33
 
34
+ api_arguments = []
35
+ for param in template.parameter_spec.parameters:
36
+ default = None
37
+ if param.type == V1DeploymentTemplateParameterType.INPUT and param.input:
38
+ default = param.input.default_value
39
+ if param.type == V1DeploymentTemplateParameterType.SELECT and param.select:
40
+ default = param.select.options[0]
41
+ if param.type == V1DeploymentTemplateParameterType.CHECKBOX and param.checkbox:
42
+ default = (
43
+ (param.checkbox.true_value or "True")
44
+ if param.checkbox.is_checked
45
+ else (param.checkbox.false_value or "False")
46
+ )
47
+
48
+ api_arguments.append(
49
+ {
50
+ "name": param.name,
51
+ "short_description": param.short_description,
52
+ "required": param.required,
53
+ "type": param.type,
54
+ "default": default,
55
+ }
56
+ )
57
+ return template, api_arguments
58
+
34
59
  @backoff.on_predicate(backoff.expo, lambda x: not x, max_tries=5)
35
60
  def list_apis(self, search_query: str) -> List[V1DeploymentTemplateGalleryResponse]:
36
61
  kwargs = {"show_globally_visible": True}
@@ -84,7 +109,7 @@ class AIHubApi:
84
109
 
85
110
  return job
86
111
 
87
- def deploy_api(
112
+ def run_api(
88
113
  self, template_id: str, project_id: str, cluster_id: str, name: Optional[str], api_arguments: Dict[str, str]
89
114
  ) -> V1Deployment:
90
115
  template = self._client.deployment_templates_service_get_deployment_template(template_id)
@@ -141,13 +141,15 @@ class JobApiV2:
141
141
  self,
142
142
  name: str,
143
143
  command: Optional[str],
144
- cluster_id: str,
144
+ cluster_id: Optional[str],
145
145
  teamspace_id: str,
146
146
  studio_id: Optional[str],
147
147
  image: Optional[str],
148
148
  machine: Machine,
149
149
  interruptible: bool,
150
150
  env: Optional[Dict[str, str]],
151
+ image_credentials: Optional[str],
152
+ cluster_auth: bool,
151
153
  ) -> V1Job:
152
154
  env_vars = []
153
155
  if env is not None:
@@ -160,13 +162,15 @@ class JobApiV2:
160
162
 
161
163
  spec = V1JobSpec(
162
164
  cloudspace_id=studio_id or "",
163
- cluster_id=cluster_id,
165
+ cluster_id=cluster_id or "",
164
166
  command=command or "",
165
167
  env=env_vars,
166
168
  image=image or "",
167
169
  instance_name=instance_name,
168
170
  run_id=run_id,
169
171
  spot=interruptible,
172
+ image_cluster_credentials=cluster_auth,
173
+ image_secret_ref=image_credentials or "",
170
174
  )
171
175
  body = ProjectIdJobsBody(name=name, spec=spec)
172
176
 
@@ -3,7 +3,7 @@ from dataclasses import dataclass
3
3
  from pathlib import Path
4
4
  from typing import Dict, List, Optional
5
5
 
6
- from lightning_sdk.api.utils import _download_model_files, _DummyBody, _ModelFileUploader
6
+ from lightning_sdk.api.utils import _download_model_files, _DummyBody, _get_model_version, _ModelFileUploader
7
7
  from lightning_sdk.lightning_cloud.login import Auth
8
8
  from lightning_sdk.lightning_cloud.openapi import (
9
9
  ModelIdVersionsBody,
@@ -159,6 +159,16 @@ class TeamspaceApi:
159
159
 
160
160
  return self._client.assistants_service_create_assistant(body=body, project_id=teamspace_id)
161
161
 
162
+ # lazy property which is only created when needed
163
+ @property
164
+ def models(self) -> ModelsStoreApi:
165
+ if not self._models:
166
+ self._models = ModelsStoreApi(self._client.api_client)
167
+ return self._models
168
+
169
+ def get_model_version(self, name: str, version: str, teamspace_id: str) -> V1ModelVersionArchive:
170
+ return _get_model_version(client=self._client, name=name, version=version, teamspace_id=teamspace_id)
171
+
162
172
  def create_model(
163
173
  self,
164
174
  name: str,
@@ -167,17 +177,15 @@ class TeamspaceApi:
167
177
  teamspace_id: str,
168
178
  cluster_id: str,
169
179
  ) -> V1ModelVersionArchive:
170
- if not self._models:
171
- self._models = ModelsStoreApi(self._client.api_client)
172
180
  # ask if such model already exists by listing models with specific name
173
- models = self._models.models_store_list_models(project_id=teamspace_id, name=name).models
181
+ models = self.models.models_store_list_models(project_id=teamspace_id, name=name).models
174
182
  if len(models) == 0:
175
- return self._models.models_store_create_model(
183
+ return self.models.models_store_create_model(
176
184
  body=ProjectIdModelsBody(cluster_id=cluster_id, metadata=metadata, name=name, private=private),
177
185
  project_id=teamspace_id,
178
186
  )
179
187
  assert len(models) == 1, "Multiple models with the same name found"
180
- return self._models.models_store_create_model_version(
188
+ return self.models.models_store_create_model_version(
181
189
  body=ModelIdVersionsBody(cluster_id=cluster_id),
182
190
  project_id=teamspace_id,
183
191
  model_id=models[0].id,
@@ -185,18 +193,16 @@ class TeamspaceApi:
185
193
 
186
194
  def delete_model(self, name: str, version: Optional[str], teamspace_id: str) -> None:
187
195
  """Delete a model or a version from the model store."""
188
- if not self._models:
189
- self._models = ModelsStoreApi(self._client.api_client)
190
- models = self._models.models_store_list_models(project_id=teamspace_id, name=name).models
196
+ models = self.models.models_store_list_models(project_id=teamspace_id, name=name).models
191
197
  assert len(models) == 1, "Multiple models with the same name found"
192
198
  model_id = models[0].id
193
199
  # decide if delete only version of whole model
194
200
  if version:
195
201
  if version == "latest":
196
202
  version = models[0].latest_version
197
- self._models.models_store_delete_model_version(project_id=teamspace_id, model_id=model_id, version=version)
203
+ self.models.models_store_delete_model_version(project_id=teamspace_id, model_id=model_id, version=version)
198
204
  else:
199
- self._models.models_store_delete_model(project_id=teamspace_id, model_id=model_id)
205
+ self.models.models_store_delete_model(project_id=teamspace_id, model_id=model_id)
200
206
 
201
207
  def upload_model_file(
202
208
  self,
@@ -242,9 +248,7 @@ class TeamspaceApi:
242
248
  )
243
249
 
244
250
  def complete_model_upload(self, model_id: str, version: str, teamspace_id: str) -> None:
245
- if not self._models:
246
- self._models = ModelsStoreApi(self._client.api_client)
247
- self._models.models_store_complete_model_upload(
251
+ self.models.models_store_complete_model_upload(
248
252
  body=_DummyBody(),
249
253
  project_id=teamspace_id,
250
254
  model_id=model_id,
@@ -28,6 +28,7 @@ from lightning_sdk.lightning_cloud.openapi import (
28
28
  V1UploadProjectArtifactResponse,
29
29
  VersionUploadsBody,
30
30
  )
31
+ from lightning_sdk.lightning_cloud.openapi.models.v1_model_version_archive import V1ModelVersionArchive
31
32
 
32
33
  try:
33
34
  from lightning_sdk.lightning_cloud.openapi import AppsIdBody1 as AppsIdBody
@@ -492,6 +493,24 @@ class _FileDownloader:
492
493
  os.rename(tmp_filename, self.local_path)
493
494
 
494
495
 
496
+ def _get_model_version(client: LightningClient, teamspace_id: str, name: str, version: str) -> V1ModelVersionArchive:
497
+ api = ModelsStoreApi(client.api_client)
498
+ models = api.models_store_list_models(project_id=teamspace_id, name=name).models
499
+ if not models:
500
+ raise ValueError(f"Model `{name}` does not exist")
501
+ elif len(models) > 1:
502
+ raise ValueError("Multiple models with the same name found")
503
+ if version == "latest":
504
+ return models[0].latest_version
505
+ versions = api.models_store_list_model_versions(project_id=teamspace_id, model_id=models[0].id).versions
506
+ if not versions:
507
+ raise ValueError(f"Model `{name}` does not have any versions")
508
+ for ver in versions:
509
+ if ver.version == version:
510
+ return ver
511
+ raise ValueError(f"Model `{name}` does not have version `{version}`")
512
+
513
+
495
514
  def _download_model_files(
496
515
  client: LightningClient,
497
516
  teamspace_id: str,
@@ -46,4 +46,4 @@ class _AIHub(_StudiosMenu):
46
46
  teamspace: Teamspace to deploy the API to. Defaults to user's default teamspace.
47
47
  org: Organization to deploy the API to. Defaults to user's default organization.
48
48
  """
49
- return self._hub.deploy(api_id, cluster=cluster, name=name, teamspace=teamspace, org=org)
49
+ return self._hub.run(api_id, cluster=cluster, name=name, teamspace=teamspace, org=org)
@@ -5,6 +5,7 @@ from lightning_sdk.api.studio_api import _cloud_url
5
5
  from lightning_sdk.cli.ai_hub import _AIHub
6
6
  from lightning_sdk.cli.download import _Downloads
7
7
  from lightning_sdk.cli.legacy import _LegacyLightningCLI
8
+ from lightning_sdk.cli.run import _Run
8
9
  from lightning_sdk.cli.upload import _Uploads
9
10
  from lightning_sdk.lightning_cloud.login import Auth
10
11
 
@@ -19,8 +20,7 @@ class StudioCLI:
19
20
  self.upload = _Uploads()
20
21
  self.aihub = _AIHub()
21
22
 
22
- if _LIGHTNING_AVAILABLE:
23
- self.run = _LegacyLightningCLI()
23
+ self.run = _Run(legacy_run=_LegacyLightningCLI() if _LIGHTNING_AVAILABLE else None)
24
24
 
25
25
  def login(self) -> None:
26
26
  """Login to Lightning AI Studios."""
@@ -1,8 +1,11 @@
1
- from typing import Tuple
1
+ import os
2
+ from typing import Any, Dict, List, Tuple
2
3
 
3
4
  from lightning_sdk.api import OrgApi, UserApi
4
5
  from lightning_sdk.cli.exceptions import StudioCliError
6
+ from lightning_sdk.lightning_cloud.openapi.models import V1Membership, V1OwnerType
5
7
  from lightning_sdk.teamspace import Teamspace
8
+ from lightning_sdk.user import User
6
9
  from lightning_sdk.utils.resolve import _get_authed_user
7
10
 
8
11
 
@@ -17,22 +20,49 @@ def _parse_model_name(name: str) -> Tuple[str, str, str]:
17
20
  return org_name, teamspace_name, model_name
18
21
 
19
22
 
23
+ def _get_teamspace_and_path(
24
+ ts: V1Membership, org_api: OrgApi, user_api: UserApi, authed_user: User
25
+ ) -> Tuple[str, Dict[str, Any]]:
26
+ if ts.owner_type == V1OwnerType.ORGANIZATION:
27
+ org = org_api._get_org_by_id(ts.owner_id)
28
+ return f"{org.name}/{ts.name}", {"name": ts.name, "org": org.name}
29
+
30
+ if ts.owner_type == V1OwnerType.USER and ts.owner_id != authed_user.id:
31
+ user = user_api._get_user_by_id(ts.owner_id) # todo: check also the name
32
+ return f"{user.username}/{ts.name}", {"name": ts.name, "user": User(name=user.username)}
33
+
34
+ if ts.owner_type == V1OwnerType.USER:
35
+ return f"{authed_user.name}/{ts.name}", {"name": ts.name, "user": authed_user}
36
+
37
+ raise StudioCliError(f"Unknown organization type {ts.owner_type}")
38
+
39
+
40
+ def _list_teamspaces() -> List[str]:
41
+ org_api = OrgApi()
42
+ user_api = UserApi()
43
+ authed_user = _get_authed_user()
44
+
45
+ return [
46
+ _get_teamspace_and_path(ts, org_api, user_api, authed_user)[0]
47
+ for ts in user_api._get_all_teamspace_memberships("")
48
+ ]
49
+
50
+
20
51
  def _get_teamspace(name: str, organization: str) -> Teamspace:
21
52
  """Get a Teamspace object from the SDK."""
22
53
  org_api = OrgApi()
23
- user = _get_authed_user()
24
- teamspaces = {}
25
- for ts in UserApi()._get_all_teamspace_memberships(""):
26
- if ts.owner_type == "organization":
27
- org = org_api._get_org_by_id(ts.owner_id)
28
- teamspaces[f"{org.name}/{ts.name}"] = {"name": ts.name, "org": org.name}
29
- elif ts.owner_type == "user": # todo: check also the name
30
- teamspaces[f"{user.name}/{ts.name}"] = {"name": ts.name, "user": user}
31
- else:
32
- raise StudioCliError(f"Unknown organization type {ts.owner_type}")
54
+ user_api = UserApi()
55
+ authed_user = _get_authed_user()
33
56
 
34
57
  requested_teamspace = f"{organization}/{name}".lower()
35
- if requested_teamspace not in teamspaces:
36
- options = "\n\t".join(teamspaces.keys())
37
- raise StudioCliError(f"Teamspace `{requested_teamspace}` not found. Available teamspaces: \n\t{options}")
38
- return Teamspace(**teamspaces[requested_teamspace])
58
+
59
+ for ts in user_api._get_all_teamspace_memberships(""):
60
+ if ts.name != name:
61
+ continue
62
+
63
+ teamspace_path, teamspace = _get_teamspace_and_path(ts, org_api, user_api, authed_user)
64
+ if requested_teamspace == teamspace_path:
65
+ return Teamspace(**teamspace)
66
+
67
+ options = f"{os.linesep}\t".join(_list_teamspaces())
68
+ raise StudioCliError(f"Teamspace `{requested_teamspace}` not found. Available teamspaces: {os.linesep}\t{options}")
@@ -0,0 +1,80 @@
1
+ from typing import TYPE_CHECKING, Dict, Optional
2
+
3
+ from lightning_sdk.job import Job
4
+ from lightning_sdk.machine import Machine
5
+
6
+ if TYPE_CHECKING:
7
+ from lightning_sdk.cli.legacy import _LegacyLightningCLI
8
+
9
+ _MACHINE_VALUES = tuple([machine.value for machine in Machine])
10
+
11
+
12
+ class _Run:
13
+ """Run async workloads on the Lightning AI platform."""
14
+
15
+ def __init__(self, legacy_run: Optional["_LegacyLightningCLI"] = None) -> None:
16
+ if legacy_run is not None:
17
+ self.app = legacy_run.app
18
+ self.model = legacy_run.model
19
+
20
+ # Need to set the docstring here for f-strings to work.
21
+ # Sadly this is the only way to really show options as f-strings are not allowed as docstrings directly
22
+ # and fire does not show values for literals, just that it is a literal.
23
+ docstr = f"""Run async workloads using a docker image or a compute environment from your studio.
24
+
25
+ Args:
26
+ name: The name of the job. Needs to be unique within the teamspace.
27
+ machine: The machine type to run the job on. One of {", ".join(_MACHINE_VALUES)}.
28
+ command: The command to run inside your job. Required if using a studio. Optional if using an image.
29
+ If not provided for images, will run the container entrypoint and default command.
30
+ studio: The studio env to run the job with. Mutually exclusive with image.
31
+ image: The docker image to run the job with. Mutually exclusive with studio.
32
+ teamspace: The teamspace the job should be associated with. Defaults to the current teamspace.
33
+ org: The organization owning the teamspace (if any). Defaults to the current organization.
34
+ user: The user owning the teamspace (if any). Defaults to the current user.
35
+ cluster: The cluster to run the job on. Defaults to the studio cluster if running with studio compute env.
36
+ If not provided will fall back to the teamspaces default cluster.
37
+ env: Environment variables to set inside the job.
38
+ interruptible: Whether the job should run on interruptible instances. They are cheaper but can be preempted.
39
+ image_credentials: The credentials used to pull the image. Required if the image is private.
40
+ This should be the name of the respective credentials secret created on the Lightning AI platform.
41
+ cluster_auth: Whether to authenticate with the cluster to pull the image.
42
+ Required if the registry is part of a cluster provider (e.g. ECR).
43
+ """
44
+ self.job.__func__.__doc__ = docstr
45
+
46
+ # TODO: sadly, fire displays both Optional[type] and Union[type, None] as Optional[Optional]
47
+ # see https://github.com/google/python-fire/pull/513
48
+ # might need to move to different cli library
49
+ def job(
50
+ self,
51
+ name: str,
52
+ machine: str,
53
+ command: Optional[str] = None,
54
+ studio: Optional[str] = None,
55
+ image: Optional[str] = None,
56
+ teamspace: Optional[str] = None,
57
+ org: Optional[str] = None,
58
+ user: Optional[str] = None,
59
+ cluster: Optional[str] = None,
60
+ env: Optional[Dict[str, str]] = None,
61
+ interruptible: bool = False,
62
+ image_credentials: Optional[str] = None,
63
+ cluster_auth: bool = False,
64
+ ) -> None:
65
+ machine_enum = Machine(machine.upper())
66
+ Job.run(
67
+ name=name,
68
+ machine=machine_enum,
69
+ command=command,
70
+ studio=studio,
71
+ image=image,
72
+ teamspace=teamspace,
73
+ org=org,
74
+ user=user,
75
+ cluster=cluster,
76
+ env=env,
77
+ interruptible=interruptible,
78
+ image_credentials=image_credentials,
79
+ cluster_auth=cluster_auth,
80
+ )
lightning_sdk/job/base.py CHANGED
@@ -16,15 +16,20 @@ class _BaseJob(ABC):
16
16
  def __init__(
17
17
  self,
18
18
  name: str,
19
- teamspace: Union[str, "Teamspace"] = None,
20
- org: Union[str, "Organization"] = None,
21
- user: Union[str, "User"] = None,
22
- cluster: Optional[str] = None,
19
+ teamspace: Union[str, "Teamspace", None] = None,
20
+ org: Union[str, "Organization", None] = None,
21
+ user: Union[str, "User", None] = None,
23
22
  *,
24
23
  _fetch_job: bool = True,
25
24
  ) -> None:
26
- self._teamspace = _resolve_teamspace(teamspace=teamspace, org=org, user=user)
27
- self._cluster = cluster
25
+ _teamspace = _resolve_teamspace(teamspace=teamspace, org=org, user=user)
26
+ if _teamspace is None:
27
+ raise ValueError(
28
+ "Cannot resolve the teamspace from provided arguments."
29
+ f" Got teamspace={teamspace}, org={org}, user={user}."
30
+ )
31
+ else:
32
+ self._teamspace = _teamspace
28
33
  self._name = name
29
34
  self._job = None
30
35
 
@@ -37,18 +42,27 @@ class _BaseJob(ABC):
37
42
  name: str,
38
43
  machine: "Machine",
39
44
  command: Optional[str] = None,
40
- studio: Optional["Studio"] = None,
45
+ studio: Union["Studio", str, None] = None,
41
46
  image: Optional[str] = None,
42
- teamspace: Union[str, "Teamspace"] = None,
43
- org: Union[str, "Organization"] = None,
44
- user: Union[str, "User"] = None,
47
+ teamspace: Union[str, "Teamspace", None] = None,
48
+ org: Union[str, "Organization", None] = None,
49
+ user: Union[str, "User", None] = None,
45
50
  cluster: Optional[str] = None,
46
51
  env: Optional[Dict[str, str]] = None,
47
52
  interruptible: bool = False,
53
+ image_credentials: Optional[str] = None,
54
+ cluster_auth: bool = False,
48
55
  ) -> "_BaseJob":
56
+ from lightning_sdk.studio import Studio
57
+
49
58
  if not name:
50
59
  raise ValueError("A job needs to have a name!")
51
- if studio is not None:
60
+
61
+ if image is None:
62
+ if not isinstance(studio, Studio):
63
+ studio = Studio(name=studio, teamspace=teamspace, org=org, user=user, cluster=cluster, create_ok=False)
64
+
65
+ # studio is a Studio instance at this point
52
66
  if teamspace is None:
53
67
  teamspace = studio.teamspace
54
68
  else:
@@ -60,11 +74,39 @@ class _BaseJob(ABC):
60
74
  "Can only run jobs with Studio envs in the teamspace of that Studio."
61
75
  )
62
76
 
63
- # TODO: resolve studio and support string studios
64
- # TODO: assertions for studio to be on cluster
65
- # TODO: if cluster is not provided use studio cluster if provided, otherwise use default cluster from teamspace
66
- inst = cls(name=name, teamspace=teamspace, org=org, user=user, cluster=cluster, _fetch_job=False)
67
- inst._submit(machine=machine, command=command, studio=studio, image=image, env=env, interruptible=interruptible)
77
+ if cluster is None:
78
+ cluster = studio.cluster
79
+
80
+ if cluster != studio.cluster:
81
+ raise ValueError(
82
+ "Studio cluster does not match provided cluster. "
83
+ "Can only run jobs with Studio envs in the same cluster."
84
+ )
85
+
86
+ if image_credentials is not None:
87
+ raise ValueError("image_credentials is only supported when using a custom image")
88
+
89
+ if cluster_auth:
90
+ raise ValueError("cluster_auth is only supported when using a custom image")
91
+
92
+ else:
93
+ if studio is not None:
94
+ raise RuntimeError(
95
+ "image and studio are mutually exclusive as both define the environment to run the job in"
96
+ )
97
+
98
+ inst = cls(name=name, teamspace=teamspace, org=org, user=user, _fetch_job=False)
99
+ inst._submit(
100
+ machine=machine,
101
+ cluster=cluster,
102
+ command=command,
103
+ studio=studio,
104
+ image=image,
105
+ env=env,
106
+ interruptible=interruptible,
107
+ image_credentials=image_credentials,
108
+ cluster_auth=cluster_auth,
109
+ )
68
110
  return inst
69
111
 
70
112
  @abstractmethod
@@ -76,6 +118,9 @@ class _BaseJob(ABC):
76
118
  image: Optional[str] = None,
77
119
  env: Optional[Dict[str, str]] = None,
78
120
  interruptible: bool = False,
121
+ cluster: Optional[str] = None,
122
+ image_credentials: Optional[str] = None,
123
+ cluster_auth: bool = False,
79
124
  ) -> None:
80
125
  """Submits a job and updates the internal _job attribute as well as the _name attribute."""
81
126
 
@@ -123,7 +168,3 @@ class _BaseJob(ABC):
123
168
  @property
124
169
  def teamspace(self) -> "Teamspace":
125
170
  return self._teamspace
126
-
127
- @property
128
- def cluster(self) -> Optional[str]:
129
- return self._cluster