lightning-sdk 0.2.14__py3-none-any.whl → 0.2.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lightning_sdk/__init__.py +1 -1
- lightning_sdk/api/base_studio_api.py +79 -0
- lightning_sdk/api/cluster_api.py +83 -1
- lightning_sdk/api/license_api.py +48 -0
- lightning_sdk/api/llm_api.py +73 -12
- lightning_sdk/api/studio_api.py +50 -1
- lightning_sdk/api/teamspace_api.py +127 -1
- lightning_sdk/api/utils.py +4 -0
- lightning_sdk/base_studio.py +83 -0
- lightning_sdk/cli/create.py +21 -1
- lightning_sdk/cli/delete.py +6 -8
- lightning_sdk/cli/deploy/__init__.py +0 -0
- lightning_sdk/cli/deploy/_auth.py +189 -0
- lightning_sdk/cli/deploy/devbox.py +157 -0
- lightning_sdk/cli/{serve.py → deploy/serve.py} +22 -281
- lightning_sdk/cli/download.py +69 -16
- lightning_sdk/cli/entrypoint.py +1 -1
- lightning_sdk/cli/open.py +21 -2
- lightning_sdk/cli/start.py +12 -3
- lightning_sdk/cli/teamspace_menu.py +9 -1
- lightning_sdk/cli/upload.py +2 -5
- lightning_sdk/lightning_cloud/openapi/__init__.py +29 -0
- lightning_sdk/lightning_cloud/openapi/api/__init__.py +1 -0
- lightning_sdk/lightning_cloud/openapi/api/assistants_service_api.py +121 -0
- lightning_sdk/lightning_cloud/openapi/api/billing_service_api.py +9 -1
- lightning_sdk/lightning_cloud/openapi/api/cloud_space_service_api.py +226 -0
- lightning_sdk/lightning_cloud/openapi/api/cluster_service_api.py +105 -0
- lightning_sdk/lightning_cloud/openapi/api/file_system_service_api.py +178 -0
- lightning_sdk/lightning_cloud/openapi/api/jobs_service_api.py +984 -101
- lightning_sdk/lightning_cloud/openapi/api/product_license_service_api.py +525 -0
- lightning_sdk/lightning_cloud/openapi/api/storage_service_api.py +93 -0
- lightning_sdk/lightning_cloud/openapi/configuration.py +1 -1
- lightning_sdk/lightning_cloud/openapi/models/__init__.py +28 -0
- lightning_sdk/lightning_cloud/openapi/models/assistant_id_conversations_body.py +79 -1
- lightning_sdk/lightning_cloud/openapi/models/cloudspaces_id_body.py +53 -1
- lightning_sdk/lightning_cloud/openapi/models/deployment_id_alertingpolicies_body.py +331 -0
- lightning_sdk/lightning_cloud/openapi/models/deployment_id_alertingpolicies_body1.py +305 -0
- lightning_sdk/lightning_cloud/openapi/models/deployments_id_body.py +53 -1
- lightning_sdk/lightning_cloud/openapi/models/endpoints_id_body.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/model_id_versions_body.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/models_id_body.py +123 -0
- lightning_sdk/lightning_cloud/openapi/models/orgs_id_body.py +183 -1
- lightning_sdk/lightning_cloud/openapi/models/pipelines_id_body.py +6 -6
- lightning_sdk/lightning_cloud/openapi/models/project_id_cloudspaces_body.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/project_id_storage_body.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/projects_id_body.py +107 -3
- lightning_sdk/lightning_cloud/openapi/models/storage_complete_body.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/update.py +79 -1
- lightning_sdk/lightning_cloud/openapi/models/uploads_upload_id_body1.py +55 -3
- lightning_sdk/lightning_cloud/openapi/models/v1_aws_direct_v1.py +53 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_cloud_provider.py +3 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space.py +79 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space_environment_config.py +123 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space_environment_template_config.py +79 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space_environment_type.py +104 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space_source_type.py +103 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_cloudflare_v1.py +66 -66
- lightning_sdk/lightning_cloud/openapi/models/v1_cluster_spec.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_cluster_tagging_options.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_cluster_upload.py +149 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_complete_upload.py +55 -3
- lightning_sdk/lightning_cloud/openapi/models/v1_conversation.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_create_cloud_space_environment_template_request.py +79 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_delete_deployment_alerting_policy_response.py +175 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_deployment.py +53 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_deployment_alerting_event.py +487 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_deployment_alerting_policy.py +383 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_deployment_alerting_policy_frequency.py +105 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_deployment_alerting_policy_operation.py +105 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_deployment_alerting_policy_severity.py +106 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_deployment_alerting_policy_type.py +111 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_deployment_api.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_deployment_state.py +4 -4
- lightning_sdk/lightning_cloud/openapi/models/v1_endpoint.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_external_search_user.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_ge_list_deployment_routing_telemetry_response.py +123 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_get_cloud_space_instance_open_ports_response.py +123 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_get_deployment_routing_telemetry_content_response.py +123 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_get_job_stats_response.py +53 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_get_organization_storage_metadata_response.py +331 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_get_project_balance_response.py +1 -27
- lightning_sdk/lightning_cloud/openapi/models/v1_google_cloud_direct_v1.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_job_type.py +1 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_list_deployment_alerting_events_response.py +123 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_list_deployment_alerting_policies_response.py +175 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_list_product_licenses_response.py +123 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_managed_model.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_membership.py +43 -17
- lightning_sdk/lightning_cloud/openapi/models/v1_modify_filesystem_volume_response.py +97 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_organization.py +183 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_pipeline.py +6 -6
- lightning_sdk/lightning_cloud/openapi/models/v1_pipeline_state.py +111 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_presigned_url.py +53 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_product_license.py +409 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_product_license_check_response.py +123 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_project.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_project_membership.py +43 -17
- lightning_sdk/lightning_cloud/openapi/models/v1_project_settings.py +107 -3
- lightning_sdk/lightning_cloud/openapi/models/v1_project_storage.py +53 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_r2_data_connection.py +53 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_routing_telemetry.py +253 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_secret_type.py +1 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_server_alert_type.py +2 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_sleep_server_response.py +97 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_trigger_filesystem_upgrade_response.py +123 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_upload_project_artifact_response.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_usage_report.py +79 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_user_features.py +347 -113
- lightning_sdk/lightning_cloud/openapi/models/v1_user_requested_compute_config.py +27 -1
- lightning_sdk/lightning_cloud/rest_client.py +4 -0
- lightning_sdk/llm/llm.py +132 -40
- lightning_sdk/services/__init__.py +1 -1
- lightning_sdk/services/license.py +236 -0
- lightning_sdk/studio.py +62 -1
- lightning_sdk/teamspace.py +68 -0
- {lightning_sdk-0.2.14.dist-info → lightning_sdk-0.2.16.dist-info}/METADATA +1 -1
- {lightning_sdk-0.2.14.dist-info → lightning_sdk-0.2.16.dist-info}/RECORD +122 -86
- /lightning_sdk/services/{finetune/__init__.py → finetune_llm.py} +0 -0
- {lightning_sdk-0.2.14.dist-info → lightning_sdk-0.2.16.dist-info}/LICENSE +0 -0
- {lightning_sdk-0.2.14.dist-info → lightning_sdk-0.2.16.dist-info}/WHEEL +0 -0
- {lightning_sdk-0.2.14.dist-info → lightning_sdk-0.2.16.dist-info}/entry_points.txt +0 -0
- {lightning_sdk-0.2.14.dist-info → lightning_sdk-0.2.16.dist-info}/top_level.txt +0 -0
|
@@ -42,6 +42,7 @@ class V1UserRequestedComputeConfig(object):
|
|
|
42
42
|
"""
|
|
43
43
|
swagger_types = {
|
|
44
44
|
'affinity_identifier': 'str',
|
|
45
|
+
'cluster_override': 'str',
|
|
45
46
|
'count': 'int',
|
|
46
47
|
'cpu_image_override': 'str',
|
|
47
48
|
'disk_size': 'int',
|
|
@@ -56,6 +57,7 @@ class V1UserRequestedComputeConfig(object):
|
|
|
56
57
|
|
|
57
58
|
attribute_map = {
|
|
58
59
|
'affinity_identifier': 'affinityIdentifier',
|
|
60
|
+
'cluster_override': 'clusterOverride',
|
|
59
61
|
'count': 'count',
|
|
60
62
|
'cpu_image_override': 'cpuImageOverride',
|
|
61
63
|
'disk_size': 'diskSize',
|
|
@@ -68,9 +70,10 @@ class V1UserRequestedComputeConfig(object):
|
|
|
68
70
|
'spot': 'spot'
|
|
69
71
|
}
|
|
70
72
|
|
|
71
|
-
def __init__(self, affinity_identifier: 'str' =None, count: 'int' =None, cpu_image_override: 'str' =None, disk_size: 'int' =None, gpu_image_override: 'str' =None, id: 'str' =None, name: 'str' =None, requested_run_duration_seconds: 'str' =None, same_compute_on_resume: 'bool' =None, shm_size: 'int' =None, spot: 'bool' =None): # noqa: E501
|
|
73
|
+
def __init__(self, affinity_identifier: 'str' =None, cluster_override: 'str' =None, count: 'int' =None, cpu_image_override: 'str' =None, disk_size: 'int' =None, gpu_image_override: 'str' =None, id: 'str' =None, name: 'str' =None, requested_run_duration_seconds: 'str' =None, same_compute_on_resume: 'bool' =None, shm_size: 'int' =None, spot: 'bool' =None): # noqa: E501
|
|
72
74
|
"""V1UserRequestedComputeConfig - a model defined in Swagger""" # noqa: E501
|
|
73
75
|
self._affinity_identifier = None
|
|
76
|
+
self._cluster_override = None
|
|
74
77
|
self._count = None
|
|
75
78
|
self._cpu_image_override = None
|
|
76
79
|
self._disk_size = None
|
|
@@ -84,6 +87,8 @@ class V1UserRequestedComputeConfig(object):
|
|
|
84
87
|
self.discriminator = None
|
|
85
88
|
if affinity_identifier is not None:
|
|
86
89
|
self.affinity_identifier = affinity_identifier
|
|
90
|
+
if cluster_override is not None:
|
|
91
|
+
self.cluster_override = cluster_override
|
|
87
92
|
if count is not None:
|
|
88
93
|
self.count = count
|
|
89
94
|
if cpu_image_override is not None:
|
|
@@ -128,6 +133,27 @@ class V1UserRequestedComputeConfig(object):
|
|
|
128
133
|
|
|
129
134
|
self._affinity_identifier = affinity_identifier
|
|
130
135
|
|
|
136
|
+
@property
|
|
137
|
+
def cluster_override(self) -> 'str':
|
|
138
|
+
"""Gets the cluster_override of this V1UserRequestedComputeConfig. # noqa: E501
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
:return: The cluster_override of this V1UserRequestedComputeConfig. # noqa: E501
|
|
142
|
+
:rtype: str
|
|
143
|
+
"""
|
|
144
|
+
return self._cluster_override
|
|
145
|
+
|
|
146
|
+
@cluster_override.setter
|
|
147
|
+
def cluster_override(self, cluster_override: 'str'):
|
|
148
|
+
"""Sets the cluster_override of this V1UserRequestedComputeConfig.
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
:param cluster_override: The cluster_override of this V1UserRequestedComputeConfig. # noqa: E501
|
|
152
|
+
:type: str
|
|
153
|
+
"""
|
|
154
|
+
|
|
155
|
+
self._cluster_override = cluster_override
|
|
156
|
+
|
|
131
157
|
@property
|
|
132
158
|
def count(self) -> 'int':
|
|
133
159
|
"""Gets the count of this V1UserRequestedComputeConfig. # noqa: E501
|
|
@@ -35,6 +35,8 @@ from lightning_sdk.lightning_cloud.openapi import (
|
|
|
35
35
|
LitRegistryServiceApi,
|
|
36
36
|
PipelinesServiceApi,
|
|
37
37
|
SchedulesServiceApi,
|
|
38
|
+
ProductLicenseServiceApi,
|
|
39
|
+
CloudSpaceEnvironmentTemplateServiceApi
|
|
38
40
|
)
|
|
39
41
|
from lightning_sdk.lightning_cloud.openapi.rest import ApiException
|
|
40
42
|
from lightning_sdk.lightning_cloud.source_code.logs_socket_api import LightningLogsSocketAPI
|
|
@@ -97,6 +99,8 @@ class GridRestClient(
|
|
|
97
99
|
LitRegistryServiceApi,
|
|
98
100
|
PipelinesServiceApi,
|
|
99
101
|
SchedulesServiceApi,
|
|
102
|
+
ProductLicenseServiceApi,
|
|
103
|
+
CloudSpaceEnvironmentTemplateServiceApi
|
|
100
104
|
):
|
|
101
105
|
|
|
102
106
|
def __init__(self, api_client: Optional[ApiClient] = None):
|
lightning_sdk/llm/llm.py
CHANGED
|
@@ -1,52 +1,113 @@
|
|
|
1
|
-
|
|
1
|
+
import os
|
|
2
|
+
import warnings
|
|
3
|
+
from typing import Dict, Generator, List, Optional, Set, Tuple, Union
|
|
2
4
|
|
|
3
|
-
from lightning_sdk.api import UserApi
|
|
4
5
|
from lightning_sdk.api.llm_api import LLMApi
|
|
5
|
-
from lightning_sdk.
|
|
6
|
+
from lightning_sdk.cli.teamspace_menu import _TeamspacesMenu
|
|
6
7
|
from lightning_sdk.lightning_cloud.openapi import V1Assistant
|
|
8
|
+
from lightning_sdk.lightning_cloud.openapi.models.v1_conversation_response_chunk import V1ConversationResponseChunk
|
|
7
9
|
from lightning_sdk.lightning_cloud.openapi.rest import ApiException
|
|
8
10
|
from lightning_sdk.organization import Organization
|
|
9
|
-
from lightning_sdk.
|
|
10
|
-
from lightning_sdk.
|
|
11
|
+
from lightning_sdk.owner import Owner
|
|
12
|
+
from lightning_sdk.teamspace import Teamspace
|
|
13
|
+
from lightning_sdk.utils.resolve import _get_authed_user, _resolve_org, _resolve_teamspace
|
|
11
14
|
|
|
12
15
|
|
|
13
16
|
class LLM:
|
|
14
17
|
def __init__(
|
|
15
18
|
self,
|
|
16
19
|
name: str,
|
|
17
|
-
|
|
18
|
-
org: Union[str, "Organization", None] = None,
|
|
20
|
+
teamspace: Optional[str] = None,
|
|
19
21
|
) -> None:
|
|
20
|
-
|
|
21
|
-
|
|
22
|
+
"""Initializes the LLM instance with teamspace information, which is required for billing purposes.
|
|
23
|
+
|
|
24
|
+
Teamspace information is resolved through the following methods:
|
|
25
|
+
1. `.lightning/credentials.json` - Attempts to retrieve the teamspace from the local credentials file.
|
|
26
|
+
2. Environment Variables - Checks for `LIGHTNING_*` environment variables.
|
|
27
|
+
3. User Authentication - Redirects the user to the login page if teamspace information is not found.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
name (str): The name of the model or resource.
|
|
31
|
+
teamspace (Optional[str]): The specified teamspace for billing. If not provided, it will be resolved
|
|
32
|
+
through the above methods.
|
|
33
|
+
|
|
34
|
+
Raises:
|
|
35
|
+
ValueError: If teamspace information cannot be resolved.
|
|
36
|
+
"""
|
|
37
|
+
menu = _TeamspacesMenu()
|
|
38
|
+
user = _get_authed_user()
|
|
39
|
+
possible_teamspaces = menu._get_possible_teamspaces(user)
|
|
40
|
+
if teamspace is None:
|
|
41
|
+
# get current teamspace
|
|
42
|
+
self._teamspace = _resolve_teamspace(teamspace=None, org=None, user=None)
|
|
43
|
+
else:
|
|
44
|
+
self._teamspace = Teamspace(**menu._get_teamspace_from_name(teamspace, possible_teamspaces))
|
|
45
|
+
|
|
46
|
+
if self._teamspace is None:
|
|
47
|
+
# select the first available teamspace
|
|
48
|
+
first_teamspace = next(iter(possible_teamspaces.values()), None)
|
|
49
|
+
|
|
50
|
+
if first_teamspace:
|
|
51
|
+
self._teamspace = Teamspace(
|
|
52
|
+
name=first_teamspace["name"],
|
|
53
|
+
org=first_teamspace["org"],
|
|
54
|
+
user=first_teamspace["user"],
|
|
55
|
+
)
|
|
56
|
+
warnings.warn(
|
|
57
|
+
f"No teamspace given. Using teamspace: {self._teamspace.name}.",
|
|
58
|
+
UserWarning,
|
|
59
|
+
stacklevel=2,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
if self._teamspace is None:
|
|
63
|
+
raise ValueError("Teamspace is required for billing but could not be resolved. ")
|
|
64
|
+
|
|
65
|
+
self._user = user
|
|
66
|
+
|
|
67
|
+
self._model_provider, self._model_name = self._parse_model_name(name)
|
|
22
68
|
|
|
23
|
-
|
|
24
|
-
self._auth.authenticate()
|
|
25
|
-
self._user = User(name=UserApi()._get_user_by_id(self._auth.user_id).username)
|
|
26
|
-
except ConnectionError as e:
|
|
27
|
-
raise e
|
|
28
|
-
|
|
29
|
-
self._name = name
|
|
30
|
-
try:
|
|
31
|
-
self._user = _resolve_user(self._user or user)
|
|
32
|
-
except ValueError:
|
|
33
|
-
self._user = None
|
|
69
|
+
self._llm_api = LLMApi()
|
|
34
70
|
|
|
35
|
-
self._name = name
|
|
36
|
-
self._org, self._model_name = self._parse_model_name(name)
|
|
37
71
|
try:
|
|
38
72
|
# check if it is a org model
|
|
39
|
-
self._org = _resolve_org(self.
|
|
73
|
+
self._org = _resolve_org(self._model_provider)
|
|
74
|
+
|
|
75
|
+
try:
|
|
76
|
+
# check if user has access to the org
|
|
77
|
+
self._org_models = self._build_model_lookup(self._get_org_models())
|
|
78
|
+
except ApiException:
|
|
79
|
+
warnings.warn(
|
|
80
|
+
f"User is not authenticated to access the model in organization: '{self._model_provider}'.\n"
|
|
81
|
+
" Proceeding with appropriate org models, user models, or public models.",
|
|
82
|
+
UserWarning,
|
|
83
|
+
stacklevel=2,
|
|
84
|
+
)
|
|
85
|
+
self._model_provider = None
|
|
86
|
+
raise
|
|
40
87
|
except ApiException:
|
|
41
|
-
self.
|
|
88
|
+
if isinstance(self._teamspace.owner, Organization):
|
|
89
|
+
self._org = self._teamspace.owner
|
|
90
|
+
else:
|
|
91
|
+
self._org = None
|
|
92
|
+
self._org_models = self._build_model_lookup(self._get_org_models())
|
|
42
93
|
|
|
43
|
-
self._llm_api = LLMApi()
|
|
44
94
|
self._public_models = self._build_model_lookup(self._get_public_models())
|
|
45
|
-
self._org_models = self._build_model_lookup(self._get_org_models())
|
|
46
95
|
self._user_models = self._build_model_lookup(self._get_user_models())
|
|
47
96
|
self._model = self._get_model()
|
|
48
97
|
self._conversations = {}
|
|
49
98
|
|
|
99
|
+
@property
|
|
100
|
+
def name(self) -> str:
|
|
101
|
+
return self._model_name
|
|
102
|
+
|
|
103
|
+
@property
|
|
104
|
+
def provider(self) -> str:
|
|
105
|
+
return self._model_provider
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
def owner(self) -> Optional[Owner]:
|
|
109
|
+
return self._teamspace.owner
|
|
110
|
+
|
|
50
111
|
def _parse_model_name(self, name: str) -> Tuple[str, str]:
|
|
51
112
|
parts = name.split("/")
|
|
52
113
|
if len(parts) == 1:
|
|
@@ -95,47 +156,76 @@ class LLM:
|
|
|
95
156
|
available_models_str = "\n".join(available_models)
|
|
96
157
|
raise ValueError(f"Model '{self._model_name}' not found. \nAvailable models: \n{available_models_str}")
|
|
97
158
|
|
|
98
|
-
def _get_conversations(self) ->
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
159
|
+
def _get_conversations(self) -> None:
|
|
160
|
+
conversations = self._llm_api.list_conversations(assistant_id=self._model.id)
|
|
161
|
+
for conversation in conversations:
|
|
162
|
+
if conversation.name and conversation.name not in self._conversations:
|
|
163
|
+
self._conversations[conversation.name] = conversation.id
|
|
164
|
+
|
|
165
|
+
def _stream_chat_response(
|
|
166
|
+
self, result: Generator[V1ConversationResponseChunk, None, None], conversation: Optional[str] = None
|
|
167
|
+
) -> Generator[str, None, None]:
|
|
168
|
+
first_line = next(result, None)
|
|
169
|
+
if first_line:
|
|
170
|
+
if conversation and first_line.conversation_id:
|
|
171
|
+
self._conversations[conversation] = first_line.conversation_id
|
|
172
|
+
yield first_line.choices[0].delta.content
|
|
102
173
|
|
|
103
|
-
|
|
104
|
-
|
|
174
|
+
for line in result:
|
|
175
|
+
yield line.choices[0].delta.content
|
|
105
176
|
|
|
106
177
|
def chat(
|
|
107
178
|
self,
|
|
108
179
|
prompt: str,
|
|
109
180
|
system_prompt: Optional[str] = None,
|
|
110
181
|
max_completion_tokens: Optional[int] = 500,
|
|
182
|
+
images: Optional[Union[List[str], str]] = None,
|
|
111
183
|
conversation: Optional[str] = None,
|
|
112
|
-
|
|
184
|
+
metadata: Optional[Dict[str, str]] = None,
|
|
185
|
+
stream: bool = False,
|
|
186
|
+
upload_local_images: bool = False,
|
|
187
|
+
) -> Union[str, Generator[str, None, None]]:
|
|
113
188
|
if conversation and conversation not in self._conversations:
|
|
114
|
-
self.
|
|
189
|
+
self._get_conversations()
|
|
190
|
+
|
|
191
|
+
if images:
|
|
192
|
+
if isinstance(images, str):
|
|
193
|
+
images = [images]
|
|
194
|
+
for image in images:
|
|
195
|
+
if not isinstance(image, str):
|
|
196
|
+
raise NotImplementedError(f"Image type {type(image)} are not supported yet.")
|
|
197
|
+
if not image.startswith("http") and upload_local_images:
|
|
198
|
+
self._teamspace.upload_file(file_path=image, remote_path=f"images/{os.path.basename(image)}")
|
|
115
199
|
|
|
116
200
|
conversation_id = self._conversations.get(conversation) if conversation else None
|
|
117
201
|
output = self._llm_api.start_conversation(
|
|
118
202
|
prompt=prompt,
|
|
119
203
|
system_prompt=system_prompt,
|
|
120
204
|
max_completion_tokens=max_completion_tokens,
|
|
205
|
+
images=images,
|
|
121
206
|
assistant_id=self._model.id,
|
|
122
207
|
conversation_id=conversation_id,
|
|
208
|
+
billing_project_id=self._teamspace.id,
|
|
209
|
+
metadata=metadata,
|
|
210
|
+
name=conversation,
|
|
211
|
+
stream=stream,
|
|
123
212
|
)
|
|
124
|
-
if
|
|
125
|
-
|
|
126
|
-
|
|
213
|
+
if not stream:
|
|
214
|
+
if conversation and not conversation_id:
|
|
215
|
+
self._conversations[conversation] = output.conversation_id
|
|
216
|
+
return output.choices[0].delta.content
|
|
217
|
+
return self._stream_chat_response(output, conversation=conversation)
|
|
127
218
|
|
|
128
219
|
def list_conversations(self) -> List[Dict]:
|
|
129
|
-
self.
|
|
220
|
+
self._get_conversations()
|
|
130
221
|
return list(self._conversations.keys())
|
|
131
222
|
|
|
132
223
|
def _get_conversation_messages(self, conversation_id: str) -> Optional[str]:
|
|
133
224
|
return self._llm_api.get_conversation(assistant_id=self._model.id, conversation_id=conversation_id)
|
|
134
225
|
|
|
135
226
|
def get_history(self, conversation: str) -> Optional[List[Dict]]:
|
|
136
|
-
# TODO: after updating backend, this will fetch conversation from backend
|
|
137
227
|
if conversation not in self._conversations:
|
|
138
|
-
self.
|
|
228
|
+
self._get_conversations()
|
|
139
229
|
|
|
140
230
|
if conversation not in self._conversations:
|
|
141
231
|
raise ValueError(
|
|
@@ -152,6 +242,8 @@ class LLM:
|
|
|
152
242
|
return history
|
|
153
243
|
|
|
154
244
|
def reset_conversation(self, conversation: str) -> None:
|
|
245
|
+
if conversation not in self._conversations:
|
|
246
|
+
self._get_conversations()
|
|
155
247
|
if conversation in self._conversations:
|
|
156
248
|
self._llm_api.reset_conversation(
|
|
157
249
|
assistant_id=self._model.id,
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from lightning_sdk.services.file_endpoint import Client
|
|
2
|
-
from lightning_sdk.services.
|
|
2
|
+
from lightning_sdk.services.finetune_llm import LLMFinetune
|
|
3
3
|
from lightning_sdk.services.utilities import download_file
|
|
4
4
|
|
|
5
5
|
__all__ = ["LLMFinetune", "Client", "download_file"]
|
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
import importlib
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
import socket
|
|
5
|
+
import threading
|
|
6
|
+
from functools import partial
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Optional
|
|
9
|
+
|
|
10
|
+
from lightning_sdk.api.license_api import LicenseApi
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class LightningLicense:
|
|
14
|
+
"""This class is used to manage the license for the Lightning SDK."""
|
|
15
|
+
|
|
16
|
+
_is_valid: Optional[bool] = None
|
|
17
|
+
_license_api: Optional[LicenseApi] = None
|
|
18
|
+
_stream_messages: Optional[callable] = None
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
name: str,
|
|
23
|
+
license_key: Optional[str] = None,
|
|
24
|
+
product_version: Optional[str] = None,
|
|
25
|
+
product_type: str = "package",
|
|
26
|
+
stream_messages: callable = print,
|
|
27
|
+
) -> None:
|
|
28
|
+
self._product_name = name
|
|
29
|
+
self._license_key = license_key
|
|
30
|
+
self._product_version = product_version
|
|
31
|
+
self.product_type = product_type
|
|
32
|
+
self._is_valid = None
|
|
33
|
+
self._license_api = None
|
|
34
|
+
self._stream_messages = stream_messages
|
|
35
|
+
|
|
36
|
+
def validate_license(self) -> bool:
|
|
37
|
+
"""Validate the license key."""
|
|
38
|
+
if not self.is_online():
|
|
39
|
+
raise ConnectionError("No internet connection.")
|
|
40
|
+
|
|
41
|
+
self._license_api = LicenseApi()
|
|
42
|
+
return self._license_api.valid_license(
|
|
43
|
+
license_key=self.license_key,
|
|
44
|
+
product_name=self.product_name,
|
|
45
|
+
product_version=self.product_version,
|
|
46
|
+
product_type=self.product_type,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
@staticmethod
|
|
50
|
+
def is_online(timeout: float = 2.0) -> bool:
|
|
51
|
+
"""Check if the system is online by attempting to connect to a public DNS server (Google's).
|
|
52
|
+
|
|
53
|
+
This is a simple way to check for internet connectivity.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
timeout: The timeout for the connection attempt.
|
|
57
|
+
"""
|
|
58
|
+
try:
|
|
59
|
+
socket.create_connection(("8.8.8.8", 53), timeout=timeout)
|
|
60
|
+
return True
|
|
61
|
+
except OSError:
|
|
62
|
+
return False
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def is_valid(self) -> Optional[bool]:
|
|
66
|
+
"""Check if the license key is valid.
|
|
67
|
+
|
|
68
|
+
license validation within package:
|
|
69
|
+
- user online with valid key -> everything as now
|
|
70
|
+
- user online with invalid key -> warning using wrong key + instructions
|
|
71
|
+
- user online with no key -> warning for missing license approval + instructions
|
|
72
|
+
- user offline with a key -> small warning that key could not be verified
|
|
73
|
+
- user offline with no key -> warning for missing license approval + instructions
|
|
74
|
+
"""
|
|
75
|
+
if isinstance(self._is_valid, bool):
|
|
76
|
+
# if the license key is already validated, return the cached value
|
|
77
|
+
return self._is_valid
|
|
78
|
+
if not self.product_version:
|
|
79
|
+
self._stream_messages("Product version is not set correctly, consider leave it empty for auto-determine.")
|
|
80
|
+
if not self.license_key:
|
|
81
|
+
self._stream_messages(
|
|
82
|
+
"License key is not set neither cannot be found in the package root or user home."
|
|
83
|
+
" Please make sure you have signed the license agreement and set the license key."
|
|
84
|
+
" For more information, please refer to the documentation.",
|
|
85
|
+
)
|
|
86
|
+
is_online = self.is_online()
|
|
87
|
+
if self.license_key and is_online:
|
|
88
|
+
self._is_valid = self.validate_license()
|
|
89
|
+
elif not is_online:
|
|
90
|
+
self._stream_messages(
|
|
91
|
+
"License key is set but the system is offline. "
|
|
92
|
+
"Please make sure you have a valid license key and the system is online."
|
|
93
|
+
)
|
|
94
|
+
return self._is_valid
|
|
95
|
+
|
|
96
|
+
@property
|
|
97
|
+
def has_required_details(self) -> bool:
|
|
98
|
+
"""Check if the license key and product name are set."""
|
|
99
|
+
return bool(self.license_key and self.product_name and self.product_type)
|
|
100
|
+
|
|
101
|
+
@staticmethod
|
|
102
|
+
def _find_package_license_key(package_name: str) -> Optional[str]:
|
|
103
|
+
"""Find the license key in the package root as .license_key or in user home as .lightning/licenses.json.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
package_name: The name of the package. If not provided, it will be determined from the current module.
|
|
107
|
+
"""
|
|
108
|
+
if not package_name:
|
|
109
|
+
return None
|
|
110
|
+
try:
|
|
111
|
+
pkg_locations = importlib.util.find_spec(package_name).submodule_search_locations
|
|
112
|
+
if not pkg_locations:
|
|
113
|
+
return None
|
|
114
|
+
license_file = os.path.join(pkg_locations[0], ".license_key")
|
|
115
|
+
with open(license_file) as fp:
|
|
116
|
+
return fp.read().strip()
|
|
117
|
+
except (FileNotFoundError, ModuleNotFoundError):
|
|
118
|
+
return None
|
|
119
|
+
|
|
120
|
+
@staticmethod
|
|
121
|
+
def _find_user_license_key(package_name: str) -> Optional[str]:
|
|
122
|
+
"""Find the license key in the user home as .lightning/licenses.json.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
package_name: The name of the package.
|
|
126
|
+
"""
|
|
127
|
+
home = str(Path.home())
|
|
128
|
+
package_name = package_name.lower()
|
|
129
|
+
license_file = os.path.join(home, ".lightning", "licenses.json")
|
|
130
|
+
try:
|
|
131
|
+
with open(license_file) as fp:
|
|
132
|
+
licenses = json.load(fp)
|
|
133
|
+
# Check for the license key in the licenses.json file
|
|
134
|
+
for name in (package_name, package_name.replace("-", "_"), package_name.replace("_", "-")):
|
|
135
|
+
if name in licenses:
|
|
136
|
+
return licenses[name]
|
|
137
|
+
return None
|
|
138
|
+
except (FileNotFoundError, json.JSONDecodeError):
|
|
139
|
+
return None
|
|
140
|
+
|
|
141
|
+
@staticmethod
|
|
142
|
+
def _determine_package_version(package_name: str) -> Optional[str]:
|
|
143
|
+
"""Determine the product version based on the instantiation of the class.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
package_name: The name of the package. If not provided, it will be determined from the current module.
|
|
147
|
+
"""
|
|
148
|
+
try:
|
|
149
|
+
pkg = importlib.import_module(package_name)
|
|
150
|
+
return getattr(pkg, "__version__", None)
|
|
151
|
+
except ImportError:
|
|
152
|
+
return None
|
|
153
|
+
|
|
154
|
+
@property
|
|
155
|
+
def license_key(self) -> Optional[str]:
|
|
156
|
+
"""Get the license key."""
|
|
157
|
+
if not self._license_key:
|
|
158
|
+
# If the license key is not set, fist try to find it in the package root
|
|
159
|
+
self._license_key = self._find_package_license_key(self.product_name.replace("-", "_"))
|
|
160
|
+
# If not found, try to find it in the user home
|
|
161
|
+
if not self._license_key:
|
|
162
|
+
self._license_key = self._find_user_license_key(self.product_name)
|
|
163
|
+
return self._license_key
|
|
164
|
+
|
|
165
|
+
@property
|
|
166
|
+
def product_name(self) -> str:
|
|
167
|
+
"""Get the product name."""
|
|
168
|
+
return self._product_name
|
|
169
|
+
|
|
170
|
+
@property
|
|
171
|
+
def product_version(self) -> Optional[str]:
|
|
172
|
+
"""Get the product version."""
|
|
173
|
+
if not self._product_version and self.product_type == "package":
|
|
174
|
+
self._product_version = self._determine_package_version(self.product_name.replace("-", "_"))
|
|
175
|
+
return self._product_version
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def check_license(
|
|
179
|
+
name: str,
|
|
180
|
+
license_key: Optional[str] = None,
|
|
181
|
+
product_version: Optional[str] = None,
|
|
182
|
+
product_type: str = "package",
|
|
183
|
+
stream_messages: callable = print,
|
|
184
|
+
) -> None:
|
|
185
|
+
"""Run the license check and stream outputs.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
name: The name of the product.
|
|
189
|
+
license_key: The license key to check.
|
|
190
|
+
product_version: The version of the product.
|
|
191
|
+
product_type: The type of the product.
|
|
192
|
+
stream_messages: A callable to stream messages.
|
|
193
|
+
"""
|
|
194
|
+
lit_license = LightningLicense(
|
|
195
|
+
name=name,
|
|
196
|
+
license_key=license_key,
|
|
197
|
+
product_version=product_version,
|
|
198
|
+
product_type=product_type,
|
|
199
|
+
stream_messages=stream_messages,
|
|
200
|
+
)
|
|
201
|
+
if lit_license.is_valid is False:
|
|
202
|
+
stream_messages(
|
|
203
|
+
"License key is not valid.\n"
|
|
204
|
+
f" Key: {lit_license.license_key}\n"
|
|
205
|
+
" Please make sure you have a valid license key."
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def check_license_in_background(
|
|
210
|
+
name: str,
|
|
211
|
+
license_key: Optional[str] = None,
|
|
212
|
+
product_version: Optional[str] = None,
|
|
213
|
+
product_type: str = "package",
|
|
214
|
+
stream_messages: callable = print,
|
|
215
|
+
) -> threading.Thread:
|
|
216
|
+
"""Run the license check in a background thread and stream outputs.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
name: The name of the product.
|
|
220
|
+
license_key: The license key to check.
|
|
221
|
+
product_version: The version of the product.
|
|
222
|
+
product_type: The type of the product.
|
|
223
|
+
stream_messages: A callable to stream messages.
|
|
224
|
+
"""
|
|
225
|
+
check_license_local = partial(
|
|
226
|
+
check_license,
|
|
227
|
+
name=name,
|
|
228
|
+
license_key=license_key,
|
|
229
|
+
product_version=product_version,
|
|
230
|
+
product_type=product_type,
|
|
231
|
+
stream_messages=stream_messages,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
thread = threading.Thread(target=check_license_local, daemon=True)
|
|
235
|
+
thread.start()
|
|
236
|
+
return thread
|
lightning_sdk/studio.py
CHANGED
|
@@ -1,13 +1,16 @@
|
|
|
1
1
|
import glob
|
|
2
2
|
import os
|
|
3
3
|
import warnings
|
|
4
|
+
from enum import Enum
|
|
4
5
|
from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Tuple, Union
|
|
5
6
|
|
|
6
7
|
from tqdm.auto import tqdm
|
|
7
8
|
|
|
9
|
+
from lightning_sdk.api.cluster_api import ClusterApi
|
|
8
10
|
from lightning_sdk.api.studio_api import StudioApi
|
|
9
11
|
from lightning_sdk.api.utils import _machine_to_compute_name
|
|
10
12
|
from lightning_sdk.constants import _LIGHTNING_DEBUG
|
|
13
|
+
from lightning_sdk.lightning_cloud.openapi import V1CloudSpaceSourceType
|
|
11
14
|
from lightning_sdk.machine import Machine
|
|
12
15
|
from lightning_sdk.organization import Organization
|
|
13
16
|
from lightning_sdk.owner import Owner
|
|
@@ -24,6 +27,19 @@ if TYPE_CHECKING:
|
|
|
24
27
|
_logger = _setup_logger(__name__)
|
|
25
28
|
|
|
26
29
|
|
|
30
|
+
class Provider(Enum):
|
|
31
|
+
# Machine providers based on v1CloudProvider
|
|
32
|
+
AWS = "AWS"
|
|
33
|
+
GCP = "GCP"
|
|
34
|
+
VULTR = "VULTR"
|
|
35
|
+
LAMBDA_LABS = "LAMBDA_LABS"
|
|
36
|
+
DGX = "DGX"
|
|
37
|
+
VOLTAGE_PARK = "VOLTAGE_PARK"
|
|
38
|
+
NEBIUS = "NEBIUS"
|
|
39
|
+
CLOUDFLARE = "CLOUDFLARE"
|
|
40
|
+
LIGHTNING = "LIGHTNING"
|
|
41
|
+
|
|
42
|
+
|
|
27
43
|
class Studio:
|
|
28
44
|
"""A single Lightning AI Studio.
|
|
29
45
|
|
|
@@ -38,6 +54,8 @@ class Studio:
|
|
|
38
54
|
cloud_account: the name of the cloud account, the studio should be created on.
|
|
39
55
|
Doesn't matter when the studio already exists.
|
|
40
56
|
create_ok: whether the studio will be created if it does not yet exist. Defaults to True
|
|
57
|
+
provider: the provider of the machine, the studio should be created on.
|
|
58
|
+
|
|
41
59
|
Note:
|
|
42
60
|
Since a teamspace can either be owned by an org or by a user directly,
|
|
43
61
|
only one of the arguments can be provided.
|
|
@@ -56,8 +74,11 @@ class Studio:
|
|
|
56
74
|
cloud_account: Optional[str] = None,
|
|
57
75
|
create_ok: bool = True,
|
|
58
76
|
cluster: Optional[str] = None, # deprecated in favor of cloud_account
|
|
77
|
+
provider: Optional[str] = None,
|
|
78
|
+
source: Optional[V1CloudSpaceSourceType] = None,
|
|
59
79
|
) -> None:
|
|
60
80
|
self._studio_api = StudioApi()
|
|
81
|
+
self._cluster_api = ClusterApi()
|
|
61
82
|
|
|
62
83
|
self._teamspace = _resolve_teamspace(teamspace=teamspace, org=org, user=user)
|
|
63
84
|
self._cloud_account = _resolve_deprecated_cluster(cloud_account, cluster)
|
|
@@ -65,6 +86,16 @@ class Studio:
|
|
|
65
86
|
|
|
66
87
|
self._plugins = {}
|
|
67
88
|
|
|
89
|
+
if provider is not None:
|
|
90
|
+
if isinstance(provider, str) and provider in Provider.__members__:
|
|
91
|
+
provider = Provider(provider)
|
|
92
|
+
else:
|
|
93
|
+
raise ValueError(f"Invalid provider: {provider}. Must be one of {Provider.__members__.keys()}.")
|
|
94
|
+
self._cloud_account = self._cluster_api.get_cluster_provider_mapping(
|
|
95
|
+
self._teamspace.id,
|
|
96
|
+
self._teamspace.owner.id,
|
|
97
|
+
)[provider.value]
|
|
98
|
+
|
|
68
99
|
if name is None:
|
|
69
100
|
studio_id = os.environ.get("LIGHTNING_CLOUD_SPACE_ID", None)
|
|
70
101
|
if studio_id is None:
|
|
@@ -76,7 +107,7 @@ class Studio:
|
|
|
76
107
|
except ValueError as e:
|
|
77
108
|
if create_ok:
|
|
78
109
|
self._studio = self._studio_api.create_studio(
|
|
79
|
-
name, self._teamspace.id, cloud_account=self._cloud_account
|
|
110
|
+
name, self._teamspace.id, cloud_account=self._cloud_account, source=source
|
|
80
111
|
)
|
|
81
112
|
else:
|
|
82
113
|
raise ValueError(f"Studio {name} does not exist.") from e
|
|
@@ -221,6 +252,36 @@ class Studio:
|
|
|
221
252
|
self._studio.id, self._teamspace.id, machine, interruptible=interruptible
|
|
222
253
|
)
|
|
223
254
|
|
|
255
|
+
def run_and_detach(self, *commands: str, timeout: float = 10, check_interval: float = 1) -> str:
|
|
256
|
+
"""Runs given commands on the Studio and returns immediately.
|
|
257
|
+
|
|
258
|
+
The command will continue to run in the background.
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
timeout: wait for this many seconds for the command to finish.
|
|
262
|
+
check_interval: check the status of the command every this many seconds.
|
|
263
|
+
"""
|
|
264
|
+
if check_interval > timeout:
|
|
265
|
+
raise ValueError("check_interval must be less than timeout")
|
|
266
|
+
|
|
267
|
+
if _LIGHTNING_DEBUG:
|
|
268
|
+
print(f"Running {commands=}")
|
|
269
|
+
status = self.status
|
|
270
|
+
if status != Status.Running:
|
|
271
|
+
raise RuntimeError(f"Cannot run a command in a studio that is not running. Studio {self.name} is {status}.")
|
|
272
|
+
|
|
273
|
+
iter_output = self._studio_api.run_studio_commands_and_yield(
|
|
274
|
+
self._studio.id, self._teamspace.id, *commands, timeout=timeout, check_interval=check_interval
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
output = ""
|
|
278
|
+
code = None
|
|
279
|
+
for line, exit_code in iter_output:
|
|
280
|
+
print(line)
|
|
281
|
+
output += line
|
|
282
|
+
code = exit_code
|
|
283
|
+
return output, code
|
|
284
|
+
|
|
224
285
|
def run_with_exit_code(self, *commands: str) -> Tuple[str, int]:
|
|
225
286
|
"""Runs given commands on the Studio while returning output and exit code.
|
|
226
287
|
|