lightning-sdk 0.1.40__py3-none-any.whl → 0.1.42__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. lightning_sdk/__init__.py +1 -1
  2. lightning_sdk/ai_hub.py +8 -3
  3. lightning_sdk/api/ai_hub_api.py +3 -3
  4. lightning_sdk/api/deployment_api.py +6 -6
  5. lightning_sdk/api/job_api.py +32 -6
  6. lightning_sdk/api/mmt_api.py +60 -19
  7. lightning_sdk/api/studio_api.py +37 -19
  8. lightning_sdk/api/teamspace_api.py +34 -29
  9. lightning_sdk/api/utils.py +48 -35
  10. lightning_sdk/cli/ai_hub.py +3 -3
  11. lightning_sdk/cli/entrypoint.py +3 -1
  12. lightning_sdk/cli/mmt.py +11 -10
  13. lightning_sdk/cli/run.py +9 -8
  14. lightning_sdk/cli/serve.py +130 -0
  15. lightning_sdk/deployment/deployment.py +18 -12
  16. lightning_sdk/job/base.py +118 -24
  17. lightning_sdk/job/job.py +87 -9
  18. lightning_sdk/job/v1.py +75 -18
  19. lightning_sdk/job/v2.py +51 -15
  20. lightning_sdk/job/work.py +36 -7
  21. lightning_sdk/lightning_cloud/openapi/__init__.py +13 -0
  22. lightning_sdk/lightning_cloud/openapi/api/jobs_service_api.py +215 -5
  23. lightning_sdk/lightning_cloud/openapi/api/lit_logger_service_api.py +218 -0
  24. lightning_sdk/lightning_cloud/openapi/api/models_store_api.py +226 -0
  25. lightning_sdk/lightning_cloud/openapi/api/secret_service_api.py +5 -1
  26. lightning_sdk/lightning_cloud/openapi/api/snowflake_service_api.py +21 -1
  27. lightning_sdk/lightning_cloud/openapi/models/__init__.py +13 -0
  28. lightning_sdk/lightning_cloud/openapi/models/create_deployment_request_defines_a_spec_for_the_job_that_allows_for_autoscaling_jobs.py +27 -1
  29. lightning_sdk/lightning_cloud/openapi/models/deploymenttemplates_id_body.py +27 -1
  30. lightning_sdk/lightning_cloud/openapi/models/id_visibility_body.py +123 -0
  31. lightning_sdk/lightning_cloud/openapi/models/model_id_versions_body.py +29 -3
  32. lightning_sdk/lightning_cloud/openapi/models/project_id_multimachinejobs_body.py +27 -1
  33. lightning_sdk/lightning_cloud/openapi/models/project_id_snowflake_body.py +15 -67
  34. lightning_sdk/lightning_cloud/openapi/models/query_query_id_body.py +17 -69
  35. lightning_sdk/lightning_cloud/openapi/models/snowflake_export_body.py +29 -81
  36. lightning_sdk/lightning_cloud/openapi/models/snowflake_query_body.py +17 -69
  37. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_api.py +27 -1
  38. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_spec.py +27 -1
  39. lightning_sdk/lightning_cloud/openapi/models/v1_get_model_file_url_response.py +27 -1
  40. lightning_sdk/lightning_cloud/openapi/models/v1_get_model_files_response.py +17 -17
  41. lightning_sdk/lightning_cloud/openapi/models/v1_get_model_files_url_response.py +149 -0
  42. lightning_sdk/lightning_cloud/openapi/models/v1_get_project_balance_response.py +27 -1
  43. lightning_sdk/lightning_cloud/openapi/models/v1_header.py +175 -0
  44. lightning_sdk/lightning_cloud/openapi/models/v1_job_spec.py +27 -1
  45. lightning_sdk/lightning_cloud/openapi/models/v1_list_multi_machine_job_events_response.py +123 -0
  46. lightning_sdk/lightning_cloud/openapi/models/v1_managed_model.py +29 -3
  47. lightning_sdk/lightning_cloud/openapi/models/v1_metrics_stream.py +27 -1
  48. lightning_sdk/lightning_cloud/openapi/models/v1_model_file.py +175 -0
  49. lightning_sdk/lightning_cloud/openapi/models/v1_multi_machine_job.py +27 -1
  50. lightning_sdk/lightning_cloud/openapi/models/v1_multi_machine_job_event.py +331 -0
  51. lightning_sdk/lightning_cloud/openapi/models/v1_multi_machine_job_event_type.py +104 -0
  52. lightning_sdk/lightning_cloud/openapi/models/v1_multi_machine_job_fault_tolerance.py +149 -0
  53. lightning_sdk/lightning_cloud/openapi/models/v1_multi_machine_job_fault_tolerance_strategy.py +105 -0
  54. lightning_sdk/lightning_cloud/openapi/models/v1_multi_machine_job_status.py +27 -1
  55. lightning_sdk/lightning_cloud/openapi/models/v1_rule_resource.py +2 -0
  56. lightning_sdk/lightning_cloud/openapi/models/v1_secret_type.py +1 -0
  57. lightning_sdk/lightning_cloud/openapi/models/v1_snowflake_data_connection.py +29 -81
  58. lightning_sdk/lightning_cloud/openapi/models/v1_system_metrics.py +29 -3
  59. lightning_sdk/lightning_cloud/openapi/models/v1_trainium_system_metrics.py +175 -0
  60. lightning_sdk/lightning_cloud/openapi/models/v1_update_metrics_stream_visibility_response.py +97 -0
  61. lightning_sdk/lightning_cloud/openapi/models/v1_user_features.py +41 -67
  62. lightning_sdk/lightning_cloud/openapi/models/v1_validate_deployment_image_request.py +149 -0
  63. lightning_sdk/lightning_cloud/openapi/models/v1_validate_deployment_image_response.py +97 -0
  64. lightning_sdk/lightning_cloud/rest_client.py +2 -0
  65. lightning_sdk/mmt/__init__.py +3 -0
  66. lightning_sdk/{_mmt → mmt}/base.py +20 -14
  67. lightning_sdk/{_mmt → mmt}/mmt.py +46 -17
  68. lightning_sdk/mmt/v1.py +129 -0
  69. lightning_sdk/{_mmt → mmt}/v2.py +16 -21
  70. lightning_sdk/plugin.py +43 -16
  71. lightning_sdk/services/file_endpoint.py +11 -5
  72. lightning_sdk/studio.py +16 -9
  73. lightning_sdk/teamspace.py +26 -14
  74. lightning_sdk/utils/resolve.py +18 -0
  75. {lightning_sdk-0.1.40.dist-info → lightning_sdk-0.1.42.dist-info}/METADATA +3 -1
  76. {lightning_sdk-0.1.40.dist-info → lightning_sdk-0.1.42.dist-info}/RECORD +80 -66
  77. lightning_sdk/_mmt/__init__.py +0 -3
  78. lightning_sdk/_mmt/v1.py +0 -69
  79. {lightning_sdk-0.1.40.dist-info → lightning_sdk-0.1.42.dist-info}/LICENSE +0 -0
  80. {lightning_sdk-0.1.40.dist-info → lightning_sdk-0.1.42.dist-info}/WHEEL +0 -0
  81. {lightning_sdk-0.1.40.dist-info → lightning_sdk-0.1.42.dist-info}/entry_points.txt +0 -0
  82. {lightning_sdk-0.1.40.dist-info → lightning_sdk-0.1.42.dist-info}/top_level.txt +0 -0
lightning_sdk/__init__.py CHANGED
@@ -27,5 +27,5 @@ __all__ = [
27
27
  "AIHub",
28
28
  ]
29
29
 
30
- __version__ = "0.1.40"
30
+ __version__ = "0.1.42"
31
31
  _check_version_and_prompt_upgrade(__version__)
lightning_sdk/ai_hub.py CHANGED
@@ -124,7 +124,7 @@ class AIHub:
124
124
  api_id: str,
125
125
  api_arguments: Optional[Dict[str, Any]] = None,
126
126
  name: Optional[str] = None,
127
- cluster: Optional[str] = None,
127
+ cloud_account: Optional[str] = None,
128
128
  teamspace: Optional[Union[str, "Teamspace"]] = None,
129
129
  org: Optional[Union[str, "Organization"]] = None,
130
130
  ) -> Dict[str, Union[str, bool]]:
@@ -143,7 +143,8 @@ class AIHub:
143
143
  api_id: The ID of the AIHub template you want to run.
144
144
  api_arguments: Additional API argument, such as model name, or batch size.
145
145
  name: Name for the deployed API. Defaults to None.
146
- cluster: The cluster where you want to run the template, such as "lightning-public-prod". Defaults to None.
146
+ cloud_account: The cloud account where you want to run the template, such as "lightning-public-prod".
147
+ Defaults to None.
147
148
  teamspace: The team or group for deployment. Defaults to None.
148
149
  org: The organization for deployment. Defaults to None.
149
150
 
@@ -160,7 +161,11 @@ class AIHub:
160
161
 
161
162
  api_arguments = api_arguments or {}
162
163
  deployment = self._api.run_api(
163
- template_id=api_id, cluster_id=cluster, project_id=teamspace_id, name=name, api_arguments=api_arguments
164
+ template_id=api_id,
165
+ cloud_account=cloud_account or "",
166
+ project_id=teamspace_id,
167
+ name=name,
168
+ api_arguments=api_arguments,
164
169
  )
165
170
  url = (
166
171
  quote(
@@ -105,12 +105,12 @@ class AIHubApi:
105
105
  elif not p.required:
106
106
  AIHubApi._update_parameters(job, p.placements, pattern, "")
107
107
  else:
108
- raise ValueError(f"API reqires argument '{p.name}' but is not provided with api_arguments.")
108
+ raise ValueError(f"API requires argument '{p.name}' but is not provided with api_arguments.")
109
109
 
110
110
  return job
111
111
 
112
112
  def run_api(
113
- self, template_id: str, project_id: str, cluster_id: str, name: Optional[str], api_arguments: Dict[str, str]
113
+ self, template_id: str, project_id: str, cloud_account: str, name: Optional[str], api_arguments: Dict[str, str]
114
114
  ) -> V1Deployment:
115
115
  template = self._client.deployment_templates_service_get_deployment_template(template_id)
116
116
  name = name or template.name
@@ -121,7 +121,7 @@ class AIHubApi:
121
121
  project_id=project_id,
122
122
  body=CreateDeploymentRequestDefinesASpecForTheJobThatAllowsForAutoscalingJobs(
123
123
  autoscaling=template.spec_v2.autoscaling,
124
- cluster_id=cluster_id,
124
+ cluster_id=cloud_account,
125
125
  endpoint=template.spec_v2.endpoint,
126
126
  name=name,
127
127
  replicas=0,
@@ -239,7 +239,7 @@ class DeploymentApi:
239
239
  command: Optional[str] = None,
240
240
  env: Optional[List[Union[Env, Secret]]] = None,
241
241
  spot: Optional[bool] = None,
242
- cluster_id: Optional[str] = None,
242
+ cloud_account: Optional[str] = None,
243
243
  min_replicas: Optional[int] = None,
244
244
  max_replicas: Optional[int] = None,
245
245
  name: Optional[str] = None,
@@ -270,7 +270,7 @@ class DeploymentApi:
270
270
  requires_release |= apply_change(deployment.spec, "command", command)
271
271
  requires_release |= apply_change(deployment.spec, "env", to_env(env))
272
272
  requires_release |= apply_change(deployment.spec, "env", to_health_check(health_check))
273
- requires_release |= apply_change(deployment.spec, "cluster_id", cluster_id)
273
+ requires_release |= apply_change(deployment.spec, "cluster_id", cloud_account)
274
274
  requires_release |= apply_change(deployment.spec, "spot", spot)
275
275
 
276
276
  if requires_release:
@@ -521,7 +521,7 @@ def to_health_check(
521
521
 
522
522
 
523
523
  def to_spec(
524
- cluster_id: Optional[str],
524
+ cloud_account: Optional[str],
525
525
  machine: Optional[Machine],
526
526
  environment: Optional[str],
527
527
  entrypoint: Optional[str],
@@ -530,8 +530,8 @@ def to_spec(
530
530
  env: Optional[List[Union[Secret, Env]]] = None,
531
531
  health_check: Optional[Union[HttpHealthCheck, ExecHealthCheck]] = None,
532
532
  ) -> V1JobSpec:
533
- if cluster_id is None:
534
- raise ValueError("The cluster should be defined.")
533
+ if cloud_account is None:
534
+ raise ValueError("The cloud account should be defined.")
535
535
 
536
536
  if machine is None:
537
537
  raise ValueError("The machine should be defined.")
@@ -540,7 +540,7 @@ def to_spec(
540
540
  raise ValueError("The environment should be defined.")
541
541
 
542
542
  return V1JobSpec(
543
- cluster_id=cluster_id,
543
+ cluster_id=cloud_account,
544
544
  command=command,
545
545
  entrypoint=entrypoint,
546
546
  env=to_env(env),
@@ -24,6 +24,7 @@ from lightning_sdk.lightning_cloud.openapi import (
24
24
  V1LightningappInstanceState,
25
25
  V1LightningappInstanceStatus,
26
26
  V1LightningworkSpec,
27
+ V1LightningworkState,
27
28
  V1ListLightningworkResponse,
28
29
  V1UserRequestedComputeConfig,
29
30
  )
@@ -107,7 +108,7 @@ class JobApiV1:
107
108
  command: str,
108
109
  studio_id: str,
109
110
  teamspace_id: str,
110
- cluster_id: str,
111
+ cloud_account: str,
111
112
  machine: Machine,
112
113
  interruptible: bool,
113
114
  ) -> Externalv1LightningappInstance:
@@ -116,7 +117,7 @@ class JobApiV1:
116
117
  self._client,
117
118
  studio_id=studio_id,
118
119
  teamspace_id=teamspace_id,
119
- cluster_id=cluster_id,
120
+ cloud_account=cloud_account,
120
121
  plugin_type="job",
121
122
  compute=_MACHINE_TO_COMPUTE_NAME[machine],
122
123
  name=name,
@@ -124,6 +125,31 @@ class JobApiV1:
124
125
  interruptible=interruptible,
125
126
  )
126
127
 
128
+ def get_status_from_work(self, work: Externalv1Lightningwork) -> "Status":
129
+ from lightning_sdk.status import Status
130
+
131
+ internal_status = work.status.phase
132
+
133
+ if internal_status in (
134
+ V1LightningworkState.UNSPECIFIED,
135
+ V1LightningworkState.IMAGE_BUILDING,
136
+ V1LightningworkState.PENDING,
137
+ V1LightningworkState.NOT_STARTED,
138
+ V1LightningworkState.DELETED,
139
+ ):
140
+ return Status.Pending
141
+
142
+ if internal_status == V1LightningworkState.RUNNING:
143
+ return Status.Running
144
+
145
+ if internal_status == V1LightningworkState.STOPPED:
146
+ return Status.Stopped
147
+
148
+ if internal_status == V1LightningworkState.FAILED:
149
+ return Status.Failed
150
+
151
+ return Status.Pending
152
+
127
153
 
128
154
  class JobApiV2:
129
155
  v2_job_state_pending = "pending"
@@ -141,7 +167,7 @@ class JobApiV2:
141
167
  self,
142
168
  name: str,
143
169
  command: Optional[str],
144
- cluster_id: Optional[str],
170
+ cloud_account: Optional[str],
145
171
  teamspace_id: str,
146
172
  studio_id: Optional[str],
147
173
  image: Optional[str],
@@ -149,7 +175,7 @@ class JobApiV2:
149
175
  interruptible: bool,
150
176
  env: Optional[Dict[str, str]],
151
177
  image_credentials: Optional[str],
152
- cluster_auth: bool,
178
+ cloud_account_auth: bool,
153
179
  artifacts_local: Optional[str],
154
180
  artifacts_remote: Optional[str],
155
181
  ) -> V1Job:
@@ -164,14 +190,14 @@ class JobApiV2:
164
190
 
165
191
  spec = V1JobSpec(
166
192
  cloudspace_id=studio_id or "",
167
- cluster_id=cluster_id or "",
193
+ cluster_id=cloud_account or "",
168
194
  command=command or "",
169
195
  env=env_vars,
170
196
  image=image or "",
171
197
  instance_name=instance_name,
172
198
  run_id=run_id,
173
199
  spot=interruptible,
174
- image_cluster_credentials=cluster_auth,
200
+ image_cluster_credentials=cloud_account_auth,
175
201
  image_secret_ref=image_credentials or "",
176
202
  artifacts_source=artifacts_local or "",
177
203
  artifacts_destination=artifacts_remote or "",
@@ -1,18 +1,23 @@
1
+ import json
1
2
  import time
2
- from typing import TYPE_CHECKING, Dict, Optional
3
+ from typing import TYPE_CHECKING, Dict, List, Optional
3
4
 
5
+ from lightning_sdk.api.job_api import JobApiV1
4
6
  from lightning_sdk.api.utils import (
5
7
  _COMPUTE_NAME_TO_MACHINE,
6
8
  _MACHINE_TO_COMPUTE_NAME,
9
+ _create_app,
7
10
  )
8
11
  from lightning_sdk.api.utils import (
9
12
  _get_cloud_url as _cloud_url,
10
13
  )
11
14
  from lightning_sdk.constants import __GLOBAL_LIGHTNING_UNIQUE_IDS_STORE__
12
15
  from lightning_sdk.lightning_cloud.openapi import (
16
+ Externalv1LightningappInstance,
13
17
  MultimachinejobsIdBody,
14
18
  ProjectIdMultimachinejobsBody,
15
19
  V1EnvVar,
20
+ V1Job,
16
21
  V1JobSpec,
17
22
  V1MultiMachineJob,
18
23
  V1MultiMachineJobState,
@@ -24,14 +29,43 @@ if TYPE_CHECKING:
24
29
  from lightning_sdk.status import Status
25
30
 
26
31
 
27
- class MMTApi:
28
- mmt_state_unspecified = "MultiMachineJob_STATE_UNSPECIFIED"
29
- mmt_state_running = "MultiMachineJob_STATE_RUNNING"
30
- mmt_state_stopped = "MultiMachineJob_STATE_STOPPED"
31
- mmt_state_deleted = "MultiMachineJob_STATE_DELETED"
32
- mmt_state_failed = "MultiMachineJob_STATE_FAILED"
33
- mmt_state_completed = "MultiMachineJob_STATE_COMPLETED"
32
+ class MMTApiV1(JobApiV1):
33
+ def __init__(self) -> None:
34
+ self._cloud_url = _cloud_url()
35
+ self._client = LightningClient(max_tries=7)
36
+
37
+ def submit_job(
38
+ self,
39
+ name: str,
40
+ num_machines: int,
41
+ command: Optional[str],
42
+ cloud_account: Optional[str],
43
+ teamspace_id: str,
44
+ studio_id: str,
45
+ machine: Machine,
46
+ interruptible: bool,
47
+ strategy: str,
48
+ ) -> Externalv1LightningappInstance:
49
+ """Creates a multi-machine job with given commands."""
50
+ distributed_args = {
51
+ "cloud_compute": _MACHINE_TO_COMPUTE_NAME[machine],
52
+ "num_instances": num_machines,
53
+ "strategy": strategy,
54
+ }
55
+ return _create_app(
56
+ client=self._client,
57
+ studio_id=studio_id,
58
+ teamspace_id=teamspace_id,
59
+ cloud_account=cloud_account or "",
60
+ plugin_type="distributed_plugin",
61
+ entrypoint=command,
62
+ name=name,
63
+ distributedArguments=json.dumps(distributed_args),
64
+ interruptible=interruptible,
65
+ )
66
+
34
67
 
68
+ class MMTApiV2:
35
69
  def __init__(self) -> None:
36
70
  self._cloud_url = _cloud_url()
37
71
  self._client = LightningClient(max_tries=7)
@@ -41,7 +75,7 @@ class MMTApi:
41
75
  name: str,
42
76
  num_machines: int,
43
77
  command: Optional[str],
44
- cluster_id: Optional[str],
78
+ cloud_account: Optional[str],
45
79
  teamspace_id: str,
46
80
  studio_id: Optional[str],
47
81
  image: Optional[str],
@@ -49,7 +83,7 @@ class MMTApi:
49
83
  interruptible: bool,
50
84
  env: Optional[Dict[str, str]],
51
85
  image_credentials: Optional[str],
52
- cluster_auth: bool,
86
+ cloud_account_auth: bool,
53
87
  artifacts_local: Optional[str],
54
88
  artifacts_remote: Optional[str],
55
89
  ) -> V1MultiMachineJob:
@@ -64,19 +98,22 @@ class MMTApi:
64
98
 
65
99
  spec = V1JobSpec(
66
100
  cloudspace_id=studio_id or "",
67
- cluster_id=cluster_id or "",
101
+ cluster_id=cloud_account or "",
68
102
  command=command or "",
103
+ entrypoint="sh -c",
69
104
  env=env_vars,
70
105
  image=image or "",
71
106
  instance_name=instance_name,
72
107
  run_id=run_id,
73
108
  spot=interruptible,
74
- image_cluster_credentials=cluster_auth,
109
+ image_cluster_credentials=cloud_account_auth,
75
110
  image_secret_ref=image_credentials or "",
76
111
  artifacts_source=artifacts_local or "",
77
112
  artifacts_destination=artifacts_remote or "",
78
113
  )
79
- body = ProjectIdMultimachinejobsBody(name=name, spec=spec, cluster_id=cluster_id or "", machines=num_machines)
114
+ body = ProjectIdMultimachinejobsBody(
115
+ name=name, spec=spec, cluster_id=cloud_account or "", machines=num_machines
116
+ )
80
117
 
81
118
  job: V1MultiMachineJob = self._client.jobs_service_create_multi_machine_job(project_id=teamspace_id, body=body)
82
119
  return job
@@ -106,7 +143,7 @@ class MMTApi:
106
143
  return
107
144
 
108
145
  if current_state != Status.Stopped:
109
- update_body = MultimachinejobsIdBody(desired_state=self.mmt_state_stopped)
146
+ update_body = MultimachinejobsIdBody(desired_state=V1MultiMachineJobState.STOPPED)
110
147
  self._client.jobs_service_update_multi_machine_job(body=update_body, project_id=teamspace_id, id=job_id)
111
148
 
112
149
  while True:
@@ -123,18 +160,22 @@ class MMTApi:
123
160
  def delete_job(self, job_id: str, teamspace_id: str) -> None:
124
161
  self._client.jobs_service_delete_multi_machine_job(project_id=teamspace_id, id=job_id)
125
162
 
163
+ def list_mmt_subjobs(self, job_id: str, teamspace_id: str) -> List[V1Job]:
164
+ jobs_resp = self._client.jobs_service_list_jobs(project_id=teamspace_id, multi_machine_job_id=job_id)
165
+ return jobs_resp.jobs
166
+
126
167
  def _job_state_to_external(self, state: V1MultiMachineJobState) -> "Status":
127
168
  from lightning_sdk.status import Status
128
169
 
129
- if str(state) == self.mmt_state_unspecified:
170
+ if str(state) == V1MultiMachineJobState.UNSPECIFIED:
130
171
  return Status.Pending
131
- if str(state) == self.mmt_state_running:
172
+ if str(state) == V1MultiMachineJobState.RUNNING:
132
173
  return Status.Running
133
- if str(state) == self.mmt_state_stopped:
174
+ if str(state) == V1MultiMachineJobState.STOPPED:
134
175
  return Status.Stopped
135
- if str(state) == self.mmt_state_completed:
176
+ if str(state) == V1MultiMachineJobState.COMPLETED:
136
177
  return Status.Completed
137
- if str(state) == self.mmt_state_failed:
178
+ if str(state) == V1MultiMachineJobState.FAILED:
138
179
  return Status.Failed
139
180
  return Status.Pending
140
181
 
@@ -108,11 +108,11 @@ class StudioApi:
108
108
  self,
109
109
  name: str,
110
110
  teamspace_id: str,
111
- cluster: Optional[str] = None,
111
+ cloud_account: Optional[str] = None,
112
112
  ) -> V1CloudSpace:
113
- """Create a Studio with a given name in a given Teamspace on a possibly given cluster."""
113
+ """Create a Studio with a given name in a given Teamspace on a possibly given cloud_account."""
114
114
  body = ProjectIdCloudspacesBody(
115
- cluster_id=cluster,
115
+ cluster_id=cloud_account,
116
116
  name=name,
117
117
  display_name=name,
118
118
  seed_files=[V1CloudSpaceSeedFile(path="main.py", contents="print('Hello, Lightning World!')\n")],
@@ -383,20 +383,32 @@ class StudioApi:
383
383
  self._client.cloud_space_service_delete_cloud_space(project_id=teamspace_id, id=studio_id)
384
384
 
385
385
  def upload_file(
386
- self, studio_id: str, teamspace_id: str, cluster_id: str, file_path: str, remote_path: str, progress_bar: bool
386
+ self,
387
+ studio_id: str,
388
+ teamspace_id: str,
389
+ cloud_account: str,
390
+ file_path: str,
391
+ remote_path: str,
392
+ progress_bar: bool,
387
393
  ) -> None:
388
394
  """Uploads file to given remote path on the studio."""
389
395
  _FileUploader(
390
396
  client=self._client,
391
397
  teamspace_id=teamspace_id,
392
- cluster_id=cluster_id,
398
+ cloud_account=cloud_account,
393
399
  file_path=file_path,
394
400
  remote_path=_sanitize_studio_remote_path(remote_path, studio_id),
395
401
  progress_bar=progress_bar,
396
402
  )()
397
403
 
398
404
  def download_file(
399
- self, path: str, target_path: str, studio_id: str, teamspace_id: str, cluster_id: str, progress_bar: bool = True
405
+ self,
406
+ path: str,
407
+ target_path: str,
408
+ studio_id: str,
409
+ teamspace_id: str,
410
+ cloud_account: str,
411
+ progress_bar: bool = True,
400
412
  ) -> None:
401
413
  """Downloads a given file from a Studio to a target location."""
402
414
  # TODO: Update this endpoint to permit basic auth
@@ -405,7 +417,7 @@ class StudioApi:
405
417
  token = self._client.auth_service_login(V1LoginRequest(auth.api_key)).token
406
418
 
407
419
  query_params = {
408
- "clusterId": cluster_id,
420
+ "clusterId": cloud_account,
409
421
  "key": _sanitize_studio_remote_path(path, studio_id),
410
422
  "token": token,
411
423
  }
@@ -439,7 +451,13 @@ class StudioApi:
439
451
  pbar_update(len(chunk))
440
452
 
441
453
  def download_folder(
442
- self, path: str, target_path: str, studio_id: str, teamspace_id: str, cluster_id: str, progress_bar: bool = True
454
+ self,
455
+ path: str,
456
+ target_path: str,
457
+ studio_id: str,
458
+ teamspace_id: str,
459
+ cloud_account: str,
460
+ progress_bar: bool = True,
443
461
  ) -> None:
444
462
  """Downloads a given folder from a Studio to a target location."""
445
463
  # TODO: Update this endpoint to permit basic auth
@@ -448,7 +466,7 @@ class StudioApi:
448
466
  token = self._client.auth_service_login(V1LoginRequest(auth.api_key)).token
449
467
 
450
468
  query_params = {
451
- "clusterId": cluster_id,
469
+ "clusterId": cloud_account,
452
470
  "prefix": _sanitize_studio_remote_path(path, studio_id),
453
471
  "token": token,
454
472
  }
@@ -553,14 +571,14 @@ class StudioApi:
553
571
  machine: Machine,
554
572
  studio_id: str,
555
573
  teamspace_id: str,
556
- cluster_id: str,
574
+ cloud_account: str,
557
575
  interruptible: bool,
558
576
  ) -> Externalv1LightningappInstance:
559
577
  """Creates a job with given commands."""
560
578
  return self._create_app(
561
579
  studio_id=studio_id,
562
580
  teamspace_id=teamspace_id,
563
- cluster_id=cluster_id,
581
+ cloud_account=cloud_account,
564
582
  plugin_type="job",
565
583
  entrypoint=entrypoint,
566
584
  name=name,
@@ -577,7 +595,7 @@ class StudioApi:
577
595
  strategy: str,
578
596
  studio_id: str,
579
597
  teamspace_id: str,
580
- cluster_id: str,
598
+ cloud_account: str,
581
599
  interruptible: bool,
582
600
  ) -> Externalv1LightningappInstance:
583
601
  """Creates a multi-machine job with given commands."""
@@ -589,7 +607,7 @@ class StudioApi:
589
607
  return self._create_app(
590
608
  studio_id=studio_id,
591
609
  teamspace_id=teamspace_id,
592
- cluster_id=cluster_id,
610
+ cloud_account=cloud_account,
593
611
  plugin_type="distributed_plugin",
594
612
  entrypoint=entrypoint,
595
613
  name=name,
@@ -605,7 +623,7 @@ class StudioApi:
605
623
  machine: Machine,
606
624
  studio_id: str,
607
625
  teamspace_id: str,
608
- cluster_id: str,
626
+ cloud_account: str,
609
627
  interruptible: bool,
610
628
  ) -> Externalv1LightningappInstance:
611
629
  """Creates a multi-machine job with given commands."""
@@ -616,7 +634,7 @@ class StudioApi:
616
634
  return self._create_app(
617
635
  studio_id=studio_id,
618
636
  teamspace_id=teamspace_id,
619
- cluster_id=cluster_id,
637
+ cloud_account=cloud_account,
620
638
  plugin_type="litdata",
621
639
  entrypoint=entrypoint,
622
640
  name=name,
@@ -638,14 +656,14 @@ class StudioApi:
638
656
  endpoint: str,
639
657
  studio_id: str,
640
658
  teamspace_id: str,
641
- cluster_id: str,
659
+ cloud_account: str,
642
660
  interruptible: bool,
643
661
  ) -> Externalv1LightningappInstance:
644
662
  """Creates an inference job for given endpoint."""
645
663
  return self._create_app(
646
664
  studio_id=studio_id,
647
665
  teamspace_id=teamspace_id,
648
- cluster_id=cluster_id,
666
+ cloud_account=cloud_account,
649
667
  plugin_type="inference_plugin",
650
668
  compute=_MACHINE_TO_COMPUTE_NAME[machine],
651
669
  entrypoint=entrypoint,
@@ -661,14 +679,14 @@ class StudioApi:
661
679
  )
662
680
 
663
681
  def _create_app(
664
- self, studio_id: str, teamspace_id: str, cluster_id: str, plugin_type: str, **other_arguments: Any
682
+ self, studio_id: str, teamspace_id: str, cloud_account: str, plugin_type: str, **other_arguments: Any
665
683
  ) -> Externalv1LightningappInstance:
666
684
  """Creates an arbitrary app."""
667
685
  return _create_app(
668
686
  self._client,
669
687
  studio_id=studio_id,
670
688
  teamspace_id=teamspace_id,
671
- cluster_id=cluster_id,
689
+ cloud_account=cloud_account,
672
690
  plugin_type=plugin_type,
673
691
  **other_arguments,
674
692
  )
@@ -2,6 +2,8 @@ import os
2
2
  from pathlib import Path
3
3
  from typing import Dict, List, Optional
4
4
 
5
+ from tqdm.auto import tqdm
6
+
5
7
  from lightning_sdk.api.utils import _download_model_files, _DummyBody, _get_model_version, _ModelFileUploader
6
8
  from lightning_sdk.lightning_cloud.login import Auth
7
9
  from lightning_sdk.lightning_cloud.openapi import (
@@ -64,12 +66,12 @@ class TeamspaceApi:
64
66
  teamspaces.append(self._get_teamspace_by_id(teamspace.project_id))
65
67
  return teamspaces
66
68
 
67
- def list_studios(self, teamspace_id: str, cluster_id: str = "") -> List[V1CloudSpace]:
69
+ def list_studios(self, teamspace_id: str, cloud_account: str = "") -> List[V1CloudSpace]:
68
70
  """List studios in teamspace."""
69
71
  kwargs = {"project_id": teamspace_id, "user_id": self._get_authed_user_id()}
70
72
 
71
- if cluster_id:
72
- kwargs["cluster_id"] = cluster_id
73
+ if cloud_account:
74
+ kwargs["cluster_id"] = cloud_account
73
75
 
74
76
  cloudspaces = []
75
77
 
@@ -85,8 +87,8 @@ class TeamspaceApi:
85
87
 
86
88
  return cloudspaces
87
89
 
88
- def list_clusters(self, teamspace_id: str) -> List[V1ProjectClusterBinding]:
89
- """Lists clusters in a teamspace."""
90
+ def list_cloud_accounts(self, teamspace_id: str) -> List[V1ProjectClusterBinding]:
91
+ """Lists cloud_accounts in a teamspace."""
90
92
  return self._client.projects_service_list_project_cluster_bindings(project_id=teamspace_id).clusters
91
93
 
92
94
  def _get_authed_user_id(self) -> str:
@@ -95,32 +97,32 @@ class TeamspaceApi:
95
97
  auth.authenticate()
96
98
  return auth.user_id
97
99
 
98
- def get_default_cluster_id(self, teamspace_id: str) -> str:
99
- """Get the default cluster id of the teamspace."""
100
+ def get_default_cloud_account(self, teamspace_id: str) -> str:
101
+ """Get the default cloud account id of the teamspace."""
100
102
  return self._client.projects_service_get_project(teamspace_id).project_settings.preferred_cluster
101
103
 
102
- def _determine_cluster_id(self, teamspace_id: str) -> str:
103
- """Attempts to determine the cluster id of the teamspace.
104
+ def _determine_cloud_account(self, teamspace_id: str) -> str:
105
+ """Attempts to determine the cloud account id of the teamspace.
104
106
 
105
107
  Raises an error if it's ambiguous.
106
108
 
107
109
  """
108
- # when you run from studio, the cluster is with env. vars
109
- cluster_id = os.getenv("LIGHTNING_CLUSTER_ID")
110
- if cluster_id:
111
- return cluster_id
110
+ # when you run from studio, the cloud account is with env. vars
111
+ cloud_account = os.getenv("LIGHTNING_CLUSTER_ID")
112
+ if cloud_account:
113
+ return cloud_account
112
114
 
113
115
  # if there is only one cluster, use that and ignore default setting :D
114
- cluster_ids = [c.cluster_id for c in self.list_clusters(teamspace_id=teamspace_id)]
115
- if len(cluster_ids) == 1:
116
- return cluster_ids[0]
117
- # otherwise, try to determine the default cluster, another API call but we do not care :(
118
- default_cluster_id = self.get_default_cluster_id(teamspace_id=teamspace_id)
119
- if default_cluster_id:
120
- return default_cluster_id
116
+ cloud_accounts = [c.cluster_id for c in self.list_cloud_accounts(teamspace_id=teamspace_id)]
117
+ if len(cloud_accounts) == 1:
118
+ return cloud_accounts[0]
119
+ # otherwise, try to determine the default cloud_account, another API call but we do not care :(
120
+ default_cloud_account = self.get_default_cloud_account(teamspace_id=teamspace_id)
121
+ if default_cloud_account:
122
+ return default_cloud_account
121
123
  raise RuntimeError(
122
- "Could not determine the current cluster id. Please provide it manually as input."
123
- f" Choices are: {', '.join(cluster_ids)}"
124
+ "Could not determine the current cloud account. Please provide it manually as input."
125
+ f" Choices are: {', '.join(cloud_accounts)}"
124
126
  )
125
127
 
126
128
  def create_agent(
@@ -174,18 +176,18 @@ class TeamspaceApi:
174
176
  metadata: Dict[str, str],
175
177
  private: bool,
176
178
  teamspace_id: str,
177
- cluster_id: str,
179
+ cloud_account: str,
178
180
  ) -> V1ModelVersionArchive:
179
181
  # ask if such model already exists by listing models with specific name
180
182
  models = self.models.models_store_list_models(project_id=teamspace_id, name=name).models
181
183
  if len(models) == 0:
182
184
  return self.models.models_store_create_model(
183
- body=ProjectIdModelsBody(cluster_id=cluster_id, metadata=metadata, name=name, private=private),
185
+ body=ProjectIdModelsBody(cluster_id=cloud_account, metadata=metadata, name=name, private=private),
184
186
  project_id=teamspace_id,
185
187
  )
186
188
  assert len(models) == 1, "Multiple models with the same name found"
187
189
  return self.models.models_store_create_model_version(
188
- body=ModelIdVersionsBody(cluster_id=cluster_id),
190
+ body=ModelIdVersionsBody(cluster_id=cloud_account),
189
191
  project_id=teamspace_id,
190
192
  model_id=models[0].id,
191
193
  )
@@ -209,7 +211,7 @@ class TeamspaceApi:
209
211
  version: str,
210
212
  local_path: Path,
211
213
  remote_path: str,
212
- cluster_id: str,
214
+ cloud_account: str,
213
215
  teamspace_id: str,
214
216
  progress_bar: bool = True,
215
217
  ) -> None:
@@ -218,7 +220,7 @@ class TeamspaceApi:
218
220
  model_id=model_id,
219
221
  version=version,
220
222
  teamspace_id=teamspace_id,
221
- cluster_id=cluster_id,
223
+ cloud_account=cloud_account,
222
224
  file_path=str(local_path),
223
225
  remote_path=str(remote_path),
224
226
  progress_bar=progress_bar,
@@ -231,20 +233,23 @@ class TeamspaceApi:
231
233
  version: str,
232
234
  root_path: Path,
233
235
  filepaths: List[Path],
234
- cluster_id: str,
236
+ cloud_account: str,
235
237
  teamspace_id: str,
236
238
  progress_bar: bool = True,
237
239
  ) -> None:
240
+ main_pbar = tqdm(total=len(filepaths), desc="Uploading files...", position=0) if progress_bar else None
238
241
  for filepath in filepaths:
239
242
  self.upload_model_file(
240
243
  model_id=model_id,
241
244
  version=version,
242
245
  local_path=filepath,
243
246
  remote_path=str(filepath.relative_to(root_path)),
244
- cluster_id=cluster_id,
247
+ cloud_account=cloud_account,
245
248
  teamspace_id=teamspace_id,
246
249
  progress_bar=progress_bar, # TODO: Global progress bar
247
250
  )
251
+ if main_pbar:
252
+ main_pbar.update(1)
248
253
 
249
254
  def complete_model_upload(self, model_id: str, version: str, teamspace_id: str) -> None:
250
255
  self.models.models_store_complete_model_upload(