lightning-sdk 0.2.23__py3-none-any.whl → 0.2.24rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lightning_sdk/__init__.py +1 -1
- lightning_sdk/api/base_studio_api.py +9 -2
- lightning_sdk/api/deployment_api.py +9 -9
- lightning_sdk/api/llm_api.py +5 -11
- lightning_sdk/api/pipeline_api.py +31 -11
- lightning_sdk/api/studio_api.py +4 -0
- lightning_sdk/base_studio.py +22 -6
- lightning_sdk/deployment/deployment.py +17 -7
- lightning_sdk/lightning_cloud/openapi/__init__.py +12 -0
- lightning_sdk/lightning_cloud/openapi/api/__init__.py +1 -0
- lightning_sdk/lightning_cloud/openapi/api/assistants_service_api.py +114 -1
- lightning_sdk/lightning_cloud/openapi/api/cloudy_service_api.py +129 -0
- lightning_sdk/lightning_cloud/openapi/api/organizations_service_api.py +105 -0
- lightning_sdk/lightning_cloud/openapi/api/pipelines_service_api.py +4 -4
- lightning_sdk/lightning_cloud/openapi/api/user_service_api.py +105 -0
- lightning_sdk/lightning_cloud/openapi/models/__init__.py +11 -0
- lightning_sdk/lightning_cloud/openapi/models/agents_id_body.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/assistant_id_conversations_body.py +53 -1
- lightning_sdk/lightning_cloud/openapi/models/credits_autoreplenish_body.py +175 -0
- lightning_sdk/lightning_cloud/openapi/models/credits_autoreplenish_body1.py +175 -0
- lightning_sdk/lightning_cloud/openapi/models/externalv1_user_status.py +79 -1
- lightning_sdk/lightning_cloud/openapi/models/orgs_id_body.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/pipelines_id_body1.py +123 -0
- lightning_sdk/lightning_cloud/openapi/models/project_id_agents_body.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/project_id_cloudspaces_body.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/update.py +29 -3
- lightning_sdk/lightning_cloud/openapi/models/v1_assistant.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_cloud_provider.py +1 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space_environment_config.py +29 -3
- lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space_environment_template_config.py +29 -3
- lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space_environment_type.py +1 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space_specialized_view.py +104 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_cloudy_expert.py +279 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_cluster_accelerator.py +79 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_cluster_capacity_reservation.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_cluster_spec.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_cluster_status.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_conversation_response_chunk.py +29 -3
- lightning_sdk/lightning_cloud/openapi/models/v1_create_cloud_space_environment_template_request.py +29 -3
- lightning_sdk/lightning_cloud/openapi/models/v1_create_organization_request.py +79 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_deployment_status.py +47 -21
- lightning_sdk/lightning_cloud/openapi/models/v1_external_cluster_spec.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_get_job_stats_response.py +53 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_instance_overprovisioning_spec.py +1 -27
- lightning_sdk/lightning_cloud/openapi/models/v1_kubernetes_direct_v1.py +123 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_kubernetes_direct_v1_status.py +149 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_list_cloudy_experts_response.py +123 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_login_request.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_magic_link_login_request.py +29 -3
- lightning_sdk/lightning_cloud/openapi/models/v1_magic_link_login_response.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_organization.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_token_usage.py +175 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_update_organization_credits_auto_replenish_response.py +97 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_update_user_credits_auto_replenish_response.py +97 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_update_user_request.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_user_features.py +185 -29
- lightning_sdk/llm/llm.py +113 -115
- lightning_sdk/llm/public_assistants.json +8 -0
- lightning_sdk/pipeline/pipeline.py +17 -2
- lightning_sdk/pipeline/printer.py +11 -10
- lightning_sdk/pipeline/steps.py +4 -1
- lightning_sdk/pipeline/utils.py +29 -4
- lightning_sdk/studio.py +3 -0
- {lightning_sdk-0.2.23.dist-info → lightning_sdk-0.2.24rc0.dist-info}/METADATA +1 -1
- {lightning_sdk-0.2.23.dist-info → lightning_sdk-0.2.24rc0.dist-info}/RECORD +70 -57
- {lightning_sdk-0.2.23.dist-info → lightning_sdk-0.2.24rc0.dist-info}/LICENSE +0 -0
- {lightning_sdk-0.2.23.dist-info → lightning_sdk-0.2.24rc0.dist-info}/WHEEL +0 -0
- {lightning_sdk-0.2.23.dist-info → lightning_sdk-0.2.24rc0.dist-info}/entry_points.txt +0 -0
- {lightning_sdk-0.2.23.dist-info → lightning_sdk-0.2.24rc0.dist-info}/top_level.txt +0 -0
lightning_sdk/llm/llm.py
CHANGED
|
@@ -1,19 +1,37 @@
|
|
|
1
|
+
import json
|
|
1
2
|
import os
|
|
2
|
-
import
|
|
3
|
-
from typing import AsyncGenerator, Dict, Generator, List, Optional, Set, Tuple, Union
|
|
3
|
+
from typing import AsyncGenerator, ClassVar, Dict, Generator, List, Optional, Tuple, Union
|
|
4
4
|
|
|
5
5
|
from lightning_sdk.api.llm_api import LLMApi
|
|
6
|
-
from lightning_sdk.cli.teamspace_menu import _TeamspacesMenu
|
|
7
|
-
from lightning_sdk.lightning_cloud.openapi import V1Assistant
|
|
8
6
|
from lightning_sdk.lightning_cloud.openapi.models.v1_conversation_response_chunk import V1ConversationResponseChunk
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
7
|
+
|
|
8
|
+
PUBLIC_MODEL_PROVIDERS: Dict[str, str] = {
|
|
9
|
+
"openai": "OpenAI",
|
|
10
|
+
"anthropic": "Anthropic",
|
|
11
|
+
"google": "Google",
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _load_public_assistants() -> Dict[str, str]:
|
|
16
|
+
"""Load public assistants from a JSON file."""
|
|
17
|
+
try:
|
|
18
|
+
json_path = os.path.join(os.path.dirname(__file__), "public_assistants.json")
|
|
19
|
+
with open(json_path) as f:
|
|
20
|
+
return json.load(f)
|
|
21
|
+
except Exception as e:
|
|
22
|
+
print(f"[warning] Failed to load public_assistants.json: {e}")
|
|
23
|
+
return {}
|
|
14
24
|
|
|
15
25
|
|
|
16
26
|
class LLM:
|
|
27
|
+
_auth_info_cached: ClassVar[bool] = False
|
|
28
|
+
_cached_auth_info: ClassVar[Dict[str, Optional[str]]] = {}
|
|
29
|
+
_llm_api_cache: ClassVar[Dict[Optional[str], LLMApi]] = {}
|
|
30
|
+
_public_assistants: ClassVar[Optional[Dict[str, str]]] = None
|
|
31
|
+
|
|
32
|
+
def __new__(cls, name: str, teamspace: Optional[str] = None, enable_async: Optional[bool] = False) -> "LLM":
|
|
33
|
+
return super().__new__(cls)
|
|
34
|
+
|
|
17
35
|
def __init__(
|
|
18
36
|
self,
|
|
19
37
|
name: str,
|
|
@@ -36,67 +54,18 @@ class LLM:
|
|
|
36
54
|
Raises:
|
|
37
55
|
ValueError: If teamspace information cannot be resolved.
|
|
38
56
|
"""
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
possible_teamspaces = menu._get_possible_teamspaces(user)
|
|
42
|
-
if teamspace is None:
|
|
43
|
-
# get current teamspace
|
|
44
|
-
self._teamspace = _resolve_teamspace(teamspace=None, org=None, user=None)
|
|
45
|
-
else:
|
|
46
|
-
self._teamspace = Teamspace(**menu._get_teamspace_from_name(teamspace, possible_teamspaces))
|
|
47
|
-
|
|
48
|
-
if self._teamspace is None:
|
|
49
|
-
# select the first available teamspace
|
|
50
|
-
first_teamspace = next(iter(possible_teamspaces.values()), None)
|
|
51
|
-
|
|
52
|
-
if first_teamspace:
|
|
53
|
-
self._teamspace = Teamspace(
|
|
54
|
-
name=first_teamspace["name"],
|
|
55
|
-
org=first_teamspace["org"],
|
|
56
|
-
user=first_teamspace["user"],
|
|
57
|
-
)
|
|
58
|
-
warnings.warn(
|
|
59
|
-
f"No teamspace given. Using teamspace: {self._teamspace.name}.",
|
|
60
|
-
UserWarning,
|
|
61
|
-
stacklevel=2,
|
|
62
|
-
)
|
|
63
|
-
|
|
64
|
-
if self._teamspace is None:
|
|
65
|
-
raise ValueError("Teamspace is required for billing but could not be resolved. ")
|
|
66
|
-
|
|
67
|
-
self._user = user
|
|
57
|
+
# TODO support user input teamspace
|
|
58
|
+
self._get_auth_info()
|
|
68
59
|
|
|
69
60
|
self._model_provider, self._model_name = self._parse_model_name(name)
|
|
70
|
-
|
|
71
|
-
self._llm_api = LLMApi()
|
|
72
61
|
self._enable_async = enable_async
|
|
73
62
|
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
63
|
+
# Reuse LLMApi per teamspace (as billing is based on teamspace)
|
|
64
|
+
if teamspace not in LLM._llm_api_cache:
|
|
65
|
+
LLM._llm_api_cache[teamspace] = LLMApi()
|
|
66
|
+
self._llm_api = LLM._llm_api_cache[teamspace]
|
|
77
67
|
|
|
78
|
-
|
|
79
|
-
# check if user has access to the org
|
|
80
|
-
self._org_models = self._build_model_lookup(self._get_org_models())
|
|
81
|
-
except ApiException:
|
|
82
|
-
warnings.warn(
|
|
83
|
-
f"User is not authenticated to access the model in organization: '{self._model_provider}'.\n"
|
|
84
|
-
" Proceeding with appropriate org models, user models, or public models.",
|
|
85
|
-
UserWarning,
|
|
86
|
-
stacklevel=2,
|
|
87
|
-
)
|
|
88
|
-
self._model_provider = None
|
|
89
|
-
raise
|
|
90
|
-
except ApiException:
|
|
91
|
-
if isinstance(self._teamspace.owner, Organization):
|
|
92
|
-
self._org = self._teamspace.owner
|
|
93
|
-
else:
|
|
94
|
-
self._org = None
|
|
95
|
-
self._org_models = self._build_model_lookup(self._get_org_models())
|
|
96
|
-
|
|
97
|
-
self._public_models = self._build_model_lookup(self._get_public_models())
|
|
98
|
-
self._user_models = self._build_model_lookup(self._get_user_models())
|
|
99
|
-
self._model = self._get_model()
|
|
68
|
+
self._model_id = self._get_model_id()
|
|
100
69
|
self._conversations = {}
|
|
101
70
|
|
|
102
71
|
@property
|
|
@@ -107,9 +76,33 @@ class LLM:
|
|
|
107
76
|
def provider(self) -> str:
|
|
108
77
|
return self._model_provider
|
|
109
78
|
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
79
|
+
def _get_auth_info(self) -> None:
|
|
80
|
+
if not LLM._auth_info_cached:
|
|
81
|
+
teamspace_name = os.environ.get("LIGHTNING_TEAMSPACE", None)
|
|
82
|
+
if teamspace_name is None:
|
|
83
|
+
raise ValueError(
|
|
84
|
+
"Teamspace name must be provided either through "
|
|
85
|
+
"the environment variable LIGHTNING_TEAMSPACE or as an argument."
|
|
86
|
+
)
|
|
87
|
+
LLM._cached_auth_info = {
|
|
88
|
+
"teamspace_name": teamspace_name,
|
|
89
|
+
"teamspace_id": os.environ.get("LIGHTNING_CLOUD_PROJECT_ID", None),
|
|
90
|
+
"user_name": os.environ.get("LIGHTNING_USERNAME", ""),
|
|
91
|
+
"user_id": os.environ.get("LIGHTNING_USER_ID", None),
|
|
92
|
+
"org_name": os.environ.get("LIGHTNING_ORG", ""),
|
|
93
|
+
"cloud_url": os.environ.get("LIGHTNING_CLOUD_URL", None),
|
|
94
|
+
}
|
|
95
|
+
LLM._auth_info_cached = True
|
|
96
|
+
if LLM._public_assistants is None:
|
|
97
|
+
LLM._public_assistants = _load_public_assistants()
|
|
98
|
+
# Always assign to the current instance
|
|
99
|
+
self._teamspace_name = LLM._cached_auth_info["teamspace_name"]
|
|
100
|
+
self._teamspace_id = LLM._cached_auth_info["teamspace_id"]
|
|
101
|
+
self._user_name = LLM._cached_auth_info["user_name"]
|
|
102
|
+
self._user_id = LLM._cached_auth_info["user_id"]
|
|
103
|
+
self._org_name = LLM._cached_auth_info["org_name"]
|
|
104
|
+
self._cloud_url = LLM._cached_auth_info["cloud_url"]
|
|
105
|
+
self._org = None
|
|
113
106
|
|
|
114
107
|
def _parse_model_name(self, name: str) -> Tuple[str, str]:
|
|
115
108
|
parts = name.split("/")
|
|
@@ -117,50 +110,60 @@ class LLM:
|
|
|
117
110
|
# a user model or a org model
|
|
118
111
|
return None, parts[0]
|
|
119
112
|
if len(parts) == 2:
|
|
120
|
-
return parts[0], parts[1]
|
|
113
|
+
return parts[0].lower(), parts[1]
|
|
121
114
|
raise ValueError(
|
|
122
115
|
f"Model name must be in the format `organization/model_name` or `model_name`, but got '{name}'."
|
|
123
116
|
)
|
|
124
117
|
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
return self._user_models.get(self._model_name)[0]
|
|
148
|
-
|
|
149
|
-
available_models = []
|
|
150
|
-
if self._public_models:
|
|
151
|
-
available_models.append(f"Public Models: {', '.join(self._public_models.keys())}")
|
|
152
|
-
|
|
153
|
-
if self._org and self._org_models:
|
|
154
|
-
available_models.append(f"Org ({self._org.name}) Models: {', '.join(self._org_models.keys())}")
|
|
118
|
+
# returns the assistant ID
|
|
119
|
+
def _get_model_id(self) -> str:
|
|
120
|
+
if self._model_provider in PUBLIC_MODEL_PROVIDERS:
|
|
121
|
+
# if prod
|
|
122
|
+
if (
|
|
123
|
+
self._cloud_url == "https://lightning.ai"
|
|
124
|
+
and LLM._public_assistants
|
|
125
|
+
and f"{self._model_provider}/{self._model_name}" in LLM._public_assistants
|
|
126
|
+
):
|
|
127
|
+
return LLM._public_assistants[f"{self._model_provider}/{self._model_name}"]
|
|
128
|
+
try:
|
|
129
|
+
return self._llm_api.get_assistant(
|
|
130
|
+
model_provider=PUBLIC_MODEL_PROVIDERS[self._model_provider],
|
|
131
|
+
model_name=self._model_name,
|
|
132
|
+
user_name="",
|
|
133
|
+
org_name="",
|
|
134
|
+
)
|
|
135
|
+
except Exception as e:
|
|
136
|
+
raise ValueError(
|
|
137
|
+
f"Public model '{self._model_provider}/{self._model_name}' not found. "
|
|
138
|
+
"Please check the model name or provider."
|
|
139
|
+
) from e
|
|
155
140
|
|
|
156
|
-
|
|
157
|
-
|
|
141
|
+
# Try organization model
|
|
142
|
+
try:
|
|
143
|
+
return self._llm_api.get_assistant(
|
|
144
|
+
model_provider="",
|
|
145
|
+
model_name=self._model_name,
|
|
146
|
+
user_name="",
|
|
147
|
+
org_name=self._model_provider,
|
|
148
|
+
)
|
|
149
|
+
except Exception:
|
|
150
|
+
pass
|
|
158
151
|
|
|
159
|
-
|
|
160
|
-
|
|
152
|
+
# Try user model
|
|
153
|
+
try:
|
|
154
|
+
return self._llm_api.get_assistant(
|
|
155
|
+
model_provider="",
|
|
156
|
+
model_name=self._model_name,
|
|
157
|
+
user_name=self._model_provider,
|
|
158
|
+
org_name="",
|
|
159
|
+
)
|
|
160
|
+
except Exception as user_error:
|
|
161
|
+
raise ValueError(
|
|
162
|
+
f"Model '{self._model_provider}/{self._model_name}' not found as either an org or user model.\n"
|
|
163
|
+
) from user_error
|
|
161
164
|
|
|
162
165
|
def _get_conversations(self) -> None:
|
|
163
|
-
conversations = self._llm_api.list_conversations(assistant_id=self.
|
|
166
|
+
conversations = self._llm_api.list_conversations(assistant_id=self._model_id)
|
|
164
167
|
for conversation in conversations:
|
|
165
168
|
if conversation.name and conversation.name not in self._conversations:
|
|
166
169
|
self._conversations[conversation.name] = conversation.id
|
|
@@ -191,7 +194,6 @@ class LLM:
|
|
|
191
194
|
conversation: Optional[str] = None,
|
|
192
195
|
metadata: Optional[Dict[str, str]] = None,
|
|
193
196
|
stream: bool = False,
|
|
194
|
-
upload_local_images: bool = False,
|
|
195
197
|
) -> Union[str, AsyncGenerator[str, None]]:
|
|
196
198
|
conversation_id = self._conversations.get(conversation) if conversation else None
|
|
197
199
|
output = await self._llm_api.async_start_conversation(
|
|
@@ -199,9 +201,9 @@ class LLM:
|
|
|
199
201
|
system_prompt=system_prompt,
|
|
200
202
|
max_completion_tokens=max_completion_tokens,
|
|
201
203
|
images=images,
|
|
202
|
-
assistant_id=self.
|
|
204
|
+
assistant_id=self._model_id,
|
|
203
205
|
conversation_id=conversation_id,
|
|
204
|
-
billing_project_id=self.
|
|
206
|
+
billing_project_id=self._teamspace_id,
|
|
205
207
|
metadata=metadata,
|
|
206
208
|
name=conversation,
|
|
207
209
|
stream=stream,
|
|
@@ -221,7 +223,6 @@ class LLM:
|
|
|
221
223
|
conversation: Optional[str] = None,
|
|
222
224
|
metadata: Optional[Dict[str, str]] = None,
|
|
223
225
|
stream: bool = False,
|
|
224
|
-
upload_local_images: bool = False,
|
|
225
226
|
) -> Union[str, Generator[str, None, None]]:
|
|
226
227
|
if conversation and conversation not in self._conversations:
|
|
227
228
|
self._get_conversations()
|
|
@@ -232,8 +233,6 @@ class LLM:
|
|
|
232
233
|
for image in images:
|
|
233
234
|
if not isinstance(image, str):
|
|
234
235
|
raise NotImplementedError(f"Image type {type(image)} are not supported yet.")
|
|
235
|
-
if not image.startswith("http") and upload_local_images:
|
|
236
|
-
self._teamspace.upload_file(file_path=image, remote_path=f"images/{os.path.basename(image)}")
|
|
237
236
|
|
|
238
237
|
conversation_id = self._conversations.get(conversation) if conversation else None
|
|
239
238
|
|
|
@@ -246,7 +245,6 @@ class LLM:
|
|
|
246
245
|
conversation,
|
|
247
246
|
metadata,
|
|
248
247
|
stream,
|
|
249
|
-
upload_local_images,
|
|
250
248
|
)
|
|
251
249
|
|
|
252
250
|
output = self._llm_api.start_conversation(
|
|
@@ -254,9 +252,9 @@ class LLM:
|
|
|
254
252
|
system_prompt=system_prompt,
|
|
255
253
|
max_completion_tokens=max_completion_tokens,
|
|
256
254
|
images=images,
|
|
257
|
-
assistant_id=self.
|
|
255
|
+
assistant_id=self._model_id,
|
|
258
256
|
conversation_id=conversation_id,
|
|
259
|
-
billing_project_id=self.
|
|
257
|
+
billing_project_id=self._teamspace_id,
|
|
260
258
|
metadata=metadata,
|
|
261
259
|
name=conversation,
|
|
262
260
|
stream=stream,
|
|
@@ -272,7 +270,7 @@ class LLM:
|
|
|
272
270
|
return list(self._conversations.keys())
|
|
273
271
|
|
|
274
272
|
def _get_conversation_messages(self, conversation_id: str) -> Optional[str]:
|
|
275
|
-
return self._llm_api.get_conversation(assistant_id=self.
|
|
273
|
+
return self._llm_api.get_conversation(assistant_id=self._model_id, conversation_id=conversation_id)
|
|
276
274
|
|
|
277
275
|
def get_history(self, conversation: str) -> Optional[List[Dict]]:
|
|
278
276
|
if conversation not in self._conversations:
|
|
@@ -297,7 +295,7 @@ class LLM:
|
|
|
297
295
|
self._get_conversations()
|
|
298
296
|
if conversation in self._conversations:
|
|
299
297
|
self._llm_api.reset_conversation(
|
|
300
|
-
assistant_id=self.
|
|
298
|
+
assistant_id=self._model_id,
|
|
301
299
|
conversation_id=self._conversations[conversation],
|
|
302
300
|
)
|
|
303
301
|
del self._conversations[conversation]
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
{
|
|
2
|
+
"openai/gpt-4o": "ast_01jdjds71fs8gt47jexzed4czs",
|
|
3
|
+
"openai/gpt-4": "ast_01jd38ze6tjbrcd4942nhz41zn",
|
|
4
|
+
"openai/o3-mini": "ast_01jz3t13fhnjhh11t1k8b5gyp1",
|
|
5
|
+
"anthropic/claude-3-5-sonnet-20240620": "ast_01jd3923a6p98rqwh3dpj686pq",
|
|
6
|
+
"google/gemini-2.5-pro": "ast_01jz3tdb1fhey798k95pv61v57",
|
|
7
|
+
"google/gemini-2.5-flash": "ast_01jz3thxskg4fcdk4xhkjkym5a"
|
|
8
|
+
}
|
|
@@ -86,6 +86,12 @@ class Pipeline:
|
|
|
86
86
|
if len(steps) == 0:
|
|
87
87
|
raise ValueError("The provided steps is empty")
|
|
88
88
|
|
|
89
|
+
provided_cloud_account = None
|
|
90
|
+
if self._cloud_account:
|
|
91
|
+
provided_cloud_account = self._cloud_account
|
|
92
|
+
elif self._default_cluster:
|
|
93
|
+
provided_cloud_account = self._default_cluster.cluster_id
|
|
94
|
+
|
|
89
95
|
for step_idx, pipeline_step in enumerate(steps):
|
|
90
96
|
if pipeline_step.name in [None, ""]:
|
|
91
97
|
pipeline_step.name = f"step-{step_idx}"
|
|
@@ -98,20 +104,29 @@ class Pipeline:
|
|
|
98
104
|
pipeline_step.cloud_account = self._studio.cloud_account
|
|
99
105
|
pipeline_step.studio = self._studio
|
|
100
106
|
|
|
107
|
+
if not pipeline_step.cloud_account and isinstance(provided_cloud_account, str):
|
|
108
|
+
pipeline_step.cloud_account = provided_cloud_account
|
|
109
|
+
|
|
101
110
|
cluster_ids = set(step.cloud_account for step in steps if step.cloud_account not in ["", None]) # noqa: C401
|
|
102
111
|
|
|
103
|
-
cloud_account =
|
|
112
|
+
cloud_account = (
|
|
113
|
+
list(cluster_ids)[0] if len(cluster_ids) == 1 and self._cloud_account is None else "" # noqa: RUF015
|
|
114
|
+
)
|
|
104
115
|
|
|
105
116
|
steps = [step.to_proto(self._teamspace, cloud_account, self._shared_filesystem) for step in steps]
|
|
106
117
|
|
|
107
118
|
proto_steps = prepare_steps(steps)
|
|
108
119
|
schedules = schedules or []
|
|
109
120
|
|
|
121
|
+
for schedule_idx, schedule in enumerate(schedules):
|
|
122
|
+
if schedule.name is None:
|
|
123
|
+
schedule.name = f"schedule-{schedule_idx}"
|
|
124
|
+
|
|
110
125
|
parent_pipeline_id = None if self._pipeline is None else self._pipeline.id
|
|
111
126
|
|
|
112
127
|
self._pipeline = self._pipeline_api.create_pipeline(
|
|
113
128
|
self._name,
|
|
114
|
-
self._teamspace
|
|
129
|
+
self._teamspace,
|
|
115
130
|
proto_steps,
|
|
116
131
|
self._shared_filesystem,
|
|
117
132
|
schedules,
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
import os
|
|
2
2
|
from typing import Any, ClassVar, Dict, List
|
|
3
3
|
|
|
4
|
-
from lightning_sdk.lightning_cloud.openapi.models import
|
|
4
|
+
from lightning_sdk.lightning_cloud.openapi.models import V1Pipeline, V1PipelineStepType
|
|
5
|
+
from lightning_sdk.pipeline.utils import _get_spec
|
|
5
6
|
|
|
6
7
|
|
|
7
8
|
class PipelinePrinter:
|
|
@@ -31,7 +32,7 @@ class PipelinePrinter:
|
|
|
31
32
|
self._schedules = schedules
|
|
32
33
|
cluster_ids: set[str] = set()
|
|
33
34
|
for step in self._proto_steps:
|
|
34
|
-
job_spec =
|
|
35
|
+
job_spec = _get_spec(step)
|
|
35
36
|
cluster_ids.add(job_spec.cluster_id)
|
|
36
37
|
self._cluster_ids = cluster_ids
|
|
37
38
|
|
|
@@ -107,15 +108,15 @@ class PipelinePrinter:
|
|
|
107
108
|
|
|
108
109
|
if self._shared_filesystem.enabled and len(self._cluster_ids) == 1:
|
|
109
110
|
shared_path = ""
|
|
111
|
+
cluster_id = list(self._cluster_ids)[0] # noqa: RUF015
|
|
110
112
|
if self._pipeline.shared_filesystem.s3_folder:
|
|
111
|
-
cluster_id = list(self._cluster_ids)[0] # noqa: RUF015
|
|
112
113
|
shared_path = f"/teamspace/s3_folders/pipelines-{cluster_id}"
|
|
114
|
+
if self._pipeline.shared_filesystem.gcs_folder:
|
|
115
|
+
shared_path = f"/teamspace/gcs_folders/pipelines-{cluster_id}"
|
|
116
|
+
if self._pipeline.shared_filesystem.efs:
|
|
117
|
+
shared_path = f"/teamspace/efs_connections/pipelines-{cluster_id}"
|
|
118
|
+
if self._pipeline.shared_filesystem.filestore:
|
|
119
|
+
shared_path = f"/teamspace/gcs_connections/pipelines-{cluster_id}"
|
|
120
|
+
|
|
113
121
|
if shared_path:
|
|
114
122
|
self._print(f" - {shared_path}")
|
|
115
|
-
|
|
116
|
-
def _get_spec(self, step: Any) -> V1JobSpec:
|
|
117
|
-
if step.type == V1PipelineStepType.DEPLOYMENT:
|
|
118
|
-
return step.deployment.spec
|
|
119
|
-
if step.type == V1PipelineStepType.MMT:
|
|
120
|
-
return step.mmt.spec
|
|
121
|
-
return step.job.spec
|
lightning_sdk/pipeline/steps.py
CHANGED
|
@@ -69,12 +69,15 @@ class DeploymentStep:
|
|
|
69
69
|
|
|
70
70
|
self.machine = machine or Machine.CPU
|
|
71
71
|
self.image = image
|
|
72
|
+
autoscaling_metric_name = (
|
|
73
|
+
("CPU" if self.machine.is_cpu() else "GPU") if isinstance(self.machine, Machine) else "CPU"
|
|
74
|
+
)
|
|
72
75
|
self.autoscale = autoscale or AutoScaleConfig(
|
|
73
76
|
min_replicas=0,
|
|
74
77
|
max_replicas=1,
|
|
75
78
|
target_metrics=[
|
|
76
79
|
AutoScalingMetric(
|
|
77
|
-
name=
|
|
80
|
+
name=autoscaling_metric_name,
|
|
78
81
|
target=80,
|
|
79
82
|
)
|
|
80
83
|
],
|
lightning_sdk/pipeline/utils.py
CHANGED
|
@@ -1,6 +1,11 @@
|
|
|
1
|
-
from typing import List, Literal, Optional, Union
|
|
2
|
-
|
|
3
|
-
from lightning_sdk.lightning_cloud.openapi.models import
|
|
1
|
+
from typing import Any, List, Literal, Optional, Union
|
|
2
|
+
|
|
3
|
+
from lightning_sdk.lightning_cloud.openapi.models import (
|
|
4
|
+
V1JobSpec,
|
|
5
|
+
V1PipelineStep,
|
|
6
|
+
V1PipelineStepType,
|
|
7
|
+
V1SharedFilesystem,
|
|
8
|
+
)
|
|
4
9
|
from lightning_sdk.studio import Studio
|
|
5
10
|
|
|
6
11
|
DEFAULT = "DEFAULT"
|
|
@@ -21,7 +26,7 @@ def prepare_steps(steps: List["V1PipelineStep"]) -> List["V1PipelineStep"]:
|
|
|
21
26
|
else:
|
|
22
27
|
raise ValueError(f"A step with the name {current_step.name} already exists.")
|
|
23
28
|
|
|
24
|
-
if steps[0].wait_for
|
|
29
|
+
if steps[0].wait_for not in [None, DEFAULT, []]:
|
|
25
30
|
raise ValueError("The first step isn't allowed to receive `wait_for=...`.")
|
|
26
31
|
|
|
27
32
|
steps[0].wait_for = []
|
|
@@ -89,3 +94,23 @@ def _to_wait_for(wait_for: Optional[Union[str, List[str]]]) -> Optional[Union[Li
|
|
|
89
94
|
return []
|
|
90
95
|
|
|
91
96
|
return wait_for if isinstance(wait_for, list) else [wait_for]
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def _get_cloud_account(steps: List[V1PipelineStep]) -> Optional[str]:
|
|
100
|
+
if len(steps) == 0:
|
|
101
|
+
return None
|
|
102
|
+
|
|
103
|
+
cluster_ids: set[str] = set()
|
|
104
|
+
for step in steps:
|
|
105
|
+
job_spec = _get_spec(step)
|
|
106
|
+
cluster_ids.add(job_spec.cluster_id)
|
|
107
|
+
|
|
108
|
+
return sorted(cluster_ids)[0]
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _get_spec(step: Any) -> V1JobSpec:
|
|
112
|
+
if step.type == V1PipelineStepType.DEPLOYMENT:
|
|
113
|
+
return step.deployment.spec
|
|
114
|
+
if step.type == V1PipelineStepType.MMT:
|
|
115
|
+
return step.mmt.spec
|
|
116
|
+
return step.job.spec
|
lightning_sdk/studio.py
CHANGED
|
@@ -88,6 +88,9 @@ class Studio:
|
|
|
88
88
|
|
|
89
89
|
self._plugins = {}
|
|
90
90
|
|
|
91
|
+
if self._teamspace is None:
|
|
92
|
+
raise ValueError("Couldn't resolve teamspace from the provided name, org, or user")
|
|
93
|
+
|
|
91
94
|
if provider is not None:
|
|
92
95
|
if isinstance(provider, str) and provider in Provider.__members__:
|
|
93
96
|
provider = Provider(provider)
|