lightning-sdk 2025.10.14__py3-none-any.whl → 2025.10.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lightning_sdk/__init__.py +6 -3
- lightning_sdk/api/base_studio_api.py +13 -9
- lightning_sdk/api/license_api.py +26 -59
- lightning_sdk/api/studio_api.py +7 -2
- lightning_sdk/base_studio.py +30 -17
- lightning_sdk/cli/base_studio/list.py +1 -3
- lightning_sdk/cli/entrypoint.py +8 -34
- lightning_sdk/cli/studio/connect.py +42 -92
- lightning_sdk/cli/studio/create.py +23 -1
- lightning_sdk/cli/studio/start.py +12 -2
- lightning_sdk/cli/utils/get_base_studio.py +24 -0
- lightning_sdk/cli/utils/handle_machine_and_gpus_args.py +71 -0
- lightning_sdk/cli/utils/logging.py +121 -0
- lightning_sdk/cli/utils/ssh_connection.py +1 -1
- lightning_sdk/constants.py +1 -0
- lightning_sdk/helpers.py +53 -34
- lightning_sdk/lightning_cloud/login.py +260 -10
- lightning_sdk/lightning_cloud/openapi/__init__.py +10 -3
- lightning_sdk/lightning_cloud/openapi/api/auth_service_api.py +97 -0
- lightning_sdk/lightning_cloud/openapi/api/product_license_service_api.py +108 -108
- lightning_sdk/lightning_cloud/openapi/models/__init__.py +10 -3
- lightning_sdk/lightning_cloud/openapi/models/create_machine_request_represents_the_request_to_create_a_machine.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/externalv1_cloud_space_instance_status.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/id_fork_body1.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/license_key_validate_body.py +123 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_create_license_request.py +175 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_delete_license_response.py +97 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_external_cluster_spec.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_get_cloud_space_transfer_estimate_response.py +29 -3
- lightning_sdk/lightning_cloud/openapi/models/v1_incident.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_incident_detail.py +149 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_incident_event.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_license.py +227 -0
- lightning_sdk/lightning_cloud/openapi/models/{v1_list_product_licenses_response.py → v1_list_license_response.py} +16 -16
- lightning_sdk/lightning_cloud/openapi/models/v1_machine.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_slack_notifier.py +53 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_token_login_request.py +123 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_token_login_response.py +123 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_token_owner_type.py +104 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_user_features.py +53 -79
- lightning_sdk/lightning_cloud/openapi/models/{v1_product_license_check_response.py → v1_validate_license_response.py} +21 -21
- lightning_sdk/lightning_cloud/rest_client.py +48 -45
- lightning_sdk/machine.py +2 -0
- lightning_sdk/studio.py +14 -2
- lightning_sdk/utils/license.py +13 -0
- {lightning_sdk-2025.10.14.dist-info → lightning_sdk-2025.10.22.dist-info}/METADATA +1 -1
- {lightning_sdk-2025.10.14.dist-info → lightning_sdk-2025.10.22.dist-info}/RECORD +51 -41
- lightning_sdk/lightning_cloud/openapi/models/v1_product_license.py +0 -435
- lightning_sdk/services/license.py +0 -363
- {lightning_sdk-2025.10.14.dist-info → lightning_sdk-2025.10.22.dist-info}/LICENSE +0 -0
- {lightning_sdk-2025.10.14.dist-info → lightning_sdk-2025.10.22.dist-info}/WHEEL +0 -0
- {lightning_sdk-2025.10.14.dist-info → lightning_sdk-2025.10.22.dist-info}/entry_points.txt +0 -0
- {lightning_sdk-2025.10.14.dist-info → lightning_sdk-2025.10.22.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from lightning_sdk.base_studio import BaseStudio
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def get_base_studio_id(studio_type: Optional[str]) -> Optional[str]:
|
|
7
|
+
base_studios = BaseStudio()
|
|
8
|
+
base_studios = base_studios.list()
|
|
9
|
+
template_id = None
|
|
10
|
+
|
|
11
|
+
if base_studios and len(base_studios):
|
|
12
|
+
# if not specified by user, use the first existing template studio
|
|
13
|
+
template_id = base_studios[0].id
|
|
14
|
+
# else, try to match the provided studio_type to base studio name
|
|
15
|
+
if studio_type:
|
|
16
|
+
normalized_studio_type = studio_type.lower().replace(" ", "-")
|
|
17
|
+
match = next(
|
|
18
|
+
(s for s in base_studios if s.name.lower().replace(" ", "-") == normalized_studio_type),
|
|
19
|
+
None,
|
|
20
|
+
)
|
|
21
|
+
if match:
|
|
22
|
+
template_id = match.id
|
|
23
|
+
|
|
24
|
+
return template_id
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
from typing import Dict, Optional, Set
|
|
2
|
+
|
|
3
|
+
import click
|
|
4
|
+
|
|
5
|
+
from lightning_sdk.machine import Machine
|
|
6
|
+
|
|
7
|
+
DEFAULT_MACHINE = "CPU"
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _split_gpus_spec(gpus: str) -> tuple[str, int]:
|
|
11
|
+
machine_name, machine_val = gpus.split(":", 1)
|
|
12
|
+
machine_name = machine_name.strip()
|
|
13
|
+
machine_val = machine_val.strip()
|
|
14
|
+
|
|
15
|
+
if not machine_val.isdigit() or int(machine_val) <= 0:
|
|
16
|
+
raise ValueError(f"Invalid GPU count '{machine_val}'. Must be a positive integer.")
|
|
17
|
+
|
|
18
|
+
machine_num = int(machine_val)
|
|
19
|
+
return machine_name, machine_num
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _construct_available_gpus(machine_options: Dict[str, str]) -> Set[str]:
|
|
23
|
+
# returns available gpus:count
|
|
24
|
+
available_gpus = set()
|
|
25
|
+
for v in machine_options.values():
|
|
26
|
+
if "_X_" in v:
|
|
27
|
+
gpu_type_num = v.replace("_X_", ":")
|
|
28
|
+
available_gpus.add(gpu_type_num)
|
|
29
|
+
else:
|
|
30
|
+
available_gpus.add(v)
|
|
31
|
+
return available_gpus
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _get_machine_from_gpus(gpus: str) -> Machine:
|
|
35
|
+
machine_name = gpus
|
|
36
|
+
machine_num = 1
|
|
37
|
+
|
|
38
|
+
if ":" in gpus:
|
|
39
|
+
machine_name, machine_num = _split_gpus_spec(gpus)
|
|
40
|
+
|
|
41
|
+
machine_options = {
|
|
42
|
+
m.name.lower(): m.name for m in Machine.__dict__.values() if isinstance(m, Machine) and m._include_in_cli
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
if machine_num == 1:
|
|
46
|
+
# e.g. gpus=L4 or gpus=L4:1
|
|
47
|
+
gpu_key = machine_name.lower()
|
|
48
|
+
try:
|
|
49
|
+
return machine_options[gpu_key]
|
|
50
|
+
except KeyError:
|
|
51
|
+
available = ", ".join(_construct_available_gpus(machine_options))
|
|
52
|
+
raise ValueError(f"Invalid GPU type '{machine_name}'. Available options: {available}") from None
|
|
53
|
+
|
|
54
|
+
# Else: e.g. gpus=L4:4
|
|
55
|
+
gpu_key = f"{machine_name.lower()}_x_{machine_num}"
|
|
56
|
+
try:
|
|
57
|
+
return machine_options[gpu_key]
|
|
58
|
+
except KeyError:
|
|
59
|
+
available = ", ".join(_construct_available_gpus(machine_options))
|
|
60
|
+
raise ValueError(f"Invalid GPU configuration '{gpus}'. Available options: {available}") from None
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def handle_machine_and_gpus_args(machine: Optional[str], gpus: Optional[str]) -> str:
|
|
64
|
+
if machine and gpus:
|
|
65
|
+
raise click.UsageError("Options --machine and --gpus are mutually exclusive. Provide only one.")
|
|
66
|
+
elif gpus:
|
|
67
|
+
machine = _get_machine_from_gpus(gpus.strip())
|
|
68
|
+
elif not machine:
|
|
69
|
+
machine = DEFAULT_MACHINE
|
|
70
|
+
|
|
71
|
+
return machine
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
import shlex
|
|
2
|
+
import sys
|
|
3
|
+
import traceback
|
|
4
|
+
from contextlib import suppress
|
|
5
|
+
from time import time
|
|
6
|
+
from types import TracebackType
|
|
7
|
+
from typing import Optional, Type
|
|
8
|
+
|
|
9
|
+
import click
|
|
10
|
+
from rich.console import Group
|
|
11
|
+
from rich.panel import Panel
|
|
12
|
+
from rich.syntax import Syntax
|
|
13
|
+
from rich.text import Text
|
|
14
|
+
|
|
15
|
+
from lightning_sdk.cli.utils import rich_to_str
|
|
16
|
+
from lightning_sdk.constants import _LIGHTNING_DEBUG
|
|
17
|
+
from lightning_sdk.lightning_cloud.openapi.models.v1_create_sdk_command_history_request import (
|
|
18
|
+
V1CreateSDKCommandHistoryRequest,
|
|
19
|
+
)
|
|
20
|
+
from lightning_sdk.lightning_cloud.openapi.models.v1_sdk_command_history_severity import V1SDKCommandHistorySeverity
|
|
21
|
+
from lightning_sdk.lightning_cloud.openapi.models.v1_sdk_command_history_type import V1SDKCommandHistoryType
|
|
22
|
+
from lightning_sdk.lightning_cloud.rest_client import LightningClient
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _log_command(message: str = "", duration: int = 0, error: Optional[str] = None) -> None:
|
|
26
|
+
original_command = " ".join(shlex.quote(arg) for arg in sys.argv)
|
|
27
|
+
client = LightningClient(retry=False, max_tries=0)
|
|
28
|
+
|
|
29
|
+
body = V1CreateSDKCommandHistoryRequest(
|
|
30
|
+
command=original_command,
|
|
31
|
+
duration=duration,
|
|
32
|
+
message=message,
|
|
33
|
+
project_id=None,
|
|
34
|
+
severity=V1SDKCommandHistorySeverity.INFO,
|
|
35
|
+
type=V1SDKCommandHistoryType.CLI,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
if error:
|
|
39
|
+
body.severity = V1SDKCommandHistorySeverity.WARNING if error == "0" else V1SDKCommandHistorySeverity.ERROR
|
|
40
|
+
body.message = body.message + f" | Error: {error}"
|
|
41
|
+
|
|
42
|
+
# limit characters
|
|
43
|
+
body.message = body.message[:1000]
|
|
44
|
+
|
|
45
|
+
with suppress(Exception):
|
|
46
|
+
client.s_dk_command_history_service_create_sdk_command_history(body=body)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _notify_exception(exception_type: Type[BaseException], value: BaseException, tb: TracebackType) -> None:
|
|
50
|
+
"""CLI won't show tracebacks, just print the exception message."""
|
|
51
|
+
message = str(value.args[0]) if value.args else str(value) or "An unknown error occurred"
|
|
52
|
+
|
|
53
|
+
error_text = Text()
|
|
54
|
+
error_text.append(f"{exception_type.__name__}: ", style="bold red")
|
|
55
|
+
error_text.append(message, style="white")
|
|
56
|
+
|
|
57
|
+
renderables = [error_text]
|
|
58
|
+
|
|
59
|
+
if _LIGHTNING_DEBUG:
|
|
60
|
+
tb_text = "".join(traceback.format_exception(exception_type, value, tb))
|
|
61
|
+
renderables.append(Text("\n\nFull traceback:\n", style="bold yellow"))
|
|
62
|
+
renderables.append(Syntax(tb_text, "python", theme="monokai light", line_numbers=False, word_wrap=True))
|
|
63
|
+
else:
|
|
64
|
+
renderables.append(Text("\n\n🐞 To view the full traceback, set: LIGHTNING_DEBUG=1"))
|
|
65
|
+
|
|
66
|
+
renderables.append(Text("\n📘 Need help? Run: lightning <command> --help", style="cyan"))
|
|
67
|
+
|
|
68
|
+
text = rich_to_str(Panel(Group(*renderables), title="⚡ Lightning CLI Error", border_style="red"))
|
|
69
|
+
click.echo(text, color=True)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def logging_excepthook(exception_type: Type[BaseException], value: BaseException, tb: TracebackType) -> None:
|
|
73
|
+
try:
|
|
74
|
+
tb_str = "".join(traceback.format_exception(exception_type, value, tb))
|
|
75
|
+
ctx = click.get_current_context(silent=True)
|
|
76
|
+
command_context = ctx.command_path if ctx else "outside_command_context"
|
|
77
|
+
|
|
78
|
+
message = (
|
|
79
|
+
f"Command: {command_context} | Type: {exception_type.__name__!s} | Value: {value!s} | Traceback: {tb_str}"
|
|
80
|
+
)
|
|
81
|
+
_log_command(message=message)
|
|
82
|
+
finally:
|
|
83
|
+
_notify_exception(exception_type, value, tb)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class CommandLoggingGroup(click.Group):
|
|
87
|
+
def _format_ctx(self, ctx: click.Context) -> str:
|
|
88
|
+
parts = []
|
|
89
|
+
for k, v in ctx.params.items():
|
|
90
|
+
if v is True:
|
|
91
|
+
parts.append(f"--{k}")
|
|
92
|
+
elif v is False or v is None:
|
|
93
|
+
continue
|
|
94
|
+
else:
|
|
95
|
+
parts.append(f"--{k} {v}")
|
|
96
|
+
params = " ".join(parts)
|
|
97
|
+
args = " ".join(ctx.args)
|
|
98
|
+
return (
|
|
99
|
+
f"""Commands: {ctx.command_path} | Subcommand: {ctx.invoked_subcommand} | Params: {params} | Args:{args}"""
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
def invoke(self, ctx: click.Context) -> any:
|
|
103
|
+
"""Overrides the default invoke to wrap command execution with tracking."""
|
|
104
|
+
start_time = time()
|
|
105
|
+
error_message = None
|
|
106
|
+
|
|
107
|
+
try:
|
|
108
|
+
return super().invoke(ctx)
|
|
109
|
+
except click.ClickException as e:
|
|
110
|
+
error_message = str(e)
|
|
111
|
+
e.show()
|
|
112
|
+
ctx.exit(e.exit_code)
|
|
113
|
+
except Exception as e:
|
|
114
|
+
error_message = str(e)
|
|
115
|
+
raise
|
|
116
|
+
finally:
|
|
117
|
+
_log_command(
|
|
118
|
+
message=self._format_ctx(ctx),
|
|
119
|
+
duration=int(time() - start_time),
|
|
120
|
+
error=error_message,
|
|
121
|
+
)
|
lightning_sdk/constants.py
CHANGED
lightning_sdk/helpers.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import functools
|
|
2
1
|
import importlib
|
|
3
2
|
import os
|
|
4
3
|
import sys
|
|
@@ -10,48 +9,68 @@ import tqdm
|
|
|
10
9
|
import tqdm.std
|
|
11
10
|
from packaging import version as packaging_version
|
|
12
11
|
|
|
13
|
-
|
|
12
|
+
from lightning_sdk.constants import _LIGHTNING_DISABLE_VERSION_CHECK
|
|
14
13
|
|
|
15
14
|
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
"""Check PyPI for newer versions of ``lightning-sdk``.
|
|
19
|
-
|
|
20
|
-
Returning the newest version if different from the current or ``None`` otherwise.
|
|
15
|
+
class VersionChecker:
|
|
16
|
+
"""Handles version checking and upgrade prompts for lightning-sdk.
|
|
21
17
|
|
|
18
|
+
This class ensures that version check warnings are only shown once per session,
|
|
19
|
+
preventing duplicate warnings in multithreaded scenarios.
|
|
22
20
|
"""
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
21
|
+
|
|
22
|
+
def __init__(self, package_name: str = "lightning-sdk") -> None:
|
|
23
|
+
self.package_name = package_name
|
|
24
|
+
self._warning_shown = False
|
|
25
|
+
self._cached_version: Optional[str] = None
|
|
26
|
+
|
|
27
|
+
def _get_newer_version(self, curr_version: str) -> Optional[str]:
|
|
28
|
+
"""Check PyPI for newer versions of ``lightning-sdk``.
|
|
29
|
+
|
|
30
|
+
Returning the newest version if different from the current or ``None`` otherwise.
|
|
31
|
+
"""
|
|
32
|
+
if self._cached_version is not None:
|
|
33
|
+
return self._cached_version
|
|
34
|
+
|
|
35
|
+
if _LIGHTNING_DISABLE_VERSION_CHECK == 1 or packaging_version.parse(curr_version).is_prerelease:
|
|
36
|
+
self._cached_version = None
|
|
31
37
|
return None
|
|
32
|
-
latest_version = response_json["info"]["version"]
|
|
33
|
-
parsed_version = packaging_version.parse(latest_version)
|
|
34
|
-
is_invalid = response_json["info"]["yanked"] or parsed_version.is_devrelease or parsed_version.is_prerelease
|
|
35
|
-
return None if curr_version == latest_version or is_invalid else latest_version
|
|
36
|
-
except requests.exceptions.RequestException:
|
|
37
|
-
return None
|
|
38
38
|
|
|
39
|
+
try:
|
|
40
|
+
response = requests.get(f"https://pypi.org/pypi/{self.package_name}/json")
|
|
41
|
+
response_json = response.json()
|
|
42
|
+
releases = response_json["releases"]
|
|
43
|
+
if curr_version not in releases:
|
|
44
|
+
# Always return None if not installed from PyPI (e.g. dev versions)
|
|
45
|
+
self._cached_version = None
|
|
46
|
+
return None
|
|
47
|
+
latest_version = response_json["info"]["version"]
|
|
48
|
+
parsed_version = packaging_version.parse(latest_version)
|
|
49
|
+
is_invalid = response_json["info"]["yanked"] or parsed_version.is_devrelease or parsed_version.is_prerelease
|
|
50
|
+
self._cached_version = None if curr_version == latest_version or is_invalid else latest_version
|
|
51
|
+
return self._cached_version
|
|
52
|
+
except requests.exceptions.RequestException:
|
|
53
|
+
self._cached_version = None
|
|
54
|
+
return None
|
|
39
55
|
|
|
40
|
-
def
|
|
41
|
-
|
|
56
|
+
def check_and_prompt_upgrade(self, curr_version: str) -> None:
|
|
57
|
+
"""Checks that the current version of ``lightning-sdk`` is the latest on PyPI.
|
|
42
58
|
|
|
43
|
-
|
|
59
|
+
If not, warn the user to upgrade ``lightning-sdk``.
|
|
60
|
+
Tracks if the warning has already been shown in this session to avoid duplicate warnings.
|
|
61
|
+
"""
|
|
62
|
+
if self._warning_shown:
|
|
63
|
+
return
|
|
44
64
|
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
return
|
|
65
|
+
new_version = self._get_newer_version(curr_version)
|
|
66
|
+
if new_version:
|
|
67
|
+
warnings.warn(
|
|
68
|
+
f"A newer version of {self.package_name} is available ({new_version}). "
|
|
69
|
+
f"Please consider upgrading with `pip install -U {self.package_name}`. "
|
|
70
|
+
"Not all platform functionality can be guaranteed to work with the current version.",
|
|
71
|
+
UserWarning,
|
|
72
|
+
)
|
|
73
|
+
self._warning_shown = True
|
|
55
74
|
|
|
56
75
|
|
|
57
76
|
def _set_tqdm_envvars_noninteractive() -> None:
|
|
@@ -5,7 +5,7 @@ import logging
|
|
|
5
5
|
import os
|
|
6
6
|
import pathlib
|
|
7
7
|
from dataclasses import dataclass
|
|
8
|
-
from typing import Optional
|
|
8
|
+
from typing import Optional, Literal
|
|
9
9
|
from urllib.parse import urlencode
|
|
10
10
|
|
|
11
11
|
import webbrowser
|
|
@@ -22,14 +22,21 @@ from lightning_sdk.lightning_cloud.openapi.api import \
|
|
|
22
22
|
AuthServiceApi
|
|
23
23
|
from lightning_sdk.lightning_cloud.openapi.models.v1_guest_login_request import \
|
|
24
24
|
V1GuestLoginRequest
|
|
25
|
+
from lightning_sdk.lightning_cloud.openapi.models.v1_token_login_request import \
|
|
26
|
+
V1TokenLoginRequest
|
|
27
|
+
from lightning_sdk.lightning_cloud.openapi.models.v1_refresh_request import \
|
|
28
|
+
V1RefreshRequest
|
|
25
29
|
|
|
26
30
|
logger = logging.getLogger(__name__)
|
|
27
31
|
|
|
32
|
+
# Authentication override types
|
|
33
|
+
AuthOverride = Literal["auth_token", "api_key", "guest"]
|
|
28
34
|
|
|
29
35
|
class Keys(Enum):
|
|
30
36
|
# USERNAME = "LIGHTNING_USERNAME"
|
|
31
37
|
USER_ID = "LIGHTNING_USER_ID"
|
|
32
38
|
API_KEY = "LIGHTNING_API_KEY"
|
|
39
|
+
AUTH_TOKEN = "LIGHTNING_AUTH_TOKEN"
|
|
33
40
|
|
|
34
41
|
@property
|
|
35
42
|
def suffix(self):
|
|
@@ -41,15 +48,18 @@ class Auth:
|
|
|
41
48
|
# username: Optional[str] = None
|
|
42
49
|
user_id: Optional[str] = None
|
|
43
50
|
api_key: Optional[str] = None
|
|
51
|
+
auth_token: Optional[str] = None
|
|
44
52
|
|
|
45
53
|
secrets_file = pathlib.Path(env.LIGHTNING_CREDENTIAL_PATH)
|
|
46
54
|
|
|
47
55
|
def __post_init__(self):
|
|
48
56
|
for key in Keys:
|
|
49
|
-
|
|
57
|
+
# Only set from environment if not already set
|
|
58
|
+
if getattr(self, key.suffix) is None:
|
|
59
|
+
setattr(self, key.suffix, os.environ.get(key.value, None))
|
|
50
60
|
|
|
51
61
|
# used by authenticate method
|
|
52
|
-
self._with_env_var = bool(self.api_key)
|
|
62
|
+
self._with_env_var = bool(self.api_key or self.auth_token)
|
|
53
63
|
|
|
54
64
|
def load(self) -> bool:
|
|
55
65
|
"""Load credentials from disk and update properties with credentials.
|
|
@@ -71,6 +81,7 @@ class Auth:
|
|
|
71
81
|
token: str = "",
|
|
72
82
|
user_id: str = "",
|
|
73
83
|
api_key: str = "",
|
|
84
|
+
auth_token: str = "",
|
|
74
85
|
# username: str = "",
|
|
75
86
|
) -> None:
|
|
76
87
|
"""save credentials to disk."""
|
|
@@ -81,6 +92,7 @@ class Auth:
|
|
|
81
92
|
# f"{Keys.USERNAME.suffix}": username,
|
|
82
93
|
f"{Keys.USER_ID.suffix}": user_id,
|
|
83
94
|
f"{Keys.API_KEY.suffix}": api_key,
|
|
95
|
+
f"{Keys.AUTH_TOKEN.suffix}": auth_token,
|
|
84
96
|
},
|
|
85
97
|
f,
|
|
86
98
|
)
|
|
@@ -88,6 +100,7 @@ class Auth:
|
|
|
88
100
|
# self.username = username
|
|
89
101
|
self.user_id = user_id
|
|
90
102
|
self.api_key = api_key
|
|
103
|
+
self.auth_token = auth_token
|
|
91
104
|
logger.debug("credentials saved successfully")
|
|
92
105
|
|
|
93
106
|
@classmethod
|
|
@@ -99,15 +112,78 @@ class Auth:
|
|
|
99
112
|
os.environ.pop(key.value, None)
|
|
100
113
|
logger.debug("credentials removed successfully")
|
|
101
114
|
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
115
|
+
def get_auth_header(self, override: Optional[AuthOverride] = None) -> Optional[str]:
|
|
116
|
+
"""Get authentication header with optional override.
|
|
117
|
+
|
|
118
|
+
By default, uses the automatic priority selection (auth_token > api_key).
|
|
119
|
+
You can override this for specific cases where you need a different auth method.
|
|
120
|
+
|
|
121
|
+
Parameters
|
|
122
|
+
----------
|
|
123
|
+
override : AuthOverride, optional
|
|
124
|
+
Override the default authentication method:
|
|
125
|
+
- "auth_token": Force use of JWT auth token (Bearer)
|
|
126
|
+
- "api_key": Force use of API key (Basic)
|
|
127
|
+
- "guest": Force use of guest credentials (Basic)
|
|
128
|
+
- None: Use automatic selection (default)
|
|
129
|
+
|
|
130
|
+
Returns
|
|
131
|
+
-------
|
|
132
|
+
Optional[str]
|
|
133
|
+
The authorization header for the specified method.
|
|
134
|
+
|
|
135
|
+
Raises
|
|
136
|
+
------
|
|
137
|
+
ValueError
|
|
138
|
+
If the specified override is not available or invalid.
|
|
139
|
+
"""
|
|
140
|
+
if override == "auth_token":
|
|
141
|
+
if not self.auth_token:
|
|
142
|
+
raise ValueError(
|
|
143
|
+
"Auth token override requested but no JWT token available. "
|
|
144
|
+
"Please use token_login() method first."
|
|
145
|
+
)
|
|
146
|
+
return f"Bearer {self.auth_token}"
|
|
147
|
+
|
|
148
|
+
elif override == "api_key":
|
|
149
|
+
if not self.api_key or not self.user_id:
|
|
150
|
+
raise ValueError(
|
|
151
|
+
"API key override requested but no API key or user ID available. "
|
|
152
|
+
"Please set LIGHTNING_API_KEY and LIGHTNING_USER_ID environment variables "
|
|
153
|
+
"or use authenticate() method."
|
|
154
|
+
)
|
|
155
|
+
token = f"{self.user_id}:{self.api_key}"
|
|
156
|
+
return f"Basic {base64.b64encode(token.encode('ascii')).decode('ascii')}" # noqa E501
|
|
157
|
+
|
|
158
|
+
elif override == "guest":
|
|
159
|
+
if not self.api_key or not self.user_id:
|
|
160
|
+
raise ValueError(
|
|
161
|
+
"Guest override requested but no guest credentials available. "
|
|
162
|
+
"Please call guest_login() method first."
|
|
163
|
+
)
|
|
106
164
|
token = f"{self.user_id}:{self.api_key}"
|
|
107
165
|
return f"Basic {base64.b64encode(token.encode('ascii')).decode('ascii')}" # noqa E501
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
166
|
+
|
|
167
|
+
elif override is None:
|
|
168
|
+
# Use the original automatic selection logic (default behavior)
|
|
169
|
+
if self.auth_token:
|
|
170
|
+
return f"Bearer {self.auth_token}"
|
|
171
|
+
elif self.api_key:
|
|
172
|
+
token = f"{self.user_id}:{self.api_key}"
|
|
173
|
+
return f"Basic {base64.b64encode(token.encode('ascii')).decode('ascii')}" # noqa E501
|
|
174
|
+
else:
|
|
175
|
+
raise ValueError(
|
|
176
|
+
"No authentication credentials available. Please authenticate first using "
|
|
177
|
+
"token_login(), guest_login(), or authenticate() methods."
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
else:
|
|
181
|
+
raise ValueError(f"Invalid authentication override: {override}")
|
|
182
|
+
|
|
183
|
+
@property
|
|
184
|
+
def auth_header(self) -> Optional[str]:
|
|
185
|
+
"""authentication header used by lightning-cloud client (automatic selection)."""
|
|
186
|
+
return self.get_auth_header()
|
|
111
187
|
|
|
112
188
|
def _run_server(self) -> None:
|
|
113
189
|
"""start a server to complete authentication."""
|
|
@@ -130,6 +206,8 @@ class Auth:
|
|
|
130
206
|
self._run_server()
|
|
131
207
|
return self.auth_header
|
|
132
208
|
|
|
209
|
+
elif self.auth_token:
|
|
210
|
+
return self.auth_header
|
|
133
211
|
elif self.user_id and self.api_key:
|
|
134
212
|
return self.auth_header
|
|
135
213
|
|
|
@@ -192,6 +270,178 @@ class Auth:
|
|
|
192
270
|
|
|
193
271
|
return self.auth_header
|
|
194
272
|
|
|
273
|
+
def token_login(self, token_key: str, save_token: bool = True) -> Optional[str]:
|
|
274
|
+
"""Performs token-based authentication.
|
|
275
|
+
|
|
276
|
+
This method sends a request to the token login endpoint to authenticate
|
|
277
|
+
using an auth token key, optionally saves the returned JWT token, and
|
|
278
|
+
returns the authorization header.
|
|
279
|
+
|
|
280
|
+
Parameters
|
|
281
|
+
----------
|
|
282
|
+
token_key : str
|
|
283
|
+
The auth token key to use for authentication.
|
|
284
|
+
save_token : bool, optional
|
|
285
|
+
Whether to save the JWT token for future use. Defaults to True.
|
|
286
|
+
If False, the token is only used for the current session.
|
|
287
|
+
|
|
288
|
+
Returns
|
|
289
|
+
-------
|
|
290
|
+
Optional[str]
|
|
291
|
+
The authorization header to use for subsequent requests.
|
|
292
|
+
|
|
293
|
+
Raises
|
|
294
|
+
------
|
|
295
|
+
RuntimeError
|
|
296
|
+
If the token login request fails.
|
|
297
|
+
ValueError
|
|
298
|
+
If the response from the server is invalid.
|
|
299
|
+
"""
|
|
300
|
+
config = Configuration()
|
|
301
|
+
config.host = env.LIGHTNING_CLOUD_URL
|
|
302
|
+
api_client = ApiClient(configuration=config)
|
|
303
|
+
auth_api = AuthServiceApi(api_client)
|
|
304
|
+
|
|
305
|
+
logger.debug(f"Attempting token login to {config.host}")
|
|
306
|
+
|
|
307
|
+
try:
|
|
308
|
+
body = V1TokenLoginRequest(token_key=token_key)
|
|
309
|
+
response = auth_api.auth_service_token_login(body)
|
|
310
|
+
|
|
311
|
+
except requests.RequestException as e:
|
|
312
|
+
logger.error(f"Token login request failed: {e}")
|
|
313
|
+
raise RuntimeError(
|
|
314
|
+
"Failed to connect to the token login endpoint. "
|
|
315
|
+
"Please check your network connection and the server status."
|
|
316
|
+
) from e
|
|
317
|
+
|
|
318
|
+
# Extract the JWT token from the response
|
|
319
|
+
jwt_token = getattr(response, "token", None)
|
|
320
|
+
|
|
321
|
+
if not jwt_token:
|
|
322
|
+
logger.error(
|
|
323
|
+
f"No token received from token login response: {response}"
|
|
324
|
+
)
|
|
325
|
+
raise ValueError(
|
|
326
|
+
"The token login response did not contain a valid JWT token."
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
# Set the JWT token in memory
|
|
330
|
+
self.auth_token = jwt_token
|
|
331
|
+
|
|
332
|
+
# Optionally save the JWT token to disk
|
|
333
|
+
if save_token:
|
|
334
|
+
self.save(auth_token=jwt_token)
|
|
335
|
+
logger.info("Successfully authenticated using auth token and saved to disk.")
|
|
336
|
+
else:
|
|
337
|
+
logger.info("Successfully authenticated using auth token (not saved to disk).")
|
|
338
|
+
|
|
339
|
+
return self.auth_header
|
|
340
|
+
|
|
341
|
+
def refresh_token(self, duration: int = 43200) -> Optional[str]:
|
|
342
|
+
"""Refreshes the current JWT token.
|
|
343
|
+
|
|
344
|
+
This method sends a request to the refresh endpoint to get a new JWT token
|
|
345
|
+
with the specified duration, saves the new token, and returns the updated
|
|
346
|
+
authorization header.
|
|
347
|
+
|
|
348
|
+
Parameters
|
|
349
|
+
----------
|
|
350
|
+
duration : int, optional
|
|
351
|
+
Duration in seconds for the new token. Can range from 900 seconds (15 minutes)
|
|
352
|
+
up to a maximum of 129,600 seconds (36 hours), with a default of 43,200 seconds (12 hours).
|
|
353
|
+
|
|
354
|
+
Returns
|
|
355
|
+
-------
|
|
356
|
+
Optional[str]
|
|
357
|
+
The updated authorization header with the new JWT token.
|
|
358
|
+
|
|
359
|
+
Raises
|
|
360
|
+
------
|
|
361
|
+
RuntimeError
|
|
362
|
+
If the refresh request fails.
|
|
363
|
+
ValueError
|
|
364
|
+
If no valid JWT token is available to refresh, or if the response is invalid.
|
|
365
|
+
"""
|
|
366
|
+
if not self.auth_token:
|
|
367
|
+
raise ValueError(
|
|
368
|
+
"No JWT token available to refresh. Please authenticate first using "
|
|
369
|
+
"token_login() or authenticate() methods."
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
config = Configuration()
|
|
373
|
+
config.host = env.LIGHTNING_CLOUD_URL
|
|
374
|
+
# Set the current auth token as the authorization header
|
|
375
|
+
config.api_key_prefix['Authorization'] = 'Bearer'
|
|
376
|
+
config.api_key['Authorization'] = self.auth_token
|
|
377
|
+
|
|
378
|
+
api_client = ApiClient(configuration=config)
|
|
379
|
+
auth_api = AuthServiceApi(api_client)
|
|
380
|
+
|
|
381
|
+
logger.debug(f"Attempting to refresh JWT token with duration {duration} seconds")
|
|
382
|
+
|
|
383
|
+
try:
|
|
384
|
+
body = V1RefreshRequest(duration=str(duration))
|
|
385
|
+
response = auth_api.auth_service_refresh(body)
|
|
386
|
+
|
|
387
|
+
except requests.RequestException as e:
|
|
388
|
+
logger.error(f"Token refresh request failed: {e}")
|
|
389
|
+
raise RuntimeError(
|
|
390
|
+
"Failed to connect to the refresh endpoint. "
|
|
391
|
+
"Please check your network connection and the server status."
|
|
392
|
+
) from e
|
|
393
|
+
|
|
394
|
+
# Extract the new JWT token from the response
|
|
395
|
+
new_jwt_token = getattr(response, "token", None)
|
|
396
|
+
|
|
397
|
+
if not new_jwt_token:
|
|
398
|
+
logger.error(
|
|
399
|
+
f"No token received from refresh response: {response}"
|
|
400
|
+
)
|
|
401
|
+
raise ValueError(
|
|
402
|
+
"The refresh response did not contain a valid JWT token."
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
# Save the new JWT token
|
|
406
|
+
self.save(auth_token=new_jwt_token)
|
|
407
|
+
logger.info("Successfully refreshed JWT token.")
|
|
408
|
+
|
|
409
|
+
return self.auth_header
|
|
410
|
+
|
|
411
|
+
def create_api_client(self, override: Optional[AuthOverride] = None) -> 'ApiClient':
|
|
412
|
+
"""Create an API client with optional authentication override.
|
|
413
|
+
|
|
414
|
+
This is a convenience method for creating API clients that use a specific
|
|
415
|
+
authentication method instead of the default automatic selection.
|
|
416
|
+
|
|
417
|
+
Parameters
|
|
418
|
+
----------
|
|
419
|
+
override : AuthOverride, optional
|
|
420
|
+
Override the default authentication method for this API client.
|
|
421
|
+
See get_auth_header() for available options.
|
|
422
|
+
|
|
423
|
+
Returns
|
|
424
|
+
-------
|
|
425
|
+
ApiClient
|
|
426
|
+
Configured API client with the specified authentication method.
|
|
427
|
+
"""
|
|
428
|
+
from lightning_sdk.lightning_cloud.openapi import ApiClient, Configuration
|
|
429
|
+
|
|
430
|
+
config = Configuration()
|
|
431
|
+
config.host = env.LIGHTNING_CLOUD_URL
|
|
432
|
+
|
|
433
|
+
# Get the auth header for the specified override
|
|
434
|
+
auth_header = self.get_auth_header(override)
|
|
435
|
+
|
|
436
|
+
# Create the API client
|
|
437
|
+
client = ApiClient(configuration=config)
|
|
438
|
+
|
|
439
|
+
# Set the Authorization header directly in default_headers
|
|
440
|
+
if auth_header:
|
|
441
|
+
client.set_default_header('Authorization', auth_header)
|
|
442
|
+
|
|
443
|
+
return client
|
|
444
|
+
|
|
195
445
|
|
|
196
446
|
class AuthServer:
|
|
197
447
|
|