lightning-sdk 0.2.23__py3-none-any.whl → 0.2.24rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. lightning_sdk/__init__.py +1 -1
  2. lightning_sdk/api/base_studio_api.py +9 -2
  3. lightning_sdk/api/deployment_api.py +9 -9
  4. lightning_sdk/api/llm_api.py +23 -13
  5. lightning_sdk/api/pipeline_api.py +31 -11
  6. lightning_sdk/api/studio_api.py +4 -0
  7. lightning_sdk/base_studio.py +22 -6
  8. lightning_sdk/deployment/deployment.py +17 -7
  9. lightning_sdk/lightning_cloud/openapi/__init__.py +18 -0
  10. lightning_sdk/lightning_cloud/openapi/api/__init__.py +2 -0
  11. lightning_sdk/lightning_cloud/openapi/api/assistants_service_api.py +114 -1
  12. lightning_sdk/lightning_cloud/openapi/api/cloudy_service_api.py +129 -0
  13. lightning_sdk/lightning_cloud/openapi/api/cluster_service_api.py +9 -1
  14. lightning_sdk/lightning_cloud/openapi/api/lit_logger_service_api.py +13 -1
  15. lightning_sdk/lightning_cloud/openapi/api/organizations_service_api.py +105 -0
  16. lightning_sdk/lightning_cloud/openapi/api/pipelines_service_api.py +4 -4
  17. lightning_sdk/lightning_cloud/openapi/api/user_service_api.py +105 -0
  18. lightning_sdk/lightning_cloud/openapi/api/volume_service_api.py +258 -0
  19. lightning_sdk/lightning_cloud/openapi/models/__init__.py +16 -0
  20. lightning_sdk/lightning_cloud/openapi/models/agents_id_body.py +27 -1
  21. lightning_sdk/lightning_cloud/openapi/models/assistant_id_conversations_body.py +79 -1
  22. lightning_sdk/lightning_cloud/openapi/models/credits_autoreplenish_body.py +175 -0
  23. lightning_sdk/lightning_cloud/openapi/models/credits_autoreplenish_body1.py +175 -0
  24. lightning_sdk/lightning_cloud/openapi/models/externalv1_user_status.py +53 -1
  25. lightning_sdk/lightning_cloud/openapi/models/orgs_id_body.py +27 -1
  26. lightning_sdk/lightning_cloud/openapi/models/pipelines_id_body1.py +123 -0
  27. lightning_sdk/lightning_cloud/openapi/models/project_id_agents_body.py +27 -1
  28. lightning_sdk/lightning_cloud/openapi/models/project_id_cloudspaces_body.py +27 -1
  29. lightning_sdk/lightning_cloud/openapi/models/project_id_schedules_body.py +27 -1
  30. lightning_sdk/lightning_cloud/openapi/models/schedules_id_body.py +27 -1
  31. lightning_sdk/lightning_cloud/openapi/models/update.py +29 -3
  32. lightning_sdk/lightning_cloud/openapi/models/v1_assistant.py +27 -1
  33. lightning_sdk/lightning_cloud/openapi/models/v1_billing_tier.py +1 -0
  34. lightning_sdk/lightning_cloud/openapi/models/v1_cloud_provider.py +1 -0
  35. lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space.py +27 -1
  36. lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space_environment_config.py +29 -3
  37. lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space_environment_template_config.py +29 -3
  38. lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space_environment_type.py +1 -0
  39. lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space_specialized_view.py +104 -0
  40. lightning_sdk/lightning_cloud/openapi/models/v1_cloudy_expert.py +279 -0
  41. lightning_sdk/lightning_cloud/openapi/models/v1_cluster_accelerator.py +79 -1
  42. lightning_sdk/lightning_cloud/openapi/models/v1_cluster_capacity_reservation.py +53 -1
  43. lightning_sdk/lightning_cloud/openapi/models/v1_cluster_spec.py +27 -1
  44. lightning_sdk/lightning_cloud/openapi/models/v1_cluster_status.py +27 -1
  45. lightning_sdk/lightning_cloud/openapi/models/v1_conversation_response_chunk.py +29 -3
  46. lightning_sdk/lightning_cloud/openapi/models/v1_create_cloud_space_environment_template_request.py +29 -3
  47. lightning_sdk/lightning_cloud/openapi/models/v1_create_organization_request.py +79 -1
  48. lightning_sdk/lightning_cloud/openapi/models/v1_data_connection_tier.py +103 -0
  49. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_status.py +47 -21
  50. lightning_sdk/lightning_cloud/openapi/models/v1_external_cluster_spec.py +27 -1
  51. lightning_sdk/lightning_cloud/openapi/models/v1_filestore_data_connection.py +29 -3
  52. lightning_sdk/lightning_cloud/openapi/models/v1_get_job_stats_response.py +53 -1
  53. lightning_sdk/lightning_cloud/openapi/models/v1_get_user_response.py +27 -1
  54. lightning_sdk/lightning_cloud/openapi/models/v1_get_volume_response.py +123 -0
  55. lightning_sdk/lightning_cloud/openapi/models/v1_instance_overprovisioning_spec.py +1 -27
  56. lightning_sdk/lightning_cloud/openapi/models/v1_kubernetes_direct_v1.py +149 -0
  57. lightning_sdk/lightning_cloud/openapi/models/v1_kubernetes_direct_v1_status.py +149 -0
  58. lightning_sdk/lightning_cloud/openapi/models/v1_list_cloudy_experts_response.py +123 -0
  59. lightning_sdk/lightning_cloud/openapi/models/v1_magic_link_login_response.py +27 -1
  60. lightning_sdk/lightning_cloud/openapi/models/v1_organization.py +27 -1
  61. lightning_sdk/lightning_cloud/openapi/models/v1_rule_resource.py +1 -0
  62. lightning_sdk/lightning_cloud/openapi/models/v1_schedule.py +27 -1
  63. lightning_sdk/lightning_cloud/openapi/models/v1_schedule_action_type.py +104 -0
  64. lightning_sdk/lightning_cloud/openapi/models/v1_schedule_resource_type.py +1 -0
  65. lightning_sdk/lightning_cloud/openapi/models/v1_token_usage.py +175 -0
  66. lightning_sdk/lightning_cloud/openapi/models/v1_update_organization_credits_auto_replenish_response.py +97 -0
  67. lightning_sdk/lightning_cloud/openapi/models/v1_update_user_credits_auto_replenish_response.py +97 -0
  68. lightning_sdk/lightning_cloud/openapi/models/v1_update_volume_response.py +123 -0
  69. lightning_sdk/lightning_cloud/openapi/models/v1_user_features.py +317 -31
  70. lightning_sdk/lightning_cloud/openapi/models/volumes_id_body.py +123 -0
  71. lightning_sdk/llm/llm.py +118 -115
  72. lightning_sdk/llm/public_assistants.json +8 -0
  73. lightning_sdk/pipeline/pipeline.py +17 -2
  74. lightning_sdk/pipeline/printer.py +11 -10
  75. lightning_sdk/pipeline/steps.py +4 -1
  76. lightning_sdk/pipeline/utils.py +29 -4
  77. lightning_sdk/studio.py +3 -0
  78. {lightning_sdk-0.2.23.dist-info → lightning_sdk-0.2.24rc1.dist-info}/METADATA +1 -1
  79. {lightning_sdk-0.2.23.dist-info → lightning_sdk-0.2.24rc1.dist-info}/RECORD +83 -64
  80. {lightning_sdk-0.2.23.dist-info → lightning_sdk-0.2.24rc1.dist-info}/LICENSE +0 -0
  81. {lightning_sdk-0.2.23.dist-info → lightning_sdk-0.2.24rc1.dist-info}/WHEEL +0 -0
  82. {lightning_sdk-0.2.23.dist-info → lightning_sdk-0.2.24rc1.dist-info}/entry_points.txt +0 -0
  83. {lightning_sdk-0.2.23.dist-info → lightning_sdk-0.2.24rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,123 @@
1
+ # coding: utf-8
2
+
3
+ """
4
+ external/v1/auth_service.proto
5
+
6
+ No description provided (generated by Swagger Codegen https://github.com/swagger-api/swagger-codegen) # noqa: E501
7
+
8
+ OpenAPI spec version: version not set
9
+
10
+ Generated by: https://github.com/swagger-api/swagger-codegen.git
11
+
12
+ NOTE
13
+ ----
14
+ standard swagger-codegen-cli for this python client has been modified
15
+ by custom templates. The purpose of these templates is to include
16
+ typing information in the API and Model code. Please refer to the
17
+ main grid repository for more info
18
+ """
19
+
20
+ import pprint
21
+ import re # noqa: F401
22
+
23
+ from typing import TYPE_CHECKING
24
+
25
+ import six
26
+
27
+ if TYPE_CHECKING:
28
+ from datetime import datetime
29
+ from lightning_sdk.lightning_cloud.openapi.models import *
30
+
31
+ class VolumesIdBody(object):
32
+ """NOTE: This class is auto generated by the swagger code generator program.
33
+
34
+ Do not edit the class manually.
35
+ """
36
+ """
37
+ Attributes:
38
+ swagger_types (dict): The key is attribute name
39
+ and the value is attribute type.
40
+ attribute_map (dict): The key is attribute name
41
+ and the value is json key in definition.
42
+ """
43
+ swagger_types = {
44
+ 'volume': 'V1Volume'
45
+ }
46
+
47
+ attribute_map = {
48
+ 'volume': 'volume'
49
+ }
50
+
51
+ def __init__(self, volume: 'V1Volume' =None): # noqa: E501
52
+ """VolumesIdBody - a model defined in Swagger""" # noqa: E501
53
+ self._volume = None
54
+ self.discriminator = None
55
+ if volume is not None:
56
+ self.volume = volume
57
+
58
+ @property
59
+ def volume(self) -> 'V1Volume':
60
+ """Gets the volume of this VolumesIdBody. # noqa: E501
61
+
62
+
63
+ :return: The volume of this VolumesIdBody. # noqa: E501
64
+ :rtype: V1Volume
65
+ """
66
+ return self._volume
67
+
68
+ @volume.setter
69
+ def volume(self, volume: 'V1Volume'):
70
+ """Sets the volume of this VolumesIdBody.
71
+
72
+
73
+ :param volume: The volume of this VolumesIdBody. # noqa: E501
74
+ :type: V1Volume
75
+ """
76
+
77
+ self._volume = volume
78
+
79
+ def to_dict(self) -> dict:
80
+ """Returns the model properties as a dict"""
81
+ result = {}
82
+
83
+ for attr, _ in six.iteritems(self.swagger_types):
84
+ value = getattr(self, attr)
85
+ if isinstance(value, list):
86
+ result[attr] = list(map(
87
+ lambda x: x.to_dict() if hasattr(x, "to_dict") else x,
88
+ value
89
+ ))
90
+ elif hasattr(value, "to_dict"):
91
+ result[attr] = value.to_dict()
92
+ elif isinstance(value, dict):
93
+ result[attr] = dict(map(
94
+ lambda item: (item[0], item[1].to_dict())
95
+ if hasattr(item[1], "to_dict") else item,
96
+ value.items()
97
+ ))
98
+ else:
99
+ result[attr] = value
100
+ if issubclass(VolumesIdBody, dict):
101
+ for key, value in self.items():
102
+ result[key] = value
103
+
104
+ return result
105
+
106
+ def to_str(self) -> str:
107
+ """Returns the string representation of the model"""
108
+ return pprint.pformat(self.to_dict())
109
+
110
+ def __repr__(self) -> str:
111
+ """For `print` and `pprint`"""
112
+ return self.to_str()
113
+
114
+ def __eq__(self, other: 'VolumesIdBody') -> bool:
115
+ """Returns true if both objects are equal"""
116
+ if not isinstance(other, VolumesIdBody):
117
+ return False
118
+
119
+ return self.__dict__ == other.__dict__
120
+
121
+ def __ne__(self, other: 'VolumesIdBody') -> bool:
122
+ """Returns true if both objects are not equal"""
123
+ return not self == other
lightning_sdk/llm/llm.py CHANGED
@@ -1,19 +1,37 @@
1
+ import json
1
2
  import os
2
- import warnings
3
- from typing import AsyncGenerator, Dict, Generator, List, Optional, Set, Tuple, Union
3
+ from typing import Any, AsyncGenerator, ClassVar, Dict, Generator, List, Optional, Tuple, Union
4
4
 
5
5
  from lightning_sdk.api.llm_api import LLMApi
6
- from lightning_sdk.cli.teamspace_menu import _TeamspacesMenu
7
- from lightning_sdk.lightning_cloud.openapi import V1Assistant
8
6
  from lightning_sdk.lightning_cloud.openapi.models.v1_conversation_response_chunk import V1ConversationResponseChunk
9
- from lightning_sdk.lightning_cloud.openapi.rest import ApiException
10
- from lightning_sdk.organization import Organization
11
- from lightning_sdk.owner import Owner
12
- from lightning_sdk.teamspace import Teamspace
13
- from lightning_sdk.utils.resolve import _get_authed_user, _resolve_org, _resolve_teamspace
7
+
8
+ PUBLIC_MODEL_PROVIDERS: Dict[str, str] = {
9
+ "openai": "OpenAI",
10
+ "anthropic": "Anthropic",
11
+ "google": "Google",
12
+ }
13
+
14
+
15
+ def _load_public_assistants() -> Dict[str, str]:
16
+ """Load public assistants from a JSON file."""
17
+ try:
18
+ json_path = os.path.join(os.path.dirname(__file__), "public_assistants.json")
19
+ with open(json_path) as f:
20
+ return json.load(f)
21
+ except Exception as e:
22
+ print(f"[warning] Failed to load public_assistants.json: {e}")
23
+ return {}
14
24
 
15
25
 
16
26
  class LLM:
27
+ _auth_info_cached: ClassVar[bool] = False
28
+ _cached_auth_info: ClassVar[Dict[str, Optional[str]]] = {}
29
+ _llm_api_cache: ClassVar[Dict[Optional[str], LLMApi]] = {}
30
+ _public_assistants: ClassVar[Optional[Dict[str, str]]] = None
31
+
32
+ def __new__(cls, name: str, teamspace: Optional[str] = None, enable_async: Optional[bool] = False) -> "LLM":
33
+ return super().__new__(cls)
34
+
17
35
  def __init__(
18
36
  self,
19
37
  name: str,
@@ -36,67 +54,18 @@ class LLM:
36
54
  Raises:
37
55
  ValueError: If teamspace information cannot be resolved.
38
56
  """
39
- menu = _TeamspacesMenu()
40
- user = _get_authed_user()
41
- possible_teamspaces = menu._get_possible_teamspaces(user)
42
- if teamspace is None:
43
- # get current teamspace
44
- self._teamspace = _resolve_teamspace(teamspace=None, org=None, user=None)
45
- else:
46
- self._teamspace = Teamspace(**menu._get_teamspace_from_name(teamspace, possible_teamspaces))
47
-
48
- if self._teamspace is None:
49
- # select the first available teamspace
50
- first_teamspace = next(iter(possible_teamspaces.values()), None)
51
-
52
- if first_teamspace:
53
- self._teamspace = Teamspace(
54
- name=first_teamspace["name"],
55
- org=first_teamspace["org"],
56
- user=first_teamspace["user"],
57
- )
58
- warnings.warn(
59
- f"No teamspace given. Using teamspace: {self._teamspace.name}.",
60
- UserWarning,
61
- stacklevel=2,
62
- )
63
-
64
- if self._teamspace is None:
65
- raise ValueError("Teamspace is required for billing but could not be resolved. ")
66
-
67
- self._user = user
57
+ # TODO support user input teamspace
58
+ self._get_auth_info()
68
59
 
69
60
  self._model_provider, self._model_name = self._parse_model_name(name)
70
-
71
- self._llm_api = LLMApi()
72
61
  self._enable_async = enable_async
73
62
 
74
- try:
75
- # check if it is a org model
76
- self._org = _resolve_org(self._model_provider)
63
+ # Reuse LLMApi per teamspace (as billing is based on teamspace)
64
+ if teamspace not in LLM._llm_api_cache:
65
+ LLM._llm_api_cache[teamspace] = LLMApi()
66
+ self._llm_api = LLM._llm_api_cache[teamspace]
77
67
 
78
- try:
79
- # check if user has access to the org
80
- self._org_models = self._build_model_lookup(self._get_org_models())
81
- except ApiException:
82
- warnings.warn(
83
- f"User is not authenticated to access the model in organization: '{self._model_provider}'.\n"
84
- " Proceeding with appropriate org models, user models, or public models.",
85
- UserWarning,
86
- stacklevel=2,
87
- )
88
- self._model_provider = None
89
- raise
90
- except ApiException:
91
- if isinstance(self._teamspace.owner, Organization):
92
- self._org = self._teamspace.owner
93
- else:
94
- self._org = None
95
- self._org_models = self._build_model_lookup(self._get_org_models())
96
-
97
- self._public_models = self._build_model_lookup(self._get_public_models())
98
- self._user_models = self._build_model_lookup(self._get_user_models())
99
- self._model = self._get_model()
68
+ self._model_id = self._get_model_id()
100
69
  self._conversations = {}
101
70
 
102
71
  @property
@@ -107,9 +76,33 @@ class LLM:
107
76
  def provider(self) -> str:
108
77
  return self._model_provider
109
78
 
110
- @property
111
- def owner(self) -> Optional[Owner]:
112
- return self._teamspace.owner
79
+ def _get_auth_info(self) -> None:
80
+ if not LLM._auth_info_cached:
81
+ teamspace_name = os.environ.get("LIGHTNING_TEAMSPACE", None)
82
+ if teamspace_name is None:
83
+ raise ValueError(
84
+ "Teamspace name must be provided either through "
85
+ "the environment variable LIGHTNING_TEAMSPACE or as an argument."
86
+ )
87
+ LLM._cached_auth_info = {
88
+ "teamspace_name": teamspace_name,
89
+ "teamspace_id": os.environ.get("LIGHTNING_CLOUD_PROJECT_ID", None),
90
+ "user_name": os.environ.get("LIGHTNING_USERNAME", ""),
91
+ "user_id": os.environ.get("LIGHTNING_USER_ID", None),
92
+ "org_name": os.environ.get("LIGHTNING_ORG", ""),
93
+ "cloud_url": os.environ.get("LIGHTNING_CLOUD_URL", None),
94
+ }
95
+ LLM._auth_info_cached = True
96
+ if LLM._public_assistants is None:
97
+ LLM._public_assistants = _load_public_assistants()
98
+ # Always assign to the current instance
99
+ self._teamspace_name = LLM._cached_auth_info["teamspace_name"]
100
+ self._teamspace_id = LLM._cached_auth_info["teamspace_id"]
101
+ self._user_name = LLM._cached_auth_info["user_name"]
102
+ self._user_id = LLM._cached_auth_info["user_id"]
103
+ self._org_name = LLM._cached_auth_info["org_name"]
104
+ self._cloud_url = LLM._cached_auth_info["cloud_url"]
105
+ self._org = None
113
106
 
114
107
  def _parse_model_name(self, name: str) -> Tuple[str, str]:
115
108
  parts = name.split("/")
@@ -117,50 +110,60 @@ class LLM:
117
110
  # a user model or a org model
118
111
  return None, parts[0]
119
112
  if len(parts) == 2:
120
- return parts[0], parts[1]
113
+ return parts[0].lower(), parts[1]
121
114
  raise ValueError(
122
115
  f"Model name must be in the format `organization/model_name` or `model_name`, but got '{name}'."
123
116
  )
124
117
 
125
- def _build_model_lookup(self, endpoints: List[str]) -> Dict[str, Set[str]]:
126
- result = {}
127
- for endpoint in endpoints:
128
- result.setdefault(endpoint.model, []).append(endpoint)
129
- return result
130
-
131
- def _get_public_models(self) -> List[str]:
132
- return self._llm_api.get_public_models()
133
-
134
- def _get_org_models(self) -> List[str]:
135
- return self._llm_api.get_org_models(self._org.id) if self._org else []
136
-
137
- def _get_user_models(self) -> List[str]:
138
- return self._llm_api.get_user_models(self._user.id) if self._user else []
139
-
140
- def _get_model(self) -> V1Assistant:
141
- # TODO how to handle multiple models with same model type? For now, just use the first one
142
- if self._model_name in self._public_models:
143
- return self._public_models.get(self._model_name)[0]
144
- if self._model_name in self._org_models:
145
- return self._org_models.get(self._model_name)[0]
146
- if self._model_name in self._user_models:
147
- return self._user_models.get(self._model_name)[0]
148
-
149
- available_models = []
150
- if self._public_models:
151
- available_models.append(f"Public Models: {', '.join(self._public_models.keys())}")
152
-
153
- if self._org and self._org_models:
154
- available_models.append(f"Org ({self._org.name}) Models: {', '.join(self._org_models.keys())}")
118
+ # returns the assistant ID
119
+ def _get_model_id(self) -> str:
120
+ if self._model_provider in PUBLIC_MODEL_PROVIDERS:
121
+ # if prod
122
+ if (
123
+ self._cloud_url == "https://lightning.ai"
124
+ and LLM._public_assistants
125
+ and f"{self._model_provider}/{self._model_name}" in LLM._public_assistants
126
+ ):
127
+ return LLM._public_assistants[f"{self._model_provider}/{self._model_name}"]
128
+ try:
129
+ return self._llm_api.get_assistant(
130
+ model_provider=PUBLIC_MODEL_PROVIDERS[self._model_provider],
131
+ model_name=self._model_name,
132
+ user_name="",
133
+ org_name="",
134
+ )
135
+ except Exception as e:
136
+ raise ValueError(
137
+ f"Public model '{self._model_provider}/{self._model_name}' not found. "
138
+ "Please check the model name or provider."
139
+ ) from e
155
140
 
156
- if self._user and self._user_models:
157
- available_models.append(f"User ({self._user.name}) Models: {', '.join(self._user_models.keys())}")
141
+ # Try organization model
142
+ try:
143
+ return self._llm_api.get_assistant(
144
+ model_provider="",
145
+ model_name=self._model_name,
146
+ user_name="",
147
+ org_name=self._model_provider,
148
+ )
149
+ except Exception:
150
+ pass
158
151
 
159
- available_models_str = "\n".join(available_models)
160
- raise ValueError(f"Model '{self._model_name}' not found. \nAvailable models: \n{available_models_str}")
152
+ # Try user model
153
+ try:
154
+ return self._llm_api.get_assistant(
155
+ model_provider="",
156
+ model_name=self._model_name,
157
+ user_name=self._model_provider,
158
+ org_name="",
159
+ )
160
+ except Exception as user_error:
161
+ raise ValueError(
162
+ f"Model '{self._model_provider}/{self._model_name}' not found as either an org or user model.\n"
163
+ ) from user_error
161
164
 
162
165
  def _get_conversations(self) -> None:
163
- conversations = self._llm_api.list_conversations(assistant_id=self._model.id)
166
+ conversations = self._llm_api.list_conversations(assistant_id=self._model_id)
164
167
  for conversation in conversations:
165
168
  if conversation.name and conversation.name not in self._conversations:
166
169
  self._conversations[conversation.name] = conversation.id
@@ -191,7 +194,7 @@ class LLM:
191
194
  conversation: Optional[str] = None,
192
195
  metadata: Optional[Dict[str, str]] = None,
193
196
  stream: bool = False,
194
- upload_local_images: bool = False,
197
+ **kwargs: Any,
195
198
  ) -> Union[str, AsyncGenerator[str, None]]:
196
199
  conversation_id = self._conversations.get(conversation) if conversation else None
197
200
  output = await self._llm_api.async_start_conversation(
@@ -199,12 +202,13 @@ class LLM:
199
202
  system_prompt=system_prompt,
200
203
  max_completion_tokens=max_completion_tokens,
201
204
  images=images,
202
- assistant_id=self._model.id,
205
+ assistant_id=self._model_id,
203
206
  conversation_id=conversation_id,
204
- billing_project_id=self._teamspace.id,
207
+ billing_project_id=self._teamspace_id,
205
208
  metadata=metadata,
206
209
  name=conversation,
207
210
  stream=stream,
211
+ **kwargs,
208
212
  )
209
213
  if not stream:
210
214
  if conversation and not conversation_id:
@@ -221,7 +225,7 @@ class LLM:
221
225
  conversation: Optional[str] = None,
222
226
  metadata: Optional[Dict[str, str]] = None,
223
227
  stream: bool = False,
224
- upload_local_images: bool = False,
228
+ **kwargs: Any,
225
229
  ) -> Union[str, Generator[str, None, None]]:
226
230
  if conversation and conversation not in self._conversations:
227
231
  self._get_conversations()
@@ -232,8 +236,6 @@ class LLM:
232
236
  for image in images:
233
237
  if not isinstance(image, str):
234
238
  raise NotImplementedError(f"Image type {type(image)} are not supported yet.")
235
- if not image.startswith("http") and upload_local_images:
236
- self._teamspace.upload_file(file_path=image, remote_path=f"images/{os.path.basename(image)}")
237
239
 
238
240
  conversation_id = self._conversations.get(conversation) if conversation else None
239
241
 
@@ -246,7 +248,7 @@ class LLM:
246
248
  conversation,
247
249
  metadata,
248
250
  stream,
249
- upload_local_images,
251
+ **kwargs,
250
252
  )
251
253
 
252
254
  output = self._llm_api.start_conversation(
@@ -254,12 +256,13 @@ class LLM:
254
256
  system_prompt=system_prompt,
255
257
  max_completion_tokens=max_completion_tokens,
256
258
  images=images,
257
- assistant_id=self._model.id,
259
+ assistant_id=self._model_id,
258
260
  conversation_id=conversation_id,
259
- billing_project_id=self._teamspace.id,
261
+ billing_project_id=self._teamspace_id,
260
262
  metadata=metadata,
261
263
  name=conversation,
262
264
  stream=stream,
265
+ **kwargs,
263
266
  )
264
267
  if not stream:
265
268
  if conversation and not conversation_id:
@@ -272,7 +275,7 @@ class LLM:
272
275
  return list(self._conversations.keys())
273
276
 
274
277
  def _get_conversation_messages(self, conversation_id: str) -> Optional[str]:
275
- return self._llm_api.get_conversation(assistant_id=self._model.id, conversation_id=conversation_id)
278
+ return self._llm_api.get_conversation(assistant_id=self._model_id, conversation_id=conversation_id)
276
279
 
277
280
  def get_history(self, conversation: str) -> Optional[List[Dict]]:
278
281
  if conversation not in self._conversations:
@@ -297,7 +300,7 @@ class LLM:
297
300
  self._get_conversations()
298
301
  if conversation in self._conversations:
299
302
  self._llm_api.reset_conversation(
300
- assistant_id=self._model.id,
303
+ assistant_id=self._model_id,
301
304
  conversation_id=self._conversations[conversation],
302
305
  )
303
306
  del self._conversations[conversation]
@@ -0,0 +1,8 @@
1
+ {
2
+ "openai/gpt-4o": "ast_01jdjds71fs8gt47jexzed4czs",
3
+ "openai/gpt-4": "ast_01jd38ze6tjbrcd4942nhz41zn",
4
+ "openai/o3-mini": "ast_01jz3t13fhnjhh11t1k8b5gyp1",
5
+ "anthropic/claude-3-5-sonnet-20240620": "ast_01jd3923a6p98rqwh3dpj686pq",
6
+ "google/gemini-2.5-pro": "ast_01jz3tdb1fhey798k95pv61v57",
7
+ "google/gemini-2.5-flash": "ast_01jz3thxskg4fcdk4xhkjkym5a"
8
+ }
@@ -86,6 +86,12 @@ class Pipeline:
86
86
  if len(steps) == 0:
87
87
  raise ValueError("The provided steps is empty")
88
88
 
89
+ provided_cloud_account = None
90
+ if self._cloud_account:
91
+ provided_cloud_account = self._cloud_account
92
+ elif self._default_cluster:
93
+ provided_cloud_account = self._default_cluster.cluster_id
94
+
89
95
  for step_idx, pipeline_step in enumerate(steps):
90
96
  if pipeline_step.name in [None, ""]:
91
97
  pipeline_step.name = f"step-{step_idx}"
@@ -98,20 +104,29 @@ class Pipeline:
98
104
  pipeline_step.cloud_account = self._studio.cloud_account
99
105
  pipeline_step.studio = self._studio
100
106
 
107
+ if not pipeline_step.cloud_account and isinstance(provided_cloud_account, str):
108
+ pipeline_step.cloud_account = provided_cloud_account
109
+
101
110
  cluster_ids = set(step.cloud_account for step in steps if step.cloud_account not in ["", None]) # noqa: C401
102
111
 
103
- cloud_account = list(cluster_ids)[0] if len(cluster_ids) == 1 and self._cloud_account is None else "" # noqa: RUF015
112
+ cloud_account = (
113
+ list(cluster_ids)[0] if len(cluster_ids) == 1 and self._cloud_account is None else "" # noqa: RUF015
114
+ )
104
115
 
105
116
  steps = [step.to_proto(self._teamspace, cloud_account, self._shared_filesystem) for step in steps]
106
117
 
107
118
  proto_steps = prepare_steps(steps)
108
119
  schedules = schedules or []
109
120
 
121
+ for schedule_idx, schedule in enumerate(schedules):
122
+ if schedule.name is None:
123
+ schedule.name = f"schedule-{schedule_idx}"
124
+
110
125
  parent_pipeline_id = None if self._pipeline is None else self._pipeline.id
111
126
 
112
127
  self._pipeline = self._pipeline_api.create_pipeline(
113
128
  self._name,
114
- self._teamspace.id,
129
+ self._teamspace,
115
130
  proto_steps,
116
131
  self._shared_filesystem,
117
132
  schedules,
@@ -1,7 +1,8 @@
1
1
  import os
2
2
  from typing import Any, ClassVar, Dict, List
3
3
 
4
- from lightning_sdk.lightning_cloud.openapi.models import V1JobSpec, V1Pipeline, V1PipelineStepType
4
+ from lightning_sdk.lightning_cloud.openapi.models import V1Pipeline, V1PipelineStepType
5
+ from lightning_sdk.pipeline.utils import _get_spec
5
6
 
6
7
 
7
8
  class PipelinePrinter:
@@ -31,7 +32,7 @@ class PipelinePrinter:
31
32
  self._schedules = schedules
32
33
  cluster_ids: set[str] = set()
33
34
  for step in self._proto_steps:
34
- job_spec = self._get_spec(step)
35
+ job_spec = _get_spec(step)
35
36
  cluster_ids.add(job_spec.cluster_id)
36
37
  self._cluster_ids = cluster_ids
37
38
 
@@ -107,15 +108,15 @@ class PipelinePrinter:
107
108
 
108
109
  if self._shared_filesystem.enabled and len(self._cluster_ids) == 1:
109
110
  shared_path = ""
111
+ cluster_id = list(self._cluster_ids)[0] # noqa: RUF015
110
112
  if self._pipeline.shared_filesystem.s3_folder:
111
- cluster_id = list(self._cluster_ids)[0] # noqa: RUF015
112
113
  shared_path = f"/teamspace/s3_folders/pipelines-{cluster_id}"
114
+ if self._pipeline.shared_filesystem.gcs_folder:
115
+ shared_path = f"/teamspace/gcs_folders/pipelines-{cluster_id}"
116
+ if self._pipeline.shared_filesystem.efs:
117
+ shared_path = f"/teamspace/efs_connections/pipelines-{cluster_id}"
118
+ if self._pipeline.shared_filesystem.filestore:
119
+ shared_path = f"/teamspace/gcs_connections/pipelines-{cluster_id}"
120
+
113
121
  if shared_path:
114
122
  self._print(f" - {shared_path}")
115
-
116
- def _get_spec(self, step: Any) -> V1JobSpec:
117
- if step.type == V1PipelineStepType.DEPLOYMENT:
118
- return step.deployment.spec
119
- if step.type == V1PipelineStepType.MMT:
120
- return step.mmt.spec
121
- return step.job.spec
@@ -69,12 +69,15 @@ class DeploymentStep:
69
69
 
70
70
  self.machine = machine or Machine.CPU
71
71
  self.image = image
72
+ autoscaling_metric_name = (
73
+ ("CPU" if self.machine.is_cpu() else "GPU") if isinstance(self.machine, Machine) else "CPU"
74
+ )
72
75
  self.autoscale = autoscale or AutoScaleConfig(
73
76
  min_replicas=0,
74
77
  max_replicas=1,
75
78
  target_metrics=[
76
79
  AutoScalingMetric(
77
- name="CPU" if self.machine.is_cpu() else "GPU",
80
+ name=autoscaling_metric_name,
78
81
  target=80,
79
82
  )
80
83
  ],
@@ -1,6 +1,11 @@
1
- from typing import List, Literal, Optional, Union
2
-
3
- from lightning_sdk.lightning_cloud.openapi.models import V1PipelineStep, V1SharedFilesystem
1
+ from typing import Any, List, Literal, Optional, Union
2
+
3
+ from lightning_sdk.lightning_cloud.openapi.models import (
4
+ V1JobSpec,
5
+ V1PipelineStep,
6
+ V1PipelineStepType,
7
+ V1SharedFilesystem,
8
+ )
4
9
  from lightning_sdk.studio import Studio
5
10
 
6
11
  DEFAULT = "DEFAULT"
@@ -21,7 +26,7 @@ def prepare_steps(steps: List["V1PipelineStep"]) -> List["V1PipelineStep"]:
21
26
  else:
22
27
  raise ValueError(f"A step with the name {current_step.name} already exists.")
23
28
 
24
- if steps[0].wait_for != DEFAULT:
29
+ if steps[0].wait_for not in [None, DEFAULT, []]:
25
30
  raise ValueError("The first step isn't allowed to receive `wait_for=...`.")
26
31
 
27
32
  steps[0].wait_for = []
@@ -89,3 +94,23 @@ def _to_wait_for(wait_for: Optional[Union[str, List[str]]]) -> Optional[Union[Li
89
94
  return []
90
95
 
91
96
  return wait_for if isinstance(wait_for, list) else [wait_for]
97
+
98
+
99
+ def _get_cloud_account(steps: List[V1PipelineStep]) -> Optional[str]:
100
+ if len(steps) == 0:
101
+ return None
102
+
103
+ cluster_ids: set[str] = set()
104
+ for step in steps:
105
+ job_spec = _get_spec(step)
106
+ cluster_ids.add(job_spec.cluster_id)
107
+
108
+ return sorted(cluster_ids)[0]
109
+
110
+
111
+ def _get_spec(step: Any) -> V1JobSpec:
112
+ if step.type == V1PipelineStepType.DEPLOYMENT:
113
+ return step.deployment.spec
114
+ if step.type == V1PipelineStepType.MMT:
115
+ return step.mmt.spec
116
+ return step.job.spec
lightning_sdk/studio.py CHANGED
@@ -88,6 +88,9 @@ class Studio:
88
88
 
89
89
  self._plugins = {}
90
90
 
91
+ if self._teamspace is None:
92
+ raise ValueError("Couldn't resolve teamspace from the provided name, org, or user")
93
+
91
94
  if provider is not None:
92
95
  if isinstance(provider, str) and provider in Provider.__members__:
93
96
  provider = Provider(provider)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: lightning_sdk
3
- Version: 0.2.23
3
+ Version: 0.2.24rc1
4
4
  Summary: SDK to develop using Lightning AI Studios
5
5
  Author-email: Lightning-AI <justus@lightning.ai>
6
6
  License: MIT License