lightning-sdk 2025.7.17__py3-none-any.whl → 2025.7.30rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lightning_sdk/__init__.py +3 -2
- lightning_sdk/api/cloud_account_api.py +204 -0
- lightning_sdk/api/deployment_api.py +11 -0
- lightning_sdk/api/job_api.py +82 -10
- lightning_sdk/api/llm_api.py +1 -1
- lightning_sdk/api/mmt_api.py +44 -5
- lightning_sdk/api/pipeline_api.py +4 -3
- lightning_sdk/api/studio_api.py +51 -8
- lightning_sdk/api/utils.py +6 -2
- lightning_sdk/cli/clusters_menu.py +3 -3
- lightning_sdk/cli/create.py +25 -11
- lightning_sdk/cli/deploy/_auth.py +19 -3
- lightning_sdk/cli/deploy/serve.py +21 -5
- lightning_sdk/cli/download.py +25 -1
- lightning_sdk/cli/entrypoint.py +4 -2
- lightning_sdk/cli/list.py +5 -1
- lightning_sdk/cli/run.py +3 -1
- lightning_sdk/cli/start.py +40 -8
- lightning_sdk/cli/switch.py +3 -1
- lightning_sdk/deployment/deployment.py +8 -0
- lightning_sdk/job/base.py +27 -3
- lightning_sdk/job/job.py +28 -4
- lightning_sdk/job/v1.py +10 -1
- lightning_sdk/job/v2.py +22 -2
- lightning_sdk/job/work.py +5 -1
- lightning_sdk/lightning_cloud/openapi/__init__.py +14 -1
- lightning_sdk/lightning_cloud/openapi/api/assistants_service_api.py +428 -0
- lightning_sdk/lightning_cloud/openapi/api/billing_service_api.py +153 -48
- lightning_sdk/lightning_cloud/openapi/api/cloudy_service_api.py +295 -0
- lightning_sdk/lightning_cloud/openapi/api/cluster_service_api.py +93 -0
- lightning_sdk/lightning_cloud/openapi/models/__init__.py +14 -1
- lightning_sdk/lightning_cloud/openapi/models/agentmanagedendpoints_id_body.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/blogposts_id_body.py +53 -1
- lightning_sdk/lightning_cloud/openapi/models/conversations_id_body1.py +123 -0
- lightning_sdk/lightning_cloud/openapi/models/messages_id_body.py +123 -0
- lightning_sdk/lightning_cloud/openapi/models/metricsstream_id_body.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/project_id_schedules_body.py +81 -3
- lightning_sdk/lightning_cloud/openapi/models/schedules_id_body.py +79 -1
- lightning_sdk/lightning_cloud/openapi/models/user_id_upgradetrigger_body.py +201 -0
- lightning_sdk/lightning_cloud/openapi/models/user_user_id_body.py +201 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_billing_subscription.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_blog_post.py +53 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_cloudy_settings.py +227 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_cluster_spec.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_conversation.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_conversation_response_chunk.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_create_billing_upgrade_trigger_record_response.py +97 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_create_blog_post_request.py +53 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_create_checkout_session_request.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_create_subscription_checkout_session_request.py +55 -3
- lightning_sdk/lightning_cloud/openapi/models/v1_function_call.py +149 -0
- lightning_sdk/lightning_cloud/openapi/models/{v1_get_clickhouse_assistant_session_daily_aggregated_response.py → v1_get_assistant_session_daily_aggregated_response.py} +22 -22
- lightning_sdk/lightning_cloud/openapi/models/v1_get_cluster_health_response.py +149 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_get_user_response.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_job_spec.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_kubernetes_direct_v1.py +105 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_like_status.py +104 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_list_published_managed_endpoints_response.py +123 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_managed_endpoint.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_managed_model.py +95 -17
- lightning_sdk/lightning_cloud/openapi/models/v1_message.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_quote_subscription_response.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_resource_visibility.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_response_choice.py +29 -3
- lightning_sdk/lightning_cloud/openapi/models/v1_schedule.py +79 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_service_health.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_slurm_v1.py +79 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_slurm_v1_status.py +79 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_tool_call.py +175 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_update_conversation_like_response.py +149 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_update_conversation_message_like_response.py +149 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_user_features.py +79 -313
- lightning_sdk/lightning_cloud/openapi/models/v1_volume_state.py +1 -0
- lightning_sdk/llm/llm.py +69 -11
- lightning_sdk/llm/public_assistants.json +32 -8
- lightning_sdk/machine.py +151 -43
- lightning_sdk/mmt/base.py +20 -2
- lightning_sdk/mmt/mmt.py +25 -3
- lightning_sdk/mmt/v1.py +7 -1
- lightning_sdk/mmt/v2.py +27 -3
- lightning_sdk/models.py +1 -1
- lightning_sdk/organization.py +4 -0
- lightning_sdk/pipeline/pipeline.py +16 -5
- lightning_sdk/pipeline/printer.py +5 -3
- lightning_sdk/pipeline/schedule.py +844 -1
- lightning_sdk/pipeline/steps.py +19 -4
- lightning_sdk/sandbox.py +4 -1
- lightning_sdk/serve.py +2 -0
- lightning_sdk/studio.py +91 -44
- lightning_sdk/teamspace.py +19 -10
- lightning_sdk/utils/resolve.py +37 -2
- {lightning_sdk-2025.7.17.dist-info → lightning_sdk-2025.7.30rc0.dist-info}/METADATA +7 -5
- {lightning_sdk-2025.7.17.dist-info → lightning_sdk-2025.7.30rc0.dist-info}/RECORD +98 -85
- lightning_sdk/api/cluster_api.py +0 -119
- /lightning_sdk/cli/{inspect.py → inspection.py} +0 -0
- {lightning_sdk-2025.7.17.dist-info → lightning_sdk-2025.7.30rc0.dist-info}/LICENSE +0 -0
- {lightning_sdk-2025.7.17.dist-info → lightning_sdk-2025.7.30rc0.dist-info}/WHEEL +0 -0
- {lightning_sdk-2025.7.17.dist-info → lightning_sdk-2025.7.30rc0.dist-info}/entry_points.txt +0 -0
- {lightning_sdk-2025.7.17.dist-info → lightning_sdk-2025.7.30rc0.dist-info}/top_level.txt +0 -0
lightning_sdk/llm/llm.py
CHANGED
|
@@ -2,6 +2,7 @@ import json
|
|
|
2
2
|
import os
|
|
3
3
|
from typing import Any, AsyncGenerator, ClassVar, Dict, Generator, List, Optional, Tuple, Union
|
|
4
4
|
|
|
5
|
+
from lightning_sdk.api import TeamspaceApi, UserApi
|
|
5
6
|
from lightning_sdk.api.llm_api import LLMApi
|
|
6
7
|
from lightning_sdk.lightning_cloud.openapi.models.v1_conversation_response_chunk import V1ConversationResponseChunk
|
|
7
8
|
|
|
@@ -13,7 +14,7 @@ PUBLIC_MODEL_PROVIDERS: Dict[str, str] = {
|
|
|
13
14
|
}
|
|
14
15
|
|
|
15
16
|
|
|
16
|
-
def _load_public_assistants() -> Dict[str, str]:
|
|
17
|
+
def _load_public_assistants() -> Dict[str, Dict[str, Any]]:
|
|
17
18
|
"""Load public assistants from a JSON file."""
|
|
18
19
|
try:
|
|
19
20
|
json_path = os.path.join(os.path.dirname(__file__), "public_assistants.json")
|
|
@@ -28,7 +29,7 @@ class LLM:
|
|
|
28
29
|
_auth_info_cached: ClassVar[bool] = False
|
|
29
30
|
_cached_auth_info: ClassVar[Dict[str, Optional[str]]] = {}
|
|
30
31
|
_llm_api_cache: ClassVar[Dict[Optional[str], LLMApi]] = {}
|
|
31
|
-
_public_assistants: ClassVar[Optional[Dict[str, str]]] = None
|
|
32
|
+
_public_assistants: ClassVar[Optional[Dict[str, Dict[str, Any]]]] = None
|
|
32
33
|
|
|
33
34
|
def __new__(cls, name: str, teamspace: Optional[str] = None, enable_async: Optional[bool] = False) -> "LLM":
|
|
34
35
|
return super().__new__(cls)
|
|
@@ -55,8 +56,18 @@ class LLM:
|
|
|
55
56
|
Raises:
|
|
56
57
|
ValueError: If teamspace information cannot be resolved.
|
|
57
58
|
"""
|
|
58
|
-
|
|
59
|
-
|
|
59
|
+
teamspace_name = None
|
|
60
|
+
if teamspace:
|
|
61
|
+
try:
|
|
62
|
+
owner, teamspace_name = teamspace.split("/", maxsplit=1)
|
|
63
|
+
except ValueError as e:
|
|
64
|
+
raise ValueError(
|
|
65
|
+
f"Invalid teamspace format: '{teamspace}'. "
|
|
66
|
+
"Teamspace should be specified as '{teamspace_owner}/{teamspace_name}' "
|
|
67
|
+
"(e.g., 'my-org/my-teamspace')."
|
|
68
|
+
) from e
|
|
69
|
+
|
|
70
|
+
self._get_auth_info(teamspace_name)
|
|
60
71
|
|
|
61
72
|
self._model_provider, self._model_name = self._parse_model_name(name)
|
|
62
73
|
self._enable_async = enable_async
|
|
@@ -66,6 +77,7 @@ class LLM:
|
|
|
66
77
|
LLM._llm_api_cache[teamspace] = LLMApi()
|
|
67
78
|
self._llm_api = LLM._llm_api_cache[teamspace]
|
|
68
79
|
|
|
80
|
+
self._context_length = None
|
|
69
81
|
self._model_id = self._get_model_id()
|
|
70
82
|
self._conversations = {}
|
|
71
83
|
|
|
@@ -77,14 +89,45 @@ class LLM:
|
|
|
77
89
|
def provider(self) -> str:
|
|
78
90
|
return self._model_provider
|
|
79
91
|
|
|
80
|
-
def
|
|
92
|
+
def context_length(self, model: Optional[str] = None) -> Optional[int]:
|
|
93
|
+
if model is None:
|
|
94
|
+
return self._context_length
|
|
95
|
+
|
|
96
|
+
context_info = self._public_assistants.get(model)
|
|
97
|
+
if context_info is None or "context_length" not in context_info:
|
|
98
|
+
raise ValueError(f"Cannot access context length of model '{model}'.")
|
|
99
|
+
|
|
100
|
+
return int(context_info["context_length"])
|
|
101
|
+
|
|
102
|
+
def _get_auth_info(self, teamspace_name: Optional[str] = None) -> None:
|
|
103
|
+
# TODO: Validate user input teamspace name
|
|
81
104
|
if not LLM._auth_info_cached:
|
|
82
|
-
teamspace_name = os.environ.get("LIGHTNING_TEAMSPACE", None)
|
|
83
105
|
if teamspace_name is None:
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
106
|
+
# studio users
|
|
107
|
+
teamspace_name = os.environ.get("LIGHTNING_TEAMSPACE", None)
|
|
108
|
+
|
|
109
|
+
if teamspace_name is None:
|
|
110
|
+
# local users with no given teamspace
|
|
111
|
+
try:
|
|
112
|
+
teamspace_api = TeamspaceApi()
|
|
113
|
+
user_api = UserApi()
|
|
114
|
+
authed_user = user_api._client.auth_service_get_user()
|
|
115
|
+
default_teamspace = teamspace_api.list_teamspaces(owner_id=authed_user.id)[0]
|
|
116
|
+
teamspace_name = default_teamspace.name
|
|
117
|
+
teamspace_id = default_teamspace.id
|
|
118
|
+
os.environ["LIGHTNING_CLOUD_PROJECT_ID"] = teamspace_id
|
|
119
|
+
os.environ["LIGHTNING_TEAMSPACE"] = teamspace_name
|
|
120
|
+
except Exception as err:
|
|
121
|
+
# throw an appropriate error that guides users to login through the platform
|
|
122
|
+
raise ValueError(
|
|
123
|
+
"Teamspace information is missing. "
|
|
124
|
+
"If this is your first time using LitAI, please log in at https://lightning.ai/sign-up "
|
|
125
|
+
"and re-run your script, or set the environment variable LIGHTNING_TEAMSPACE=<your-teamspace>."
|
|
126
|
+
) from err
|
|
127
|
+
|
|
128
|
+
# TODO: when teamspace_name is given, we don't know the teamspace_id yet
|
|
129
|
+
# TODO: if LIGHTNING_CLOUD_PROJECT_ID does not exist, we have to get the id from the teamspace name
|
|
130
|
+
|
|
88
131
|
LLM._cached_auth_info = {
|
|
89
132
|
"teamspace_name": teamspace_name,
|
|
90
133
|
"teamspace_id": os.environ.get("LIGHTNING_CLOUD_PROJECT_ID", None),
|
|
@@ -125,7 +168,10 @@ class LLM:
|
|
|
125
168
|
and LLM._public_assistants
|
|
126
169
|
and f"{self._model_provider}/{self._model_name}" in LLM._public_assistants
|
|
127
170
|
):
|
|
128
|
-
|
|
171
|
+
self._context_length = int(
|
|
172
|
+
LLM._public_assistants[f"{self._model_provider}/{self._model_name}"]["context_length"]
|
|
173
|
+
)
|
|
174
|
+
return LLM._public_assistants[f"{self._model_provider}/{self._model_name}"]["id"]
|
|
129
175
|
try:
|
|
130
176
|
return self._llm_api.get_assistant(
|
|
131
177
|
model_provider=PUBLIC_MODEL_PROVIDERS[self._model_provider],
|
|
@@ -139,6 +185,18 @@ class LLM:
|
|
|
139
185
|
"Please check the model name or provider."
|
|
140
186
|
) from e
|
|
141
187
|
|
|
188
|
+
if self._model_provider == "lightning-ai":
|
|
189
|
+
# Try model provider model
|
|
190
|
+
try:
|
|
191
|
+
return self._llm_api.get_assistant(
|
|
192
|
+
model_provider=self._model_provider,
|
|
193
|
+
model_name=self._model_name,
|
|
194
|
+
user_name="",
|
|
195
|
+
org_name="",
|
|
196
|
+
)
|
|
197
|
+
except Exception:
|
|
198
|
+
pass
|
|
199
|
+
|
|
142
200
|
# Try organization model
|
|
143
201
|
try:
|
|
144
202
|
return self._llm_api.get_assistant(
|
|
@@ -1,10 +1,34 @@
|
|
|
1
1
|
{
|
|
2
|
-
"openai/gpt-4o":
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
"
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
2
|
+
"openai/gpt-4o": {
|
|
3
|
+
"id": "ast_01jdjds71fs8gt47jexzed4czs",
|
|
4
|
+
"context_length": 128000
|
|
5
|
+
},
|
|
6
|
+
"openai/gpt-4": {
|
|
7
|
+
"id": "ast_01jd38ze6tjbrcd4942nhz41zn",
|
|
8
|
+
"context_length": 8192
|
|
9
|
+
},
|
|
10
|
+
"openai/o3-mini": {
|
|
11
|
+
"id": "ast_01jz3t13fhnjhh11t1k8b5gyp1",
|
|
12
|
+
"context_length": 128000
|
|
13
|
+
},
|
|
14
|
+
"anthropic/claude-3-5-sonnet-20240620": {
|
|
15
|
+
"id": "ast_01jd3923a6p98rqwh3dpj686pq",
|
|
16
|
+
"context_length": 200000
|
|
17
|
+
},
|
|
18
|
+
"google/gemini-2.5-pro": {
|
|
19
|
+
"id": "ast_01jz3tdb1fhey798k95pv61v57",
|
|
20
|
+
"context_length": 1048576
|
|
21
|
+
},
|
|
22
|
+
"google/gemini-2.5-flash": {
|
|
23
|
+
"id": "ast_01jz3thxskg4fcdk4xhkjkym5a",
|
|
24
|
+
"context_length": 8000
|
|
25
|
+
},
|
|
26
|
+
"google/gemini-2.5-flash-lite-preview-06-17": {
|
|
27
|
+
"id": "ast_01jz3thxskg4fcdk4xhkjkym5b",
|
|
28
|
+
"context_length": 8000
|
|
29
|
+
},
|
|
30
|
+
"lightning-ai/llama4-maverick": {
|
|
31
|
+
"id": "ast_01k0wgg56tm8mv9n12aq2mnxas",
|
|
32
|
+
"context_length": 100000
|
|
33
|
+
}
|
|
10
34
|
}
|
lightning_sdk/machine.py
CHANGED
|
@@ -1,60 +1,113 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
|
+
from enum import Enum
|
|
2
3
|
from typing import Any, ClassVar, Optional, Tuple
|
|
3
4
|
|
|
4
5
|
|
|
6
|
+
class CloudProvider(Enum):
|
|
7
|
+
AWS = "AWS"
|
|
8
|
+
GCP = "GCP"
|
|
9
|
+
VULTR = "VULTR"
|
|
10
|
+
LAMBDA_LABS = "LAMBDA_LABS"
|
|
11
|
+
DGX = "DGX"
|
|
12
|
+
VOLTAGE_PARK = "VOLTAGE_PARK"
|
|
13
|
+
NEBIUS = "NEBIUS"
|
|
14
|
+
LIGHTNING = "LIGHTNING"
|
|
15
|
+
|
|
16
|
+
def __str__(self) -> str:
|
|
17
|
+
"""Converts the CloudProvider to a str."""
|
|
18
|
+
return self.value
|
|
19
|
+
|
|
20
|
+
|
|
5
21
|
@dataclass(frozen=True)
|
|
6
22
|
class Machine:
|
|
7
|
-
#
|
|
8
|
-
|
|
23
|
+
# supported CPU variations
|
|
24
|
+
CPU_X_2: ClassVar["Machine"]
|
|
25
|
+
CPU_X_4: ClassVar["Machine"]
|
|
26
|
+
CPU_X_8: ClassVar["Machine"]
|
|
27
|
+
CPU_X_16: ClassVar["Machine"]
|
|
28
|
+
# default CPU machines
|
|
9
29
|
CPU_SMALL: ClassVar["Machine"]
|
|
30
|
+
CPU: ClassVar["Machine"]
|
|
31
|
+
# supported data-prep variations (big disk)
|
|
10
32
|
DATA_PREP: ClassVar["Machine"]
|
|
11
33
|
DATA_PREP_MAX: ClassVar["Machine"]
|
|
12
34
|
DATA_PREP_ULTRA: ClassVar["Machine"]
|
|
35
|
+
|
|
36
|
+
# supported GPU types
|
|
37
|
+
# supported T4 variations
|
|
13
38
|
T4: ClassVar["Machine"]
|
|
39
|
+
T4_X_2: ClassVar["Machine"]
|
|
14
40
|
T4_X_4: ClassVar["Machine"]
|
|
41
|
+
T4_X_8: ClassVar["Machine"]
|
|
42
|
+
# supported L4 variations
|
|
15
43
|
L4: ClassVar["Machine"]
|
|
16
44
|
L4_X_2: ClassVar["Machine"]
|
|
17
45
|
L4_X_4: ClassVar["Machine"]
|
|
18
46
|
L4_X_8: ClassVar["Machine"]
|
|
19
|
-
|
|
20
|
-
A10G_X_4: ClassVar["Machine"]
|
|
21
|
-
A10G_X_8: ClassVar["Machine"]
|
|
47
|
+
# supported L40S variations
|
|
22
48
|
L40S: ClassVar["Machine"]
|
|
49
|
+
L40S_X_2: ClassVar["Machine"]
|
|
23
50
|
L40S_X_4: ClassVar["Machine"]
|
|
24
51
|
L40S_X_8: ClassVar["Machine"]
|
|
52
|
+
# supported A100 variations
|
|
53
|
+
# defaults, can be either A100 type depending on cloud provider availability
|
|
54
|
+
A100: ClassVar["Machine"]
|
|
25
55
|
A100_X_2: ClassVar["Machine"]
|
|
26
56
|
A100_X_4: ClassVar["Machine"]
|
|
27
57
|
A100_X_8: ClassVar["Machine"]
|
|
28
|
-
|
|
58
|
+
# A100 40GB versions
|
|
59
|
+
A100_40GB: ClassVar["Machine"]
|
|
60
|
+
A100_40GB_X_2: ClassVar["Machine"]
|
|
61
|
+
A100_40GB_X_4: ClassVar["Machine"]
|
|
62
|
+
A100_40GB_X_8: ClassVar["Machine"]
|
|
63
|
+
# A100 80GB versions
|
|
64
|
+
A100_80GB: ClassVar["Machine"]
|
|
65
|
+
A100_80GB_X_2: ClassVar["Machine"]
|
|
66
|
+
A100_80GB_X_4: ClassVar["Machine"]
|
|
67
|
+
A100_80GB_X_8: ClassVar["Machine"]
|
|
68
|
+
|
|
69
|
+
H100: ClassVar["Machine"]
|
|
70
|
+
H100_X_2: ClassVar["Machine"]
|
|
71
|
+
H100_X_4: ClassVar["Machine"]
|
|
29
72
|
H100_X_8: ClassVar["Machine"]
|
|
73
|
+
|
|
74
|
+
H200: ClassVar["Machine"]
|
|
30
75
|
H200_X_8: ClassVar["Machine"]
|
|
76
|
+
B200_X_8: ClassVar["Machine"]
|
|
77
|
+
|
|
78
|
+
# Specialized Machines
|
|
31
79
|
|
|
32
80
|
name: str
|
|
33
|
-
|
|
81
|
+
slug: str
|
|
82
|
+
instance_type: Optional[str] = None
|
|
83
|
+
family: Optional[str] = None
|
|
84
|
+
accelerator_count: Optional[int] = None
|
|
34
85
|
cost: Optional[float] = None
|
|
35
86
|
interruptible_cost: Optional[float] = None
|
|
36
87
|
wait_time: Optional[float] = None
|
|
37
88
|
interruptible_wait_time: Optional[float] = None
|
|
89
|
+
_include_in_cli: bool = True
|
|
38
90
|
|
|
39
91
|
def __str__(self) -> str:
|
|
40
92
|
"""String representation of the Machine."""
|
|
41
|
-
return str(self.name) if self.name else str(self.instance_type)
|
|
93
|
+
return str(self.name) if self.name else (self.slug if self.slug else str(self.instance_type))
|
|
42
94
|
|
|
43
95
|
def __eq__(self, other: object) -> bool:
|
|
44
96
|
"""Machines are equal if the instance type is equal."""
|
|
45
97
|
if isinstance(other, Machine):
|
|
46
|
-
return
|
|
98
|
+
return (
|
|
99
|
+
# equality based on raw instance type (provider specific)
|
|
100
|
+
(self.instance_type and self.instance_type == other.instance_type)
|
|
101
|
+
# equality based on slug (provider agnostic)
|
|
102
|
+
or self.slug == other.slug
|
|
103
|
+
# equality based on machine specs (e.g. A100_80GB_X_8 == A100_X_8)
|
|
104
|
+
or (self.family == other.family and self.accelerator_count == other.accelerator_count)
|
|
105
|
+
)
|
|
47
106
|
return False
|
|
48
107
|
|
|
49
108
|
def is_cpu(self) -> bool:
|
|
50
109
|
"""Whether the machine is a CPU."""
|
|
51
|
-
return (
|
|
52
|
-
self == Machine.CPU
|
|
53
|
-
or self == Machine.CPU_SMALL
|
|
54
|
-
or self == Machine.DATA_PREP
|
|
55
|
-
or self == Machine.DATA_PREP_MAX
|
|
56
|
-
or self == Machine.DATA_PREP_ULTRA
|
|
57
|
-
)
|
|
110
|
+
return self.family in ("CPU", "DATA_PREP")
|
|
58
111
|
|
|
59
112
|
@classmethod
|
|
60
113
|
def from_str(cls, machine: str, *additional_machine_ids: Any) -> "Machine":
|
|
@@ -63,34 +116,89 @@ class Machine:
|
|
|
63
116
|
)
|
|
64
117
|
for m in possible_values:
|
|
65
118
|
for machine_id in [machine, *additional_machine_ids]:
|
|
66
|
-
if machine_id in (
|
|
119
|
+
if machine_id in (
|
|
120
|
+
getattr(m, "name", None),
|
|
121
|
+
getattr(m, "instance_type", None),
|
|
122
|
+
getattr(m, "slug", None),
|
|
123
|
+
):
|
|
67
124
|
return m
|
|
68
125
|
|
|
69
126
|
if additional_machine_ids:
|
|
70
127
|
return cls(machine, *additional_machine_ids)
|
|
71
|
-
return cls(machine, machine)
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
Machine.
|
|
77
|
-
Machine.
|
|
78
|
-
|
|
79
|
-
Machine.
|
|
80
|
-
Machine.
|
|
81
|
-
Machine.
|
|
82
|
-
Machine.
|
|
83
|
-
|
|
84
|
-
Machine.
|
|
85
|
-
Machine.
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
Machine.
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
Machine.
|
|
95
|
-
Machine.
|
|
96
|
-
Machine.
|
|
128
|
+
return cls(machine, machine, machine)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
# CPU machines
|
|
132
|
+
# default CPU machines
|
|
133
|
+
Machine.CPU_SMALL = Machine(name="CPU_SMALL", slug="cpu-2", family="CPU", accelerator_count=2)
|
|
134
|
+
Machine.CPU = Machine(name="CPU", slug="cpu-4", family="CPU", accelerator_count=4)
|
|
135
|
+
# available CPU variations
|
|
136
|
+
Machine.CPU_X_2 = Machine(name="CPU_X_2", slug="cpu-2", family="CPU", accelerator_count=2)
|
|
137
|
+
Machine.CPU_X_4 = Machine(name="CPU_X_4", slug="cpu-4", family="CPU", accelerator_count=4)
|
|
138
|
+
Machine.CPU_X_8 = Machine(name="CPU_X_8", slug="cpu-8", family="CPU", accelerator_count=8)
|
|
139
|
+
Machine.CPU_X_16 = Machine(name="CPU_X_16", slug="cpu-16", family="CPU", accelerator_count=16)
|
|
140
|
+
# available data-prep (big disk) machines
|
|
141
|
+
Machine.DATA_PREP = Machine(name="DATA_PREP", slug="data-prep-mid", family="DATA_PREP", accelerator_count=32)
|
|
142
|
+
Machine.DATA_PREP_MAX = Machine(
|
|
143
|
+
name="DATA_PREP_MAX", slug="data-prep-max-large", family="DATA_PREP", accelerator_count=64
|
|
144
|
+
)
|
|
145
|
+
Machine.DATA_PREP_ULTRA = Machine(
|
|
146
|
+
name="DATA_PREP_ULTRA", slug="data-prep-ultra-extra-large", family="DATA_PREP", accelerator_count=96
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
# GPU machines
|
|
150
|
+
# available T4 machines
|
|
151
|
+
Machine.T4 = Machine(name="T4", slug="lit-t4-1", family="T4", accelerator_count=1)
|
|
152
|
+
Machine.T4_X_2 = Machine(name="T4_X_2", slug="lit-t4-2", family="T4", accelerator_count=2)
|
|
153
|
+
Machine.T4_X_4 = Machine(name="T4_X_4", slug="lit-t4-4", family="T4", accelerator_count=4)
|
|
154
|
+
Machine.T4_X_8 = Machine(name="T4_X_8", slug="lit-t4-8", family="T4", accelerator_count=8)
|
|
155
|
+
# available L4 machines
|
|
156
|
+
Machine.L4 = Machine(name="L4", slug="lit-l4-1", family="L4", accelerator_count=1)
|
|
157
|
+
Machine.L4_X_2 = Machine(name="L4_X_2", slug="lit-l4-2", family="L4", accelerator_count=2)
|
|
158
|
+
Machine.L4_X_4 = Machine(name="L4_X_4", slug="lit-l4-4", family="L4", accelerator_count=4)
|
|
159
|
+
Machine.L4_X_8 = Machine(name="L4_X_8", slug="lit-l4-8", family="L4", accelerator_count=8)
|
|
160
|
+
# available L40S machines
|
|
161
|
+
Machine.L40S = Machine(name="L40S", slug="lit-l40s-1", family="L40S", accelerator_count=1)
|
|
162
|
+
Machine.L40S_X_2 = Machine(name="L40S_X_2", slug="lit-l40s-2", family="L40S", accelerator_count=2)
|
|
163
|
+
Machine.L40S_X_4 = Machine(name="L40S_X_4", slug="lit-l40s-4", family="L40S", accelerator_count=4)
|
|
164
|
+
Machine.L40S_X_8 = Machine(name="L40S_X_8", slug="lit-l40s-8", family="L40S", accelerator_count=8)
|
|
165
|
+
# available A100 Machines
|
|
166
|
+
Machine.A100 = Machine(name="A100", slug="lit-a100-1", family="A100", accelerator_count=1)
|
|
167
|
+
Machine.A100_X_2 = Machine(name="A100_X_2", slug="lit-a100-2", family="A100", accelerator_count=2)
|
|
168
|
+
Machine.A100_X_4 = Machine(name="A100_X_4", slug="lit-a100-4", family="A100", accelerator_count=4)
|
|
169
|
+
Machine.A100_X_8 = Machine(name="A100_X_8", slug="lit-a100-8", family="A100", accelerator_count=8)
|
|
170
|
+
# don't include variants in cli, only default types that can match for all variants
|
|
171
|
+
Machine.A100_40GB = Machine(
|
|
172
|
+
name="A100_40GB", slug="lit-a100-40gb-1", family="A100", accelerator_count=1, _include_in_cli=False
|
|
173
|
+
)
|
|
174
|
+
Machine.A100_40GB_X_2 = Machine(
|
|
175
|
+
name="A100_40GB_X_2", slug="lit-a100-40gb-2", family="A100", accelerator_count=2, _include_in_cli=False
|
|
176
|
+
)
|
|
177
|
+
Machine.A100_40GB_X_4 = Machine(
|
|
178
|
+
name="A100_40GB_X_4", slug="lit-a100-40gb-4", family="A100", accelerator_count=4, _include_in_cli=False
|
|
179
|
+
)
|
|
180
|
+
Machine.A100_40GB_X_8 = Machine(
|
|
181
|
+
name="A100_40GB_X_8", slug="lit-a100-40gb-8", family="A100", accelerator_count=8, _include_in_cli=False
|
|
182
|
+
)
|
|
183
|
+
Machine.A100_80GB = Machine(
|
|
184
|
+
name="A100_80GB", slug="lit-a100-80gb-1", family="A100", accelerator_count=1, _include_in_cli=False
|
|
185
|
+
)
|
|
186
|
+
Machine.A100_80GB_X_2 = Machine(
|
|
187
|
+
name="A100_80GB_X_2", slug="lit-a100-80gb-2", family="A100", accelerator_count=2, _include_in_cli=False
|
|
188
|
+
)
|
|
189
|
+
Machine.A100_80GB_X_4 = Machine(
|
|
190
|
+
name="A100_80GB_X_4", slug="lit-a100-80gb-4", family="A100", accelerator_count=4, _include_in_cli=False
|
|
191
|
+
)
|
|
192
|
+
Machine.A100_80GB_X_8 = Machine(
|
|
193
|
+
name="A100_80GB_X_8", slug="lit-a100-80gb-8", family="A100", accelerator_count=8, _include_in_cli=False
|
|
194
|
+
)
|
|
195
|
+
# available H100 machines
|
|
196
|
+
Machine.H100 = Machine(name="H100", slug="lit-h100-1", family="H100", accelerator_count=1)
|
|
197
|
+
Machine.H100_X_2 = Machine(name="H100_X_2", slug="lit-h100-2", family="H100", accelerator_count=2)
|
|
198
|
+
Machine.H100_X_4 = Machine(name="H100_X_4", slug="lit-h100-4", family="H100", accelerator_count=4)
|
|
199
|
+
Machine.H100_X_8 = Machine(name="H100_X_8", slug="lit-h100-8", family="H100", accelerator_count=8)
|
|
200
|
+
# available H200 machines
|
|
201
|
+
Machine.H200 = Machine(name="H200", slug="lit-h200x-1", family="H200", accelerator_count=1)
|
|
202
|
+
Machine.H200_X_8 = Machine(name="H200_X_8", slug="lit-h200x-8", family="H200", accelerator_count=8)
|
|
203
|
+
# available B200 machines
|
|
204
|
+
Machine.B200_X_8 = Machine(name="B200_X_8", slug="lit-b200x-8", family="B200", accelerator_count=8)
|
lightning_sdk/mmt/base.py
CHANGED
|
@@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Protocol, Tuple, Union
|
|
|
4
4
|
|
|
5
5
|
if TYPE_CHECKING:
|
|
6
6
|
from lightning_sdk.job.base import MachineDict
|
|
7
|
-
from lightning_sdk.machine import Machine
|
|
7
|
+
from lightning_sdk.machine import CloudProvider, Machine
|
|
8
8
|
from lightning_sdk.organization import Organization
|
|
9
9
|
from lightning_sdk.status import Status
|
|
10
10
|
from lightning_sdk.studio import Studio
|
|
@@ -64,12 +64,14 @@ class _BaseMMT(_BaseJob):
|
|
|
64
64
|
org: Union[str, "Organization", None] = None,
|
|
65
65
|
user: Union[str, "User", None] = None,
|
|
66
66
|
cloud_account: Optional[str] = None,
|
|
67
|
+
cloud_provider: Optional[Union["CloudProvider", str]] = None,
|
|
67
68
|
env: Optional[Dict[str, str]] = None,
|
|
68
69
|
interruptible: bool = False,
|
|
69
70
|
image_credentials: Optional[str] = None,
|
|
70
71
|
cloud_account_auth: bool = False,
|
|
71
72
|
entrypoint: str = "sh -c",
|
|
72
73
|
path_mappings: Optional[Dict[str, str]] = None,
|
|
74
|
+
max_runtime: Optional[int] = None,
|
|
73
75
|
artifacts_local: Optional[str] = None, # deprecated in favor of path_mappings
|
|
74
76
|
artifacts_remote: Optional[str] = None, # deprecated in favor of path_mappings
|
|
75
77
|
cluster: Optional[str] = None, # deprecated in favor of cloud_account
|
|
@@ -89,7 +91,11 @@ class _BaseMMT(_BaseJob):
|
|
|
89
91
|
user: The user owning the teamspace (if any). Defaults to the current user.
|
|
90
92
|
cloud_account: The cloud account to run the job on.
|
|
91
93
|
Defaults to the studio cloud account if running with studio compute env.
|
|
92
|
-
If not provided
|
|
94
|
+
If not provided and `cloud_account_provider` is set, will resolve cluster from this, else
|
|
95
|
+
will fall back to the teamspaces default cloud account.
|
|
96
|
+
cloud_account_provider: The provider to select the cloud-account from.
|
|
97
|
+
If set, must be in agreement with the provider from the cloud_account (if specified).
|
|
98
|
+
If not specified, falls backto the teamspace default cloud account.
|
|
93
99
|
env: Environment variables to set inside the job.
|
|
94
100
|
interruptible: Whether the job should run on interruptible instances. They are cheaper but can be preempted.
|
|
95
101
|
image_credentials: The credentials used to pull the image. Required if the image is private.
|
|
@@ -109,6 +115,10 @@ class _BaseMMT(_BaseJob):
|
|
|
109
115
|
}
|
|
110
116
|
If the path inside the connection is omitted it's assumed to be the root path of that connection.
|
|
111
117
|
Only applicable when submitting docker jobs.
|
|
118
|
+
max_runtime: the duration (in seconds) for which to allocate the machine.
|
|
119
|
+
Irrelevant for most machines, required for some of the top-end machines on GCP.
|
|
120
|
+
If in doubt, set it. Won't have an effect on machines not requiring it.
|
|
121
|
+
Defaults to 3h
|
|
112
122
|
"""
|
|
113
123
|
from lightning_sdk.lightning_cloud.openapi.rest import ApiException
|
|
114
124
|
from lightning_sdk.studio import Studio
|
|
@@ -191,6 +201,7 @@ class _BaseMMT(_BaseJob):
|
|
|
191
201
|
num_machines=num_machines,
|
|
192
202
|
machine=machine,
|
|
193
203
|
cloud_account=cloud_account,
|
|
204
|
+
cloud_provider=cloud_provider,
|
|
194
205
|
command=command,
|
|
195
206
|
studio=studio,
|
|
196
207
|
image=image,
|
|
@@ -202,6 +213,7 @@ class _BaseMMT(_BaseJob):
|
|
|
202
213
|
path_mappings=path_mappings,
|
|
203
214
|
artifacts_local=artifacts_local,
|
|
204
215
|
artifacts_remote=artifacts_remote,
|
|
216
|
+
max_runtime=max_runtime,
|
|
205
217
|
)
|
|
206
218
|
return inst
|
|
207
219
|
|
|
@@ -216,12 +228,14 @@ class _BaseMMT(_BaseJob):
|
|
|
216
228
|
env: Optional[Dict[str, str]] = None,
|
|
217
229
|
interruptible: bool = False,
|
|
218
230
|
cloud_account: Optional[str] = None,
|
|
231
|
+
cloud_provider: Optional[Union["CloudProvider", str]] = None,
|
|
219
232
|
image_credentials: Optional[str] = None,
|
|
220
233
|
cloud_account_auth: bool = False,
|
|
221
234
|
entrypoint: str = "sh -c",
|
|
222
235
|
path_mappings: Optional[Dict[str, str]] = None,
|
|
223
236
|
artifacts_local: Optional[str] = None, # deprecated in favor of path_mappings
|
|
224
237
|
artifacts_remote: Optional[str] = None, # deprecated in favor of path_mappings
|
|
238
|
+
max_runtime: Optional[int] = None,
|
|
225
239
|
) -> None:
|
|
226
240
|
"""Submit a new multi-machine job to the Lightning AI platform.
|
|
227
241
|
|
|
@@ -253,6 +267,10 @@ class _BaseMMT(_BaseJob):
|
|
|
253
267
|
}
|
|
254
268
|
If the path inside the connection is omitted it's assumed to be the root path of that connection.
|
|
255
269
|
Only applicable when submitting docker jobs.
|
|
270
|
+
max_runtime: the duration (in seconds) for which to allocate the machine.
|
|
271
|
+
Irrelevant for most machines, required for some of the top-end machines on GCP.
|
|
272
|
+
If in doubt, set it. Won't have an effect on machines not requiring it.
|
|
273
|
+
Defaults to 3h
|
|
256
274
|
"""
|
|
257
275
|
|
|
258
276
|
@property
|
lightning_sdk/mmt/mmt.py
CHANGED
|
@@ -1,12 +1,13 @@
|
|
|
1
1
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
|
2
2
|
|
|
3
|
+
from lightning_sdk.api.cloud_account_api import CloudAccountApi
|
|
3
4
|
from lightning_sdk.mmt.base import MMTMachine, _BaseMMT
|
|
4
5
|
from lightning_sdk.mmt.v1 import _MMTV1
|
|
5
6
|
from lightning_sdk.mmt.v2 import _MMTV2
|
|
6
7
|
from lightning_sdk.utils.resolve import _setup_logger
|
|
7
8
|
|
|
8
9
|
if TYPE_CHECKING:
|
|
9
|
-
from lightning_sdk.machine import Machine
|
|
10
|
+
from lightning_sdk.machine import CloudProvider, Machine
|
|
10
11
|
from lightning_sdk.organization import Organization
|
|
11
12
|
from lightning_sdk.status import Status
|
|
12
13
|
from lightning_sdk.studio import Studio
|
|
@@ -75,6 +76,7 @@ class MMT(_BaseMMT):
|
|
|
75
76
|
)
|
|
76
77
|
|
|
77
78
|
self._internal_mmt = mmt
|
|
79
|
+
self._cloud_account_api = CloudAccountApi()
|
|
78
80
|
|
|
79
81
|
@classmethod
|
|
80
82
|
def run(
|
|
@@ -89,12 +91,14 @@ class MMT(_BaseMMT):
|
|
|
89
91
|
org: Union[str, "Organization", None] = None,
|
|
90
92
|
user: Union[str, "User", None] = None,
|
|
91
93
|
cloud_account: Optional[str] = None,
|
|
94
|
+
cloud_provider: Optional[Union["CloudProvider", str]] = None,
|
|
92
95
|
env: Optional[Dict[str, str]] = None,
|
|
93
96
|
interruptible: bool = False,
|
|
94
97
|
image_credentials: Optional[str] = None,
|
|
95
98
|
cloud_account_auth: bool = False,
|
|
96
99
|
entrypoint: str = "sh -c",
|
|
97
100
|
path_mappings: Optional[Dict[str, str]] = None,
|
|
101
|
+
max_runtime: Optional[int] = None,
|
|
98
102
|
artifacts_local: Optional[str] = None,
|
|
99
103
|
artifacts_remote: Optional[str] = None,
|
|
100
104
|
cluster: Optional[str] = None, # deprecated in favor of cloud_account
|
|
@@ -114,7 +118,11 @@ class MMT(_BaseMMT):
|
|
|
114
118
|
user: The user owning the teamspace (if any). Defaults to the current user.
|
|
115
119
|
cloud_account: The cloud account to run the job on.
|
|
116
120
|
Defaults to the studio cloud account if running with studio compute env.
|
|
117
|
-
If not provided
|
|
121
|
+
If not provided and `cloud_account_provider` is set, will resolve cluster from this, else
|
|
122
|
+
will fall back to the teamspaces default cloud account.
|
|
123
|
+
cloud_account_provider: The provider to select the cloud-account from.
|
|
124
|
+
If set, must be in agreement with the provider from the cloud_account (if specified).
|
|
125
|
+
If not specified, falls backto the teamspace default cloud account.
|
|
118
126
|
env: Environment variables to set inside the job.
|
|
119
127
|
interruptible: Whether the job should run on interruptible instances. They are cheaper but can be preempted.
|
|
120
128
|
image_credentials: The credentials used to pull the image. Required if the image is private.
|
|
@@ -145,6 +153,7 @@ class MMT(_BaseMMT):
|
|
|
145
153
|
org=org,
|
|
146
154
|
user=user,
|
|
147
155
|
cloud_account=cloud_account,
|
|
156
|
+
cloud_provider=cloud_provider,
|
|
148
157
|
env=env,
|
|
149
158
|
interruptible=interruptible,
|
|
150
159
|
image_credentials=image_credentials,
|
|
@@ -154,6 +163,7 @@ class MMT(_BaseMMT):
|
|
|
154
163
|
artifacts_local=artifacts_local,
|
|
155
164
|
artifacts_remote=artifacts_remote,
|
|
156
165
|
cluster=cluster, # deprecated in favor of cloud_account
|
|
166
|
+
max_runtime=max_runtime,
|
|
157
167
|
)
|
|
158
168
|
# required for typing with "MMT"
|
|
159
169
|
assert isinstance(ret_val, cls)
|
|
@@ -173,10 +183,12 @@ class MMT(_BaseMMT):
|
|
|
173
183
|
env: Optional[Dict[str, str]] = None,
|
|
174
184
|
interruptible: bool = False,
|
|
175
185
|
cloud_account: Optional[str] = None,
|
|
186
|
+
cloud_provider: Optional[Union["CloudProvider", str]] = None,
|
|
176
187
|
image_credentials: Optional[str] = None,
|
|
177
188
|
cloud_account_auth: bool = False,
|
|
178
189
|
entrypoint: str = "sh -c",
|
|
179
190
|
path_mappings: Optional[Dict[str, str]] = None,
|
|
191
|
+
max_runtime: Optional[int] = None,
|
|
180
192
|
artifacts_local: Optional[str] = None, # deprecated in favor of path_mappings
|
|
181
193
|
artifacts_remote: Optional[str] = None, # deprecated in favor of path_mappings
|
|
182
194
|
) -> "MMT":
|
|
@@ -193,7 +205,11 @@ class MMT(_BaseMMT):
|
|
|
193
205
|
interruptible: Whether the job should run on interruptible instances. They are cheaper but can be preempted.
|
|
194
206
|
cloud_account: The cloud account to run the job on.
|
|
195
207
|
Defaults to the studio cloud account if running with studio compute env.
|
|
196
|
-
If not provided
|
|
208
|
+
If not provided and `cloud_account_provider` is set, will resolve cluster from this, else
|
|
209
|
+
will fall back to the teamspaces default cloud account.
|
|
210
|
+
cloud_account_provider: The provider to select the cloud-account from.
|
|
211
|
+
If set, must be in agreement with the provider from the cloud_account (if specified).
|
|
212
|
+
If not specified, falls backto the teamspace default cloud account.
|
|
197
213
|
image_credentials: The credentials used to pull the image. Required if the image is private.
|
|
198
214
|
This should be the name of the respective credentials secret created on the Lightning AI platform.
|
|
199
215
|
cloud_account_auth: Whether to authenticate with the cloud account to pull the image.
|
|
@@ -211,11 +227,16 @@ class MMT(_BaseMMT):
|
|
|
211
227
|
}
|
|
212
228
|
If the path inside the connection is omitted it's assumed to be the root path of that connection.
|
|
213
229
|
Only applicable when submitting docker jobs.
|
|
230
|
+
max_runtime: the duration (in seconds) for which to allocate the machine.
|
|
231
|
+
Irrelevant for most machines, required for some of the top-end machines on GCP.
|
|
232
|
+
If in doubt, set it. Won't have an effect on machines not requiring it.
|
|
233
|
+
Defaults to 3h
|
|
214
234
|
"""
|
|
215
235
|
self._job = self._internal_mmt._submit(
|
|
216
236
|
num_machines=num_machines,
|
|
217
237
|
machine=machine,
|
|
218
238
|
cloud_account=cloud_account,
|
|
239
|
+
cloud_provider=cloud_provider,
|
|
219
240
|
command=command,
|
|
220
241
|
studio=studio,
|
|
221
242
|
image=image,
|
|
@@ -227,6 +248,7 @@ class MMT(_BaseMMT):
|
|
|
227
248
|
path_mappings=path_mappings,
|
|
228
249
|
artifacts_local=artifacts_local,
|
|
229
250
|
artifacts_remote=artifacts_remote,
|
|
251
|
+
max_runtime=max_runtime,
|
|
230
252
|
)
|
|
231
253
|
return self
|
|
232
254
|
|