lightning-sdk 2025.10.8__py3-none-any.whl → 2025.10.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. lightning_sdk/__init__.py +6 -3
  2. lightning_sdk/api/base_studio_api.py +13 -9
  3. lightning_sdk/api/cloud_account_api.py +0 -2
  4. lightning_sdk/api/license_api.py +26 -59
  5. lightning_sdk/api/studio_api.py +15 -2
  6. lightning_sdk/base_studio.py +30 -17
  7. lightning_sdk/cli/base_studio/list.py +1 -3
  8. lightning_sdk/cli/entrypoint.py +8 -34
  9. lightning_sdk/cli/studio/connect.py +42 -92
  10. lightning_sdk/cli/studio/create.py +23 -1
  11. lightning_sdk/cli/studio/start.py +12 -2
  12. lightning_sdk/cli/utils/get_base_studio.py +24 -0
  13. lightning_sdk/cli/utils/handle_machine_and_gpus_args.py +71 -0
  14. lightning_sdk/cli/utils/logging.py +121 -0
  15. lightning_sdk/cli/utils/ssh_connection.py +1 -1
  16. lightning_sdk/constants.py +1 -0
  17. lightning_sdk/helpers.py +53 -34
  18. lightning_sdk/job/job.py +5 -0
  19. lightning_sdk/job/v1.py +8 -0
  20. lightning_sdk/job/v2.py +8 -0
  21. lightning_sdk/lightning_cloud/login.py +260 -10
  22. lightning_sdk/lightning_cloud/openapi/__init__.py +30 -3
  23. lightning_sdk/lightning_cloud/openapi/api/__init__.py +1 -0
  24. lightning_sdk/lightning_cloud/openapi/api/assistants_service_api.py +19 -19
  25. lightning_sdk/lightning_cloud/openapi/api/auth_service_api.py +97 -0
  26. lightning_sdk/lightning_cloud/openapi/api/cloud_space_service_api.py +105 -0
  27. lightning_sdk/lightning_cloud/openapi/api/k8_s_cluster_service_api.py +1463 -240
  28. lightning_sdk/lightning_cloud/openapi/api/product_license_service_api.py +108 -108
  29. lightning_sdk/lightning_cloud/openapi/api/sdk_command_history_service_api.py +141 -0
  30. lightning_sdk/lightning_cloud/openapi/models/__init__.py +29 -3
  31. lightning_sdk/lightning_cloud/openapi/models/cloudspace_id_visibility_body.py +27 -1
  32. lightning_sdk/lightning_cloud/openapi/models/cluster_id_metrics_body.py +53 -1
  33. lightning_sdk/lightning_cloud/openapi/models/create_machine_request_represents_the_request_to_create_a_machine.py +27 -1
  34. lightning_sdk/lightning_cloud/openapi/models/deployments_id_body.py +27 -1
  35. lightning_sdk/lightning_cloud/openapi/models/externalv1_cloud_space_instance_status.py +79 -1
  36. lightning_sdk/lightning_cloud/openapi/models/id_fork_body1.py +27 -1
  37. lightning_sdk/lightning_cloud/openapi/models/id_transfer_body.py +53 -1
  38. lightning_sdk/lightning_cloud/openapi/models/incident_id_messages_body.py +149 -0
  39. lightning_sdk/lightning_cloud/openapi/models/incidents_id_body.py +279 -0
  40. lightning_sdk/lightning_cloud/openapi/models/license_key_validate_body.py +123 -0
  41. lightning_sdk/lightning_cloud/openapi/models/messages_message_id_body.py +149 -0
  42. lightning_sdk/lightning_cloud/openapi/models/project_id_incidents_body.py +279 -0
  43. lightning_sdk/lightning_cloud/openapi/models/projects_id_body.py +27 -1
  44. lightning_sdk/lightning_cloud/openapi/models/storage_complete_body.py +15 -15
  45. lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space_transfer_metadata.py +53 -1
  46. lightning_sdk/lightning_cloud/openapi/models/v1_create_license_request.py +175 -0
  47. lightning_sdk/lightning_cloud/openapi/models/v1_create_project_request.py +27 -1
  48. lightning_sdk/lightning_cloud/openapi/models/v1_create_sdk_command_history_request.py +253 -0
  49. lightning_sdk/lightning_cloud/openapi/models/v1_create_sdk_command_history_response.py +97 -0
  50. lightning_sdk/lightning_cloud/openapi/models/v1_delete_incident_message_response.py +97 -0
  51. lightning_sdk/lightning_cloud/openapi/models/v1_delete_incident_response.py +97 -0
  52. lightning_sdk/lightning_cloud/openapi/models/v1_delete_license_response.py +97 -0
  53. lightning_sdk/lightning_cloud/openapi/models/v1_deployment.py +27 -1
  54. lightning_sdk/lightning_cloud/openapi/models/v1_external_cluster_spec.py +27 -1
  55. lightning_sdk/lightning_cloud/openapi/models/v1_get_cloud_space_transfer_estimate_response.py +149 -0
  56. lightning_sdk/lightning_cloud/openapi/models/v1_group_pod_metrics.py +1241 -0
  57. lightning_sdk/lightning_cloud/openapi/models/v1_incident.py +565 -0
  58. lightning_sdk/lightning_cloud/openapi/models/v1_incident_detail.py +149 -0
  59. lightning_sdk/lightning_cloud/openapi/models/v1_incident_event.py +27 -1
  60. lightning_sdk/lightning_cloud/openapi/models/v1_incident_message.py +253 -0
  61. lightning_sdk/lightning_cloud/openapi/models/v1_incident_type.py +1 -0
  62. lightning_sdk/lightning_cloud/openapi/models/v1_job.py +53 -1
  63. lightning_sdk/lightning_cloud/openapi/models/v1_job_spec.py +27 -1
  64. lightning_sdk/lightning_cloud/openapi/models/v1_kai_scheduler_queue_metrics.py +627 -0
  65. lightning_sdk/lightning_cloud/openapi/models/v1_license.py +227 -0
  66. lightning_sdk/lightning_cloud/openapi/models/v1_list_group_pod_metrics_response.py +123 -0
  67. lightning_sdk/lightning_cloud/openapi/models/v1_list_incident_messages_response.py +149 -0
  68. lightning_sdk/lightning_cloud/openapi/models/v1_list_incidents_response.py +149 -0
  69. lightning_sdk/lightning_cloud/openapi/models/v1_list_kai_scheduler_queues_metrics_response.py +123 -0
  70. lightning_sdk/lightning_cloud/openapi/models/{v1_list_product_licenses_response.py → v1_list_license_response.py} +16 -16
  71. lightning_sdk/lightning_cloud/openapi/models/v1_machine.py +79 -1
  72. lightning_sdk/lightning_cloud/openapi/models/v1_membership.py +27 -1
  73. lightning_sdk/lightning_cloud/openapi/models/v1_project_membership.py +27 -1
  74. lightning_sdk/lightning_cloud/openapi/models/v1_project_settings.py +27 -1
  75. lightning_sdk/lightning_cloud/openapi/models/v1_resource_visibility.py +1 -27
  76. lightning_sdk/lightning_cloud/openapi/models/v1_sdk_command_history_severity.py +104 -0
  77. lightning_sdk/lightning_cloud/openapi/models/v1_sdk_command_history_type.py +104 -0
  78. lightning_sdk/lightning_cloud/openapi/models/v1_server_alert_type.py +1 -0
  79. lightning_sdk/lightning_cloud/openapi/models/v1_slack_notifier.py +53 -1
  80. lightning_sdk/lightning_cloud/openapi/models/v1_token_login_request.py +123 -0
  81. lightning_sdk/lightning_cloud/openapi/models/v1_token_login_response.py +123 -0
  82. lightning_sdk/lightning_cloud/openapi/models/v1_token_owner_type.py +104 -0
  83. lightning_sdk/lightning_cloud/openapi/models/v1_user_features.py +209 -131
  84. lightning_sdk/lightning_cloud/openapi/models/{v1_product_license_check_response.py → v1_validate_license_response.py} +21 -21
  85. lightning_sdk/lightning_cloud/rest_client.py +48 -45
  86. lightning_sdk/machine.py +2 -1
  87. lightning_sdk/studio.py +22 -2
  88. lightning_sdk/utils/license.py +13 -0
  89. {lightning_sdk-2025.10.8.dist-info → lightning_sdk-2025.10.22.dist-info}/METADATA +1 -1
  90. {lightning_sdk-2025.10.8.dist-info → lightning_sdk-2025.10.22.dist-info}/RECORD +94 -64
  91. lightning_sdk/lightning_cloud/openapi/models/v1_product_license.py +0 -435
  92. lightning_sdk/services/license.py +0 -363
  93. {lightning_sdk-2025.10.8.dist-info → lightning_sdk-2025.10.22.dist-info}/LICENSE +0 -0
  94. {lightning_sdk-2025.10.8.dist-info → lightning_sdk-2025.10.22.dist-info}/WHEEL +0 -0
  95. {lightning_sdk-2025.10.8.dist-info → lightning_sdk-2025.10.22.dist-info}/entry_points.txt +0 -0
  96. {lightning_sdk-2025.10.8.dist-info → lightning_sdk-2025.10.22.dist-info}/top_level.txt +0 -0
lightning_sdk/__init__.py CHANGED
@@ -2,7 +2,7 @@ from lightning_sdk.agents import Agent
2
2
  from lightning_sdk.ai_hub import AIHub
3
3
  from lightning_sdk.constants import __GLOBAL_LIGHTNING_UNIQUE_IDS_STORE__ # noqa: F401
4
4
  from lightning_sdk.deployment import Deployment
5
- from lightning_sdk.helpers import _check_version_and_prompt_upgrade, _set_tqdm_envvars_noninteractive
5
+ from lightning_sdk.helpers import VersionChecker, _set_tqdm_envvars_noninteractive
6
6
  from lightning_sdk.job import Job
7
7
  from lightning_sdk.machine import CloudProvider, Machine
8
8
  from lightning_sdk.mmt import MMT
@@ -35,6 +35,9 @@ __all__ = [
35
35
  "VM",
36
36
  ]
37
37
 
38
- __version__ = "2025.10.08"
39
- _check_version_and_prompt_upgrade(__version__)
38
+ __version__ = "2025.10.22"
39
+
40
+ _version_checker = VersionChecker()
41
+ _version_checker.check_and_prompt_upgrade(__version__)
42
+
40
43
  _set_tqdm_envvars_noninteractive()
@@ -15,24 +15,28 @@ class BaseStudioApi:
15
15
  def __init__(self) -> None:
16
16
  self._client = LightningClient(retry=False, max_tries=0)
17
17
 
18
- def get_base_studio(self, base_studio_id: str, org_id: str) -> V1CloudSpaceEnvironmentTemplate:
18
+ def get_base_studio(self, base_studio_id: str, org_id: Optional[str] = None) -> V1CloudSpaceEnvironmentTemplate:
19
19
  """Retrieve the base studio by its ID."""
20
20
  try:
21
21
  return self._client.cloud_space_environment_template_service_get_cloud_space_environment_template(
22
- base_studio_id, org_id=org_id
22
+ base_studio_id, org_id=org_id or ""
23
23
  )
24
24
  except ValueError as e:
25
25
  raise ValueError(f"Base studio {base_studio_id} does not exist") from e
26
26
 
27
- def get_all_base_studios(self, org_id: str, managed: bool = True) -> V1ListCloudSpaceEnvironmentTemplatesResponse:
27
+ def get_all_base_studios(self, org_id: Optional[str]) -> V1ListCloudSpaceEnvironmentTemplatesResponse:
28
28
  """Retrieve all base studios for a given organization."""
29
- if managed:
30
- return self._client.cloud_space_environment_template_service_list_managed_cloud_space_environment_templates(
31
- org_id=org_id
32
- )
33
- return self._client.cloud_space_environment_template_service_list_cloud_space_environment_templates(
34
- org_id=org_id
29
+ result = self._client.cloud_space_environment_template_service_list_managed_cloud_space_environment_templates(
30
+ org_id=org_id or ""
35
31
  )
32
+ if org_id is not None:
33
+ org_templates = (
34
+ self._client.cloud_space_environment_template_service_list_cloud_space_environment_templates(
35
+ org_id=org_id
36
+ )
37
+ )
38
+ result.templates = result.templates + org_templates.templates
39
+ return result
36
40
 
37
41
  def update_base_studio(
38
42
  self,
@@ -173,8 +173,6 @@ class CloudAccountApi:
173
173
  return CloudProvider.GCP
174
174
  if cloud_account.spec.lambda_labs_v1:
175
175
  return CloudProvider.LAMBDA_LABS
176
- if cloud_account.spec.vultr_v1:
177
- return CloudProvider.VULTR
178
176
  if cloud_account.spec.voltage_park_v1:
179
177
  return CloudProvider.VOLTAGE_PARK
180
178
  if cloud_account.spec.nebius_v1:
@@ -1,70 +1,37 @@
1
- import os
2
- from typing import Optional
3
- from urllib.parse import urlencode
4
-
5
- from lightning_sdk.lightning_cloud import env
6
- from lightning_sdk.lightning_cloud.rest_client import LightningClient
7
-
8
- LICENSE_CODE = os.environ.get("LICENSE_CODE", "d9s79g79ss")
9
- # https://lightning.ai/home?settings=licenses
10
- LICENSE_SIGNING_URL = f"{env.LIGHTNING_CLOUD_URL}?settings=licenses"
11
-
12
-
13
- def generate_url_user_settings(name: str, redirect_to: str = LICENSE_SIGNING_URL) -> str:
14
- params = urlencode({"redirectTo": redirect_to, "okbhrt": LICENSE_CODE, "licenseName": name})
15
- return f"{env.LIGHTNING_CLOUD_URL}/sign-in?{params}"
1
+ from lightning_sdk.api.utils import _get_cloud_url as _cloud_url
2
+ from lightning_sdk.lightning_cloud.login import Auth
3
+ from lightning_sdk.lightning_cloud.openapi import LicenseKeyValidateBody, ProductLicenseServiceApi
16
4
 
17
5
 
18
6
  class LicenseApi:
19
- _client_authenticated: LightningClient = None
20
- _client_public: LightningClient = None
21
-
22
- @property
23
- def client_public(self) -> LightningClient:
24
- if not self._client_public:
25
- self._client_public = LightningClient(retry=False, max_tries=0, with_auth=False)
26
- return self._client_public
27
-
28
- @property
29
- def client_authenticated(self) -> LightningClient:
30
- if not self._client_authenticated:
31
- self._client_authenticated = LightningClient(retry=True, max_tries=3, with_auth=True)
32
- return self._client_authenticated
7
+ def __init__(self, login_token: str) -> None:
8
+ self._cloud_url = _cloud_url()
9
+ self._auth = Auth()
10
+ self._auth.token_login(login_token, save_token=True)
11
+ self._client = self._auth.create_api_client()
12
+ self._api = ProductLicenseServiceApi(self._client)
33
13
 
34
- def valid_license(
35
- self,
36
- license_key: str,
37
- product_name: str,
38
- product_version: Optional[str] = None,
39
- product_type: str = "package",
40
- ) -> bool:
41
- """Check if the license key is valid.
14
+ def validate_license(self, license_key: str, product_id: str) -> bool:
15
+ """Validate a license key for a specific product.
42
16
 
43
17
  Args:
44
- license_key: The license key to check.
45
- product_name: The name of the product.
46
- product_version: The version of the product.
47
- product_type: The type of the product. Default is "package".
18
+ license_key: The license key to validate
19
+ product_id: The product ID
48
20
 
49
21
  Returns:
50
- True if the license key is valid, False otherwise.
51
- """
52
- response = self.client_public.product_license_service_validate_product_license(
53
- license_key=license_key,
54
- product_name=product_name,
55
- product_version=product_version,
56
- product_type=product_type,
57
- )
58
- return response.valid
22
+ bool: True if license is valid, False otherwise
59
23
 
60
- def list_user_licenses(self, user_id: str) -> list:
61
- """List all licenses for a user.
24
+ Raises:
25
+ Exception: If license validation fails
26
+ """
27
+ try:
28
+ response = self._api.product_license_service_validate_license(
29
+ body=LicenseKeyValidateBody(product_id=product_id), license_key=license_key
30
+ )
31
+ return response.is_valid
32
+ except Exception:
33
+ raise InvalidLicenseError(f"Invalid license key {license_key} for product {product_id}") from None
62
34
 
63
- Args:
64
- user_id: The ID of the user.
65
35
 
66
- Returns:
67
- A list of licenses for the user.
68
- """
69
- response = self.client_authenticated.product_license_service_list_user_licenses(user_id=user_id)
70
- return response.licenses
36
+ class InvalidLicenseError(Exception):
37
+ pass
@@ -438,6 +438,14 @@ class StudioApi:
438
438
 
439
439
  return response.compute_config.spot
440
440
 
441
+ def get_public_ip(self, studio_id: str, teamspace_id: str) -> Optional[str]:
442
+ """Get the public IP address of the Studio."""
443
+ internal_status = self.get_studio_status(studio_id=studio_id, teamspace_id=teamspace_id).in_use
444
+ if internal_status is None:
445
+ return None
446
+
447
+ return internal_status.public_ip_address
448
+
441
449
  def _get_machines_for_cloud_account(
442
450
  self, teamspace_id: str, cloud_account_id: str, org_id: str
443
451
  ) -> List[V1ClusterAccelerator]:
@@ -581,7 +589,12 @@ class StudioApi:
581
589
  )
582
590
 
583
591
  def duplicate_studio(
584
- self, studio_id: str, teamspace_id: str, target_teamspace_id: str, machine: Machine = Machine.CPU
592
+ self,
593
+ studio_id: str,
594
+ teamspace_id: str,
595
+ target_teamspace_id: str,
596
+ machine: Machine = Machine.CPU,
597
+ new_name: Optional[str] = None,
585
598
  ) -> Dict[str, Any]:
586
599
  """Duplicates the given Studio from a given Teamspace into a given target Teamspace."""
587
600
  target_teamspace = self._client.projects_service_get_project(target_teamspace_id)
@@ -596,7 +609,7 @@ class StudioApi:
596
609
  init_kwargs["org"] = OrgApi()._get_org_by_id(target_teamspace.owner_id).name
597
610
 
598
611
  new_cloudspace = self._client.cloud_space_service_fork_cloud_space(
599
- IdForkBody1(target_project_id=target_teamspace_id), project_id=teamspace_id, id=studio_id
612
+ IdForkBody1(target_project_id=target_teamspace_id, new_name=new_name), project_id=teamspace_id, id=studio_id
600
613
  )
601
614
 
602
615
  while self.get_studio_by_id(new_cloudspace.id, target_teamspace_id).state != V1CloudSpaceState.READY:
@@ -3,11 +3,11 @@ from typing import List, Optional, Union
3
3
 
4
4
  from lightning_sdk.api.base_studio_api import BaseStudioApi
5
5
  from lightning_sdk.api.user_api import UserApi
6
- from lightning_sdk.lightning_cloud import login
7
6
  from lightning_sdk.lightning_cloud.openapi.models.v1_cloud_space_environment_type import V1CloudSpaceEnvironmentType
8
7
  from lightning_sdk.organization import Organization
8
+ from lightning_sdk.teamspace import Teamspace
9
9
  from lightning_sdk.user import User
10
- from lightning_sdk.utils.resolve import _resolve_org, _resolve_user
10
+ from lightning_sdk.utils.resolve import _resolve_teamspace
11
11
 
12
12
 
13
13
  @dataclass
@@ -24,6 +24,7 @@ class BaseStudio:
24
24
  def __init__(
25
25
  self,
26
26
  name: Optional[str] = None,
27
+ teamspace: Optional[Union[str, Teamspace]] = None,
27
28
  org: Optional[Union[str, Organization]] = None,
28
29
  user: Optional[Union[str, User]] = None,
29
30
  ) -> None:
@@ -38,26 +39,35 @@ class BaseStudio:
38
39
  Raises:
39
40
  ConnectionError: If there is an issue with the authentication process.
40
41
  """
41
- self._auth = login.Auth()
42
- self._user = None
42
+ self._teamspace = None
43
43
 
44
- try:
45
- self._auth.authenticate()
46
- if user is None:
47
- self._user = User(name=UserApi()._get_user_by_id(self._auth.user_id).username)
48
- except ConnectionError as e:
49
- raise e
44
+ _teamspace = _resolve_teamspace(teamspace=teamspace, org=org, user=user)
45
+ if _teamspace is None:
46
+ raise ValueError("Couldn't resolve teamspace from the provided name, org, or user")
50
47
 
51
- self._user = _resolve_user(self._user or user)
52
- self._org = _resolve_org(org)
48
+ self._teamspace = _teamspace
49
+
50
+ # self._auth = login.Auth()
51
+ # self._user = None
52
+
53
+ # try:
54
+ # self._auth.authenticate()
55
+ # if user is None:
56
+ # self._user = User(name=UserApi()._get_user_by_id(self._auth.user_id).username)
57
+ # except ConnectionError as e:
58
+ # raise e
59
+
60
+ # self._user = _resolve_user(self._user or user)
61
+ # self._org = _resolve_org(org)
53
62
 
54
63
  self._base_studio_api = BaseStudioApi()
55
64
 
56
65
  if name is not None:
57
- base_studio = self._base_studio_api.get_base_studio(name, self._org.id)
66
+ org_id = self._teamspace._org.id if self._teamspace._org is not None else None
67
+ base_studio = self._base_studio_api.get_base_studio(name, org_id)
58
68
 
59
69
  if base_studio is None:
60
- raise ValueError(f"Base studio with name {name} does not exist in organization {self._org.name}")
70
+ raise ValueError(f"Base studio with name {name} does not exist")
61
71
  self._base_studio = base_studio
62
72
 
63
73
  def update(
@@ -70,9 +80,11 @@ class BaseStudio:
70
80
  machine_image_version: Optional[str] = None,
71
81
  setup_script_text: Optional[str] = None,
72
82
  ) -> None:
83
+ org_id = self._teamspace._org.id if self._teamspace._org is not None else None
84
+ # TODO: if not in an org, can't update them
73
85
  self._base_studio = self._base_studio_api.update_base_studio(
74
86
  self._base_studio.id,
75
- self._org.id,
87
+ org_id,
76
88
  name=name,
77
89
  allowed_machines=allowed_machines,
78
90
  default_machine=default_machine,
@@ -82,7 +94,7 @@ class BaseStudio:
82
94
  disabled=disabled,
83
95
  )
84
96
 
85
- def list(self, managed: bool = True, include_disabled: bool = False) -> List[BaseStudioInfo]:
97
+ def list(self, include_disabled: bool = False) -> List[BaseStudioInfo]:
86
98
  """List all base studios in the organization.
87
99
 
88
100
  Args:
@@ -92,7 +104,8 @@ class BaseStudio:
92
104
  Returns:
93
105
  List[BaseStudioInfo]: A list of base studio templates.
94
106
  """
95
- templates = self._base_studio_api.get_all_base_studios(self._org.id, managed).templates
107
+ org_id = self._teamspace._org.id if self._teamspace._org is not None else None
108
+ templates = self._base_studio_api.get_all_base_studios(org_id).templates
96
109
 
97
110
  return [
98
111
  BaseStudioInfo(
@@ -21,9 +21,7 @@ def list_base_studios(include_disabled: bool) -> None:
21
21
 
22
22
  def list_impl(include_disabled: bool) -> None:
23
23
  base_studio_cls = BaseStudio()
24
- base_studios = base_studio_cls.list(include_disabled=include_disabled) + base_studio_cls.list(
25
- managed=False, include_disabled=include_disabled
26
- )
24
+ base_studios = base_studio_cls.list(include_disabled=include_disabled)
27
25
 
28
26
  table = Table(
29
27
  pad_edge=True,
@@ -2,15 +2,8 @@
2
2
 
3
3
  import os
4
4
  import sys
5
- import traceback
6
- from types import TracebackType
7
- from typing import Type
8
5
 
9
6
  import click
10
- from rich.console import Group
11
- from rich.panel import Panel
12
- from rich.syntax import Syntax
13
- from rich.text import Text
14
7
 
15
8
  from lightning_sdk import __version__
16
9
  from lightning_sdk.api.studio_api import _cloud_url
@@ -24,38 +17,19 @@ from lightning_sdk.cli.groups import (
24
17
  studio,
25
18
  vm,
26
19
  )
27
- from lightning_sdk.cli.utils import CustomHelpFormatter, rich_to_str
28
- from lightning_sdk.constants import _LIGHTNING_DEBUG
20
+ from lightning_sdk.cli.utils import CustomHelpFormatter
21
+ from lightning_sdk.cli.utils.logging import CommandLoggingGroup, logging_excepthook
29
22
  from lightning_sdk.lightning_cloud.login import Auth
30
23
 
31
24
 
32
- def _notify_exception(exception_type: Type[BaseException], value: BaseException, tb: TracebackType) -> None:
33
- """CLI won't show tracebacks, just print the exception message."""
34
- message = str(value.args[0]) if value.args else str(value) or "An unknown error occurred"
35
-
36
- error_text = Text()
37
- error_text.append(f"{exception_type.__name__}: ", style="bold red")
38
- error_text.append(message, style="white")
39
-
40
- renderables = [error_text]
41
-
42
- if _LIGHTNING_DEBUG:
43
- tb_text = "".join(traceback.format_exception(exception_type, value, tb))
44
- renderables.append(Text("\n\nFull traceback:\n", style="bold yellow"))
45
- renderables.append(Syntax(tb_text, "python", theme="monokai light", line_numbers=False, word_wrap=True))
46
- else:
47
- renderables.append(Text("\n\n🐞 To view the full traceback, set: LIGHTNING_DEBUG=1"))
48
-
49
- renderables.append(Text("\n📘 Need help? Run: lightning <command> --help", style="cyan"))
50
-
51
- text = rich_to_str(Panel(Group(*renderables), title="⚡ Lightning CLI Error", border_style="red"))
52
- click.echo(text, color=True)
53
-
54
-
55
- @click.group(name="lightning", help="Command line interface (CLI) to interact with/manage Lightning AI Studios.")
25
+ @click.group(
26
+ name="lightning",
27
+ help="Command line interface (CLI) to interact with/manage Lightning AI Studios.",
28
+ cls=CommandLoggingGroup,
29
+ )
56
30
  @click.version_option(__version__, message="Lightning CLI version %(version)s")
57
31
  def main_cli() -> None:
58
- sys.excepthook = _notify_exception
32
+ sys.excepthook = logging_excepthook
59
33
 
60
34
 
61
35
  main_cli.context_class.formatter_class = CustomHelpFormatter
@@ -2,11 +2,13 @@
2
2
 
3
3
  import subprocess
4
4
  import sys
5
- from typing import Dict, Optional, Set
5
+ from contextlib import suppress
6
+ from typing import Optional
6
7
 
7
8
  import click
8
9
 
9
- from lightning_sdk.base_studio import BaseStudio
10
+ from lightning_sdk.cli.utils.get_base_studio import get_base_studio_id
11
+ from lightning_sdk.cli.utils.handle_machine_and_gpus_args import handle_machine_and_gpus_args
10
12
  from lightning_sdk.cli.utils.richt_print import studio_name_link
11
13
  from lightning_sdk.cli.utils.save_to_config import save_studio_to_config, save_teamspace_to_config
12
14
  from lightning_sdk.cli.utils.ssh_connection import configure_ssh_internal
@@ -16,81 +18,42 @@ from lightning_sdk.machine import CloudProvider, Machine
16
18
  from lightning_sdk.studio import Studio
17
19
  from lightning_sdk.utils.names import random_unique_name
18
20
 
19
- DEFAULT_MACHINE = "CPU"
20
21
 
22
+ def _parse_args_or_get_from_current_studio(
23
+ teamspace: Optional[str],
24
+ cloud_account: Optional[str],
25
+ studio_type: Optional[str],
26
+ machine: Optional[str],
27
+ gpus: Optional[str],
28
+ cloud_provider: Optional[str],
29
+ name: Optional[str],
30
+ ) -> tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
31
+ # Parse args provided by user
32
+ menu = TeamspacesMenu()
33
+ resolved_teamspace = menu(teamspace)
34
+ save_teamspace_to_config(resolved_teamspace, overwrite=False)
21
35
 
22
- def _split_gpus_spec(gpus: str) -> tuple[str, int]:
23
- machine_name, machine_val = gpus.split(":", 1)
24
- machine_name = machine_name.strip()
25
- machine_val = machine_val.strip()
26
-
27
- if not machine_val.isdigit() or int(machine_val) <= 0:
28
- raise ValueError(f"Invalid GPU count '{machine_val}'. Must be a positive integer.")
29
-
30
- machine_num = int(machine_val)
31
- return machine_name, machine_num
32
-
33
-
34
- def _construct_available_gpus(machine_options: Dict[str, str]) -> Set[str]:
35
- # returns available gpus:count
36
- available_gpus = set()
37
- for v in machine_options.values():
38
- if "_X_" in v:
39
- gpu_type_num = v.replace("_X_", ":")
40
- available_gpus.add(gpu_type_num)
41
- else:
42
- available_gpus.add(v)
43
- return available_gpus
44
-
45
-
46
- def _get_machine_from_gpus(gpus: str) -> Machine:
47
- machine_name = gpus
48
- machine_num = 1
49
-
50
- if ":" in gpus:
51
- machine_name, machine_num = _split_gpus_spec(gpus)
52
-
53
- machine_options = {
54
- m.name.lower(): m.name for m in Machine.__dict__.values() if isinstance(m, Machine) and m._include_in_cli
55
- }
56
-
57
- if machine_num == 1:
58
- # e.g. gpus=L4 or gpus=L4:1
59
- gpu_key = machine_name.lower()
60
- try:
61
- return machine_options[gpu_key]
62
- except KeyError:
63
- available = ", ".join(_construct_available_gpus(machine_options))
64
- raise ValueError(f"Invalid GPU type '{machine_name}'. Available options: {available}") from None
65
-
66
- # Else: e.g. gpus=L4:4
67
- gpu_key = f"{machine_name.lower()}_x_{machine_num}"
68
- try:
69
- return machine_options[gpu_key]
70
- except KeyError:
71
- available = ", ".join(_construct_available_gpus(machine_options))
72
- raise ValueError(f"Invalid GPU configuration '{gpus}'. Available options: {available}") from None
36
+ template_id = get_base_studio_id(studio_type)
73
37
 
38
+ if cloud_provider is not None:
39
+ cloud_provider = CloudProvider(cloud_provider)
74
40
 
75
- def _get_base_studio_id(studio_type: Optional[str]) -> Optional[str]:
76
- base_studios = BaseStudio()
77
- base_studios = base_studios.list()
78
- template_id = None
41
+ name = name or random_unique_name()
79
42
 
80
- if base_studios and len(base_studios):
81
- # if not specified by user, use the first existing template studio
82
- template_id = base_studios[0].id
83
- # else, try to match the provided studio_type to base studio name
84
- if studio_type:
85
- normalized_studio_type = studio_type.lower().replace(" ", "-")
86
- match = next(
87
- (s for s in base_studios if s.name.lower().replace(" ", "-") == normalized_studio_type),
88
- None,
89
- )
90
- if match:
91
- template_id = match.id
43
+ with suppress(ValueError):
44
+ # Gets current studio context to use its parameters as defaults
45
+ s = Studio()
46
+ if not teamspace:
47
+ resolved_teamspace = s.teamspace
48
+ save_teamspace_to_config(resolved_teamspace, overwrite=False)
49
+ if not cloud_account:
50
+ cloud_account = s.cloud_account
51
+ if not template_id:
52
+ template_id = s._studio.environment_template_id
53
+ if not machine and not gpus:
54
+ machine = s.machine
92
55
 
93
- return template_id
56
+ return resolved_teamspace, cloud_account, template_id, machine, cloud_provider, name
94
57
 
95
58
 
96
59
  @click.command("connect")
@@ -124,6 +87,7 @@ def _get_base_studio_id(studio_type: Optional[str]) -> Optional[str]:
124
87
  "Defaults to the first available template.",
125
88
  type=click.STRING,
126
89
  )
90
+ @click.option("--interruptible", is_flag=True, help="Start the studio on an interruptible instance.")
127
91
  def connect_studio(
128
92
  name: Optional[str] = None,
129
93
  teamspace: Optional[str] = None,
@@ -132,29 +96,21 @@ def connect_studio(
132
96
  machine: Optional[str] = None,
133
97
  gpus: Optional[str] = None,
134
98
  studio_type: Optional[str] = None,
99
+ interruptible: bool = False,
135
100
  ) -> None:
136
101
  """Connect to a Studio.
137
102
 
138
103
  Example:
139
104
  lightning studio connect
140
105
  """
141
- menu = TeamspacesMenu()
142
-
143
- resolved_teamspace = menu(teamspace)
144
- save_teamspace_to_config(resolved_teamspace, overwrite=False)
145
-
146
- if cloud_provider is not None:
147
- cloud_provider = CloudProvider(cloud_provider)
148
-
149
- name = name or random_unique_name()
150
-
151
- # check for available base studios
152
- template_id = _get_base_studio_id(studio_type)
106
+ teamspace, cloud_account, template_id, machine, cloud_provider, name = _parse_args_or_get_from_current_studio(
107
+ teamspace, cloud_account, studio_type, machine, gpus, cloud_provider, name
108
+ )
153
109
 
154
110
  try:
155
111
  studio = Studio(
156
112
  name=name,
157
- teamspace=resolved_teamspace,
113
+ teamspace=teamspace,
158
114
  create_ok=True,
159
115
  cloud_provider=cloud_provider,
160
116
  cloud_account=cloud_account,
@@ -167,16 +123,10 @@ def connect_studio(
167
123
 
168
124
  Studio.show_progress = True
169
125
 
170
- if machine and gpus:
171
- raise click.UsageError("Options --machine and --gpu are mutually exclusive. Provide only one.")
172
- elif gpus:
173
- machine = _get_machine_from_gpus(gpus.strip())
174
- elif not machine:
175
- machine = DEFAULT_MACHINE
126
+ machine = handle_machine_and_gpus_args(machine, gpus)
176
127
 
177
128
  save_studio_to_config(studio)
178
- # by default, interruptible is False
179
- studio.start(machine=machine, interruptible=False)
129
+ studio.start(machine=machine, interruptible=interruptible)
180
130
 
181
131
  ssh_private_key_path = configure_ssh_internal()
182
132
 
@@ -4,6 +4,7 @@ from typing import Optional
4
4
 
5
5
  import click
6
6
 
7
+ from lightning_sdk.cli.utils.get_base_studio import get_base_studio_id
7
8
  from lightning_sdk.cli.utils.richt_print import studio_name_link
8
9
  from lightning_sdk.cli.utils.save_to_config import save_teamspace_to_config
9
10
  from lightning_sdk.cli.utils.teamspace_selection import TeamspacesMenu
@@ -25,18 +26,34 @@ from lightning_sdk.studio import VM, Studio
25
26
  help="The cloud account to create the studio on. Defaults to teamspace default.",
26
27
  type=click.STRING,
27
28
  )
29
+ @click.option(
30
+ "--studio-type",
31
+ help="The base studio template name to use for creating the studio. "
32
+ "Must be lowercase and hyphenated (use '-' instead of spaces). "
33
+ "Run 'lightning base-studio list' to see all available templates. "
34
+ "Defaults to the first available template.",
35
+ type=click.STRING,
36
+ )
28
37
  def create_studio(
29
38
  name: Optional[str] = None,
30
39
  teamspace: Optional[str] = None,
31
40
  cloud_provider: Optional[str] = None,
32
41
  cloud_account: Optional[str] = None,
42
+ studio_type: Optional[str] = None,
33
43
  ) -> None:
34
44
  """Create a new Studio.
35
45
 
36
46
  Example:
37
47
  lightning studio create
38
48
  """
39
- create_impl(name=name, teamspace=teamspace, cloud_provider=cloud_provider, cloud_account=cloud_account, vm=False)
49
+ create_impl(
50
+ name=name,
51
+ teamspace=teamspace,
52
+ cloud_provider=cloud_provider,
53
+ cloud_account=cloud_account,
54
+ vm=False,
55
+ studio_type=studio_type,
56
+ )
40
57
 
41
58
 
42
59
  def create_impl(
@@ -45,6 +62,7 @@ def create_impl(
45
62
  cloud_provider: Optional[str],
46
63
  cloud_account: Optional[str],
47
64
  vm: bool,
65
+ studio_type: Optional[str],
48
66
  ) -> None:
49
67
  menu = TeamspacesMenu()
50
68
 
@@ -57,6 +75,9 @@ def create_impl(
57
75
  create_cls = VM if vm else Studio
58
76
  cls_name = create_cls.__qualname__
59
77
 
78
+ # check for available base studios
79
+ template_id = get_base_studio_id(studio_type)
80
+
60
81
  try:
61
82
  create_cls = VM if vm else Studio
62
83
  studio = create_cls(
@@ -65,6 +86,7 @@ def create_impl(
65
86
  create_ok=True,
66
87
  cloud_provider=cloud_provider,
67
88
  cloud_account=cloud_account,
89
+ template_id=template_id,
68
90
  )
69
91
  except (RuntimeError, ValueError, ApiException):
70
92
  if name: