lightning-sdk 0.2.15__py3-none-any.whl → 0.2.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. lightning_sdk/__init__.py +1 -1
  2. lightning_sdk/api/base_studio_api.py +7 -1
  3. lightning_sdk/api/cluster_api.py +83 -1
  4. lightning_sdk/api/llm_api.py +24 -5
  5. lightning_sdk/api/studio_api.py +3 -0
  6. lightning_sdk/api/teamspace_api.py +127 -1
  7. lightning_sdk/api/utils.py +4 -0
  8. lightning_sdk/base_studio.py +14 -1
  9. lightning_sdk/cli/create.py +21 -1
  10. lightning_sdk/cli/deploy/__init__.py +0 -0
  11. lightning_sdk/cli/deploy/_auth.py +189 -0
  12. lightning_sdk/cli/deploy/devbox.py +157 -0
  13. lightning_sdk/cli/{serve.py → deploy/serve.py} +11 -322
  14. lightning_sdk/cli/download.py +44 -16
  15. lightning_sdk/cli/entrypoint.py +1 -1
  16. lightning_sdk/cli/open.py +21 -2
  17. lightning_sdk/cli/start.py +12 -3
  18. lightning_sdk/cli/upload.py +2 -4
  19. lightning_sdk/lightning_cloud/openapi/__init__.py +18 -0
  20. lightning_sdk/lightning_cloud/openapi/api/assistants_service_api.py +121 -0
  21. lightning_sdk/lightning_cloud/openapi/api/cloud_space_service_api.py +105 -0
  22. lightning_sdk/lightning_cloud/openapi/api/cluster_service_api.py +105 -0
  23. lightning_sdk/lightning_cloud/openapi/api/jobs_service_api.py +747 -105
  24. lightning_sdk/lightning_cloud/openapi/api/storage_service_api.py +93 -0
  25. lightning_sdk/lightning_cloud/openapi/models/__init__.py +18 -0
  26. lightning_sdk/lightning_cloud/openapi/models/assistant_id_conversations_body.py +27 -1
  27. lightning_sdk/lightning_cloud/openapi/models/cloudspaces_id_body.py +53 -1
  28. lightning_sdk/lightning_cloud/openapi/models/deployment_id_alertingpolicies_body.py +331 -0
  29. lightning_sdk/lightning_cloud/openapi/models/deployment_id_alertingpolicies_body1.py +305 -0
  30. lightning_sdk/lightning_cloud/openapi/models/deployments_id_body.py +53 -1
  31. lightning_sdk/lightning_cloud/openapi/models/models_id_body.py +123 -0
  32. lightning_sdk/lightning_cloud/openapi/models/orgs_id_body.py +105 -1
  33. lightning_sdk/lightning_cloud/openapi/models/project_id_cloudspaces_body.py +27 -1
  34. lightning_sdk/lightning_cloud/openapi/models/projects_id_body.py +29 -3
  35. lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space.py +53 -1
  36. lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space_source_type.py +103 -0
  37. lightning_sdk/lightning_cloud/openapi/models/v1_cluster_tagging_options.py +27 -1
  38. lightning_sdk/lightning_cloud/openapi/models/v1_delete_deployment_alerting_policy_response.py +175 -0
  39. lightning_sdk/lightning_cloud/openapi/models/v1_deployment.py +53 -1
  40. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_alerting_event.py +487 -0
  41. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_alerting_policy.py +383 -0
  42. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_alerting_policy_frequency.py +105 -0
  43. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_alerting_policy_operation.py +105 -0
  44. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_alerting_policy_severity.py +106 -0
  45. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_alerting_policy_type.py +111 -0
  46. lightning_sdk/lightning_cloud/openapi/models/v1_ge_list_deployment_routing_telemetry_response.py +27 -1
  47. lightning_sdk/lightning_cloud/openapi/models/v1_get_cloud_space_instance_open_ports_response.py +123 -0
  48. lightning_sdk/lightning_cloud/openapi/models/v1_get_deployment_routing_telemetry_content_response.py +123 -0
  49. lightning_sdk/lightning_cloud/openapi/models/v1_get_organization_storage_metadata_response.py +331 -0
  50. lightning_sdk/lightning_cloud/openapi/models/v1_get_user_response.py +1 -27
  51. lightning_sdk/lightning_cloud/openapi/models/v1_google_cloud_direct_v1.py +27 -1
  52. lightning_sdk/lightning_cloud/openapi/models/v1_list_deployment_alerting_events_response.py +123 -0
  53. lightning_sdk/lightning_cloud/openapi/models/v1_list_deployment_alerting_policies_response.py +175 -0
  54. lightning_sdk/lightning_cloud/openapi/models/v1_membership.py +27 -1
  55. lightning_sdk/lightning_cloud/openapi/models/v1_organization.py +105 -1
  56. lightning_sdk/lightning_cloud/openapi/models/v1_project.py +27 -1
  57. lightning_sdk/lightning_cloud/openapi/models/v1_project_membership.py +27 -1
  58. lightning_sdk/lightning_cloud/openapi/models/v1_project_settings.py +29 -3
  59. lightning_sdk/lightning_cloud/openapi/models/v1_project_storage.py +53 -1
  60. lightning_sdk/lightning_cloud/openapi/models/v1_routing_telemetry.py +253 -0
  61. lightning_sdk/lightning_cloud/openapi/models/v1_server_alert_type.py +1 -0
  62. lightning_sdk/lightning_cloud/openapi/models/v1_sleep_server_response.py +97 -0
  63. lightning_sdk/lightning_cloud/openapi/models/v1_update_user_request.py +1 -27
  64. lightning_sdk/lightning_cloud/openapi/models/v1_user_features.py +79 -27
  65. lightning_sdk/lightning_cloud/openapi/models/v1_user_requested_compute_config.py +27 -1
  66. lightning_sdk/llm/llm.py +52 -8
  67. lightning_sdk/studio.py +32 -1
  68. lightning_sdk/teamspace.py +68 -0
  69. {lightning_sdk-0.2.15.dist-info → lightning_sdk-0.2.16.dist-info}/METADATA +1 -1
  70. {lightning_sdk-0.2.15.dist-info → lightning_sdk-0.2.16.dist-info}/RECORD +74 -53
  71. {lightning_sdk-0.2.15.dist-info → lightning_sdk-0.2.16.dist-info}/LICENSE +0 -0
  72. {lightning_sdk-0.2.15.dist-info → lightning_sdk-0.2.16.dist-info}/WHEEL +0 -0
  73. {lightning_sdk-0.2.15.dist-info → lightning_sdk-0.2.16.dist-info}/entry_points.txt +0 -0
  74. {lightning_sdk-0.2.15.dist-info → lightning_sdk-0.2.16.dist-info}/top_level.txt +0 -0
lightning_sdk/__init__.py CHANGED
@@ -31,6 +31,6 @@ __all__ = [
31
31
  "User",
32
32
  ]
33
33
 
34
- __version__ = "0.2.15"
34
+ __version__ = "0.2.16"
35
35
  _check_version_and_prompt_upgrade(__version__)
36
36
  _set_tqdm_envvars_noninteractive()
@@ -16,11 +16,17 @@ class BaseStudioApi:
16
16
  """Retrieve the base studio by its ID."""
17
17
  try:
18
18
  return self._client.cloud_space_environment_template_service_get_cloud_space_environment_template(
19
- base_studio_id, org_id
19
+ base_studio_id, org_id=org_id
20
20
  )
21
21
  except ValueError as e:
22
22
  raise ValueError(f"Base studio {base_studio_id} does not exist") from e
23
23
 
24
+ def get_all_base_studios(self, org_id: str) -> List[V1CloudSpaceEnvironmentTemplate]:
25
+ """Retrieve all base studios for a given organization."""
26
+ return self._client.cloud_space_environment_template_service_list_cloud_space_environment_templates(
27
+ org_id=org_id
28
+ )
29
+
24
30
  def update_base_studio(
25
31
  self,
26
32
  base_studio_id: str,
@@ -1,4 +1,11 @@
1
- from lightning_sdk.lightning_cloud.openapi import Externalv1Cluster
1
+ from typing import Dict, List, Optional
2
+
3
+ from lightning_sdk.lightning_cloud.openapi import (
4
+ Externalv1Cluster,
5
+ V1CloudProvider,
6
+ V1ClusterType,
7
+ V1ListClusterAcceleratorsResponse,
8
+ )
2
9
  from lightning_sdk.lightning_cloud.rest_client import LightningClient
3
10
 
4
11
 
@@ -20,3 +27,78 @@ class ClusterApi:
20
27
  if not res:
21
28
  raise ValueError(f"Cluster {cluster_id} does not exist")
22
29
  return res
30
+
31
+ def list_cluster_accelerators(self, cluster_id: str, org_id: str) -> V1ListClusterAcceleratorsResponse:
32
+ """Lists the accelerators for a given cluster.
33
+
34
+ :param cluster_id: cluster ID test
35
+ :param project_id: the project the cluster is supposed to be associated with
36
+ :param org_id: The owning org of this cluster
37
+ """
38
+ res = self._client.cluster_service_list_cluster_accelerators(
39
+ id=cluster_id,
40
+ org_id=org_id,
41
+ )
42
+ if not res:
43
+ raise ValueError(f"Cluster {cluster_id} does not exist")
44
+ return res
45
+
46
+ def list_global_clusters(self, project_id: str, org_id: str) -> List[Externalv1Cluster]:
47
+ """Lists the accelerators for a given project.
48
+
49
+ :param project_id: project ID test
50
+ :param org_id: The owning org of this project
51
+ """
52
+ res = self._client.cluster_service_list_clusters(
53
+ project_id=project_id,
54
+ org_id=org_id,
55
+ )
56
+ if not res:
57
+ raise ValueError(f"Project {project_id} does not exist")
58
+ filtered_clusters = filter(lambda x: x.spec.cluster_type == V1ClusterType.GLOBAL, res.clusters)
59
+ return list(filtered_clusters)
60
+
61
+ def get_cluster_provider_mapping(self, project_id: str, org_id: str) -> Dict[V1CloudProvider, str]:
62
+ """Gets the cluster provider mapping."""
63
+ res = self.list_global_clusters(
64
+ project_id=project_id,
65
+ org_id=org_id,
66
+ )
67
+ return {self._get_cluster_provider(cluster): cluster.id for cluster in res}
68
+
69
+ def _get_cluster_provider(self, cluster: Optional[Externalv1Cluster]) -> V1CloudProvider:
70
+ """Determines the cloud provider based on the cluster configuration.
71
+
72
+ Args:
73
+ cluster: An optional Externalv1Cluster object containing cluster specifications
74
+
75
+ Returns:
76
+ V1CloudProvider: The determined cloud provider, defaults to AWS if no match is found
77
+ """
78
+ if not cluster:
79
+ return V1CloudProvider.AWS
80
+
81
+ if (
82
+ cluster.spec
83
+ and cluster.spec.driver
84
+ and cluster.spec.driver in [V1CloudProvider.LIGHTNING, V1CloudProvider.DGX]
85
+ ):
86
+ return cluster.spec.driver
87
+
88
+ if cluster.spec:
89
+ if cluster.spec.aws_v1:
90
+ return V1CloudProvider.AWS
91
+ if cluster.spec.google_cloud_v1:
92
+ return V1CloudProvider.GCP
93
+ if cluster.spec.lambda_labs_v1:
94
+ return V1CloudProvider.LAMBDA_LABS
95
+ if cluster.spec.vultr_v1:
96
+ return V1CloudProvider.VULTR
97
+ if cluster.spec.slurm_v1:
98
+ return V1CloudProvider.SLURM
99
+ if cluster.spec.voltage_park_v1:
100
+ return V1CloudProvider.VOLTAGE_PARK
101
+ if cluster.spec.nebius_v1:
102
+ return V1CloudProvider.NEBIUS
103
+
104
+ return V1CloudProvider.AWS
@@ -1,5 +1,6 @@
1
+ import base64
1
2
  import json
2
- from typing import Generator, List, Optional, Union
3
+ from typing import Dict, Generator, List, Optional, Union
3
4
 
4
5
  from pip._vendor.urllib3 import HTTPResponse
5
6
 
@@ -55,25 +56,29 @@ class LLMApi:
55
56
  except json.JSONDecodeError:
56
57
  print("Error decoding JSON:", decoded_line)
57
58
 
59
+ def _encode_image_bytes_to_data_url(self, image: str, mime_type: str = "image/jpeg") -> str:
60
+ with open(image, "rb") as image_file:
61
+ b64 = base64.b64encode(image_file.read()).decode("utf-8")
62
+ return f"data:{mime_type};base64,{b64}"
63
+
58
64
  def start_conversation(
59
65
  self,
60
66
  prompt: str,
61
67
  system_prompt: Optional[str],
62
68
  max_completion_tokens: int,
63
69
  assistant_id: str,
70
+ images: Optional[List[str]] = None,
64
71
  conversation_id: Optional[str] = None,
65
72
  billing_project_id: Optional[str] = None,
66
73
  name: Optional[str] = None,
74
+ metadata: Optional[Dict[str, str]] = None,
67
75
  stream: bool = False,
68
76
  ) -> Union[V1ConversationResponseChunk, Generator[V1ConversationResponseChunk, None, None]]:
69
77
  body = {
70
78
  "message": {
71
79
  "author": {"role": "user"},
72
80
  "content": [
73
- {
74
- "contentType": "text",
75
- "parts": [prompt],
76
- }
81
+ {"contentType": "text", "parts": [prompt]},
77
82
  ],
78
83
  },
79
84
  "max_completion_tokens": max_completion_tokens,
@@ -81,7 +86,21 @@ class LLMApi:
81
86
  "billing_project_id": billing_project_id,
82
87
  "name": name,
83
88
  "stream": stream,
89
+ "metadata": metadata or {},
84
90
  }
91
+ if images:
92
+ for image in images:
93
+ url = image
94
+ if not image.startswith("http"):
95
+ url = self._encode_image_bytes_to_data_url(image)
96
+
97
+ body["message"]["content"].append(
98
+ {
99
+ "contentType": "image",
100
+ "parts": [url],
101
+ }
102
+ )
103
+
85
104
  result = self._client.assistants_service_start_conversation(body, assistant_id, _preload_content=not stream)
86
105
  if not stream:
87
106
  return result.result
@@ -35,6 +35,7 @@ from lightning_sdk.lightning_cloud.openapi import (
35
35
  V1CloudSpace,
36
36
  V1CloudSpaceInstanceConfig,
37
37
  V1CloudSpaceSeedFile,
38
+ V1CloudSpaceSourceType,
38
39
  V1CloudSpaceState,
39
40
  V1EndpointType,
40
41
  V1GetCloudSpaceInstanceStatusResponse,
@@ -110,6 +111,7 @@ class StudioApi:
110
111
  name: str,
111
112
  teamspace_id: str,
112
113
  cloud_account: Optional[str] = None,
114
+ source: Optional[V1CloudSpaceSourceType] = None,
113
115
  ) -> V1CloudSpace:
114
116
  """Create a Studio with a given name in a given Teamspace on a possibly given cloud_account."""
115
117
  body = ProjectIdCloudspacesBody(
@@ -117,6 +119,7 @@ class StudioApi:
117
119
  name=name,
118
120
  display_name=name,
119
121
  seed_files=[V1CloudSpaceSeedFile(path="main.py", contents="print('Hello, Lightning World!')\n")],
122
+ source=source,
120
123
  )
121
124
  studio = self._client.cloud_space_service_create_cloud_space(body, teamspace_id)
122
125
 
@@ -1,10 +1,20 @@
1
1
  import os
2
+ import tempfile
3
+ import zipfile
2
4
  from pathlib import Path
3
5
  from typing import Dict, List, Optional, Tuple
4
6
 
7
+ import requests
5
8
  from tqdm.auto import tqdm
6
9
 
7
- from lightning_sdk.api.utils import _download_model_files, _DummyBody, _get_model_version, _ModelFileUploader
10
+ from lightning_sdk.api.utils import (
11
+ _download_model_files,
12
+ _DummyBody,
13
+ _FileUploader,
14
+ _get_model_version,
15
+ _ModelFileUploader,
16
+ _resolve_teamspace_remote_path,
17
+ )
8
18
  from lightning_sdk.lightning_cloud.login import Auth
9
19
  from lightning_sdk.lightning_cloud.openapi import (
10
20
  Externalv1LightningappInstance,
@@ -17,6 +27,7 @@ from lightning_sdk.lightning_cloud.openapi import (
17
27
  V1ClusterAccelerator,
18
28
  V1Endpoint,
19
29
  V1Job,
30
+ V1LoginRequest,
20
31
  V1Model,
21
32
  V1ModelVersionArchive,
22
33
  V1MultiMachineJob,
@@ -331,3 +342,118 @@ class TeamspaceApi:
331
342
  model_id = self.get_model(teamspace_id=teamspace_id, model_name=model_name).id
332
343
  response = self.models_api.models_store_list_model_versions(project_id=teamspace_id, model_id=model_id)
333
344
  return response.versions
345
+
346
+ def upload_file(
347
+ self,
348
+ teamspace_id: str,
349
+ cloud_account: str,
350
+ file_path: str,
351
+ remote_path: str,
352
+ progress_bar: bool,
353
+ ) -> None:
354
+ """Uploads file to given remote path in the Teamspace drive."""
355
+ _FileUploader(
356
+ client=self._client,
357
+ teamspace_id=teamspace_id,
358
+ cloud_account=cloud_account,
359
+ file_path=file_path,
360
+ remote_path=_resolve_teamspace_remote_path(remote_path),
361
+ progress_bar=progress_bar,
362
+ )()
363
+
364
+ def download_file(
365
+ self,
366
+ path: str,
367
+ target_path: str,
368
+ teamspace_id: str,
369
+ cloud_account: str,
370
+ progress_bar: bool = True,
371
+ ) -> None:
372
+ """Downloads a given file in Teamspace drive to a target location."""
373
+ # TODO: Update this endpoint to permit basic auth
374
+ auth = Auth()
375
+ auth.authenticate()
376
+ token = self._client.auth_service_login(V1LoginRequest(auth.api_key)).token
377
+
378
+ query_params = {
379
+ "clusterId": cloud_account,
380
+ "key": _resolve_teamspace_remote_path(path),
381
+ "token": token,
382
+ }
383
+
384
+ r = requests.get(
385
+ f"{self._client.api_client.configuration.host}/v1/projects/{teamspace_id}/artifacts/download",
386
+ params=query_params,
387
+ stream=True,
388
+ )
389
+ total_length = int(r.headers.get("content-length"))
390
+
391
+ if progress_bar:
392
+ pbar = tqdm(
393
+ desc=f"Downloading {os.path.split(path)[1]}",
394
+ total=total_length,
395
+ unit="B",
396
+ unit_scale=True,
397
+ unit_divisor=1000,
398
+ )
399
+
400
+ pbar_update = pbar.update
401
+ else:
402
+ pbar_update = lambda x: None
403
+
404
+ target_dir = os.path.split(target_path)[0]
405
+ if target_dir:
406
+ os.makedirs(target_dir, exist_ok=True)
407
+ with open(target_path, "wb") as f:
408
+ for chunk in r.iter_content(chunk_size=4096 * 8):
409
+ f.write(chunk)
410
+ pbar_update(len(chunk))
411
+
412
+ def download_folder(
413
+ self,
414
+ path: str,
415
+ target_path: str,
416
+ teamspace_id: str,
417
+ cloud_account: str,
418
+ progress_bar: bool = True,
419
+ ) -> None:
420
+ """Downloads a given folder from Teamspace drive to a target location."""
421
+ # TODO: Update this endpoint to permit basic auth
422
+ auth = Auth()
423
+ auth.authenticate()
424
+ token = self._client.auth_service_login(V1LoginRequest(auth.api_key)).token
425
+
426
+ query_params = {
427
+ "clusterId": cloud_account,
428
+ "prefix": _resolve_teamspace_remote_path(path),
429
+ "token": token,
430
+ }
431
+
432
+ r = requests.get(
433
+ f"{self._client.api_client.configuration.host}/v1/projects/{teamspace_id}/artifacts/download",
434
+ params=query_params,
435
+ stream=True,
436
+ )
437
+
438
+ if progress_bar:
439
+ pbar = tqdm(
440
+ desc=f"Downloading {os.path.split(path)[1]}",
441
+ unit="B",
442
+ unit_scale=True,
443
+ unit_divisor=1000,
444
+ )
445
+
446
+ pbar_update = pbar.update
447
+ else:
448
+ pbar_update = lambda x: None
449
+
450
+ if target_path:
451
+ os.makedirs(target_path, exist_ok=True)
452
+
453
+ with tempfile.TemporaryFile() as f:
454
+ for chunk in r.iter_content(chunk_size=4096 * 8):
455
+ f.write(chunk)
456
+ pbar_update(len(chunk))
457
+
458
+ with zipfile.ZipFile(f) as z:
459
+ z.extractall(target_path)
@@ -355,6 +355,10 @@ def _sanitize_studio_remote_path(path: str, studio_id: str) -> str:
355
355
  return f"/cloudspaces/{studio_id}/code/content/{path.replace('/teamspace/studios/this_studio/', '')}"
356
356
 
357
357
 
358
+ def _resolve_teamspace_remote_path(path: str) -> str:
359
+ return f"/Uploads/{path.replace('/teamspace/studios/this_studio/', '')}"
360
+
361
+
358
362
  _DOWNLOAD_REQUEST_CHUNK_SIZE = 10 * _BYTES_PER_MB
359
363
  _DOWNLOAD_MIN_CHUNK_SIZE = 100 * _BYTES_PER_KB
360
364
 
@@ -45,7 +45,12 @@ class BaseStudio:
45
45
 
46
46
  self._base_studio_api = BaseStudioApi()
47
47
 
48
- self._base_studio = self._base_studio_api.get_base_studio(name, self._org.id)
48
+ if name is not None:
49
+ base_studio = self._base_studio_api.get_base_studio(name, self._org.id)
50
+
51
+ if base_studio is None:
52
+ raise ValueError(f"Base studio with name {name} does not exist in organization {self._org.name}")
53
+ self._base_studio = base_studio
49
54
 
50
55
  def update(
51
56
  self,
@@ -68,3 +73,11 @@ class BaseStudio:
68
73
  setup_script_text=setup_script_text,
69
74
  disabled=disabled,
70
75
  )
76
+
77
+ def list(self) -> List[V1CloudSpaceEnvironmentTemplate]:
78
+ """List all base studios in the organization.
79
+
80
+ Returns:
81
+ List[V1CloudSpaceEnvironmentTemplate]: A list of base studio templates.
82
+ """
83
+ return self._base_studio_api.get_all_base_studios(self._org.id)
@@ -7,9 +7,12 @@ import click
7
7
  from rich.console import Console
8
8
 
9
9
  from lightning_sdk import Machine, Studio
10
+ from lightning_sdk.api.cluster_api import ClusterApi
10
11
  from lightning_sdk.cli.teamspace_menu import _TeamspacesMenu
12
+ from lightning_sdk.studio import Provider
11
13
 
12
14
  _MACHINE_VALUES = tuple([machine.name for machine in Machine.__dict__.values() if isinstance(machine, Machine)])
15
+ _PROVIDER_VALUES = tuple([provider.value for provider in Provider])
13
16
 
14
17
 
15
18
  @click.group("create")
@@ -44,8 +47,18 @@ def create() -> None:
44
47
  "or fall back to the teamspace default."
45
48
  ),
46
49
  )
50
+ @click.option(
51
+ "--provider",
52
+ default=None,
53
+ type=click.Choice(_PROVIDER_VALUES),
54
+ help="The provider to create the studio on. If --cloud-account is specified, this option is prioritized.",
55
+ )
47
56
  def studio(
48
- name: str, teamspace: Optional[str] = None, start: Optional[str] = None, cloud_account: Optional[str] = None
57
+ name: str,
58
+ teamspace: Optional[str] = None,
59
+ start: Optional[str] = None,
60
+ cloud_account: Optional[str] = None,
61
+ provider: Optional[str] = None,
49
62
  ) -> None:
50
63
  """Create a new studio on the Lightning AI platform.
51
64
 
@@ -57,6 +70,13 @@ def studio(
57
70
  menu = _TeamspacesMenu()
58
71
  teamspace_resolved = menu._resolve_teamspace(teamspace)
59
72
 
73
+ if provider is not None:
74
+ cluster_api = ClusterApi()
75
+ cloud_account = cluster_api.get_cluster_provider_mapping(
76
+ teamspace_resolved.id,
77
+ teamspace_resolved.owner.id,
78
+ )[provider]
79
+
60
80
  # default cloud account to current studios cloud account if run from studio
61
81
  # else it will fall back to teamspace default in the backend
62
82
  if cloud_account is None:
File without changes
@@ -0,0 +1,189 @@
1
+ import os
2
+ import time
3
+ from datetime import datetime
4
+ from enum import Enum
5
+ from typing import Any, List, Optional, TypedDict
6
+ from urllib.parse import urlencode
7
+
8
+ from rich.console import Console
9
+ from rich.prompt import Confirm
10
+
11
+ from lightning_sdk import Teamspace
12
+ from lightning_sdk.api import UserApi
13
+ from lightning_sdk.cli.teamspace_menu import _TeamspacesMenu
14
+ from lightning_sdk.lightning_cloud import env
15
+ from lightning_sdk.lightning_cloud.login import Auth, AuthServer
16
+ from lightning_sdk.lightning_cloud.openapi import V1CloudSpace
17
+ from lightning_sdk.lightning_cloud.rest_client import LightningClient
18
+ from lightning_sdk.utils.resolve import _get_authed_user, _resolve_teamspace
19
+
20
+ LITSERVE_CODE = os.environ.get("LITSERVE_CODE", "j39bzk903h")
21
+ _POLL_TIMEOUT = 120
22
+
23
+
24
+ class _AuthMode(Enum):
25
+ DEVBOX = "dev"
26
+ DEPLOY = "deploy"
27
+
28
+
29
+ class _AuthServer(AuthServer):
30
+ def __init__(self, mode: _AuthMode, *args: Any, **kwargs: Any) -> None:
31
+ self._mode = mode
32
+ super().__init__(*args, **kwargs)
33
+
34
+ def get_auth_url(self, port: int) -> str:
35
+ redirect_uri = f"http://localhost:{port}/login-complete"
36
+ params = urlencode({"redirectTo": redirect_uri, "mode": self._mode.value, "okbhrt": LITSERVE_CODE})
37
+ return f"{env.LIGHTNING_CLOUD_URL}/sign-in?{params}"
38
+
39
+
40
+ class _Auth(Auth):
41
+ def __init__(self, mode: _AuthMode, shall_confirm: bool = False) -> None:
42
+ super().__init__()
43
+ self._mode = mode
44
+ self._shall_confirm = shall_confirm
45
+
46
+ def _run_server(self) -> None:
47
+ if self._shall_confirm:
48
+ proceed = Confirm.ask(
49
+ "Authenticating with Lightning AI. This will open a browser window. Continue?", default=True
50
+ )
51
+ if not proceed:
52
+ raise RuntimeError(
53
+ "Login cancelled. Please login to Lightning AI to deploy the API. Run `lightning login` to login."
54
+ ) from None
55
+ print("Opening browser for authentication...")
56
+ print("Please come back to the terminal after logging in.")
57
+ time.sleep(3)
58
+ _AuthServer(self._mode).login_with_browser(self)
59
+
60
+
61
+ def authenticate(mode: _AuthMode, shall_confirm: bool = True) -> None:
62
+ auth = _Auth(mode, shall_confirm)
63
+ auth.authenticate()
64
+
65
+
66
+ def select_teamspace(teamspace: Optional[str], org: Optional[str], user: Optional[str]) -> Teamspace:
67
+ if teamspace is None:
68
+ user = _get_authed_user()
69
+ menu = _TeamspacesMenu()
70
+ possible_teamspaces = menu._get_possible_teamspaces(user)
71
+ if len(possible_teamspaces) == 1:
72
+ name = next(iter(possible_teamspaces.values()))["name"]
73
+ return Teamspace(name=name, org=org, user=user)
74
+
75
+ return menu._resolve_teamspace(teamspace)
76
+
77
+ return _resolve_teamspace(teamspace=teamspace, org=org, user=user)
78
+
79
+
80
+ class _UserStatus(TypedDict):
81
+ verified: bool
82
+ onboarded: bool
83
+
84
+
85
+ def poll_verified_status(timeout: int = _POLL_TIMEOUT) -> _UserStatus:
86
+ """Polls the verified status of the user until it is True or a timeout occurs."""
87
+ user_api = UserApi()
88
+ user = _get_authed_user()
89
+ start_time = datetime.now()
90
+ result = {"onboarded": False, "verified": False}
91
+ while True:
92
+ user_resp = user_api.get_user(name=user.name)
93
+ result["onboarded"] = user_resp.status.completed_project_onboarding
94
+ result["verified"] = user_resp.status.verified
95
+ if user_resp.status.verified:
96
+ return result
97
+ if (datetime.now() - start_time).total_seconds() > timeout:
98
+ break
99
+ time.sleep(5)
100
+ return result
101
+
102
+
103
+ class _OnboardingStatus(Enum):
104
+ NOT_VERIFIED = "not_verified"
105
+ ONBOARDING = "onboarding"
106
+ ONBOARDED = "onboarded"
107
+
108
+
109
+ class _Onboarding:
110
+ def __init__(self, console: Console) -> None:
111
+ self.console = console
112
+ self.user = _get_authed_user()
113
+ self.user_api = UserApi()
114
+ self.client = LightningClient(max_tries=7)
115
+
116
+ @property
117
+ def verified(self) -> bool:
118
+ return self.user_api.get_user(name=self.user.name).status.verified
119
+
120
+ @property
121
+ def is_onboarded(self) -> bool:
122
+ return self.user_api.get_user(name=self.user.name).status.completed_project_onboarding
123
+
124
+ @property
125
+ def can_join_org(self) -> bool:
126
+ return len(self.client.organizations_service_list_joinable_organizations().joinable_organizations) > 0
127
+
128
+ @property
129
+ def status(self) -> _OnboardingStatus:
130
+ if not self.verified:
131
+ return _OnboardingStatus.NOT_VERIFIED
132
+ if self.is_onboarded:
133
+ return _OnboardingStatus.ONBOARDED
134
+ return _OnboardingStatus.ONBOARDING
135
+
136
+ def _wait_user_onboarding(self, timeout: int = _POLL_TIMEOUT) -> None:
137
+ """Wait for user onboarding if they can join the teamspace otherwise move to select a teamspace."""
138
+ status = self.status
139
+ if status == _OnboardingStatus.ONBOARDED:
140
+ return
141
+
142
+ self.console.print("Waiting for account setup. Visit lightning.ai")
143
+ start_time = datetime.now()
144
+ while self.status != _OnboardingStatus.ONBOARDED:
145
+ time.sleep(5)
146
+ if self.is_onboarded:
147
+ return
148
+ if (datetime.now() - start_time).total_seconds() > timeout:
149
+ break
150
+
151
+ raise RuntimeError("Timed out waiting for onboarding status")
152
+
153
+ def get_cloudspace_id(self, teamspace: Teamspace) -> Optional[str]:
154
+ cloudspaces: List[V1CloudSpace] = self.client.cloud_space_service_list_cloud_spaces(teamspace.id).cloudspaces
155
+ cloudspaces = sorted(cloudspaces, key=lambda cloudspace: cloudspace.created_at, reverse=True)
156
+ if len(cloudspaces) == 0:
157
+ raise RuntimeError("Error creating deployment! Finish account setup at lightning.ai first.")
158
+ # get the first cloudspace
159
+ cloudspace = cloudspaces[0]
160
+ if "scratch-studio" in cloudspace.name or "scratch-studio" in cloudspace.display_name:
161
+ return cloudspace.id
162
+ return None
163
+
164
+ def select_teamspace(self, teamspace: Optional[str], org: Optional[str], user: Optional[str]) -> Teamspace:
165
+ """Select a teamspace while onboarding.
166
+
167
+ If user is being onboarded and can't join any org, the teamspace it will be resolved to the default
168
+ personal teamspace.
169
+ If user is being onboarded and can join an org then it will select default teamspace from the org.
170
+ """
171
+ if self.is_onboarded:
172
+ return select_teamspace(teamspace, org, user)
173
+
174
+ # Run only when user hasn't completed onboarding yet.
175
+ menu = _TeamspacesMenu()
176
+ self._wait_user_onboarding()
177
+ # Onboarding has been completed - user already selected organization if they could
178
+ possible_teamspaces = menu._get_possible_teamspaces(self.user)
179
+ if len(possible_teamspaces) == 1:
180
+ # User didn't select any org
181
+ value = next(iter(possible_teamspaces.values()))
182
+ return Teamspace(name=value["name"], org=value["org"], user=value["user"])
183
+
184
+ for _, value in possible_teamspaces.items():
185
+ # User select an org
186
+ # Onboarding teamspace will be the default teamspace in the selected org
187
+ if value["org"]:
188
+ return Teamspace(name=value["name"], org=value["org"], user=value["user"])
189
+ raise RuntimeError("Unable to select teamspace. Visit lightning.ai")