lightning-sdk 0.2.15__py3-none-any.whl → 0.2.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. lightning_sdk/__init__.py +1 -1
  2. lightning_sdk/api/base_studio_api.py +7 -1
  3. lightning_sdk/api/cluster_api.py +83 -1
  4. lightning_sdk/api/llm_api.py +27 -5
  5. lightning_sdk/api/studio_api.py +64 -0
  6. lightning_sdk/api/teamspace_api.py +127 -1
  7. lightning_sdk/api/utils.py +4 -0
  8. lightning_sdk/base_studio.py +14 -1
  9. lightning_sdk/cli/create.py +21 -1
  10. lightning_sdk/cli/deploy/__init__.py +0 -0
  11. lightning_sdk/cli/deploy/_auth.py +189 -0
  12. lightning_sdk/cli/deploy/devbox.py +157 -0
  13. lightning_sdk/cli/{serve.py → deploy/serve.py} +11 -322
  14. lightning_sdk/cli/download.py +44 -16
  15. lightning_sdk/cli/entrypoint.py +1 -1
  16. lightning_sdk/cli/open.py +21 -2
  17. lightning_sdk/cli/start.py +12 -3
  18. lightning_sdk/cli/upload.py +2 -4
  19. lightning_sdk/lightning_cloud/openapi/__init__.py +19 -0
  20. lightning_sdk/lightning_cloud/openapi/api/assistants_service_api.py +126 -1
  21. lightning_sdk/lightning_cloud/openapi/api/cloud_space_environment_template_service_api.py +97 -0
  22. lightning_sdk/lightning_cloud/openapi/api/cloud_space_service_api.py +105 -0
  23. lightning_sdk/lightning_cloud/openapi/api/cluster_service_api.py +105 -0
  24. lightning_sdk/lightning_cloud/openapi/api/jobs_service_api.py +752 -106
  25. lightning_sdk/lightning_cloud/openapi/api/storage_service_api.py +93 -0
  26. lightning_sdk/lightning_cloud/openapi/models/__init__.py +19 -0
  27. lightning_sdk/lightning_cloud/openapi/models/assistant_id_conversations_body.py +53 -1
  28. lightning_sdk/lightning_cloud/openapi/models/cloudspaces_id_body.py +53 -1
  29. lightning_sdk/lightning_cloud/openapi/models/create_deployment_request_defines_a_spec_for_the_job_that_allows_for_autoscaling_jobs.py +27 -1
  30. lightning_sdk/lightning_cloud/openapi/models/deployment_id_alertingpolicies_body.py +357 -0
  31. lightning_sdk/lightning_cloud/openapi/models/deployment_id_alertingpolicies_body1.py +331 -0
  32. lightning_sdk/lightning_cloud/openapi/models/deployments_id_body.py +79 -1
  33. lightning_sdk/lightning_cloud/openapi/models/models_id_body.py +123 -0
  34. lightning_sdk/lightning_cloud/openapi/models/orgs_id_body.py +105 -1
  35. lightning_sdk/lightning_cloud/openapi/models/project_id_cloudspaces_body.py +27 -1
  36. lightning_sdk/lightning_cloud/openapi/models/projects_id_body.py +29 -3
  37. lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space.py +79 -1
  38. lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space_environment_template.py +27 -1
  39. lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space_source_type.py +103 -0
  40. lightning_sdk/lightning_cloud/openapi/models/v1_cluster_tagging_options.py +27 -1
  41. lightning_sdk/lightning_cloud/openapi/models/v1_create_deployment_request.py +27 -1
  42. lightning_sdk/lightning_cloud/openapi/models/v1_data_connection.py +27 -1
  43. lightning_sdk/lightning_cloud/openapi/models/v1_delete_deployment_alerting_policy_response.py +175 -0
  44. lightning_sdk/lightning_cloud/openapi/models/v1_deployment.py +79 -1
  45. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_alerting_event.py +487 -0
  46. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_alerting_policy.py +409 -0
  47. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_alerting_policy_frequency.py +105 -0
  48. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_alerting_policy_operation.py +105 -0
  49. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_alerting_policy_severity.py +106 -0
  50. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_alerting_policy_type.py +111 -0
  51. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_alerting_recipients.py +175 -0
  52. lightning_sdk/lightning_cloud/openapi/models/v1_ge_list_deployment_routing_telemetry_response.py +27 -1
  53. lightning_sdk/lightning_cloud/openapi/models/v1_get_cloud_space_instance_open_ports_response.py +123 -0
  54. lightning_sdk/lightning_cloud/openapi/models/v1_get_deployment_routing_telemetry_content_response.py +123 -0
  55. lightning_sdk/lightning_cloud/openapi/models/v1_get_organization_storage_metadata_response.py +331 -0
  56. lightning_sdk/lightning_cloud/openapi/models/v1_get_user_response.py +1 -27
  57. lightning_sdk/lightning_cloud/openapi/models/v1_google_cloud_direct_v1.py +27 -1
  58. lightning_sdk/lightning_cloud/openapi/models/v1_list_deployment_alerting_events_response.py +123 -0
  59. lightning_sdk/lightning_cloud/openapi/models/v1_list_deployment_alerting_policies_response.py +175 -0
  60. lightning_sdk/lightning_cloud/openapi/models/v1_membership.py +27 -1
  61. lightning_sdk/lightning_cloud/openapi/models/v1_organization.py +105 -1
  62. lightning_sdk/lightning_cloud/openapi/models/v1_project.py +27 -1
  63. lightning_sdk/lightning_cloud/openapi/models/v1_project_membership.py +27 -1
  64. lightning_sdk/lightning_cloud/openapi/models/v1_project_settings.py +29 -3
  65. lightning_sdk/lightning_cloud/openapi/models/v1_project_storage.py +53 -1
  66. lightning_sdk/lightning_cloud/openapi/models/v1_routing_telemetry.py +253 -0
  67. lightning_sdk/lightning_cloud/openapi/models/v1_server_alert_type.py +1 -0
  68. lightning_sdk/lightning_cloud/openapi/models/v1_sleep_server_response.py +97 -0
  69. lightning_sdk/lightning_cloud/openapi/models/v1_update_user_request.py +1 -27
  70. lightning_sdk/lightning_cloud/openapi/models/v1_user_features.py +105 -53
  71. lightning_sdk/lightning_cloud/openapi/models/v1_user_requested_compute_config.py +27 -1
  72. lightning_sdk/llm/llm.py +54 -8
  73. lightning_sdk/studio.py +40 -1
  74. lightning_sdk/teamspace.py +68 -0
  75. {lightning_sdk-0.2.15.dist-info → lightning_sdk-0.2.17.dist-info}/METADATA +1 -1
  76. {lightning_sdk-0.2.15.dist-info → lightning_sdk-0.2.17.dist-info}/RECORD +80 -58
  77. {lightning_sdk-0.2.15.dist-info → lightning_sdk-0.2.17.dist-info}/LICENSE +0 -0
  78. {lightning_sdk-0.2.15.dist-info → lightning_sdk-0.2.17.dist-info}/WHEEL +0 -0
  79. {lightning_sdk-0.2.15.dist-info → lightning_sdk-0.2.17.dist-info}/entry_points.txt +0 -0
  80. {lightning_sdk-0.2.15.dist-info → lightning_sdk-0.2.17.dist-info}/top_level.txt +0 -0
lightning_sdk/__init__.py CHANGED
@@ -31,6 +31,6 @@ __all__ = [
31
31
  "User",
32
32
  ]
33
33
 
34
- __version__ = "0.2.15"
34
+ __version__ = "0.2.17"
35
35
  _check_version_and_prompt_upgrade(__version__)
36
36
  _set_tqdm_envvars_noninteractive()
@@ -16,11 +16,17 @@ class BaseStudioApi:
16
16
  """Retrieve the base studio by its ID."""
17
17
  try:
18
18
  return self._client.cloud_space_environment_template_service_get_cloud_space_environment_template(
19
- base_studio_id, org_id
19
+ base_studio_id, org_id=org_id
20
20
  )
21
21
  except ValueError as e:
22
22
  raise ValueError(f"Base studio {base_studio_id} does not exist") from e
23
23
 
24
+ def get_all_base_studios(self, org_id: str) -> List[V1CloudSpaceEnvironmentTemplate]:
25
+ """Retrieve all base studios for a given organization."""
26
+ return self._client.cloud_space_environment_template_service_list_cloud_space_environment_templates(
27
+ org_id=org_id
28
+ )
29
+
24
30
  def update_base_studio(
25
31
  self,
26
32
  base_studio_id: str,
@@ -1,4 +1,11 @@
1
- from lightning_sdk.lightning_cloud.openapi import Externalv1Cluster
1
+ from typing import Dict, List, Optional
2
+
3
+ from lightning_sdk.lightning_cloud.openapi import (
4
+ Externalv1Cluster,
5
+ V1CloudProvider,
6
+ V1ClusterType,
7
+ V1ListClusterAcceleratorsResponse,
8
+ )
2
9
  from lightning_sdk.lightning_cloud.rest_client import LightningClient
3
10
 
4
11
 
@@ -20,3 +27,78 @@ class ClusterApi:
20
27
  if not res:
21
28
  raise ValueError(f"Cluster {cluster_id} does not exist")
22
29
  return res
30
+
31
+ def list_cluster_accelerators(self, cluster_id: str, org_id: str) -> V1ListClusterAcceleratorsResponse:
32
+ """Lists the accelerators for a given cluster.
33
+
34
+ :param cluster_id: cluster ID test
35
+ :param project_id: the project the cluster is supposed to be associated with
36
+ :param org_id: The owning org of this cluster
37
+ """
38
+ res = self._client.cluster_service_list_cluster_accelerators(
39
+ id=cluster_id,
40
+ org_id=org_id,
41
+ )
42
+ if not res:
43
+ raise ValueError(f"Cluster {cluster_id} does not exist")
44
+ return res
45
+
46
+ def list_global_clusters(self, project_id: str, org_id: str) -> List[Externalv1Cluster]:
47
+ """Lists the accelerators for a given project.
48
+
49
+ :param project_id: project ID test
50
+ :param org_id: The owning org of this project
51
+ """
52
+ res = self._client.cluster_service_list_clusters(
53
+ project_id=project_id,
54
+ org_id=org_id,
55
+ )
56
+ if not res:
57
+ raise ValueError(f"Project {project_id} does not exist")
58
+ filtered_clusters = filter(lambda x: x.spec.cluster_type == V1ClusterType.GLOBAL, res.clusters)
59
+ return list(filtered_clusters)
60
+
61
+ def get_cluster_provider_mapping(self, project_id: str, org_id: str) -> Dict[V1CloudProvider, str]:
62
+ """Gets the cluster provider mapping."""
63
+ res = self.list_global_clusters(
64
+ project_id=project_id,
65
+ org_id=org_id,
66
+ )
67
+ return {self._get_cluster_provider(cluster): cluster.id for cluster in res}
68
+
69
+ def _get_cluster_provider(self, cluster: Optional[Externalv1Cluster]) -> V1CloudProvider:
70
+ """Determines the cloud provider based on the cluster configuration.
71
+
72
+ Args:
73
+ cluster: An optional Externalv1Cluster object containing cluster specifications
74
+
75
+ Returns:
76
+ V1CloudProvider: The determined cloud provider, defaults to AWS if no match is found
77
+ """
78
+ if not cluster:
79
+ return V1CloudProvider.AWS
80
+
81
+ if (
82
+ cluster.spec
83
+ and cluster.spec.driver
84
+ and cluster.spec.driver in [V1CloudProvider.LIGHTNING, V1CloudProvider.DGX]
85
+ ):
86
+ return cluster.spec.driver
87
+
88
+ if cluster.spec:
89
+ if cluster.spec.aws_v1:
90
+ return V1CloudProvider.AWS
91
+ if cluster.spec.google_cloud_v1:
92
+ return V1CloudProvider.GCP
93
+ if cluster.spec.lambda_labs_v1:
94
+ return V1CloudProvider.LAMBDA_LABS
95
+ if cluster.spec.vultr_v1:
96
+ return V1CloudProvider.VULTR
97
+ if cluster.spec.slurm_v1:
98
+ return V1CloudProvider.SLURM
99
+ if cluster.spec.voltage_park_v1:
100
+ return V1CloudProvider.VOLTAGE_PARK
101
+ if cluster.spec.nebius_v1:
102
+ return V1CloudProvider.NEBIUS
103
+
104
+ return V1CloudProvider.AWS
@@ -1,5 +1,6 @@
1
+ import base64
1
2
  import json
2
- from typing import Generator, List, Optional, Union
3
+ from typing import Dict, Generator, List, Optional, Union
3
4
 
4
5
  from pip._vendor.urllib3 import HTTPResponse
5
6
 
@@ -55,25 +56,31 @@ class LLMApi:
55
56
  except json.JSONDecodeError:
56
57
  print("Error decoding JSON:", decoded_line)
57
58
 
59
+ def _encode_image_bytes_to_data_url(self, image: str) -> str:
60
+ with open(image, "rb") as image_file:
61
+ b64 = base64.b64encode(image_file.read()).decode("utf-8")
62
+ extension = image.split(".")[-1]
63
+ return f"data:image/{extension};base64,{b64}"
64
+
58
65
  def start_conversation(
59
66
  self,
60
67
  prompt: str,
61
68
  system_prompt: Optional[str],
62
69
  max_completion_tokens: int,
63
70
  assistant_id: str,
71
+ images: Optional[List[str]] = None,
64
72
  conversation_id: Optional[str] = None,
65
73
  billing_project_id: Optional[str] = None,
66
74
  name: Optional[str] = None,
75
+ metadata: Optional[Dict[str, str]] = None,
67
76
  stream: bool = False,
77
+ internal_conversation: bool = False,
68
78
  ) -> Union[V1ConversationResponseChunk, Generator[V1ConversationResponseChunk, None, None]]:
69
79
  body = {
70
80
  "message": {
71
81
  "author": {"role": "user"},
72
82
  "content": [
73
- {
74
- "contentType": "text",
75
- "parts": [prompt],
76
- }
83
+ {"contentType": "text", "parts": [prompt]},
77
84
  ],
78
85
  },
79
86
  "max_completion_tokens": max_completion_tokens,
@@ -81,7 +88,22 @@ class LLMApi:
81
88
  "billing_project_id": billing_project_id,
82
89
  "name": name,
83
90
  "stream": stream,
91
+ "metadata": metadata or {},
92
+ "internal_conversation": internal_conversation,
84
93
  }
94
+ if images:
95
+ for image in images:
96
+ url = image
97
+ if not image.startswith("http"):
98
+ url = self._encode_image_bytes_to_data_url(image)
99
+
100
+ body["message"]["content"].append(
101
+ {
102
+ "contentType": "image",
103
+ "parts": [url],
104
+ }
105
+ )
106
+
85
107
  result = self._client.assistants_service_start_conversation(body, assistant_id, _preload_content=not stream)
86
108
  if not stream:
87
109
  return result.result
@@ -32,9 +32,11 @@ from lightning_sdk.lightning_cloud.openapi import (
32
32
  IdForkBody1,
33
33
  IdStartBody,
34
34
  ProjectIdCloudspacesBody,
35
+ V1Assistant,
35
36
  V1CloudSpace,
36
37
  V1CloudSpaceInstanceConfig,
37
38
  V1CloudSpaceSeedFile,
39
+ V1CloudSpaceSourceType,
38
40
  V1CloudSpaceState,
39
41
  V1EndpointType,
40
42
  V1GetCloudSpaceInstanceStatusResponse,
@@ -46,6 +48,16 @@ from lightning_sdk.lightning_cloud.openapi import (
46
48
  V1UserRequestedComputeConfig,
47
49
  )
48
50
  from lightning_sdk.lightning_cloud.openapi.models import ProjectIdEndpointsBody
51
+ from lightning_sdk.lightning_cloud.openapi.models.project_id_agentmanagedendpoints_body import (
52
+ ProjectIdAgentmanagedendpointsBody,
53
+ )
54
+ from lightning_sdk.lightning_cloud.openapi.models.project_id_agents_body import (
55
+ ProjectIdAgentsBody,
56
+ )
57
+ from lightning_sdk.lightning_cloud.openapi.models.v1_endpoint import V1Endpoint
58
+ from lightning_sdk.lightning_cloud.openapi.models.v1_managed_endpoint import V1ManagedEndpoint
59
+ from lightning_sdk.lightning_cloud.openapi.models.v1_managed_model import V1ManagedModel
60
+ from lightning_sdk.lightning_cloud.openapi.models.v1_upstream_managed import V1UpstreamManaged
49
61
  from lightning_sdk.lightning_cloud.rest_client import LightningClient
50
62
  from lightning_sdk.machine import Machine
51
63
 
@@ -110,6 +122,7 @@ class StudioApi:
110
122
  name: str,
111
123
  teamspace_id: str,
112
124
  cloud_account: Optional[str] = None,
125
+ source: Optional[V1CloudSpaceSourceType] = None,
113
126
  ) -> V1CloudSpace:
114
127
  """Create a Studio with a given name in a given Teamspace on a possibly given cloud_account."""
115
128
  body = ProjectIdCloudspacesBody(
@@ -117,6 +130,7 @@ class StudioApi:
117
130
  name=name,
118
131
  display_name=name,
119
132
  seed_files=[V1CloudSpaceSeedFile(path="main.py", contents="print('Hello, Lightning World!')\n")],
133
+ source=source,
120
134
  )
121
135
  studio = self._client.cloud_space_service_create_cloud_space(body, teamspace_id)
122
136
 
@@ -731,6 +745,56 @@ class StudioApi:
731
745
  )
732
746
  return endpoint.urls[0]
733
747
 
748
+ def create_assistant(self, studio_id: str, teamspace_id: str, port: int, assistant_name: str) -> V1Assistant:
749
+ target_teamspace = self._client.projects_service_get_project(teamspace_id)
750
+ org_id = ""
751
+ if target_teamspace.owner_type == "ORGANIZATION":
752
+ org_id = target_teamspace.owner_id
753
+ endpoint = self._client.endpoint_service_create_endpoint(
754
+ project_id=teamspace_id,
755
+ body=ProjectIdEndpointsBody(
756
+ ports=[str(port)],
757
+ cloudspace=V1UpstreamCloudSpace(
758
+ cloudspace_id=studio_id,
759
+ port=str(port),
760
+ type=V1EndpointType.PLUGIN_API,
761
+ ),
762
+ ),
763
+ )
764
+ valid_url = endpoint.urls[0]
765
+ managed_endpoint = self._client.assistants_service_create_assistant_managed_endpoint(
766
+ body=ProjectIdAgentmanagedendpointsBody(
767
+ endpoint=V1ManagedEndpoint(
768
+ name=assistant_name,
769
+ base_url=valid_url + "/v1",
770
+ models_metadata=[
771
+ V1ManagedModel(
772
+ name=assistant_name,
773
+ )
774
+ ],
775
+ ),
776
+ org_id=org_id,
777
+ ),
778
+ project_id=teamspace_id,
779
+ )
780
+
781
+ body = ProjectIdAgentsBody(
782
+ endpoint=V1Endpoint(
783
+ cloudspace=V1UpstreamCloudSpace(cloudspace_id=studio_id),
784
+ name=assistant_name,
785
+ managed=V1UpstreamManaged(id=managed_endpoint.endpoint.id),
786
+ ),
787
+ name=assistant_name,
788
+ model=assistant_name,
789
+ cloudspace_id=studio_id,
790
+ model_provider="",
791
+ )
792
+
793
+ return self._client.assistants_service_create_assistant(
794
+ body=body,
795
+ project_id=teamspace_id,
796
+ )
797
+
734
798
  def _create_app(
735
799
  self, studio_id: str, teamspace_id: str, cloud_account: str, plugin_type: str, **other_arguments: Any
736
800
  ) -> Externalv1LightningappInstance:
@@ -1,10 +1,20 @@
1
1
  import os
2
+ import tempfile
3
+ import zipfile
2
4
  from pathlib import Path
3
5
  from typing import Dict, List, Optional, Tuple
4
6
 
7
+ import requests
5
8
  from tqdm.auto import tqdm
6
9
 
7
- from lightning_sdk.api.utils import _download_model_files, _DummyBody, _get_model_version, _ModelFileUploader
10
+ from lightning_sdk.api.utils import (
11
+ _download_model_files,
12
+ _DummyBody,
13
+ _FileUploader,
14
+ _get_model_version,
15
+ _ModelFileUploader,
16
+ _resolve_teamspace_remote_path,
17
+ )
8
18
  from lightning_sdk.lightning_cloud.login import Auth
9
19
  from lightning_sdk.lightning_cloud.openapi import (
10
20
  Externalv1LightningappInstance,
@@ -17,6 +27,7 @@ from lightning_sdk.lightning_cloud.openapi import (
17
27
  V1ClusterAccelerator,
18
28
  V1Endpoint,
19
29
  V1Job,
30
+ V1LoginRequest,
20
31
  V1Model,
21
32
  V1ModelVersionArchive,
22
33
  V1MultiMachineJob,
@@ -331,3 +342,118 @@ class TeamspaceApi:
331
342
  model_id = self.get_model(teamspace_id=teamspace_id, model_name=model_name).id
332
343
  response = self.models_api.models_store_list_model_versions(project_id=teamspace_id, model_id=model_id)
333
344
  return response.versions
345
+
346
+ def upload_file(
347
+ self,
348
+ teamspace_id: str,
349
+ cloud_account: str,
350
+ file_path: str,
351
+ remote_path: str,
352
+ progress_bar: bool,
353
+ ) -> None:
354
+ """Uploads file to given remote path in the Teamspace drive."""
355
+ _FileUploader(
356
+ client=self._client,
357
+ teamspace_id=teamspace_id,
358
+ cloud_account=cloud_account,
359
+ file_path=file_path,
360
+ remote_path=_resolve_teamspace_remote_path(remote_path),
361
+ progress_bar=progress_bar,
362
+ )()
363
+
364
+ def download_file(
365
+ self,
366
+ path: str,
367
+ target_path: str,
368
+ teamspace_id: str,
369
+ cloud_account: str,
370
+ progress_bar: bool = True,
371
+ ) -> None:
372
+ """Downloads a given file in Teamspace drive to a target location."""
373
+ # TODO: Update this endpoint to permit basic auth
374
+ auth = Auth()
375
+ auth.authenticate()
376
+ token = self._client.auth_service_login(V1LoginRequest(auth.api_key)).token
377
+
378
+ query_params = {
379
+ "clusterId": cloud_account,
380
+ "key": _resolve_teamspace_remote_path(path),
381
+ "token": token,
382
+ }
383
+
384
+ r = requests.get(
385
+ f"{self._client.api_client.configuration.host}/v1/projects/{teamspace_id}/artifacts/download",
386
+ params=query_params,
387
+ stream=True,
388
+ )
389
+ total_length = int(r.headers.get("content-length"))
390
+
391
+ if progress_bar:
392
+ pbar = tqdm(
393
+ desc=f"Downloading {os.path.split(path)[1]}",
394
+ total=total_length,
395
+ unit="B",
396
+ unit_scale=True,
397
+ unit_divisor=1000,
398
+ )
399
+
400
+ pbar_update = pbar.update
401
+ else:
402
+ pbar_update = lambda x: None
403
+
404
+ target_dir = os.path.split(target_path)[0]
405
+ if target_dir:
406
+ os.makedirs(target_dir, exist_ok=True)
407
+ with open(target_path, "wb") as f:
408
+ for chunk in r.iter_content(chunk_size=4096 * 8):
409
+ f.write(chunk)
410
+ pbar_update(len(chunk))
411
+
412
+ def download_folder(
413
+ self,
414
+ path: str,
415
+ target_path: str,
416
+ teamspace_id: str,
417
+ cloud_account: str,
418
+ progress_bar: bool = True,
419
+ ) -> None:
420
+ """Downloads a given folder from Teamspace drive to a target location."""
421
+ # TODO: Update this endpoint to permit basic auth
422
+ auth = Auth()
423
+ auth.authenticate()
424
+ token = self._client.auth_service_login(V1LoginRequest(auth.api_key)).token
425
+
426
+ query_params = {
427
+ "clusterId": cloud_account,
428
+ "prefix": _resolve_teamspace_remote_path(path),
429
+ "token": token,
430
+ }
431
+
432
+ r = requests.get(
433
+ f"{self._client.api_client.configuration.host}/v1/projects/{teamspace_id}/artifacts/download",
434
+ params=query_params,
435
+ stream=True,
436
+ )
437
+
438
+ if progress_bar:
439
+ pbar = tqdm(
440
+ desc=f"Downloading {os.path.split(path)[1]}",
441
+ unit="B",
442
+ unit_scale=True,
443
+ unit_divisor=1000,
444
+ )
445
+
446
+ pbar_update = pbar.update
447
+ else:
448
+ pbar_update = lambda x: None
449
+
450
+ if target_path:
451
+ os.makedirs(target_path, exist_ok=True)
452
+
453
+ with tempfile.TemporaryFile() as f:
454
+ for chunk in r.iter_content(chunk_size=4096 * 8):
455
+ f.write(chunk)
456
+ pbar_update(len(chunk))
457
+
458
+ with zipfile.ZipFile(f) as z:
459
+ z.extractall(target_path)
@@ -355,6 +355,10 @@ def _sanitize_studio_remote_path(path: str, studio_id: str) -> str:
355
355
  return f"/cloudspaces/{studio_id}/code/content/{path.replace('/teamspace/studios/this_studio/', '')}"
356
356
 
357
357
 
358
+ def _resolve_teamspace_remote_path(path: str) -> str:
359
+ return f"/Uploads/{path.replace('/teamspace/studios/this_studio/', '')}"
360
+
361
+
358
362
  _DOWNLOAD_REQUEST_CHUNK_SIZE = 10 * _BYTES_PER_MB
359
363
  _DOWNLOAD_MIN_CHUNK_SIZE = 100 * _BYTES_PER_KB
360
364
 
@@ -45,7 +45,12 @@ class BaseStudio:
45
45
 
46
46
  self._base_studio_api = BaseStudioApi()
47
47
 
48
- self._base_studio = self._base_studio_api.get_base_studio(name, self._org.id)
48
+ if name is not None:
49
+ base_studio = self._base_studio_api.get_base_studio(name, self._org.id)
50
+
51
+ if base_studio is None:
52
+ raise ValueError(f"Base studio with name {name} does not exist in organization {self._org.name}")
53
+ self._base_studio = base_studio
49
54
 
50
55
  def update(
51
56
  self,
@@ -68,3 +73,11 @@ class BaseStudio:
68
73
  setup_script_text=setup_script_text,
69
74
  disabled=disabled,
70
75
  )
76
+
77
+ def list(self) -> List[V1CloudSpaceEnvironmentTemplate]:
78
+ """List all base studios in the organization.
79
+
80
+ Returns:
81
+ List[V1CloudSpaceEnvironmentTemplate]: A list of base studio templates.
82
+ """
83
+ return self._base_studio_api.get_all_base_studios(self._org.id)
@@ -7,9 +7,12 @@ import click
7
7
  from rich.console import Console
8
8
 
9
9
  from lightning_sdk import Machine, Studio
10
+ from lightning_sdk.api.cluster_api import ClusterApi
10
11
  from lightning_sdk.cli.teamspace_menu import _TeamspacesMenu
12
+ from lightning_sdk.studio import Provider
11
13
 
12
14
  _MACHINE_VALUES = tuple([machine.name for machine in Machine.__dict__.values() if isinstance(machine, Machine)])
15
+ _PROVIDER_VALUES = tuple([provider.value for provider in Provider])
13
16
 
14
17
 
15
18
  @click.group("create")
@@ -44,8 +47,18 @@ def create() -> None:
44
47
  "or fall back to the teamspace default."
45
48
  ),
46
49
  )
50
+ @click.option(
51
+ "--provider",
52
+ default=None,
53
+ type=click.Choice(_PROVIDER_VALUES),
54
+ help="The provider to create the studio on. If --cloud-account is specified, this option is prioritized.",
55
+ )
47
56
  def studio(
48
- name: str, teamspace: Optional[str] = None, start: Optional[str] = None, cloud_account: Optional[str] = None
57
+ name: str,
58
+ teamspace: Optional[str] = None,
59
+ start: Optional[str] = None,
60
+ cloud_account: Optional[str] = None,
61
+ provider: Optional[str] = None,
49
62
  ) -> None:
50
63
  """Create a new studio on the Lightning AI platform.
51
64
 
@@ -57,6 +70,13 @@ def studio(
57
70
  menu = _TeamspacesMenu()
58
71
  teamspace_resolved = menu._resolve_teamspace(teamspace)
59
72
 
73
+ if provider is not None:
74
+ cluster_api = ClusterApi()
75
+ cloud_account = cluster_api.get_cluster_provider_mapping(
76
+ teamspace_resolved.id,
77
+ teamspace_resolved.owner.id,
78
+ )[provider]
79
+
60
80
  # default cloud account to current studios cloud account if run from studio
61
81
  # else it will fall back to teamspace default in the backend
62
82
  if cloud_account is None:
File without changes