lightning-sdk 2025.10.14__py3-none-any.whl → 2025.10.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. lightning_sdk/__init__.py +6 -3
  2. lightning_sdk/api/base_studio_api.py +13 -9
  3. lightning_sdk/api/license_api.py +26 -59
  4. lightning_sdk/api/studio_api.py +7 -2
  5. lightning_sdk/base_studio.py +30 -17
  6. lightning_sdk/cli/base_studio/list.py +1 -3
  7. lightning_sdk/cli/entrypoint.py +8 -34
  8. lightning_sdk/cli/studio/connect.py +42 -92
  9. lightning_sdk/cli/studio/create.py +23 -1
  10. lightning_sdk/cli/studio/start.py +12 -2
  11. lightning_sdk/cli/utils/get_base_studio.py +24 -0
  12. lightning_sdk/cli/utils/handle_machine_and_gpus_args.py +71 -0
  13. lightning_sdk/cli/utils/logging.py +121 -0
  14. lightning_sdk/cli/utils/ssh_connection.py +1 -1
  15. lightning_sdk/constants.py +1 -0
  16. lightning_sdk/helpers.py +53 -34
  17. lightning_sdk/lightning_cloud/login.py +260 -10
  18. lightning_sdk/lightning_cloud/openapi/__init__.py +10 -3
  19. lightning_sdk/lightning_cloud/openapi/api/auth_service_api.py +97 -0
  20. lightning_sdk/lightning_cloud/openapi/api/product_license_service_api.py +108 -108
  21. lightning_sdk/lightning_cloud/openapi/models/__init__.py +10 -3
  22. lightning_sdk/lightning_cloud/openapi/models/create_machine_request_represents_the_request_to_create_a_machine.py +27 -1
  23. lightning_sdk/lightning_cloud/openapi/models/externalv1_cloud_space_instance_status.py +27 -1
  24. lightning_sdk/lightning_cloud/openapi/models/id_fork_body1.py +27 -1
  25. lightning_sdk/lightning_cloud/openapi/models/license_key_validate_body.py +123 -0
  26. lightning_sdk/lightning_cloud/openapi/models/v1_create_license_request.py +175 -0
  27. lightning_sdk/lightning_cloud/openapi/models/v1_delete_license_response.py +97 -0
  28. lightning_sdk/lightning_cloud/openapi/models/v1_external_cluster_spec.py +27 -1
  29. lightning_sdk/lightning_cloud/openapi/models/v1_get_cloud_space_transfer_estimate_response.py +29 -3
  30. lightning_sdk/lightning_cloud/openapi/models/v1_incident.py +27 -1
  31. lightning_sdk/lightning_cloud/openapi/models/v1_incident_detail.py +149 -0
  32. lightning_sdk/lightning_cloud/openapi/models/v1_incident_event.py +27 -1
  33. lightning_sdk/lightning_cloud/openapi/models/v1_license.py +227 -0
  34. lightning_sdk/lightning_cloud/openapi/models/{v1_list_product_licenses_response.py → v1_list_license_response.py} +16 -16
  35. lightning_sdk/lightning_cloud/openapi/models/v1_machine.py +27 -1
  36. lightning_sdk/lightning_cloud/openapi/models/v1_slack_notifier.py +53 -1
  37. lightning_sdk/lightning_cloud/openapi/models/v1_token_login_request.py +123 -0
  38. lightning_sdk/lightning_cloud/openapi/models/v1_token_login_response.py +123 -0
  39. lightning_sdk/lightning_cloud/openapi/models/v1_token_owner_type.py +104 -0
  40. lightning_sdk/lightning_cloud/openapi/models/v1_user_features.py +53 -79
  41. lightning_sdk/lightning_cloud/openapi/models/{v1_product_license_check_response.py → v1_validate_license_response.py} +21 -21
  42. lightning_sdk/lightning_cloud/rest_client.py +48 -45
  43. lightning_sdk/machine.py +2 -0
  44. lightning_sdk/studio.py +14 -2
  45. lightning_sdk/utils/license.py +13 -0
  46. {lightning_sdk-2025.10.14.dist-info → lightning_sdk-2025.10.22.dist-info}/METADATA +1 -1
  47. {lightning_sdk-2025.10.14.dist-info → lightning_sdk-2025.10.22.dist-info}/RECORD +51 -41
  48. lightning_sdk/lightning_cloud/openapi/models/v1_product_license.py +0 -435
  49. lightning_sdk/services/license.py +0 -363
  50. {lightning_sdk-2025.10.14.dist-info → lightning_sdk-2025.10.22.dist-info}/LICENSE +0 -0
  51. {lightning_sdk-2025.10.14.dist-info → lightning_sdk-2025.10.22.dist-info}/WHEEL +0 -0
  52. {lightning_sdk-2025.10.14.dist-info → lightning_sdk-2025.10.22.dist-info}/entry_points.txt +0 -0
  53. {lightning_sdk-2025.10.14.dist-info → lightning_sdk-2025.10.22.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,24 @@
1
+ from typing import Optional
2
+
3
+ from lightning_sdk.base_studio import BaseStudio
4
+
5
+
6
+ def get_base_studio_id(studio_type: Optional[str]) -> Optional[str]:
7
+ base_studios = BaseStudio()
8
+ base_studios = base_studios.list()
9
+ template_id = None
10
+
11
+ if base_studios and len(base_studios):
12
+ # if not specified by user, use the first existing template studio
13
+ template_id = base_studios[0].id
14
+ # else, try to match the provided studio_type to base studio name
15
+ if studio_type:
16
+ normalized_studio_type = studio_type.lower().replace(" ", "-")
17
+ match = next(
18
+ (s for s in base_studios if s.name.lower().replace(" ", "-") == normalized_studio_type),
19
+ None,
20
+ )
21
+ if match:
22
+ template_id = match.id
23
+
24
+ return template_id
@@ -0,0 +1,71 @@
1
+ from typing import Dict, Optional, Set
2
+
3
+ import click
4
+
5
+ from lightning_sdk.machine import Machine
6
+
7
+ DEFAULT_MACHINE = "CPU"
8
+
9
+
10
+ def _split_gpus_spec(gpus: str) -> tuple[str, int]:
11
+ machine_name, machine_val = gpus.split(":", 1)
12
+ machine_name = machine_name.strip()
13
+ machine_val = machine_val.strip()
14
+
15
+ if not machine_val.isdigit() or int(machine_val) <= 0:
16
+ raise ValueError(f"Invalid GPU count '{machine_val}'. Must be a positive integer.")
17
+
18
+ machine_num = int(machine_val)
19
+ return machine_name, machine_num
20
+
21
+
22
+ def _construct_available_gpus(machine_options: Dict[str, str]) -> Set[str]:
23
+ # returns available gpus:count
24
+ available_gpus = set()
25
+ for v in machine_options.values():
26
+ if "_X_" in v:
27
+ gpu_type_num = v.replace("_X_", ":")
28
+ available_gpus.add(gpu_type_num)
29
+ else:
30
+ available_gpus.add(v)
31
+ return available_gpus
32
+
33
+
34
+ def _get_machine_from_gpus(gpus: str) -> Machine:
35
+ machine_name = gpus
36
+ machine_num = 1
37
+
38
+ if ":" in gpus:
39
+ machine_name, machine_num = _split_gpus_spec(gpus)
40
+
41
+ machine_options = {
42
+ m.name.lower(): m.name for m in Machine.__dict__.values() if isinstance(m, Machine) and m._include_in_cli
43
+ }
44
+
45
+ if machine_num == 1:
46
+ # e.g. gpus=L4 or gpus=L4:1
47
+ gpu_key = machine_name.lower()
48
+ try:
49
+ return machine_options[gpu_key]
50
+ except KeyError:
51
+ available = ", ".join(_construct_available_gpus(machine_options))
52
+ raise ValueError(f"Invalid GPU type '{machine_name}'. Available options: {available}") from None
53
+
54
+ # Else: e.g. gpus=L4:4
55
+ gpu_key = f"{machine_name.lower()}_x_{machine_num}"
56
+ try:
57
+ return machine_options[gpu_key]
58
+ except KeyError:
59
+ available = ", ".join(_construct_available_gpus(machine_options))
60
+ raise ValueError(f"Invalid GPU configuration '{gpus}'. Available options: {available}") from None
61
+
62
+
63
+ def handle_machine_and_gpus_args(machine: Optional[str], gpus: Optional[str]) -> str:
64
+ if machine and gpus:
65
+ raise click.UsageError("Options --machine and --gpus are mutually exclusive. Provide only one.")
66
+ elif gpus:
67
+ machine = _get_machine_from_gpus(gpus.strip())
68
+ elif not machine:
69
+ machine = DEFAULT_MACHINE
70
+
71
+ return machine
@@ -0,0 +1,121 @@
1
+ import shlex
2
+ import sys
3
+ import traceback
4
+ from contextlib import suppress
5
+ from time import time
6
+ from types import TracebackType
7
+ from typing import Optional, Type
8
+
9
+ import click
10
+ from rich.console import Group
11
+ from rich.panel import Panel
12
+ from rich.syntax import Syntax
13
+ from rich.text import Text
14
+
15
+ from lightning_sdk.cli.utils import rich_to_str
16
+ from lightning_sdk.constants import _LIGHTNING_DEBUG
17
+ from lightning_sdk.lightning_cloud.openapi.models.v1_create_sdk_command_history_request import (
18
+ V1CreateSDKCommandHistoryRequest,
19
+ )
20
+ from lightning_sdk.lightning_cloud.openapi.models.v1_sdk_command_history_severity import V1SDKCommandHistorySeverity
21
+ from lightning_sdk.lightning_cloud.openapi.models.v1_sdk_command_history_type import V1SDKCommandHistoryType
22
+ from lightning_sdk.lightning_cloud.rest_client import LightningClient
23
+
24
+
25
+ def _log_command(message: str = "", duration: int = 0, error: Optional[str] = None) -> None:
26
+ original_command = " ".join(shlex.quote(arg) for arg in sys.argv)
27
+ client = LightningClient(retry=False, max_tries=0)
28
+
29
+ body = V1CreateSDKCommandHistoryRequest(
30
+ command=original_command,
31
+ duration=duration,
32
+ message=message,
33
+ project_id=None,
34
+ severity=V1SDKCommandHistorySeverity.INFO,
35
+ type=V1SDKCommandHistoryType.CLI,
36
+ )
37
+
38
+ if error:
39
+ body.severity = V1SDKCommandHistorySeverity.WARNING if error == "0" else V1SDKCommandHistorySeverity.ERROR
40
+ body.message = body.message + f" | Error: {error}"
41
+
42
+ # limit characters
43
+ body.message = body.message[:1000]
44
+
45
+ with suppress(Exception):
46
+ client.s_dk_command_history_service_create_sdk_command_history(body=body)
47
+
48
+
49
+ def _notify_exception(exception_type: Type[BaseException], value: BaseException, tb: TracebackType) -> None:
50
+ """CLI won't show tracebacks, just print the exception message."""
51
+ message = str(value.args[0]) if value.args else str(value) or "An unknown error occurred"
52
+
53
+ error_text = Text()
54
+ error_text.append(f"{exception_type.__name__}: ", style="bold red")
55
+ error_text.append(message, style="white")
56
+
57
+ renderables = [error_text]
58
+
59
+ if _LIGHTNING_DEBUG:
60
+ tb_text = "".join(traceback.format_exception(exception_type, value, tb))
61
+ renderables.append(Text("\n\nFull traceback:\n", style="bold yellow"))
62
+ renderables.append(Syntax(tb_text, "python", theme="monokai light", line_numbers=False, word_wrap=True))
63
+ else:
64
+ renderables.append(Text("\n\n🐞 To view the full traceback, set: LIGHTNING_DEBUG=1"))
65
+
66
+ renderables.append(Text("\n📘 Need help? Run: lightning <command> --help", style="cyan"))
67
+
68
+ text = rich_to_str(Panel(Group(*renderables), title="⚡ Lightning CLI Error", border_style="red"))
69
+ click.echo(text, color=True)
70
+
71
+
72
+ def logging_excepthook(exception_type: Type[BaseException], value: BaseException, tb: TracebackType) -> None:
73
+ try:
74
+ tb_str = "".join(traceback.format_exception(exception_type, value, tb))
75
+ ctx = click.get_current_context(silent=True)
76
+ command_context = ctx.command_path if ctx else "outside_command_context"
77
+
78
+ message = (
79
+ f"Command: {command_context} | Type: {exception_type.__name__!s} | Value: {value!s} | Traceback: {tb_str}"
80
+ )
81
+ _log_command(message=message)
82
+ finally:
83
+ _notify_exception(exception_type, value, tb)
84
+
85
+
86
+ class CommandLoggingGroup(click.Group):
87
+ def _format_ctx(self, ctx: click.Context) -> str:
88
+ parts = []
89
+ for k, v in ctx.params.items():
90
+ if v is True:
91
+ parts.append(f"--{k}")
92
+ elif v is False or v is None:
93
+ continue
94
+ else:
95
+ parts.append(f"--{k} {v}")
96
+ params = " ".join(parts)
97
+ args = " ".join(ctx.args)
98
+ return (
99
+ f"""Commands: {ctx.command_path} | Subcommand: {ctx.invoked_subcommand} | Params: {params} | Args:{args}"""
100
+ )
101
+
102
+ def invoke(self, ctx: click.Context) -> any:
103
+ """Overrides the default invoke to wrap command execution with tracking."""
104
+ start_time = time()
105
+ error_message = None
106
+
107
+ try:
108
+ return super().invoke(ctx)
109
+ except click.ClickException as e:
110
+ error_message = str(e)
111
+ e.show()
112
+ ctx.exit(e.exit_code)
113
+ except Exception as e:
114
+ error_message = str(e)
115
+ raise
116
+ finally:
117
+ _log_command(
118
+ message=self._format_ctx(ctx),
119
+ duration=int(time() - start_time),
120
+ error=error_message,
121
+ )
@@ -16,7 +16,7 @@ def configure_ssh_internal(force_download: bool = False) -> str:
16
16
 
17
17
 
18
18
  def download_ssh_keys(
19
- api_key: str | None,
19
+ api_key: Optional[str],
20
20
  force_download: bool = False,
21
21
  ssh_key_name: str = "lightning_rsa",
22
22
  ) -> str:
@@ -29,3 +29,4 @@ class Store:
29
29
 
30
30
 
31
31
  __GLOBAL_LIGHTNING_UNIQUE_IDS_STORE__ = Store()
32
+ _LIGHTNING_DISABLE_VERSION_CHECK = int(os.getenv("LIGHTNING_DISABLE_VERSION_CHECK", "0"))
lightning_sdk/helpers.py CHANGED
@@ -1,4 +1,3 @@
1
- import functools
2
1
  import importlib
3
2
  import os
4
3
  import sys
@@ -10,48 +9,68 @@ import tqdm
10
9
  import tqdm.std
11
10
  from packaging import version as packaging_version
12
11
 
13
- __package_name__ = "lightning-sdk"
12
+ from lightning_sdk.constants import _LIGHTNING_DISABLE_VERSION_CHECK
14
13
 
15
14
 
16
- @functools.lru_cache(maxsize=1)
17
- def _get_newer_version(curr_version: str) -> Optional[str]:
18
- """Check PyPI for newer versions of ``lightning-sdk``.
19
-
20
- Returning the newest version if different from the current or ``None`` otherwise.
15
+ class VersionChecker:
16
+ """Handles version checking and upgrade prompts for lightning-sdk.
21
17
 
18
+ This class ensures that version check warnings are only shown once per session,
19
+ preventing duplicate warnings in multithreaded scenarios.
22
20
  """
23
- if packaging_version.parse(curr_version).is_prerelease:
24
- return None
25
- try:
26
- response = requests.get(f"https://pypi.org/pypi/{__package_name__}/json")
27
- response_json = response.json()
28
- releases = response_json["releases"]
29
- if curr_version not in releases:
30
- # Always return None if not installed from PyPI (e.g. dev versions)
21
+
22
+ def __init__(self, package_name: str = "lightning-sdk") -> None:
23
+ self.package_name = package_name
24
+ self._warning_shown = False
25
+ self._cached_version: Optional[str] = None
26
+
27
+ def _get_newer_version(self, curr_version: str) -> Optional[str]:
28
+ """Check PyPI for newer versions of ``lightning-sdk``.
29
+
30
+ Returning the newest version if different from the current or ``None`` otherwise.
31
+ """
32
+ if self._cached_version is not None:
33
+ return self._cached_version
34
+
35
+ if _LIGHTNING_DISABLE_VERSION_CHECK == 1 or packaging_version.parse(curr_version).is_prerelease:
36
+ self._cached_version = None
31
37
  return None
32
- latest_version = response_json["info"]["version"]
33
- parsed_version = packaging_version.parse(latest_version)
34
- is_invalid = response_json["info"]["yanked"] or parsed_version.is_devrelease or parsed_version.is_prerelease
35
- return None if curr_version == latest_version or is_invalid else latest_version
36
- except requests.exceptions.RequestException:
37
- return None
38
38
 
39
+ try:
40
+ response = requests.get(f"https://pypi.org/pypi/{self.package_name}/json")
41
+ response_json = response.json()
42
+ releases = response_json["releases"]
43
+ if curr_version not in releases:
44
+ # Always return None if not installed from PyPI (e.g. dev versions)
45
+ self._cached_version = None
46
+ return None
47
+ latest_version = response_json["info"]["version"]
48
+ parsed_version = packaging_version.parse(latest_version)
49
+ is_invalid = response_json["info"]["yanked"] or parsed_version.is_devrelease or parsed_version.is_prerelease
50
+ self._cached_version = None if curr_version == latest_version or is_invalid else latest_version
51
+ return self._cached_version
52
+ except requests.exceptions.RequestException:
53
+ self._cached_version = None
54
+ return None
39
55
 
40
- def _check_version_and_prompt_upgrade(curr_version: str) -> None:
41
- """Checks that the current version of ``lightning-sdk`` is the latest on PyPI.
56
+ def check_and_prompt_upgrade(self, curr_version: str) -> None:
57
+ """Checks that the current version of ``lightning-sdk`` is the latest on PyPI.
42
58
 
43
- If not, warn the user to upgrade ``lightning-sdk``.
59
+ If not, warn the user to upgrade ``lightning-sdk``.
60
+ Tracks if the warning has already been shown in this session to avoid duplicate warnings.
61
+ """
62
+ if self._warning_shown:
63
+ return
44
64
 
45
- """
46
- new_version = _get_newer_version(curr_version)
47
- if new_version:
48
- warnings.warn(
49
- f"A newer version of {__package_name__} is available ({new_version}). "
50
- f"Please consider upgrading with `pip install -U {__package_name__}`. "
51
- "Not all platform functionality can be guaranteed to work with the current version.",
52
- UserWarning,
53
- )
54
- return
65
+ new_version = self._get_newer_version(curr_version)
66
+ if new_version:
67
+ warnings.warn(
68
+ f"A newer version of {self.package_name} is available ({new_version}). "
69
+ f"Please consider upgrading with `pip install -U {self.package_name}`. "
70
+ "Not all platform functionality can be guaranteed to work with the current version.",
71
+ UserWarning,
72
+ )
73
+ self._warning_shown = True
55
74
 
56
75
 
57
76
  def _set_tqdm_envvars_noninteractive() -> None:
@@ -5,7 +5,7 @@ import logging
5
5
  import os
6
6
  import pathlib
7
7
  from dataclasses import dataclass
8
- from typing import Optional
8
+ from typing import Optional, Literal
9
9
  from urllib.parse import urlencode
10
10
 
11
11
  import webbrowser
@@ -22,14 +22,21 @@ from lightning_sdk.lightning_cloud.openapi.api import \
22
22
  AuthServiceApi
23
23
  from lightning_sdk.lightning_cloud.openapi.models.v1_guest_login_request import \
24
24
  V1GuestLoginRequest
25
+ from lightning_sdk.lightning_cloud.openapi.models.v1_token_login_request import \
26
+ V1TokenLoginRequest
27
+ from lightning_sdk.lightning_cloud.openapi.models.v1_refresh_request import \
28
+ V1RefreshRequest
25
29
 
26
30
  logger = logging.getLogger(__name__)
27
31
 
32
+ # Authentication override types
33
+ AuthOverride = Literal["auth_token", "api_key", "guest"]
28
34
 
29
35
  class Keys(Enum):
30
36
  # USERNAME = "LIGHTNING_USERNAME"
31
37
  USER_ID = "LIGHTNING_USER_ID"
32
38
  API_KEY = "LIGHTNING_API_KEY"
39
+ AUTH_TOKEN = "LIGHTNING_AUTH_TOKEN"
33
40
 
34
41
  @property
35
42
  def suffix(self):
@@ -41,15 +48,18 @@ class Auth:
41
48
  # username: Optional[str] = None
42
49
  user_id: Optional[str] = None
43
50
  api_key: Optional[str] = None
51
+ auth_token: Optional[str] = None
44
52
 
45
53
  secrets_file = pathlib.Path(env.LIGHTNING_CREDENTIAL_PATH)
46
54
 
47
55
  def __post_init__(self):
48
56
  for key in Keys:
49
- setattr(self, key.suffix, os.environ.get(key.value, None))
57
+ # Only set from environment if not already set
58
+ if getattr(self, key.suffix) is None:
59
+ setattr(self, key.suffix, os.environ.get(key.value, None))
50
60
 
51
61
  # used by authenticate method
52
- self._with_env_var = bool(self.api_key)
62
+ self._with_env_var = bool(self.api_key or self.auth_token)
53
63
 
54
64
  def load(self) -> bool:
55
65
  """Load credentials from disk and update properties with credentials.
@@ -71,6 +81,7 @@ class Auth:
71
81
  token: str = "",
72
82
  user_id: str = "",
73
83
  api_key: str = "",
84
+ auth_token: str = "",
74
85
  # username: str = "",
75
86
  ) -> None:
76
87
  """save credentials to disk."""
@@ -81,6 +92,7 @@ class Auth:
81
92
  # f"{Keys.USERNAME.suffix}": username,
82
93
  f"{Keys.USER_ID.suffix}": user_id,
83
94
  f"{Keys.API_KEY.suffix}": api_key,
95
+ f"{Keys.AUTH_TOKEN.suffix}": auth_token,
84
96
  },
85
97
  f,
86
98
  )
@@ -88,6 +100,7 @@ class Auth:
88
100
  # self.username = username
89
101
  self.user_id = user_id
90
102
  self.api_key = api_key
103
+ self.auth_token = auth_token
91
104
  logger.debug("credentials saved successfully")
92
105
 
93
106
  @classmethod
@@ -99,15 +112,78 @@ class Auth:
99
112
  os.environ.pop(key.value, None)
100
113
  logger.debug("credentials removed successfully")
101
114
 
102
- @property
103
- def auth_header(self) -> Optional[str]:
104
- """authentication header used by lightning-cloud client."""
105
- if self.api_key:
115
+ def get_auth_header(self, override: Optional[AuthOverride] = None) -> Optional[str]:
116
+ """Get authentication header with optional override.
117
+
118
+ By default, uses the automatic priority selection (auth_token > api_key).
119
+ You can override this for specific cases where you need a different auth method.
120
+
121
+ Parameters
122
+ ----------
123
+ override : AuthOverride, optional
124
+ Override the default authentication method:
125
+ - "auth_token": Force use of JWT auth token (Bearer)
126
+ - "api_key": Force use of API key (Basic)
127
+ - "guest": Force use of guest credentials (Basic)
128
+ - None: Use automatic selection (default)
129
+
130
+ Returns
131
+ -------
132
+ Optional[str]
133
+ The authorization header for the specified method.
134
+
135
+ Raises
136
+ ------
137
+ ValueError
138
+ If the specified override is not available or invalid.
139
+ """
140
+ if override == "auth_token":
141
+ if not self.auth_token:
142
+ raise ValueError(
143
+ "Auth token override requested but no JWT token available. "
144
+ "Please use token_login() method first."
145
+ )
146
+ return f"Bearer {self.auth_token}"
147
+
148
+ elif override == "api_key":
149
+ if not self.api_key or not self.user_id:
150
+ raise ValueError(
151
+ "API key override requested but no API key or user ID available. "
152
+ "Please set LIGHTNING_API_KEY and LIGHTNING_USER_ID environment variables "
153
+ "or use authenticate() method."
154
+ )
155
+ token = f"{self.user_id}:{self.api_key}"
156
+ return f"Basic {base64.b64encode(token.encode('ascii')).decode('ascii')}" # noqa E501
157
+
158
+ elif override == "guest":
159
+ if not self.api_key or not self.user_id:
160
+ raise ValueError(
161
+ "Guest override requested but no guest credentials available. "
162
+ "Please call guest_login() method first."
163
+ )
106
164
  token = f"{self.user_id}:{self.api_key}"
107
165
  return f"Basic {base64.b64encode(token.encode('ascii')).decode('ascii')}" # noqa E501
108
- raise AttributeError(
109
- "Authentication Failed, no authentication header available. "
110
- "This is most likely a bug in the LightningCloud Framework")
166
+
167
+ elif override is None:
168
+ # Use the original automatic selection logic (default behavior)
169
+ if self.auth_token:
170
+ return f"Bearer {self.auth_token}"
171
+ elif self.api_key:
172
+ token = f"{self.user_id}:{self.api_key}"
173
+ return f"Basic {base64.b64encode(token.encode('ascii')).decode('ascii')}" # noqa E501
174
+ else:
175
+ raise ValueError(
176
+ "No authentication credentials available. Please authenticate first using "
177
+ "token_login(), guest_login(), or authenticate() methods."
178
+ )
179
+
180
+ else:
181
+ raise ValueError(f"Invalid authentication override: {override}")
182
+
183
+ @property
184
+ def auth_header(self) -> Optional[str]:
185
+ """authentication header used by lightning-cloud client (automatic selection)."""
186
+ return self.get_auth_header()
111
187
 
112
188
  def _run_server(self) -> None:
113
189
  """start a server to complete authentication."""
@@ -130,6 +206,8 @@ class Auth:
130
206
  self._run_server()
131
207
  return self.auth_header
132
208
 
209
+ elif self.auth_token:
210
+ return self.auth_header
133
211
  elif self.user_id and self.api_key:
134
212
  return self.auth_header
135
213
 
@@ -192,6 +270,178 @@ class Auth:
192
270
 
193
271
  return self.auth_header
194
272
 
273
+ def token_login(self, token_key: str, save_token: bool = True) -> Optional[str]:
274
+ """Performs token-based authentication.
275
+
276
+ This method sends a request to the token login endpoint to authenticate
277
+ using an auth token key, optionally saves the returned JWT token, and
278
+ returns the authorization header.
279
+
280
+ Parameters
281
+ ----------
282
+ token_key : str
283
+ The auth token key to use for authentication.
284
+ save_token : bool, optional
285
+ Whether to save the JWT token for future use. Defaults to True.
286
+ If False, the token is only used for the current session.
287
+
288
+ Returns
289
+ -------
290
+ Optional[str]
291
+ The authorization header to use for subsequent requests.
292
+
293
+ Raises
294
+ ------
295
+ RuntimeError
296
+ If the token login request fails.
297
+ ValueError
298
+ If the response from the server is invalid.
299
+ """
300
+ config = Configuration()
301
+ config.host = env.LIGHTNING_CLOUD_URL
302
+ api_client = ApiClient(configuration=config)
303
+ auth_api = AuthServiceApi(api_client)
304
+
305
+ logger.debug(f"Attempting token login to {config.host}")
306
+
307
+ try:
308
+ body = V1TokenLoginRequest(token_key=token_key)
309
+ response = auth_api.auth_service_token_login(body)
310
+
311
+ except requests.RequestException as e:
312
+ logger.error(f"Token login request failed: {e}")
313
+ raise RuntimeError(
314
+ "Failed to connect to the token login endpoint. "
315
+ "Please check your network connection and the server status."
316
+ ) from e
317
+
318
+ # Extract the JWT token from the response
319
+ jwt_token = getattr(response, "token", None)
320
+
321
+ if not jwt_token:
322
+ logger.error(
323
+ f"No token received from token login response: {response}"
324
+ )
325
+ raise ValueError(
326
+ "The token login response did not contain a valid JWT token."
327
+ )
328
+
329
+ # Set the JWT token in memory
330
+ self.auth_token = jwt_token
331
+
332
+ # Optionally save the JWT token to disk
333
+ if save_token:
334
+ self.save(auth_token=jwt_token)
335
+ logger.info("Successfully authenticated using auth token and saved to disk.")
336
+ else:
337
+ logger.info("Successfully authenticated using auth token (not saved to disk).")
338
+
339
+ return self.auth_header
340
+
341
+ def refresh_token(self, duration: int = 43200) -> Optional[str]:
342
+ """Refreshes the current JWT token.
343
+
344
+ This method sends a request to the refresh endpoint to get a new JWT token
345
+ with the specified duration, saves the new token, and returns the updated
346
+ authorization header.
347
+
348
+ Parameters
349
+ ----------
350
+ duration : int, optional
351
+ Duration in seconds for the new token. Can range from 900 seconds (15 minutes)
352
+ up to a maximum of 129,600 seconds (36 hours), with a default of 43,200 seconds (12 hours).
353
+
354
+ Returns
355
+ -------
356
+ Optional[str]
357
+ The updated authorization header with the new JWT token.
358
+
359
+ Raises
360
+ ------
361
+ RuntimeError
362
+ If the refresh request fails.
363
+ ValueError
364
+ If no valid JWT token is available to refresh, or if the response is invalid.
365
+ """
366
+ if not self.auth_token:
367
+ raise ValueError(
368
+ "No JWT token available to refresh. Please authenticate first using "
369
+ "token_login() or authenticate() methods."
370
+ )
371
+
372
+ config = Configuration()
373
+ config.host = env.LIGHTNING_CLOUD_URL
374
+ # Set the current auth token as the authorization header
375
+ config.api_key_prefix['Authorization'] = 'Bearer'
376
+ config.api_key['Authorization'] = self.auth_token
377
+
378
+ api_client = ApiClient(configuration=config)
379
+ auth_api = AuthServiceApi(api_client)
380
+
381
+ logger.debug(f"Attempting to refresh JWT token with duration {duration} seconds")
382
+
383
+ try:
384
+ body = V1RefreshRequest(duration=str(duration))
385
+ response = auth_api.auth_service_refresh(body)
386
+
387
+ except requests.RequestException as e:
388
+ logger.error(f"Token refresh request failed: {e}")
389
+ raise RuntimeError(
390
+ "Failed to connect to the refresh endpoint. "
391
+ "Please check your network connection and the server status."
392
+ ) from e
393
+
394
+ # Extract the new JWT token from the response
395
+ new_jwt_token = getattr(response, "token", None)
396
+
397
+ if not new_jwt_token:
398
+ logger.error(
399
+ f"No token received from refresh response: {response}"
400
+ )
401
+ raise ValueError(
402
+ "The refresh response did not contain a valid JWT token."
403
+ )
404
+
405
+ # Save the new JWT token
406
+ self.save(auth_token=new_jwt_token)
407
+ logger.info("Successfully refreshed JWT token.")
408
+
409
+ return self.auth_header
410
+
411
+ def create_api_client(self, override: Optional[AuthOverride] = None) -> 'ApiClient':
412
+ """Create an API client with optional authentication override.
413
+
414
+ This is a convenience method for creating API clients that use a specific
415
+ authentication method instead of the default automatic selection.
416
+
417
+ Parameters
418
+ ----------
419
+ override : AuthOverride, optional
420
+ Override the default authentication method for this API client.
421
+ See get_auth_header() for available options.
422
+
423
+ Returns
424
+ -------
425
+ ApiClient
426
+ Configured API client with the specified authentication method.
427
+ """
428
+ from lightning_sdk.lightning_cloud.openapi import ApiClient, Configuration
429
+
430
+ config = Configuration()
431
+ config.host = env.LIGHTNING_CLOUD_URL
432
+
433
+ # Get the auth header for the specified override
434
+ auth_header = self.get_auth_header(override)
435
+
436
+ # Create the API client
437
+ client = ApiClient(configuration=config)
438
+
439
+ # Set the Authorization header directly in default_headers
440
+ if auth_header:
441
+ client.set_default_header('Authorization', auth_header)
442
+
443
+ return client
444
+
195
445
 
196
446
  class AuthServer:
197
447