lightning-sdk 0.2.23__py3-none-any.whl → 0.2.24rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lightning_sdk/__init__.py +1 -1
- lightning_sdk/api/base_studio_api.py +9 -2
- lightning_sdk/api/deployment_api.py +9 -9
- lightning_sdk/api/llm_api.py +23 -13
- lightning_sdk/api/pipeline_api.py +31 -11
- lightning_sdk/api/studio_api.py +4 -0
- lightning_sdk/base_studio.py +22 -6
- lightning_sdk/deployment/deployment.py +17 -7
- lightning_sdk/lightning_cloud/openapi/__init__.py +18 -0
- lightning_sdk/lightning_cloud/openapi/api/__init__.py +2 -0
- lightning_sdk/lightning_cloud/openapi/api/assistants_service_api.py +114 -1
- lightning_sdk/lightning_cloud/openapi/api/cloudy_service_api.py +129 -0
- lightning_sdk/lightning_cloud/openapi/api/cluster_service_api.py +9 -1
- lightning_sdk/lightning_cloud/openapi/api/lit_logger_service_api.py +13 -1
- lightning_sdk/lightning_cloud/openapi/api/organizations_service_api.py +105 -0
- lightning_sdk/lightning_cloud/openapi/api/pipelines_service_api.py +4 -4
- lightning_sdk/lightning_cloud/openapi/api/user_service_api.py +105 -0
- lightning_sdk/lightning_cloud/openapi/api/volume_service_api.py +258 -0
- lightning_sdk/lightning_cloud/openapi/models/__init__.py +16 -0
- lightning_sdk/lightning_cloud/openapi/models/agents_id_body.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/assistant_id_conversations_body.py +79 -1
- lightning_sdk/lightning_cloud/openapi/models/credits_autoreplenish_body.py +175 -0
- lightning_sdk/lightning_cloud/openapi/models/credits_autoreplenish_body1.py +175 -0
- lightning_sdk/lightning_cloud/openapi/models/externalv1_user_status.py +53 -1
- lightning_sdk/lightning_cloud/openapi/models/orgs_id_body.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/pipelines_id_body1.py +123 -0
- lightning_sdk/lightning_cloud/openapi/models/project_id_agents_body.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/project_id_cloudspaces_body.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/project_id_schedules_body.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/schedules_id_body.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/update.py +29 -3
- lightning_sdk/lightning_cloud/openapi/models/v1_assistant.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_billing_tier.py +1 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_cloud_provider.py +1 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space_environment_config.py +29 -3
- lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space_environment_template_config.py +29 -3
- lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space_environment_type.py +1 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space_specialized_view.py +104 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_cloudy_expert.py +279 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_cluster_accelerator.py +79 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_cluster_capacity_reservation.py +53 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_cluster_spec.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_cluster_status.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_conversation_response_chunk.py +29 -3
- lightning_sdk/lightning_cloud/openapi/models/v1_create_cloud_space_environment_template_request.py +29 -3
- lightning_sdk/lightning_cloud/openapi/models/v1_create_organization_request.py +79 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_data_connection_tier.py +103 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_deployment_status.py +47 -21
- lightning_sdk/lightning_cloud/openapi/models/v1_external_cluster_spec.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_filestore_data_connection.py +29 -3
- lightning_sdk/lightning_cloud/openapi/models/v1_get_job_stats_response.py +53 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_get_user_response.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_get_volume_response.py +123 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_instance_overprovisioning_spec.py +1 -27
- lightning_sdk/lightning_cloud/openapi/models/v1_kubernetes_direct_v1.py +149 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_kubernetes_direct_v1_status.py +149 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_list_cloudy_experts_response.py +123 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_magic_link_login_response.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_organization.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_rule_resource.py +1 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_schedule.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_schedule_action_type.py +104 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_schedule_resource_type.py +1 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_token_usage.py +175 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_update_organization_credits_auto_replenish_response.py +97 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_update_user_credits_auto_replenish_response.py +97 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_update_volume_response.py +123 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_user_features.py +317 -31
- lightning_sdk/lightning_cloud/openapi/models/volumes_id_body.py +123 -0
- lightning_sdk/llm/llm.py +118 -115
- lightning_sdk/llm/public_assistants.json +8 -0
- lightning_sdk/pipeline/pipeline.py +17 -2
- lightning_sdk/pipeline/printer.py +11 -10
- lightning_sdk/pipeline/steps.py +4 -1
- lightning_sdk/pipeline/utils.py +29 -4
- lightning_sdk/studio.py +3 -0
- {lightning_sdk-0.2.23.dist-info → lightning_sdk-0.2.24rc1.dist-info}/METADATA +1 -1
- {lightning_sdk-0.2.23.dist-info → lightning_sdk-0.2.24rc1.dist-info}/RECORD +83 -64
- {lightning_sdk-0.2.23.dist-info → lightning_sdk-0.2.24rc1.dist-info}/LICENSE +0 -0
- {lightning_sdk-0.2.23.dist-info → lightning_sdk-0.2.24rc1.dist-info}/WHEEL +0 -0
- {lightning_sdk-0.2.23.dist-info → lightning_sdk-0.2.24rc1.dist-info}/entry_points.txt +0 -0
- {lightning_sdk-0.2.23.dist-info → lightning_sdk-0.2.24rc1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
# coding: utf-8
|
|
2
|
+
|
|
3
|
+
"""
|
|
4
|
+
external/v1/auth_service.proto
|
|
5
|
+
|
|
6
|
+
No description provided (generated by Swagger Codegen https://github.com/swagger-api/swagger-codegen) # noqa: E501
|
|
7
|
+
|
|
8
|
+
OpenAPI spec version: version not set
|
|
9
|
+
|
|
10
|
+
Generated by: https://github.com/swagger-api/swagger-codegen.git
|
|
11
|
+
|
|
12
|
+
NOTE
|
|
13
|
+
----
|
|
14
|
+
standard swagger-codegen-cli for this python client has been modified
|
|
15
|
+
by custom templates. The purpose of these templates is to include
|
|
16
|
+
typing information in the API and Model code. Please refer to the
|
|
17
|
+
main grid repository for more info
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
import pprint
|
|
21
|
+
import re # noqa: F401
|
|
22
|
+
|
|
23
|
+
from typing import TYPE_CHECKING
|
|
24
|
+
|
|
25
|
+
import six
|
|
26
|
+
|
|
27
|
+
if TYPE_CHECKING:
|
|
28
|
+
from datetime import datetime
|
|
29
|
+
from lightning_sdk.lightning_cloud.openapi.models import *
|
|
30
|
+
|
|
31
|
+
class VolumesIdBody(object):
|
|
32
|
+
"""NOTE: This class is auto generated by the swagger code generator program.
|
|
33
|
+
|
|
34
|
+
Do not edit the class manually.
|
|
35
|
+
"""
|
|
36
|
+
"""
|
|
37
|
+
Attributes:
|
|
38
|
+
swagger_types (dict): The key is attribute name
|
|
39
|
+
and the value is attribute type.
|
|
40
|
+
attribute_map (dict): The key is attribute name
|
|
41
|
+
and the value is json key in definition.
|
|
42
|
+
"""
|
|
43
|
+
swagger_types = {
|
|
44
|
+
'volume': 'V1Volume'
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
attribute_map = {
|
|
48
|
+
'volume': 'volume'
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
def __init__(self, volume: 'V1Volume' =None): # noqa: E501
|
|
52
|
+
"""VolumesIdBody - a model defined in Swagger""" # noqa: E501
|
|
53
|
+
self._volume = None
|
|
54
|
+
self.discriminator = None
|
|
55
|
+
if volume is not None:
|
|
56
|
+
self.volume = volume
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def volume(self) -> 'V1Volume':
|
|
60
|
+
"""Gets the volume of this VolumesIdBody. # noqa: E501
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
:return: The volume of this VolumesIdBody. # noqa: E501
|
|
64
|
+
:rtype: V1Volume
|
|
65
|
+
"""
|
|
66
|
+
return self._volume
|
|
67
|
+
|
|
68
|
+
@volume.setter
|
|
69
|
+
def volume(self, volume: 'V1Volume'):
|
|
70
|
+
"""Sets the volume of this VolumesIdBody.
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
:param volume: The volume of this VolumesIdBody. # noqa: E501
|
|
74
|
+
:type: V1Volume
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
self._volume = volume
|
|
78
|
+
|
|
79
|
+
def to_dict(self) -> dict:
|
|
80
|
+
"""Returns the model properties as a dict"""
|
|
81
|
+
result = {}
|
|
82
|
+
|
|
83
|
+
for attr, _ in six.iteritems(self.swagger_types):
|
|
84
|
+
value = getattr(self, attr)
|
|
85
|
+
if isinstance(value, list):
|
|
86
|
+
result[attr] = list(map(
|
|
87
|
+
lambda x: x.to_dict() if hasattr(x, "to_dict") else x,
|
|
88
|
+
value
|
|
89
|
+
))
|
|
90
|
+
elif hasattr(value, "to_dict"):
|
|
91
|
+
result[attr] = value.to_dict()
|
|
92
|
+
elif isinstance(value, dict):
|
|
93
|
+
result[attr] = dict(map(
|
|
94
|
+
lambda item: (item[0], item[1].to_dict())
|
|
95
|
+
if hasattr(item[1], "to_dict") else item,
|
|
96
|
+
value.items()
|
|
97
|
+
))
|
|
98
|
+
else:
|
|
99
|
+
result[attr] = value
|
|
100
|
+
if issubclass(VolumesIdBody, dict):
|
|
101
|
+
for key, value in self.items():
|
|
102
|
+
result[key] = value
|
|
103
|
+
|
|
104
|
+
return result
|
|
105
|
+
|
|
106
|
+
def to_str(self) -> str:
|
|
107
|
+
"""Returns the string representation of the model"""
|
|
108
|
+
return pprint.pformat(self.to_dict())
|
|
109
|
+
|
|
110
|
+
def __repr__(self) -> str:
|
|
111
|
+
"""For `print` and `pprint`"""
|
|
112
|
+
return self.to_str()
|
|
113
|
+
|
|
114
|
+
def __eq__(self, other: 'VolumesIdBody') -> bool:
|
|
115
|
+
"""Returns true if both objects are equal"""
|
|
116
|
+
if not isinstance(other, VolumesIdBody):
|
|
117
|
+
return False
|
|
118
|
+
|
|
119
|
+
return self.__dict__ == other.__dict__
|
|
120
|
+
|
|
121
|
+
def __ne__(self, other: 'VolumesIdBody') -> bool:
|
|
122
|
+
"""Returns true if both objects are not equal"""
|
|
123
|
+
return not self == other
|
lightning_sdk/llm/llm.py
CHANGED
|
@@ -1,19 +1,37 @@
|
|
|
1
|
+
import json
|
|
1
2
|
import os
|
|
2
|
-
import
|
|
3
|
-
from typing import AsyncGenerator, Dict, Generator, List, Optional, Set, Tuple, Union
|
|
3
|
+
from typing import Any, AsyncGenerator, ClassVar, Dict, Generator, List, Optional, Tuple, Union
|
|
4
4
|
|
|
5
5
|
from lightning_sdk.api.llm_api import LLMApi
|
|
6
|
-
from lightning_sdk.cli.teamspace_menu import _TeamspacesMenu
|
|
7
|
-
from lightning_sdk.lightning_cloud.openapi import V1Assistant
|
|
8
6
|
from lightning_sdk.lightning_cloud.openapi.models.v1_conversation_response_chunk import V1ConversationResponseChunk
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
7
|
+
|
|
8
|
+
PUBLIC_MODEL_PROVIDERS: Dict[str, str] = {
|
|
9
|
+
"openai": "OpenAI",
|
|
10
|
+
"anthropic": "Anthropic",
|
|
11
|
+
"google": "Google",
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _load_public_assistants() -> Dict[str, str]:
|
|
16
|
+
"""Load public assistants from a JSON file."""
|
|
17
|
+
try:
|
|
18
|
+
json_path = os.path.join(os.path.dirname(__file__), "public_assistants.json")
|
|
19
|
+
with open(json_path) as f:
|
|
20
|
+
return json.load(f)
|
|
21
|
+
except Exception as e:
|
|
22
|
+
print(f"[warning] Failed to load public_assistants.json: {e}")
|
|
23
|
+
return {}
|
|
14
24
|
|
|
15
25
|
|
|
16
26
|
class LLM:
|
|
27
|
+
_auth_info_cached: ClassVar[bool] = False
|
|
28
|
+
_cached_auth_info: ClassVar[Dict[str, Optional[str]]] = {}
|
|
29
|
+
_llm_api_cache: ClassVar[Dict[Optional[str], LLMApi]] = {}
|
|
30
|
+
_public_assistants: ClassVar[Optional[Dict[str, str]]] = None
|
|
31
|
+
|
|
32
|
+
def __new__(cls, name: str, teamspace: Optional[str] = None, enable_async: Optional[bool] = False) -> "LLM":
|
|
33
|
+
return super().__new__(cls)
|
|
34
|
+
|
|
17
35
|
def __init__(
|
|
18
36
|
self,
|
|
19
37
|
name: str,
|
|
@@ -36,67 +54,18 @@ class LLM:
|
|
|
36
54
|
Raises:
|
|
37
55
|
ValueError: If teamspace information cannot be resolved.
|
|
38
56
|
"""
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
possible_teamspaces = menu._get_possible_teamspaces(user)
|
|
42
|
-
if teamspace is None:
|
|
43
|
-
# get current teamspace
|
|
44
|
-
self._teamspace = _resolve_teamspace(teamspace=None, org=None, user=None)
|
|
45
|
-
else:
|
|
46
|
-
self._teamspace = Teamspace(**menu._get_teamspace_from_name(teamspace, possible_teamspaces))
|
|
47
|
-
|
|
48
|
-
if self._teamspace is None:
|
|
49
|
-
# select the first available teamspace
|
|
50
|
-
first_teamspace = next(iter(possible_teamspaces.values()), None)
|
|
51
|
-
|
|
52
|
-
if first_teamspace:
|
|
53
|
-
self._teamspace = Teamspace(
|
|
54
|
-
name=first_teamspace["name"],
|
|
55
|
-
org=first_teamspace["org"],
|
|
56
|
-
user=first_teamspace["user"],
|
|
57
|
-
)
|
|
58
|
-
warnings.warn(
|
|
59
|
-
f"No teamspace given. Using teamspace: {self._teamspace.name}.",
|
|
60
|
-
UserWarning,
|
|
61
|
-
stacklevel=2,
|
|
62
|
-
)
|
|
63
|
-
|
|
64
|
-
if self._teamspace is None:
|
|
65
|
-
raise ValueError("Teamspace is required for billing but could not be resolved. ")
|
|
66
|
-
|
|
67
|
-
self._user = user
|
|
57
|
+
# TODO support user input teamspace
|
|
58
|
+
self._get_auth_info()
|
|
68
59
|
|
|
69
60
|
self._model_provider, self._model_name = self._parse_model_name(name)
|
|
70
|
-
|
|
71
|
-
self._llm_api = LLMApi()
|
|
72
61
|
self._enable_async = enable_async
|
|
73
62
|
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
63
|
+
# Reuse LLMApi per teamspace (as billing is based on teamspace)
|
|
64
|
+
if teamspace not in LLM._llm_api_cache:
|
|
65
|
+
LLM._llm_api_cache[teamspace] = LLMApi()
|
|
66
|
+
self._llm_api = LLM._llm_api_cache[teamspace]
|
|
77
67
|
|
|
78
|
-
|
|
79
|
-
# check if user has access to the org
|
|
80
|
-
self._org_models = self._build_model_lookup(self._get_org_models())
|
|
81
|
-
except ApiException:
|
|
82
|
-
warnings.warn(
|
|
83
|
-
f"User is not authenticated to access the model in organization: '{self._model_provider}'.\n"
|
|
84
|
-
" Proceeding with appropriate org models, user models, or public models.",
|
|
85
|
-
UserWarning,
|
|
86
|
-
stacklevel=2,
|
|
87
|
-
)
|
|
88
|
-
self._model_provider = None
|
|
89
|
-
raise
|
|
90
|
-
except ApiException:
|
|
91
|
-
if isinstance(self._teamspace.owner, Organization):
|
|
92
|
-
self._org = self._teamspace.owner
|
|
93
|
-
else:
|
|
94
|
-
self._org = None
|
|
95
|
-
self._org_models = self._build_model_lookup(self._get_org_models())
|
|
96
|
-
|
|
97
|
-
self._public_models = self._build_model_lookup(self._get_public_models())
|
|
98
|
-
self._user_models = self._build_model_lookup(self._get_user_models())
|
|
99
|
-
self._model = self._get_model()
|
|
68
|
+
self._model_id = self._get_model_id()
|
|
100
69
|
self._conversations = {}
|
|
101
70
|
|
|
102
71
|
@property
|
|
@@ -107,9 +76,33 @@ class LLM:
|
|
|
107
76
|
def provider(self) -> str:
|
|
108
77
|
return self._model_provider
|
|
109
78
|
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
79
|
+
def _get_auth_info(self) -> None:
|
|
80
|
+
if not LLM._auth_info_cached:
|
|
81
|
+
teamspace_name = os.environ.get("LIGHTNING_TEAMSPACE", None)
|
|
82
|
+
if teamspace_name is None:
|
|
83
|
+
raise ValueError(
|
|
84
|
+
"Teamspace name must be provided either through "
|
|
85
|
+
"the environment variable LIGHTNING_TEAMSPACE or as an argument."
|
|
86
|
+
)
|
|
87
|
+
LLM._cached_auth_info = {
|
|
88
|
+
"teamspace_name": teamspace_name,
|
|
89
|
+
"teamspace_id": os.environ.get("LIGHTNING_CLOUD_PROJECT_ID", None),
|
|
90
|
+
"user_name": os.environ.get("LIGHTNING_USERNAME", ""),
|
|
91
|
+
"user_id": os.environ.get("LIGHTNING_USER_ID", None),
|
|
92
|
+
"org_name": os.environ.get("LIGHTNING_ORG", ""),
|
|
93
|
+
"cloud_url": os.environ.get("LIGHTNING_CLOUD_URL", None),
|
|
94
|
+
}
|
|
95
|
+
LLM._auth_info_cached = True
|
|
96
|
+
if LLM._public_assistants is None:
|
|
97
|
+
LLM._public_assistants = _load_public_assistants()
|
|
98
|
+
# Always assign to the current instance
|
|
99
|
+
self._teamspace_name = LLM._cached_auth_info["teamspace_name"]
|
|
100
|
+
self._teamspace_id = LLM._cached_auth_info["teamspace_id"]
|
|
101
|
+
self._user_name = LLM._cached_auth_info["user_name"]
|
|
102
|
+
self._user_id = LLM._cached_auth_info["user_id"]
|
|
103
|
+
self._org_name = LLM._cached_auth_info["org_name"]
|
|
104
|
+
self._cloud_url = LLM._cached_auth_info["cloud_url"]
|
|
105
|
+
self._org = None
|
|
113
106
|
|
|
114
107
|
def _parse_model_name(self, name: str) -> Tuple[str, str]:
|
|
115
108
|
parts = name.split("/")
|
|
@@ -117,50 +110,60 @@ class LLM:
|
|
|
117
110
|
# a user model or a org model
|
|
118
111
|
return None, parts[0]
|
|
119
112
|
if len(parts) == 2:
|
|
120
|
-
return parts[0], parts[1]
|
|
113
|
+
return parts[0].lower(), parts[1]
|
|
121
114
|
raise ValueError(
|
|
122
115
|
f"Model name must be in the format `organization/model_name` or `model_name`, but got '{name}'."
|
|
123
116
|
)
|
|
124
117
|
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
return self._user_models.get(self._model_name)[0]
|
|
148
|
-
|
|
149
|
-
available_models = []
|
|
150
|
-
if self._public_models:
|
|
151
|
-
available_models.append(f"Public Models: {', '.join(self._public_models.keys())}")
|
|
152
|
-
|
|
153
|
-
if self._org and self._org_models:
|
|
154
|
-
available_models.append(f"Org ({self._org.name}) Models: {', '.join(self._org_models.keys())}")
|
|
118
|
+
# returns the assistant ID
|
|
119
|
+
def _get_model_id(self) -> str:
|
|
120
|
+
if self._model_provider in PUBLIC_MODEL_PROVIDERS:
|
|
121
|
+
# if prod
|
|
122
|
+
if (
|
|
123
|
+
self._cloud_url == "https://lightning.ai"
|
|
124
|
+
and LLM._public_assistants
|
|
125
|
+
and f"{self._model_provider}/{self._model_name}" in LLM._public_assistants
|
|
126
|
+
):
|
|
127
|
+
return LLM._public_assistants[f"{self._model_provider}/{self._model_name}"]
|
|
128
|
+
try:
|
|
129
|
+
return self._llm_api.get_assistant(
|
|
130
|
+
model_provider=PUBLIC_MODEL_PROVIDERS[self._model_provider],
|
|
131
|
+
model_name=self._model_name,
|
|
132
|
+
user_name="",
|
|
133
|
+
org_name="",
|
|
134
|
+
)
|
|
135
|
+
except Exception as e:
|
|
136
|
+
raise ValueError(
|
|
137
|
+
f"Public model '{self._model_provider}/{self._model_name}' not found. "
|
|
138
|
+
"Please check the model name or provider."
|
|
139
|
+
) from e
|
|
155
140
|
|
|
156
|
-
|
|
157
|
-
|
|
141
|
+
# Try organization model
|
|
142
|
+
try:
|
|
143
|
+
return self._llm_api.get_assistant(
|
|
144
|
+
model_provider="",
|
|
145
|
+
model_name=self._model_name,
|
|
146
|
+
user_name="",
|
|
147
|
+
org_name=self._model_provider,
|
|
148
|
+
)
|
|
149
|
+
except Exception:
|
|
150
|
+
pass
|
|
158
151
|
|
|
159
|
-
|
|
160
|
-
|
|
152
|
+
# Try user model
|
|
153
|
+
try:
|
|
154
|
+
return self._llm_api.get_assistant(
|
|
155
|
+
model_provider="",
|
|
156
|
+
model_name=self._model_name,
|
|
157
|
+
user_name=self._model_provider,
|
|
158
|
+
org_name="",
|
|
159
|
+
)
|
|
160
|
+
except Exception as user_error:
|
|
161
|
+
raise ValueError(
|
|
162
|
+
f"Model '{self._model_provider}/{self._model_name}' not found as either an org or user model.\n"
|
|
163
|
+
) from user_error
|
|
161
164
|
|
|
162
165
|
def _get_conversations(self) -> None:
|
|
163
|
-
conversations = self._llm_api.list_conversations(assistant_id=self.
|
|
166
|
+
conversations = self._llm_api.list_conversations(assistant_id=self._model_id)
|
|
164
167
|
for conversation in conversations:
|
|
165
168
|
if conversation.name and conversation.name not in self._conversations:
|
|
166
169
|
self._conversations[conversation.name] = conversation.id
|
|
@@ -191,7 +194,7 @@ class LLM:
|
|
|
191
194
|
conversation: Optional[str] = None,
|
|
192
195
|
metadata: Optional[Dict[str, str]] = None,
|
|
193
196
|
stream: bool = False,
|
|
194
|
-
|
|
197
|
+
**kwargs: Any,
|
|
195
198
|
) -> Union[str, AsyncGenerator[str, None]]:
|
|
196
199
|
conversation_id = self._conversations.get(conversation) if conversation else None
|
|
197
200
|
output = await self._llm_api.async_start_conversation(
|
|
@@ -199,12 +202,13 @@ class LLM:
|
|
|
199
202
|
system_prompt=system_prompt,
|
|
200
203
|
max_completion_tokens=max_completion_tokens,
|
|
201
204
|
images=images,
|
|
202
|
-
assistant_id=self.
|
|
205
|
+
assistant_id=self._model_id,
|
|
203
206
|
conversation_id=conversation_id,
|
|
204
|
-
billing_project_id=self.
|
|
207
|
+
billing_project_id=self._teamspace_id,
|
|
205
208
|
metadata=metadata,
|
|
206
209
|
name=conversation,
|
|
207
210
|
stream=stream,
|
|
211
|
+
**kwargs,
|
|
208
212
|
)
|
|
209
213
|
if not stream:
|
|
210
214
|
if conversation and not conversation_id:
|
|
@@ -221,7 +225,7 @@ class LLM:
|
|
|
221
225
|
conversation: Optional[str] = None,
|
|
222
226
|
metadata: Optional[Dict[str, str]] = None,
|
|
223
227
|
stream: bool = False,
|
|
224
|
-
|
|
228
|
+
**kwargs: Any,
|
|
225
229
|
) -> Union[str, Generator[str, None, None]]:
|
|
226
230
|
if conversation and conversation not in self._conversations:
|
|
227
231
|
self._get_conversations()
|
|
@@ -232,8 +236,6 @@ class LLM:
|
|
|
232
236
|
for image in images:
|
|
233
237
|
if not isinstance(image, str):
|
|
234
238
|
raise NotImplementedError(f"Image type {type(image)} are not supported yet.")
|
|
235
|
-
if not image.startswith("http") and upload_local_images:
|
|
236
|
-
self._teamspace.upload_file(file_path=image, remote_path=f"images/{os.path.basename(image)}")
|
|
237
239
|
|
|
238
240
|
conversation_id = self._conversations.get(conversation) if conversation else None
|
|
239
241
|
|
|
@@ -246,7 +248,7 @@ class LLM:
|
|
|
246
248
|
conversation,
|
|
247
249
|
metadata,
|
|
248
250
|
stream,
|
|
249
|
-
|
|
251
|
+
**kwargs,
|
|
250
252
|
)
|
|
251
253
|
|
|
252
254
|
output = self._llm_api.start_conversation(
|
|
@@ -254,12 +256,13 @@ class LLM:
|
|
|
254
256
|
system_prompt=system_prompt,
|
|
255
257
|
max_completion_tokens=max_completion_tokens,
|
|
256
258
|
images=images,
|
|
257
|
-
assistant_id=self.
|
|
259
|
+
assistant_id=self._model_id,
|
|
258
260
|
conversation_id=conversation_id,
|
|
259
|
-
billing_project_id=self.
|
|
261
|
+
billing_project_id=self._teamspace_id,
|
|
260
262
|
metadata=metadata,
|
|
261
263
|
name=conversation,
|
|
262
264
|
stream=stream,
|
|
265
|
+
**kwargs,
|
|
263
266
|
)
|
|
264
267
|
if not stream:
|
|
265
268
|
if conversation and not conversation_id:
|
|
@@ -272,7 +275,7 @@ class LLM:
|
|
|
272
275
|
return list(self._conversations.keys())
|
|
273
276
|
|
|
274
277
|
def _get_conversation_messages(self, conversation_id: str) -> Optional[str]:
|
|
275
|
-
return self._llm_api.get_conversation(assistant_id=self.
|
|
278
|
+
return self._llm_api.get_conversation(assistant_id=self._model_id, conversation_id=conversation_id)
|
|
276
279
|
|
|
277
280
|
def get_history(self, conversation: str) -> Optional[List[Dict]]:
|
|
278
281
|
if conversation not in self._conversations:
|
|
@@ -297,7 +300,7 @@ class LLM:
|
|
|
297
300
|
self._get_conversations()
|
|
298
301
|
if conversation in self._conversations:
|
|
299
302
|
self._llm_api.reset_conversation(
|
|
300
|
-
assistant_id=self.
|
|
303
|
+
assistant_id=self._model_id,
|
|
301
304
|
conversation_id=self._conversations[conversation],
|
|
302
305
|
)
|
|
303
306
|
del self._conversations[conversation]
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
{
|
|
2
|
+
"openai/gpt-4o": "ast_01jdjds71fs8gt47jexzed4czs",
|
|
3
|
+
"openai/gpt-4": "ast_01jd38ze6tjbrcd4942nhz41zn",
|
|
4
|
+
"openai/o3-mini": "ast_01jz3t13fhnjhh11t1k8b5gyp1",
|
|
5
|
+
"anthropic/claude-3-5-sonnet-20240620": "ast_01jd3923a6p98rqwh3dpj686pq",
|
|
6
|
+
"google/gemini-2.5-pro": "ast_01jz3tdb1fhey798k95pv61v57",
|
|
7
|
+
"google/gemini-2.5-flash": "ast_01jz3thxskg4fcdk4xhkjkym5a"
|
|
8
|
+
}
|
|
@@ -86,6 +86,12 @@ class Pipeline:
|
|
|
86
86
|
if len(steps) == 0:
|
|
87
87
|
raise ValueError("The provided steps is empty")
|
|
88
88
|
|
|
89
|
+
provided_cloud_account = None
|
|
90
|
+
if self._cloud_account:
|
|
91
|
+
provided_cloud_account = self._cloud_account
|
|
92
|
+
elif self._default_cluster:
|
|
93
|
+
provided_cloud_account = self._default_cluster.cluster_id
|
|
94
|
+
|
|
89
95
|
for step_idx, pipeline_step in enumerate(steps):
|
|
90
96
|
if pipeline_step.name in [None, ""]:
|
|
91
97
|
pipeline_step.name = f"step-{step_idx}"
|
|
@@ -98,20 +104,29 @@ class Pipeline:
|
|
|
98
104
|
pipeline_step.cloud_account = self._studio.cloud_account
|
|
99
105
|
pipeline_step.studio = self._studio
|
|
100
106
|
|
|
107
|
+
if not pipeline_step.cloud_account and isinstance(provided_cloud_account, str):
|
|
108
|
+
pipeline_step.cloud_account = provided_cloud_account
|
|
109
|
+
|
|
101
110
|
cluster_ids = set(step.cloud_account for step in steps if step.cloud_account not in ["", None]) # noqa: C401
|
|
102
111
|
|
|
103
|
-
cloud_account =
|
|
112
|
+
cloud_account = (
|
|
113
|
+
list(cluster_ids)[0] if len(cluster_ids) == 1 and self._cloud_account is None else "" # noqa: RUF015
|
|
114
|
+
)
|
|
104
115
|
|
|
105
116
|
steps = [step.to_proto(self._teamspace, cloud_account, self._shared_filesystem) for step in steps]
|
|
106
117
|
|
|
107
118
|
proto_steps = prepare_steps(steps)
|
|
108
119
|
schedules = schedules or []
|
|
109
120
|
|
|
121
|
+
for schedule_idx, schedule in enumerate(schedules):
|
|
122
|
+
if schedule.name is None:
|
|
123
|
+
schedule.name = f"schedule-{schedule_idx}"
|
|
124
|
+
|
|
110
125
|
parent_pipeline_id = None if self._pipeline is None else self._pipeline.id
|
|
111
126
|
|
|
112
127
|
self._pipeline = self._pipeline_api.create_pipeline(
|
|
113
128
|
self._name,
|
|
114
|
-
self._teamspace
|
|
129
|
+
self._teamspace,
|
|
115
130
|
proto_steps,
|
|
116
131
|
self._shared_filesystem,
|
|
117
132
|
schedules,
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
import os
|
|
2
2
|
from typing import Any, ClassVar, Dict, List
|
|
3
3
|
|
|
4
|
-
from lightning_sdk.lightning_cloud.openapi.models import
|
|
4
|
+
from lightning_sdk.lightning_cloud.openapi.models import V1Pipeline, V1PipelineStepType
|
|
5
|
+
from lightning_sdk.pipeline.utils import _get_spec
|
|
5
6
|
|
|
6
7
|
|
|
7
8
|
class PipelinePrinter:
|
|
@@ -31,7 +32,7 @@ class PipelinePrinter:
|
|
|
31
32
|
self._schedules = schedules
|
|
32
33
|
cluster_ids: set[str] = set()
|
|
33
34
|
for step in self._proto_steps:
|
|
34
|
-
job_spec =
|
|
35
|
+
job_spec = _get_spec(step)
|
|
35
36
|
cluster_ids.add(job_spec.cluster_id)
|
|
36
37
|
self._cluster_ids = cluster_ids
|
|
37
38
|
|
|
@@ -107,15 +108,15 @@ class PipelinePrinter:
|
|
|
107
108
|
|
|
108
109
|
if self._shared_filesystem.enabled and len(self._cluster_ids) == 1:
|
|
109
110
|
shared_path = ""
|
|
111
|
+
cluster_id = list(self._cluster_ids)[0] # noqa: RUF015
|
|
110
112
|
if self._pipeline.shared_filesystem.s3_folder:
|
|
111
|
-
cluster_id = list(self._cluster_ids)[0] # noqa: RUF015
|
|
112
113
|
shared_path = f"/teamspace/s3_folders/pipelines-{cluster_id}"
|
|
114
|
+
if self._pipeline.shared_filesystem.gcs_folder:
|
|
115
|
+
shared_path = f"/teamspace/gcs_folders/pipelines-{cluster_id}"
|
|
116
|
+
if self._pipeline.shared_filesystem.efs:
|
|
117
|
+
shared_path = f"/teamspace/efs_connections/pipelines-{cluster_id}"
|
|
118
|
+
if self._pipeline.shared_filesystem.filestore:
|
|
119
|
+
shared_path = f"/teamspace/gcs_connections/pipelines-{cluster_id}"
|
|
120
|
+
|
|
113
121
|
if shared_path:
|
|
114
122
|
self._print(f" - {shared_path}")
|
|
115
|
-
|
|
116
|
-
def _get_spec(self, step: Any) -> V1JobSpec:
|
|
117
|
-
if step.type == V1PipelineStepType.DEPLOYMENT:
|
|
118
|
-
return step.deployment.spec
|
|
119
|
-
if step.type == V1PipelineStepType.MMT:
|
|
120
|
-
return step.mmt.spec
|
|
121
|
-
return step.job.spec
|
lightning_sdk/pipeline/steps.py
CHANGED
|
@@ -69,12 +69,15 @@ class DeploymentStep:
|
|
|
69
69
|
|
|
70
70
|
self.machine = machine or Machine.CPU
|
|
71
71
|
self.image = image
|
|
72
|
+
autoscaling_metric_name = (
|
|
73
|
+
("CPU" if self.machine.is_cpu() else "GPU") if isinstance(self.machine, Machine) else "CPU"
|
|
74
|
+
)
|
|
72
75
|
self.autoscale = autoscale or AutoScaleConfig(
|
|
73
76
|
min_replicas=0,
|
|
74
77
|
max_replicas=1,
|
|
75
78
|
target_metrics=[
|
|
76
79
|
AutoScalingMetric(
|
|
77
|
-
name=
|
|
80
|
+
name=autoscaling_metric_name,
|
|
78
81
|
target=80,
|
|
79
82
|
)
|
|
80
83
|
],
|
lightning_sdk/pipeline/utils.py
CHANGED
|
@@ -1,6 +1,11 @@
|
|
|
1
|
-
from typing import List, Literal, Optional, Union
|
|
2
|
-
|
|
3
|
-
from lightning_sdk.lightning_cloud.openapi.models import
|
|
1
|
+
from typing import Any, List, Literal, Optional, Union
|
|
2
|
+
|
|
3
|
+
from lightning_sdk.lightning_cloud.openapi.models import (
|
|
4
|
+
V1JobSpec,
|
|
5
|
+
V1PipelineStep,
|
|
6
|
+
V1PipelineStepType,
|
|
7
|
+
V1SharedFilesystem,
|
|
8
|
+
)
|
|
4
9
|
from lightning_sdk.studio import Studio
|
|
5
10
|
|
|
6
11
|
DEFAULT = "DEFAULT"
|
|
@@ -21,7 +26,7 @@ def prepare_steps(steps: List["V1PipelineStep"]) -> List["V1PipelineStep"]:
|
|
|
21
26
|
else:
|
|
22
27
|
raise ValueError(f"A step with the name {current_step.name} already exists.")
|
|
23
28
|
|
|
24
|
-
if steps[0].wait_for
|
|
29
|
+
if steps[0].wait_for not in [None, DEFAULT, []]:
|
|
25
30
|
raise ValueError("The first step isn't allowed to receive `wait_for=...`.")
|
|
26
31
|
|
|
27
32
|
steps[0].wait_for = []
|
|
@@ -89,3 +94,23 @@ def _to_wait_for(wait_for: Optional[Union[str, List[str]]]) -> Optional[Union[Li
|
|
|
89
94
|
return []
|
|
90
95
|
|
|
91
96
|
return wait_for if isinstance(wait_for, list) else [wait_for]
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def _get_cloud_account(steps: List[V1PipelineStep]) -> Optional[str]:
|
|
100
|
+
if len(steps) == 0:
|
|
101
|
+
return None
|
|
102
|
+
|
|
103
|
+
cluster_ids: set[str] = set()
|
|
104
|
+
for step in steps:
|
|
105
|
+
job_spec = _get_spec(step)
|
|
106
|
+
cluster_ids.add(job_spec.cluster_id)
|
|
107
|
+
|
|
108
|
+
return sorted(cluster_ids)[0]
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _get_spec(step: Any) -> V1JobSpec:
|
|
112
|
+
if step.type == V1PipelineStepType.DEPLOYMENT:
|
|
113
|
+
return step.deployment.spec
|
|
114
|
+
if step.type == V1PipelineStepType.MMT:
|
|
115
|
+
return step.mmt.spec
|
|
116
|
+
return step.job.spec
|
lightning_sdk/studio.py
CHANGED
|
@@ -88,6 +88,9 @@ class Studio:
|
|
|
88
88
|
|
|
89
89
|
self._plugins = {}
|
|
90
90
|
|
|
91
|
+
if self._teamspace is None:
|
|
92
|
+
raise ValueError("Couldn't resolve teamspace from the provided name, org, or user")
|
|
93
|
+
|
|
91
94
|
if provider is not None:
|
|
92
95
|
if isinstance(provider, str) and provider in Provider.__members__:
|
|
93
96
|
provider = Provider(provider)
|