lightning-sdk 0.2.23__py3-none-any.whl → 0.2.24rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. lightning_sdk/__init__.py +1 -1
  2. lightning_sdk/api/base_studio_api.py +9 -2
  3. lightning_sdk/api/deployment_api.py +9 -9
  4. lightning_sdk/api/llm_api.py +5 -11
  5. lightning_sdk/api/pipeline_api.py +31 -11
  6. lightning_sdk/api/studio_api.py +4 -0
  7. lightning_sdk/base_studio.py +22 -6
  8. lightning_sdk/deployment/deployment.py +17 -7
  9. lightning_sdk/lightning_cloud/openapi/__init__.py +12 -0
  10. lightning_sdk/lightning_cloud/openapi/api/__init__.py +1 -0
  11. lightning_sdk/lightning_cloud/openapi/api/assistants_service_api.py +114 -1
  12. lightning_sdk/lightning_cloud/openapi/api/cloudy_service_api.py +129 -0
  13. lightning_sdk/lightning_cloud/openapi/api/organizations_service_api.py +105 -0
  14. lightning_sdk/lightning_cloud/openapi/api/pipelines_service_api.py +4 -4
  15. lightning_sdk/lightning_cloud/openapi/api/user_service_api.py +105 -0
  16. lightning_sdk/lightning_cloud/openapi/models/__init__.py +11 -0
  17. lightning_sdk/lightning_cloud/openapi/models/agents_id_body.py +27 -1
  18. lightning_sdk/lightning_cloud/openapi/models/assistant_id_conversations_body.py +53 -1
  19. lightning_sdk/lightning_cloud/openapi/models/credits_autoreplenish_body.py +175 -0
  20. lightning_sdk/lightning_cloud/openapi/models/credits_autoreplenish_body1.py +175 -0
  21. lightning_sdk/lightning_cloud/openapi/models/externalv1_user_status.py +79 -1
  22. lightning_sdk/lightning_cloud/openapi/models/orgs_id_body.py +27 -1
  23. lightning_sdk/lightning_cloud/openapi/models/pipelines_id_body1.py +123 -0
  24. lightning_sdk/lightning_cloud/openapi/models/project_id_agents_body.py +27 -1
  25. lightning_sdk/lightning_cloud/openapi/models/project_id_cloudspaces_body.py +27 -1
  26. lightning_sdk/lightning_cloud/openapi/models/update.py +29 -3
  27. lightning_sdk/lightning_cloud/openapi/models/v1_assistant.py +27 -1
  28. lightning_sdk/lightning_cloud/openapi/models/v1_cloud_provider.py +1 -0
  29. lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space.py +27 -1
  30. lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space_environment_config.py +29 -3
  31. lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space_environment_template_config.py +29 -3
  32. lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space_environment_type.py +1 -0
  33. lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space_specialized_view.py +104 -0
  34. lightning_sdk/lightning_cloud/openapi/models/v1_cloudy_expert.py +279 -0
  35. lightning_sdk/lightning_cloud/openapi/models/v1_cluster_accelerator.py +79 -1
  36. lightning_sdk/lightning_cloud/openapi/models/v1_cluster_capacity_reservation.py +27 -1
  37. lightning_sdk/lightning_cloud/openapi/models/v1_cluster_spec.py +27 -1
  38. lightning_sdk/lightning_cloud/openapi/models/v1_cluster_status.py +27 -1
  39. lightning_sdk/lightning_cloud/openapi/models/v1_conversation_response_chunk.py +29 -3
  40. lightning_sdk/lightning_cloud/openapi/models/v1_create_cloud_space_environment_template_request.py +29 -3
  41. lightning_sdk/lightning_cloud/openapi/models/v1_create_organization_request.py +79 -1
  42. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_status.py +47 -21
  43. lightning_sdk/lightning_cloud/openapi/models/v1_external_cluster_spec.py +27 -1
  44. lightning_sdk/lightning_cloud/openapi/models/v1_get_job_stats_response.py +53 -1
  45. lightning_sdk/lightning_cloud/openapi/models/v1_instance_overprovisioning_spec.py +1 -27
  46. lightning_sdk/lightning_cloud/openapi/models/v1_kubernetes_direct_v1.py +123 -0
  47. lightning_sdk/lightning_cloud/openapi/models/v1_kubernetes_direct_v1_status.py +149 -0
  48. lightning_sdk/lightning_cloud/openapi/models/v1_list_cloudy_experts_response.py +123 -0
  49. lightning_sdk/lightning_cloud/openapi/models/v1_login_request.py +27 -1
  50. lightning_sdk/lightning_cloud/openapi/models/v1_magic_link_login_request.py +29 -3
  51. lightning_sdk/lightning_cloud/openapi/models/v1_magic_link_login_response.py +27 -1
  52. lightning_sdk/lightning_cloud/openapi/models/v1_organization.py +27 -1
  53. lightning_sdk/lightning_cloud/openapi/models/v1_token_usage.py +175 -0
  54. lightning_sdk/lightning_cloud/openapi/models/v1_update_organization_credits_auto_replenish_response.py +97 -0
  55. lightning_sdk/lightning_cloud/openapi/models/v1_update_user_credits_auto_replenish_response.py +97 -0
  56. lightning_sdk/lightning_cloud/openapi/models/v1_update_user_request.py +27 -1
  57. lightning_sdk/lightning_cloud/openapi/models/v1_user_features.py +185 -29
  58. lightning_sdk/llm/llm.py +113 -115
  59. lightning_sdk/llm/public_assistants.json +8 -0
  60. lightning_sdk/pipeline/pipeline.py +17 -2
  61. lightning_sdk/pipeline/printer.py +11 -10
  62. lightning_sdk/pipeline/steps.py +4 -1
  63. lightning_sdk/pipeline/utils.py +29 -4
  64. lightning_sdk/studio.py +3 -0
  65. {lightning_sdk-0.2.23.dist-info → lightning_sdk-0.2.24rc0.dist-info}/METADATA +1 -1
  66. {lightning_sdk-0.2.23.dist-info → lightning_sdk-0.2.24rc0.dist-info}/RECORD +70 -57
  67. {lightning_sdk-0.2.23.dist-info → lightning_sdk-0.2.24rc0.dist-info}/LICENSE +0 -0
  68. {lightning_sdk-0.2.23.dist-info → lightning_sdk-0.2.24rc0.dist-info}/WHEEL +0 -0
  69. {lightning_sdk-0.2.23.dist-info → lightning_sdk-0.2.24rc0.dist-info}/entry_points.txt +0 -0
  70. {lightning_sdk-0.2.23.dist-info → lightning_sdk-0.2.24rc0.dist-info}/top_level.txt +0 -0
lightning_sdk/llm/llm.py CHANGED
@@ -1,19 +1,37 @@
1
+ import json
1
2
  import os
2
- import warnings
3
- from typing import AsyncGenerator, Dict, Generator, List, Optional, Set, Tuple, Union
3
+ from typing import AsyncGenerator, ClassVar, Dict, Generator, List, Optional, Tuple, Union
4
4
 
5
5
  from lightning_sdk.api.llm_api import LLMApi
6
- from lightning_sdk.cli.teamspace_menu import _TeamspacesMenu
7
- from lightning_sdk.lightning_cloud.openapi import V1Assistant
8
6
  from lightning_sdk.lightning_cloud.openapi.models.v1_conversation_response_chunk import V1ConversationResponseChunk
9
- from lightning_sdk.lightning_cloud.openapi.rest import ApiException
10
- from lightning_sdk.organization import Organization
11
- from lightning_sdk.owner import Owner
12
- from lightning_sdk.teamspace import Teamspace
13
- from lightning_sdk.utils.resolve import _get_authed_user, _resolve_org, _resolve_teamspace
7
+
8
+ PUBLIC_MODEL_PROVIDERS: Dict[str, str] = {
9
+ "openai": "OpenAI",
10
+ "anthropic": "Anthropic",
11
+ "google": "Google",
12
+ }
13
+
14
+
15
+ def _load_public_assistants() -> Dict[str, str]:
16
+ """Load public assistants from a JSON file."""
17
+ try:
18
+ json_path = os.path.join(os.path.dirname(__file__), "public_assistants.json")
19
+ with open(json_path) as f:
20
+ return json.load(f)
21
+ except Exception as e:
22
+ print(f"[warning] Failed to load public_assistants.json: {e}")
23
+ return {}
14
24
 
15
25
 
16
26
  class LLM:
27
+ _auth_info_cached: ClassVar[bool] = False
28
+ _cached_auth_info: ClassVar[Dict[str, Optional[str]]] = {}
29
+ _llm_api_cache: ClassVar[Dict[Optional[str], LLMApi]] = {}
30
+ _public_assistants: ClassVar[Optional[Dict[str, str]]] = None
31
+
32
+ def __new__(cls, name: str, teamspace: Optional[str] = None, enable_async: Optional[bool] = False) -> "LLM":
33
+ return super().__new__(cls)
34
+
17
35
  def __init__(
18
36
  self,
19
37
  name: str,
@@ -36,67 +54,18 @@ class LLM:
36
54
  Raises:
37
55
  ValueError: If teamspace information cannot be resolved.
38
56
  """
39
- menu = _TeamspacesMenu()
40
- user = _get_authed_user()
41
- possible_teamspaces = menu._get_possible_teamspaces(user)
42
- if teamspace is None:
43
- # get current teamspace
44
- self._teamspace = _resolve_teamspace(teamspace=None, org=None, user=None)
45
- else:
46
- self._teamspace = Teamspace(**menu._get_teamspace_from_name(teamspace, possible_teamspaces))
47
-
48
- if self._teamspace is None:
49
- # select the first available teamspace
50
- first_teamspace = next(iter(possible_teamspaces.values()), None)
51
-
52
- if first_teamspace:
53
- self._teamspace = Teamspace(
54
- name=first_teamspace["name"],
55
- org=first_teamspace["org"],
56
- user=first_teamspace["user"],
57
- )
58
- warnings.warn(
59
- f"No teamspace given. Using teamspace: {self._teamspace.name}.",
60
- UserWarning,
61
- stacklevel=2,
62
- )
63
-
64
- if self._teamspace is None:
65
- raise ValueError("Teamspace is required for billing but could not be resolved. ")
66
-
67
- self._user = user
57
+ # TODO support user input teamspace
58
+ self._get_auth_info()
68
59
 
69
60
  self._model_provider, self._model_name = self._parse_model_name(name)
70
-
71
- self._llm_api = LLMApi()
72
61
  self._enable_async = enable_async
73
62
 
74
- try:
75
- # check if it is a org model
76
- self._org = _resolve_org(self._model_provider)
63
+ # Reuse LLMApi per teamspace (as billing is based on teamspace)
64
+ if teamspace not in LLM._llm_api_cache:
65
+ LLM._llm_api_cache[teamspace] = LLMApi()
66
+ self._llm_api = LLM._llm_api_cache[teamspace]
77
67
 
78
- try:
79
- # check if user has access to the org
80
- self._org_models = self._build_model_lookup(self._get_org_models())
81
- except ApiException:
82
- warnings.warn(
83
- f"User is not authenticated to access the model in organization: '{self._model_provider}'.\n"
84
- " Proceeding with appropriate org models, user models, or public models.",
85
- UserWarning,
86
- stacklevel=2,
87
- )
88
- self._model_provider = None
89
- raise
90
- except ApiException:
91
- if isinstance(self._teamspace.owner, Organization):
92
- self._org = self._teamspace.owner
93
- else:
94
- self._org = None
95
- self._org_models = self._build_model_lookup(self._get_org_models())
96
-
97
- self._public_models = self._build_model_lookup(self._get_public_models())
98
- self._user_models = self._build_model_lookup(self._get_user_models())
99
- self._model = self._get_model()
68
+ self._model_id = self._get_model_id()
100
69
  self._conversations = {}
101
70
 
102
71
  @property
@@ -107,9 +76,33 @@ class LLM:
107
76
  def provider(self) -> str:
108
77
  return self._model_provider
109
78
 
110
- @property
111
- def owner(self) -> Optional[Owner]:
112
- return self._teamspace.owner
79
+ def _get_auth_info(self) -> None:
80
+ if not LLM._auth_info_cached:
81
+ teamspace_name = os.environ.get("LIGHTNING_TEAMSPACE", None)
82
+ if teamspace_name is None:
83
+ raise ValueError(
84
+ "Teamspace name must be provided either through "
85
+ "the environment variable LIGHTNING_TEAMSPACE or as an argument."
86
+ )
87
+ LLM._cached_auth_info = {
88
+ "teamspace_name": teamspace_name,
89
+ "teamspace_id": os.environ.get("LIGHTNING_CLOUD_PROJECT_ID", None),
90
+ "user_name": os.environ.get("LIGHTNING_USERNAME", ""),
91
+ "user_id": os.environ.get("LIGHTNING_USER_ID", None),
92
+ "org_name": os.environ.get("LIGHTNING_ORG", ""),
93
+ "cloud_url": os.environ.get("LIGHTNING_CLOUD_URL", None),
94
+ }
95
+ LLM._auth_info_cached = True
96
+ if LLM._public_assistants is None:
97
+ LLM._public_assistants = _load_public_assistants()
98
+ # Always assign to the current instance
99
+ self._teamspace_name = LLM._cached_auth_info["teamspace_name"]
100
+ self._teamspace_id = LLM._cached_auth_info["teamspace_id"]
101
+ self._user_name = LLM._cached_auth_info["user_name"]
102
+ self._user_id = LLM._cached_auth_info["user_id"]
103
+ self._org_name = LLM._cached_auth_info["org_name"]
104
+ self._cloud_url = LLM._cached_auth_info["cloud_url"]
105
+ self._org = None
113
106
 
114
107
  def _parse_model_name(self, name: str) -> Tuple[str, str]:
115
108
  parts = name.split("/")
@@ -117,50 +110,60 @@ class LLM:
117
110
  # a user model or a org model
118
111
  return None, parts[0]
119
112
  if len(parts) == 2:
120
- return parts[0], parts[1]
113
+ return parts[0].lower(), parts[1]
121
114
  raise ValueError(
122
115
  f"Model name must be in the format `organization/model_name` or `model_name`, but got '{name}'."
123
116
  )
124
117
 
125
- def _build_model_lookup(self, endpoints: List[str]) -> Dict[str, Set[str]]:
126
- result = {}
127
- for endpoint in endpoints:
128
- result.setdefault(endpoint.model, []).append(endpoint)
129
- return result
130
-
131
- def _get_public_models(self) -> List[str]:
132
- return self._llm_api.get_public_models()
133
-
134
- def _get_org_models(self) -> List[str]:
135
- return self._llm_api.get_org_models(self._org.id) if self._org else []
136
-
137
- def _get_user_models(self) -> List[str]:
138
- return self._llm_api.get_user_models(self._user.id) if self._user else []
139
-
140
- def _get_model(self) -> V1Assistant:
141
- # TODO how to handle multiple models with same model type? For now, just use the first one
142
- if self._model_name in self._public_models:
143
- return self._public_models.get(self._model_name)[0]
144
- if self._model_name in self._org_models:
145
- return self._org_models.get(self._model_name)[0]
146
- if self._model_name in self._user_models:
147
- return self._user_models.get(self._model_name)[0]
148
-
149
- available_models = []
150
- if self._public_models:
151
- available_models.append(f"Public Models: {', '.join(self._public_models.keys())}")
152
-
153
- if self._org and self._org_models:
154
- available_models.append(f"Org ({self._org.name}) Models: {', '.join(self._org_models.keys())}")
118
+ # returns the assistant ID
119
+ def _get_model_id(self) -> str:
120
+ if self._model_provider in PUBLIC_MODEL_PROVIDERS:
121
+ # if prod
122
+ if (
123
+ self._cloud_url == "https://lightning.ai"
124
+ and LLM._public_assistants
125
+ and f"{self._model_provider}/{self._model_name}" in LLM._public_assistants
126
+ ):
127
+ return LLM._public_assistants[f"{self._model_provider}/{self._model_name}"]
128
+ try:
129
+ return self._llm_api.get_assistant(
130
+ model_provider=PUBLIC_MODEL_PROVIDERS[self._model_provider],
131
+ model_name=self._model_name,
132
+ user_name="",
133
+ org_name="",
134
+ )
135
+ except Exception as e:
136
+ raise ValueError(
137
+ f"Public model '{self._model_provider}/{self._model_name}' not found. "
138
+ "Please check the model name or provider."
139
+ ) from e
155
140
 
156
- if self._user and self._user_models:
157
- available_models.append(f"User ({self._user.name}) Models: {', '.join(self._user_models.keys())}")
141
+ # Try organization model
142
+ try:
143
+ return self._llm_api.get_assistant(
144
+ model_provider="",
145
+ model_name=self._model_name,
146
+ user_name="",
147
+ org_name=self._model_provider,
148
+ )
149
+ except Exception:
150
+ pass
158
151
 
159
- available_models_str = "\n".join(available_models)
160
- raise ValueError(f"Model '{self._model_name}' not found. \nAvailable models: \n{available_models_str}")
152
+ # Try user model
153
+ try:
154
+ return self._llm_api.get_assistant(
155
+ model_provider="",
156
+ model_name=self._model_name,
157
+ user_name=self._model_provider,
158
+ org_name="",
159
+ )
160
+ except Exception as user_error:
161
+ raise ValueError(
162
+ f"Model '{self._model_provider}/{self._model_name}' not found as either an org or user model.\n"
163
+ ) from user_error
161
164
 
162
165
  def _get_conversations(self) -> None:
163
- conversations = self._llm_api.list_conversations(assistant_id=self._model.id)
166
+ conversations = self._llm_api.list_conversations(assistant_id=self._model_id)
164
167
  for conversation in conversations:
165
168
  if conversation.name and conversation.name not in self._conversations:
166
169
  self._conversations[conversation.name] = conversation.id
@@ -191,7 +194,6 @@ class LLM:
191
194
  conversation: Optional[str] = None,
192
195
  metadata: Optional[Dict[str, str]] = None,
193
196
  stream: bool = False,
194
- upload_local_images: bool = False,
195
197
  ) -> Union[str, AsyncGenerator[str, None]]:
196
198
  conversation_id = self._conversations.get(conversation) if conversation else None
197
199
  output = await self._llm_api.async_start_conversation(
@@ -199,9 +201,9 @@ class LLM:
199
201
  system_prompt=system_prompt,
200
202
  max_completion_tokens=max_completion_tokens,
201
203
  images=images,
202
- assistant_id=self._model.id,
204
+ assistant_id=self._model_id,
203
205
  conversation_id=conversation_id,
204
- billing_project_id=self._teamspace.id,
206
+ billing_project_id=self._teamspace_id,
205
207
  metadata=metadata,
206
208
  name=conversation,
207
209
  stream=stream,
@@ -221,7 +223,6 @@ class LLM:
221
223
  conversation: Optional[str] = None,
222
224
  metadata: Optional[Dict[str, str]] = None,
223
225
  stream: bool = False,
224
- upload_local_images: bool = False,
225
226
  ) -> Union[str, Generator[str, None, None]]:
226
227
  if conversation and conversation not in self._conversations:
227
228
  self._get_conversations()
@@ -232,8 +233,6 @@ class LLM:
232
233
  for image in images:
233
234
  if not isinstance(image, str):
234
235
  raise NotImplementedError(f"Image type {type(image)} are not supported yet.")
235
- if not image.startswith("http") and upload_local_images:
236
- self._teamspace.upload_file(file_path=image, remote_path=f"images/{os.path.basename(image)}")
237
236
 
238
237
  conversation_id = self._conversations.get(conversation) if conversation else None
239
238
 
@@ -246,7 +245,6 @@ class LLM:
246
245
  conversation,
247
246
  metadata,
248
247
  stream,
249
- upload_local_images,
250
248
  )
251
249
 
252
250
  output = self._llm_api.start_conversation(
@@ -254,9 +252,9 @@ class LLM:
254
252
  system_prompt=system_prompt,
255
253
  max_completion_tokens=max_completion_tokens,
256
254
  images=images,
257
- assistant_id=self._model.id,
255
+ assistant_id=self._model_id,
258
256
  conversation_id=conversation_id,
259
- billing_project_id=self._teamspace.id,
257
+ billing_project_id=self._teamspace_id,
260
258
  metadata=metadata,
261
259
  name=conversation,
262
260
  stream=stream,
@@ -272,7 +270,7 @@ class LLM:
272
270
  return list(self._conversations.keys())
273
271
 
274
272
  def _get_conversation_messages(self, conversation_id: str) -> Optional[str]:
275
- return self._llm_api.get_conversation(assistant_id=self._model.id, conversation_id=conversation_id)
273
+ return self._llm_api.get_conversation(assistant_id=self._model_id, conversation_id=conversation_id)
276
274
 
277
275
  def get_history(self, conversation: str) -> Optional[List[Dict]]:
278
276
  if conversation not in self._conversations:
@@ -297,7 +295,7 @@ class LLM:
297
295
  self._get_conversations()
298
296
  if conversation in self._conversations:
299
297
  self._llm_api.reset_conversation(
300
- assistant_id=self._model.id,
298
+ assistant_id=self._model_id,
301
299
  conversation_id=self._conversations[conversation],
302
300
  )
303
301
  del self._conversations[conversation]
@@ -0,0 +1,8 @@
1
+ {
2
+ "openai/gpt-4o": "ast_01jdjds71fs8gt47jexzed4czs",
3
+ "openai/gpt-4": "ast_01jd38ze6tjbrcd4942nhz41zn",
4
+ "openai/o3-mini": "ast_01jz3t13fhnjhh11t1k8b5gyp1",
5
+ "anthropic/claude-3-5-sonnet-20240620": "ast_01jd3923a6p98rqwh3dpj686pq",
6
+ "google/gemini-2.5-pro": "ast_01jz3tdb1fhey798k95pv61v57",
7
+ "google/gemini-2.5-flash": "ast_01jz3thxskg4fcdk4xhkjkym5a"
8
+ }
@@ -86,6 +86,12 @@ class Pipeline:
86
86
  if len(steps) == 0:
87
87
  raise ValueError("The provided steps is empty")
88
88
 
89
+ provided_cloud_account = None
90
+ if self._cloud_account:
91
+ provided_cloud_account = self._cloud_account
92
+ elif self._default_cluster:
93
+ provided_cloud_account = self._default_cluster.cluster_id
94
+
89
95
  for step_idx, pipeline_step in enumerate(steps):
90
96
  if pipeline_step.name in [None, ""]:
91
97
  pipeline_step.name = f"step-{step_idx}"
@@ -98,20 +104,29 @@ class Pipeline:
98
104
  pipeline_step.cloud_account = self._studio.cloud_account
99
105
  pipeline_step.studio = self._studio
100
106
 
107
+ if not pipeline_step.cloud_account and isinstance(provided_cloud_account, str):
108
+ pipeline_step.cloud_account = provided_cloud_account
109
+
101
110
  cluster_ids = set(step.cloud_account for step in steps if step.cloud_account not in ["", None]) # noqa: C401
102
111
 
103
- cloud_account = list(cluster_ids)[0] if len(cluster_ids) == 1 and self._cloud_account is None else "" # noqa: RUF015
112
+ cloud_account = (
113
+ list(cluster_ids)[0] if len(cluster_ids) == 1 and self._cloud_account is None else "" # noqa: RUF015
114
+ )
104
115
 
105
116
  steps = [step.to_proto(self._teamspace, cloud_account, self._shared_filesystem) for step in steps]
106
117
 
107
118
  proto_steps = prepare_steps(steps)
108
119
  schedules = schedules or []
109
120
 
121
+ for schedule_idx, schedule in enumerate(schedules):
122
+ if schedule.name is None:
123
+ schedule.name = f"schedule-{schedule_idx}"
124
+
110
125
  parent_pipeline_id = None if self._pipeline is None else self._pipeline.id
111
126
 
112
127
  self._pipeline = self._pipeline_api.create_pipeline(
113
128
  self._name,
114
- self._teamspace.id,
129
+ self._teamspace,
115
130
  proto_steps,
116
131
  self._shared_filesystem,
117
132
  schedules,
@@ -1,7 +1,8 @@
1
1
  import os
2
2
  from typing import Any, ClassVar, Dict, List
3
3
 
4
- from lightning_sdk.lightning_cloud.openapi.models import V1JobSpec, V1Pipeline, V1PipelineStepType
4
+ from lightning_sdk.lightning_cloud.openapi.models import V1Pipeline, V1PipelineStepType
5
+ from lightning_sdk.pipeline.utils import _get_spec
5
6
 
6
7
 
7
8
  class PipelinePrinter:
@@ -31,7 +32,7 @@ class PipelinePrinter:
31
32
  self._schedules = schedules
32
33
  cluster_ids: set[str] = set()
33
34
  for step in self._proto_steps:
34
- job_spec = self._get_spec(step)
35
+ job_spec = _get_spec(step)
35
36
  cluster_ids.add(job_spec.cluster_id)
36
37
  self._cluster_ids = cluster_ids
37
38
 
@@ -107,15 +108,15 @@ class PipelinePrinter:
107
108
 
108
109
  if self._shared_filesystem.enabled and len(self._cluster_ids) == 1:
109
110
  shared_path = ""
111
+ cluster_id = list(self._cluster_ids)[0] # noqa: RUF015
110
112
  if self._pipeline.shared_filesystem.s3_folder:
111
- cluster_id = list(self._cluster_ids)[0] # noqa: RUF015
112
113
  shared_path = f"/teamspace/s3_folders/pipelines-{cluster_id}"
114
+ if self._pipeline.shared_filesystem.gcs_folder:
115
+ shared_path = f"/teamspace/gcs_folders/pipelines-{cluster_id}"
116
+ if self._pipeline.shared_filesystem.efs:
117
+ shared_path = f"/teamspace/efs_connections/pipelines-{cluster_id}"
118
+ if self._pipeline.shared_filesystem.filestore:
119
+ shared_path = f"/teamspace/gcs_connections/pipelines-{cluster_id}"
120
+
113
121
  if shared_path:
114
122
  self._print(f" - {shared_path}")
115
-
116
- def _get_spec(self, step: Any) -> V1JobSpec:
117
- if step.type == V1PipelineStepType.DEPLOYMENT:
118
- return step.deployment.spec
119
- if step.type == V1PipelineStepType.MMT:
120
- return step.mmt.spec
121
- return step.job.spec
@@ -69,12 +69,15 @@ class DeploymentStep:
69
69
 
70
70
  self.machine = machine or Machine.CPU
71
71
  self.image = image
72
+ autoscaling_metric_name = (
73
+ ("CPU" if self.machine.is_cpu() else "GPU") if isinstance(self.machine, Machine) else "CPU"
74
+ )
72
75
  self.autoscale = autoscale or AutoScaleConfig(
73
76
  min_replicas=0,
74
77
  max_replicas=1,
75
78
  target_metrics=[
76
79
  AutoScalingMetric(
77
- name="CPU" if self.machine.is_cpu() else "GPU",
80
+ name=autoscaling_metric_name,
78
81
  target=80,
79
82
  )
80
83
  ],
@@ -1,6 +1,11 @@
1
- from typing import List, Literal, Optional, Union
2
-
3
- from lightning_sdk.lightning_cloud.openapi.models import V1PipelineStep, V1SharedFilesystem
1
+ from typing import Any, List, Literal, Optional, Union
2
+
3
+ from lightning_sdk.lightning_cloud.openapi.models import (
4
+ V1JobSpec,
5
+ V1PipelineStep,
6
+ V1PipelineStepType,
7
+ V1SharedFilesystem,
8
+ )
4
9
  from lightning_sdk.studio import Studio
5
10
 
6
11
  DEFAULT = "DEFAULT"
@@ -21,7 +26,7 @@ def prepare_steps(steps: List["V1PipelineStep"]) -> List["V1PipelineStep"]:
21
26
  else:
22
27
  raise ValueError(f"A step with the name {current_step.name} already exists.")
23
28
 
24
- if steps[0].wait_for != DEFAULT:
29
+ if steps[0].wait_for not in [None, DEFAULT, []]:
25
30
  raise ValueError("The first step isn't allowed to receive `wait_for=...`.")
26
31
 
27
32
  steps[0].wait_for = []
@@ -89,3 +94,23 @@ def _to_wait_for(wait_for: Optional[Union[str, List[str]]]) -> Optional[Union[Li
89
94
  return []
90
95
 
91
96
  return wait_for if isinstance(wait_for, list) else [wait_for]
97
+
98
+
99
+ def _get_cloud_account(steps: List[V1PipelineStep]) -> Optional[str]:
100
+ if len(steps) == 0:
101
+ return None
102
+
103
+ cluster_ids: set[str] = set()
104
+ for step in steps:
105
+ job_spec = _get_spec(step)
106
+ cluster_ids.add(job_spec.cluster_id)
107
+
108
+ return sorted(cluster_ids)[0]
109
+
110
+
111
+ def _get_spec(step: Any) -> V1JobSpec:
112
+ if step.type == V1PipelineStepType.DEPLOYMENT:
113
+ return step.deployment.spec
114
+ if step.type == V1PipelineStepType.MMT:
115
+ return step.mmt.spec
116
+ return step.job.spec
lightning_sdk/studio.py CHANGED
@@ -88,6 +88,9 @@ class Studio:
88
88
 
89
89
  self._plugins = {}
90
90
 
91
+ if self._teamspace is None:
92
+ raise ValueError("Couldn't resolve teamspace from the provided name, org, or user")
93
+
91
94
  if provider is not None:
92
95
  if isinstance(provider, str) and provider in Provider.__members__:
93
96
  provider = Provider(provider)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: lightning_sdk
3
- Version: 0.2.23
3
+ Version: 0.2.24rc0
4
4
  Summary: SDK to develop using Lightning AI Studios
5
5
  Author-email: Lightning-AI <justus@lightning.ai>
6
6
  License: MIT License