lightning-sdk 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lightning_sdk/__init__.py +1 -1
- lightning_sdk/api/llm_api.py +28 -5
- lightning_sdk/api/studio_api.py +17 -0
- lightning_sdk/cli/entrypoint.py +1 -1
- lightning_sdk/cli/serve.py +149 -39
- lightning_sdk/deployment/deployment.py +2 -2
- lightning_sdk/lightning_cloud/openapi/__init__.py +6 -0
- lightning_sdk/lightning_cloud/openapi/api/__init__.py +1 -0
- lightning_sdk/lightning_cloud/openapi/api/git_credentials_service_api.py +497 -0
- lightning_sdk/lightning_cloud/openapi/api/jobs_service_api.py +14 -5
- lightning_sdk/lightning_cloud/openapi/models/__init__.py +5 -0
- lightning_sdk/lightning_cloud/openapi/models/deployments_id_body.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_cluster_accelerator.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_cluster_security_options.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_cluster_spec.py +79 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_create_git_credentials_request.py +175 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_delete_git_credentials_response.py +97 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_deployment.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_deployment_state.py +2 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_git_credentials.py +227 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_list_git_credentials_response.py +123 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_list_job_resources_response.py +15 -15
- lightning_sdk/lightning_cloud/openapi/models/v1_nebius_direct_v1.py +149 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_user_features.py +53 -1
- lightning_sdk/llm/llm.py +134 -30
- lightning_sdk/plugin.py +19 -0
- lightning_sdk/studio.py +33 -0
- {lightning_sdk-0.2.12.dist-info → lightning_sdk-0.2.14.dist-info}/METADATA +1 -1
- {lightning_sdk-0.2.12.dist-info → lightning_sdk-0.2.14.dist-info}/RECORD +34 -28
- /lightning_sdk/cli/{docker.py → docker_cli.py} +0 -0
- {lightning_sdk-0.2.12.dist-info → lightning_sdk-0.2.14.dist-info}/LICENSE +0 -0
- {lightning_sdk-0.2.12.dist-info → lightning_sdk-0.2.14.dist-info}/WHEEL +0 -0
- {lightning_sdk-0.2.12.dist-info → lightning_sdk-0.2.14.dist-info}/entry_points.txt +0 -0
- {lightning_sdk-0.2.12.dist-info → lightning_sdk-0.2.14.dist-info}/top_level.txt +0 -0
|
@@ -85,6 +85,7 @@ class V1UserFeatures(object):
|
|
|
85
85
|
'mmt_strategy_selector': 'bool',
|
|
86
86
|
'multicloud_saas': 'bool',
|
|
87
87
|
'multiple_studio_versions': 'bool',
|
|
88
|
+
'nerf_fs_nonpaying': 'bool',
|
|
88
89
|
'org_admin_alerts': 'bool',
|
|
89
90
|
'org_level_member_permissions': 'bool',
|
|
90
91
|
'org_usage_limits': 'bool',
|
|
@@ -110,6 +111,7 @@ class V1UserFeatures(object):
|
|
|
110
111
|
'slurm_machine_selector': 'bool',
|
|
111
112
|
'stop_ide_container_on_shutdown': 'bool',
|
|
112
113
|
'studio_config': 'bool',
|
|
114
|
+
'studio_deployment': 'bool',
|
|
113
115
|
'studio_on_stop': 'bool',
|
|
114
116
|
'studio_version_visibility': 'bool',
|
|
115
117
|
'studios_dashboard': 'bool',
|
|
@@ -167,6 +169,7 @@ class V1UserFeatures(object):
|
|
|
167
169
|
'mmt_strategy_selector': 'mmtStrategySelector',
|
|
168
170
|
'multicloud_saas': 'multicloudSaas',
|
|
169
171
|
'multiple_studio_versions': 'multipleStudioVersions',
|
|
172
|
+
'nerf_fs_nonpaying': 'nerfFsNonpaying',
|
|
170
173
|
'org_admin_alerts': 'orgAdminAlerts',
|
|
171
174
|
'org_level_member_permissions': 'orgLevelMemberPermissions',
|
|
172
175
|
'org_usage_limits': 'orgUsageLimits',
|
|
@@ -192,6 +195,7 @@ class V1UserFeatures(object):
|
|
|
192
195
|
'slurm_machine_selector': 'slurmMachineSelector',
|
|
193
196
|
'stop_ide_container_on_shutdown': 'stopIdeContainerOnShutdown',
|
|
194
197
|
'studio_config': 'studioConfig',
|
|
198
|
+
'studio_deployment': 'studioDeployment',
|
|
195
199
|
'studio_on_stop': 'studioOnStop',
|
|
196
200
|
'studio_version_visibility': 'studioVersionVisibility',
|
|
197
201
|
'studios_dashboard': 'studiosDashboard',
|
|
@@ -204,7 +208,7 @@ class V1UserFeatures(object):
|
|
|
204
208
|
'weka': 'weka'
|
|
205
209
|
}
|
|
206
210
|
|
|
207
|
-
def __init__(self, affiliate_links: 'bool' =None, agents_v2: 'bool' =None, ai_hub_monetization: 'bool' =None, auto_fast_load: 'bool' =None, auto_join_orgs: 'bool' =None, b2c_experience: 'bool' =None, byoc_litcr: 'bool' =None, cap_add: 'list[str]' =None, cap_drop: 'list[str]' =None, capacity_reservation_byoc: 'bool' =None, capacity_reservation_dry_run: 'bool' =None, chat_models: 'bool' =None, cloud_space_environment_templates: 'bool' =None, code_tab: 'bool' =None, collab_screen_sharing: 'bool' =None, concurrent_gpu_limit: 'bool' =None, cost_attribution_settings: 'bool' =None, custom_app_domain: 'bool' =None, custom_instance_types: 'bool' =None, datasets: 'bool' =None, default_one_cluster: 'bool' =None, deployment_alerts: 'bool' =None, deployment_persistent_disk: 'bool' =None, dgx_cloud: 'bool' =None, doc_helper_chat: 'bool' =None, docs_agent: 'bool' =None, drive_v2: 'bool' =None, enable_storage_limits: 'bool' =None, enterprise_compute_admin: 'bool' =None, fair_share: 'bool' =None, featured_studios_admin: 'bool' =None, filestore: 'bool' =None, gcp_local_disk_binding: 'bool' =None, inactive_notify_delete: 'bool' =None, instant_capacity_reservation: 'bool' =None, job_artifacts_v2: 'bool' =None, lambda_labs: 'bool' =None, landing_studios: 'bool' =None, lit_logger: 'bool' =None, marketplace: 'bool' =None, mmt_fault_tolerance: 'bool' =None, mmt_strategy_selector: 'bool' =None, multicloud_saas: 'bool' =None, multiple_studio_versions: 'bool' =None, org_admin_alerts: 'bool' =None, org_level_member_permissions: 'bool' =None, org_usage_limits: 'bool' =None, pipelines: 'bool' =None, plugin_distributed: 'bool' =None, plugin_inference: 'bool' =None, plugin_label_studio: 'bool' =None, plugin_langflow: 'bool' =None, plugin_python_profiler: 'bool' =None, plugin_service: 'bool' =None, plugin_sweeps: 'bool' =None, pricing_updates: 'bool' =None, product_generator: 'bool' =None, project_selector: 'bool' =None, publish_pipelines: 'bool' =None, r2_data_connections: 'bool' =None, restartable_jobs: 'bool' =None, runnable_public_studio_page: 'bool' =None, security_docs: 'bool' =None, show_dev_admin: 'bool' =None, single_wallet: 'bool' =None, slurm: 'bool' =None, slurm_machine_selector: 'bool' =None, stop_ide_container_on_shutdown: 'bool' =None, studio_config: 'bool' =None, studio_on_stop: 'bool' =None, studio_version_visibility: 'bool' =None, studios_dashboard: 'bool' =None, studios_dashboard_system_metrics: 'bool' =None, teamspace_storage_tab: 'bool' =None, trainium2: 'bool' =None, use_rclone_mounts_only: 'bool' =None, voltage_park: 'bool' =None, vultr: 'bool' =None, weka: 'bool' =None): # noqa: E501
|
|
211
|
+
def __init__(self, affiliate_links: 'bool' =None, agents_v2: 'bool' =None, ai_hub_monetization: 'bool' =None, auto_fast_load: 'bool' =None, auto_join_orgs: 'bool' =None, b2c_experience: 'bool' =None, byoc_litcr: 'bool' =None, cap_add: 'list[str]' =None, cap_drop: 'list[str]' =None, capacity_reservation_byoc: 'bool' =None, capacity_reservation_dry_run: 'bool' =None, chat_models: 'bool' =None, cloud_space_environment_templates: 'bool' =None, code_tab: 'bool' =None, collab_screen_sharing: 'bool' =None, concurrent_gpu_limit: 'bool' =None, cost_attribution_settings: 'bool' =None, custom_app_domain: 'bool' =None, custom_instance_types: 'bool' =None, datasets: 'bool' =None, default_one_cluster: 'bool' =None, deployment_alerts: 'bool' =None, deployment_persistent_disk: 'bool' =None, dgx_cloud: 'bool' =None, doc_helper_chat: 'bool' =None, docs_agent: 'bool' =None, drive_v2: 'bool' =None, enable_storage_limits: 'bool' =None, enterprise_compute_admin: 'bool' =None, fair_share: 'bool' =None, featured_studios_admin: 'bool' =None, filestore: 'bool' =None, gcp_local_disk_binding: 'bool' =None, inactive_notify_delete: 'bool' =None, instant_capacity_reservation: 'bool' =None, job_artifacts_v2: 'bool' =None, lambda_labs: 'bool' =None, landing_studios: 'bool' =None, lit_logger: 'bool' =None, marketplace: 'bool' =None, mmt_fault_tolerance: 'bool' =None, mmt_strategy_selector: 'bool' =None, multicloud_saas: 'bool' =None, multiple_studio_versions: 'bool' =None, nerf_fs_nonpaying: 'bool' =None, org_admin_alerts: 'bool' =None, org_level_member_permissions: 'bool' =None, org_usage_limits: 'bool' =None, pipelines: 'bool' =None, plugin_distributed: 'bool' =None, plugin_inference: 'bool' =None, plugin_label_studio: 'bool' =None, plugin_langflow: 'bool' =None, plugin_python_profiler: 'bool' =None, plugin_service: 'bool' =None, plugin_sweeps: 'bool' =None, pricing_updates: 'bool' =None, product_generator: 'bool' =None, project_selector: 'bool' =None, publish_pipelines: 'bool' =None, r2_data_connections: 'bool' =None, restartable_jobs: 'bool' =None, runnable_public_studio_page: 'bool' =None, security_docs: 'bool' =None, show_dev_admin: 'bool' =None, single_wallet: 'bool' =None, slurm: 'bool' =None, slurm_machine_selector: 'bool' =None, stop_ide_container_on_shutdown: 'bool' =None, studio_config: 'bool' =None, studio_deployment: 'bool' =None, studio_on_stop: 'bool' =None, studio_version_visibility: 'bool' =None, studios_dashboard: 'bool' =None, studios_dashboard_system_metrics: 'bool' =None, teamspace_storage_tab: 'bool' =None, trainium2: 'bool' =None, use_rclone_mounts_only: 'bool' =None, voltage_park: 'bool' =None, vultr: 'bool' =None, weka: 'bool' =None): # noqa: E501
|
|
208
212
|
"""V1UserFeatures - a model defined in Swagger""" # noqa: E501
|
|
209
213
|
self._affiliate_links = None
|
|
210
214
|
self._agents_v2 = None
|
|
@@ -250,6 +254,7 @@ class V1UserFeatures(object):
|
|
|
250
254
|
self._mmt_strategy_selector = None
|
|
251
255
|
self._multicloud_saas = None
|
|
252
256
|
self._multiple_studio_versions = None
|
|
257
|
+
self._nerf_fs_nonpaying = None
|
|
253
258
|
self._org_admin_alerts = None
|
|
254
259
|
self._org_level_member_permissions = None
|
|
255
260
|
self._org_usage_limits = None
|
|
@@ -275,6 +280,7 @@ class V1UserFeatures(object):
|
|
|
275
280
|
self._slurm_machine_selector = None
|
|
276
281
|
self._stop_ide_container_on_shutdown = None
|
|
277
282
|
self._studio_config = None
|
|
283
|
+
self._studio_deployment = None
|
|
278
284
|
self._studio_on_stop = None
|
|
279
285
|
self._studio_version_visibility = None
|
|
280
286
|
self._studios_dashboard = None
|
|
@@ -374,6 +380,8 @@ class V1UserFeatures(object):
|
|
|
374
380
|
self.multicloud_saas = multicloud_saas
|
|
375
381
|
if multiple_studio_versions is not None:
|
|
376
382
|
self.multiple_studio_versions = multiple_studio_versions
|
|
383
|
+
if nerf_fs_nonpaying is not None:
|
|
384
|
+
self.nerf_fs_nonpaying = nerf_fs_nonpaying
|
|
377
385
|
if org_admin_alerts is not None:
|
|
378
386
|
self.org_admin_alerts = org_admin_alerts
|
|
379
387
|
if org_level_member_permissions is not None:
|
|
@@ -424,6 +432,8 @@ class V1UserFeatures(object):
|
|
|
424
432
|
self.stop_ide_container_on_shutdown = stop_ide_container_on_shutdown
|
|
425
433
|
if studio_config is not None:
|
|
426
434
|
self.studio_config = studio_config
|
|
435
|
+
if studio_deployment is not None:
|
|
436
|
+
self.studio_deployment = studio_deployment
|
|
427
437
|
if studio_on_stop is not None:
|
|
428
438
|
self.studio_on_stop = studio_on_stop
|
|
429
439
|
if studio_version_visibility is not None:
|
|
@@ -1369,6 +1379,27 @@ class V1UserFeatures(object):
|
|
|
1369
1379
|
|
|
1370
1380
|
self._multiple_studio_versions = multiple_studio_versions
|
|
1371
1381
|
|
|
1382
|
+
@property
|
|
1383
|
+
def nerf_fs_nonpaying(self) -> 'bool':
|
|
1384
|
+
"""Gets the nerf_fs_nonpaying of this V1UserFeatures. # noqa: E501
|
|
1385
|
+
|
|
1386
|
+
|
|
1387
|
+
:return: The nerf_fs_nonpaying of this V1UserFeatures. # noqa: E501
|
|
1388
|
+
:rtype: bool
|
|
1389
|
+
"""
|
|
1390
|
+
return self._nerf_fs_nonpaying
|
|
1391
|
+
|
|
1392
|
+
@nerf_fs_nonpaying.setter
|
|
1393
|
+
def nerf_fs_nonpaying(self, nerf_fs_nonpaying: 'bool'):
|
|
1394
|
+
"""Sets the nerf_fs_nonpaying of this V1UserFeatures.
|
|
1395
|
+
|
|
1396
|
+
|
|
1397
|
+
:param nerf_fs_nonpaying: The nerf_fs_nonpaying of this V1UserFeatures. # noqa: E501
|
|
1398
|
+
:type: bool
|
|
1399
|
+
"""
|
|
1400
|
+
|
|
1401
|
+
self._nerf_fs_nonpaying = nerf_fs_nonpaying
|
|
1402
|
+
|
|
1372
1403
|
@property
|
|
1373
1404
|
def org_admin_alerts(self) -> 'bool':
|
|
1374
1405
|
"""Gets the org_admin_alerts of this V1UserFeatures. # noqa: E501
|
|
@@ -1894,6 +1925,27 @@ class V1UserFeatures(object):
|
|
|
1894
1925
|
|
|
1895
1926
|
self._studio_config = studio_config
|
|
1896
1927
|
|
|
1928
|
+
@property
|
|
1929
|
+
def studio_deployment(self) -> 'bool':
|
|
1930
|
+
"""Gets the studio_deployment of this V1UserFeatures. # noqa: E501
|
|
1931
|
+
|
|
1932
|
+
|
|
1933
|
+
:return: The studio_deployment of this V1UserFeatures. # noqa: E501
|
|
1934
|
+
:rtype: bool
|
|
1935
|
+
"""
|
|
1936
|
+
return self._studio_deployment
|
|
1937
|
+
|
|
1938
|
+
@studio_deployment.setter
|
|
1939
|
+
def studio_deployment(self, studio_deployment: 'bool'):
|
|
1940
|
+
"""Sets the studio_deployment of this V1UserFeatures.
|
|
1941
|
+
|
|
1942
|
+
|
|
1943
|
+
:param studio_deployment: The studio_deployment of this V1UserFeatures. # noqa: E501
|
|
1944
|
+
:type: bool
|
|
1945
|
+
"""
|
|
1946
|
+
|
|
1947
|
+
self._studio_deployment = studio_deployment
|
|
1948
|
+
|
|
1897
1949
|
@property
|
|
1898
1950
|
def studio_on_stop(self) -> 'bool':
|
|
1899
1951
|
"""Gets the studio_on_stop of this V1UserFeatures. # noqa: E501
|
lightning_sdk/llm/llm.py
CHANGED
|
@@ -1,41 +1,64 @@
|
|
|
1
|
-
from typing import Dict, List, Optional, Set, Tuple
|
|
1
|
+
from typing import Dict, List, Optional, Set, Tuple, Union
|
|
2
2
|
|
|
3
|
+
from lightning_sdk.api import UserApi
|
|
3
4
|
from lightning_sdk.api.llm_api import LLMApi
|
|
5
|
+
from lightning_sdk.lightning_cloud.login import Auth
|
|
4
6
|
from lightning_sdk.lightning_cloud.openapi import V1Assistant
|
|
7
|
+
from lightning_sdk.lightning_cloud.openapi.rest import ApiException
|
|
8
|
+
from lightning_sdk.organization import Organization
|
|
9
|
+
from lightning_sdk.user import User
|
|
10
|
+
from lightning_sdk.utils.resolve import _resolve_org, _resolve_user
|
|
5
11
|
|
|
6
12
|
|
|
7
13
|
class LLM:
|
|
8
|
-
def __init__(
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
name: str,
|
|
17
|
+
user: Union[str, "User", None] = None,
|
|
18
|
+
org: Union[str, "Organization", None] = None,
|
|
19
|
+
) -> None:
|
|
20
|
+
self._auth = Auth()
|
|
21
|
+
self._user = None
|
|
22
|
+
|
|
23
|
+
try:
|
|
24
|
+
self._auth.authenticate()
|
|
25
|
+
self._user = User(name=UserApi()._get_user_by_id(self._auth.user_id).username)
|
|
26
|
+
except ConnectionError as e:
|
|
27
|
+
raise e
|
|
28
|
+
|
|
29
|
+
self._name = name
|
|
30
|
+
try:
|
|
31
|
+
self._user = _resolve_user(self._user or user)
|
|
32
|
+
except ValueError:
|
|
33
|
+
self._user = None
|
|
34
|
+
|
|
9
35
|
self._name = name
|
|
10
36
|
self._org, self._model_name = self._parse_model_name(name)
|
|
37
|
+
try:
|
|
38
|
+
# check if it is a org model
|
|
39
|
+
self._org = _resolve_org(self._org or org)
|
|
40
|
+
except ApiException:
|
|
41
|
+
self._org = None
|
|
42
|
+
|
|
11
43
|
self._llm_api = LLMApi()
|
|
12
|
-
self.
|
|
13
|
-
self.
|
|
14
|
-
self.
|
|
44
|
+
self._public_models = self._build_model_lookup(self._get_public_models())
|
|
45
|
+
self._org_models = self._build_model_lookup(self._get_org_models())
|
|
46
|
+
self._user_models = self._build_model_lookup(self._get_user_models())
|
|
15
47
|
self._model = self._get_model()
|
|
48
|
+
self._conversations = {}
|
|
16
49
|
|
|
17
50
|
def _parse_model_name(self, name: str) -> Tuple[str, str]:
|
|
18
51
|
parts = name.split("/")
|
|
19
|
-
if len(parts)
|
|
20
|
-
|
|
21
|
-
|
|
52
|
+
if len(parts) == 1:
|
|
53
|
+
# a user model or a org model
|
|
54
|
+
return None, parts[0]
|
|
55
|
+
if len(parts) == 2:
|
|
56
|
+
return parts[0], parts[1]
|
|
57
|
+
raise ValueError(
|
|
58
|
+
f"Model name must be in the format `organization/model_name` or `model_name`, but got '{name}'."
|
|
59
|
+
)
|
|
22
60
|
|
|
23
61
|
def _build_model_lookup(self, endpoints: List[str]) -> Dict[str, Set[str]]:
|
|
24
|
-
return {endpoint.id: {model.name for model in endpoint.models_metadata} for endpoint in endpoints}
|
|
25
|
-
|
|
26
|
-
def _model_exists(self) -> bool:
|
|
27
|
-
if self._org not in self._models:
|
|
28
|
-
raise ValueError(
|
|
29
|
-
f"Model provider {self._org} not found. Available models providers: {list(self._models.keys())}"
|
|
30
|
-
)
|
|
31
|
-
|
|
32
|
-
if self._model_name not in self._models[self._org]:
|
|
33
|
-
raise ValueError(
|
|
34
|
-
f"Model {self._model_name} not found. Available models by {self._org}: {self._models[self._org]}"
|
|
35
|
-
)
|
|
36
|
-
return True
|
|
37
|
-
|
|
38
|
-
def _build_public_model_lookup(self, endpoints: List[str]) -> Dict[str, Set[str]]:
|
|
39
62
|
result = {}
|
|
40
63
|
for endpoint in endpoints:
|
|
41
64
|
result.setdefault(endpoint.model, []).append(endpoint)
|
|
@@ -44,13 +67,94 @@ class LLM:
|
|
|
44
67
|
def _get_public_models(self) -> List[str]:
|
|
45
68
|
return self._llm_api.get_public_models()
|
|
46
69
|
|
|
47
|
-
def
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
70
|
+
def _get_org_models(self) -> List[str]:
|
|
71
|
+
return self._llm_api.get_org_models(self._org.id) if self._org else []
|
|
72
|
+
|
|
73
|
+
def _get_user_models(self) -> List[str]:
|
|
74
|
+
return self._llm_api.get_user_models(self._user.id) if self._user else []
|
|
75
|
+
|
|
76
|
+
def _get_model(self) -> V1Assistant:
|
|
51
77
|
# TODO how to handle multiple models with same model type? For now, just use the first one
|
|
52
|
-
|
|
78
|
+
if self._model_name in self._public_models:
|
|
79
|
+
return self._public_models.get(self._model_name)[0]
|
|
80
|
+
if self._model_name in self._org_models:
|
|
81
|
+
return self._org_models.get(self._model_name)[0]
|
|
82
|
+
if self._model_name in self._user_models:
|
|
83
|
+
return self._user_models.get(self._model_name)[0]
|
|
84
|
+
|
|
85
|
+
available_models = []
|
|
86
|
+
if self._public_models:
|
|
87
|
+
available_models.append(f"Public Models: {', '.join(self._public_models.keys())}")
|
|
53
88
|
|
|
54
|
-
|
|
55
|
-
|
|
89
|
+
if self._org and self._org_models:
|
|
90
|
+
available_models.append(f"Org ({self._org.name}) Models: {', '.join(self._org_models.keys())}")
|
|
91
|
+
|
|
92
|
+
if self._user and self._user_models:
|
|
93
|
+
available_models.append(f"User ({self._user.name}) Models: {', '.join(self._user_models.keys())}")
|
|
94
|
+
|
|
95
|
+
available_models_str = "\n".join(available_models)
|
|
96
|
+
raise ValueError(f"Model '{self._model_name}' not found. \nAvailable models: \n{available_models_str}")
|
|
97
|
+
|
|
98
|
+
def _get_conversations(self) -> Dict[str, str]:
|
|
99
|
+
# TODO: after updating backend, this will fetch conversations from backend
|
|
100
|
+
# conversations = self._llm_api.list_conversations(assistant_id=self._model.id)
|
|
101
|
+
return self._conversations
|
|
102
|
+
|
|
103
|
+
def _fetch_conversations(self) -> None:
|
|
104
|
+
self._conversations = self._get_conversations()
|
|
105
|
+
|
|
106
|
+
def chat(
|
|
107
|
+
self,
|
|
108
|
+
prompt: str,
|
|
109
|
+
system_prompt: Optional[str] = None,
|
|
110
|
+
max_completion_tokens: Optional[int] = 500,
|
|
111
|
+
conversation: Optional[str] = None,
|
|
112
|
+
) -> str:
|
|
113
|
+
if conversation and conversation not in self._conversations:
|
|
114
|
+
self._fetch_conversations()
|
|
115
|
+
|
|
116
|
+
conversation_id = self._conversations.get(conversation) if conversation else None
|
|
117
|
+
output = self._llm_api.start_conversation(
|
|
118
|
+
prompt=prompt,
|
|
119
|
+
system_prompt=system_prompt,
|
|
120
|
+
max_completion_tokens=max_completion_tokens,
|
|
121
|
+
assistant_id=self._model.id,
|
|
122
|
+
conversation_id=conversation_id,
|
|
123
|
+
)
|
|
124
|
+
if conversation and not conversation_id:
|
|
125
|
+
self._conversations[conversation] = output.conversation_id
|
|
56
126
|
return output.choices[0].delta.content
|
|
127
|
+
|
|
128
|
+
def list_conversations(self) -> List[Dict]:
|
|
129
|
+
self._fetch_conversations()
|
|
130
|
+
return list(self._conversations.keys())
|
|
131
|
+
|
|
132
|
+
def _get_conversation_messages(self, conversation_id: str) -> Optional[str]:
|
|
133
|
+
return self._llm_api.get_conversation(assistant_id=self._model.id, conversation_id=conversation_id)
|
|
134
|
+
|
|
135
|
+
def get_history(self, conversation: str) -> Optional[List[Dict]]:
|
|
136
|
+
# TODO: after updating backend, this will fetch conversation from backend
|
|
137
|
+
if conversation not in self._conversations:
|
|
138
|
+
self._fetch_conversations()
|
|
139
|
+
|
|
140
|
+
if conversation not in self._conversations:
|
|
141
|
+
raise ValueError(
|
|
142
|
+
f"Conversation '{conversation}' not found. \nAvailable conversations: {self._conversations.keys()}"
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
messages = self._get_conversation_messages(self._conversations[conversation])
|
|
146
|
+
history = []
|
|
147
|
+
for message in messages:
|
|
148
|
+
if message.author.role == "user":
|
|
149
|
+
history.append({"role": "user", "content": message.content[0].parts[0]})
|
|
150
|
+
elif message.author.role == "assistant":
|
|
151
|
+
history.append({"role": "assistant", "content": message.content[0].parts[0]})
|
|
152
|
+
return history
|
|
153
|
+
|
|
154
|
+
def reset_conversation(self, conversation: str) -> None:
|
|
155
|
+
if conversation in self._conversations:
|
|
156
|
+
self._llm_api.reset_conversation(
|
|
157
|
+
assistant_id=self._model.id,
|
|
158
|
+
conversation_id=self._conversations[conversation],
|
|
159
|
+
)
|
|
160
|
+
del self._conversations[conversation]
|
lightning_sdk/plugin.py
CHANGED
|
@@ -395,6 +395,25 @@ class SlurmJobsPlugin(_Plugin):
|
|
|
395
395
|
return resp
|
|
396
396
|
|
|
397
397
|
|
|
398
|
+
class CustomPortPlugin(_Plugin):
|
|
399
|
+
"""Plugin handling the port of a given service."""
|
|
400
|
+
|
|
401
|
+
_plugin_run_name = "Custom Port"
|
|
402
|
+
_slug_name = "custom-port"
|
|
403
|
+
|
|
404
|
+
def run(self, name: Optional[str] = None, port: int = 8000) -> str:
|
|
405
|
+
"""Starts a new port to the given Studio."""
|
|
406
|
+
if name is None:
|
|
407
|
+
name = _run_name("port")
|
|
408
|
+
|
|
409
|
+
return self._studio._studio_api.start_new_port(
|
|
410
|
+
teamspace_id=self._studio._teamspace.id,
|
|
411
|
+
studio_id=self._studio._studio.id,
|
|
412
|
+
name=name,
|
|
413
|
+
port=port,
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
|
|
398
417
|
@runtime_checkable
|
|
399
418
|
class _RunnablePlugin(Protocol):
|
|
400
419
|
_plugin_run_name: str
|
lightning_sdk/studio.py
CHANGED
|
@@ -1,7 +1,10 @@
|
|
|
1
|
+
import glob
|
|
1
2
|
import os
|
|
2
3
|
import warnings
|
|
3
4
|
from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Tuple, Union
|
|
4
5
|
|
|
6
|
+
from tqdm.auto import tqdm
|
|
7
|
+
|
|
5
8
|
from lightning_sdk.api.studio_api import StudioApi
|
|
6
9
|
from lightning_sdk.api.utils import _machine_to_compute_name
|
|
7
10
|
from lightning_sdk.constants import _LIGHTNING_DEBUG
|
|
@@ -265,6 +268,34 @@ class Studio:
|
|
|
265
268
|
progress_bar=progress_bar,
|
|
266
269
|
)
|
|
267
270
|
|
|
271
|
+
def upload_folder(self, folder_path: str, remote_path: Optional[str] = None, progress_bar: bool = True) -> None:
|
|
272
|
+
"""Uploads a given folder to a remote path on the Studio."""
|
|
273
|
+
if folder_path is None:
|
|
274
|
+
raise ValueError("Cannot upload a folder that is None.")
|
|
275
|
+
folder_path = os.path.normpath(folder_path)
|
|
276
|
+
if os.path.isfile(folder_path):
|
|
277
|
+
raise NotADirectoryError(f"Cannot upload a file as a folder. '{folder_path}' is a file.")
|
|
278
|
+
if not os.path.exists(folder_path):
|
|
279
|
+
raise NotADirectoryError(f"Cannot upload a folder that does not exist. '{folder_path}' is not a directory.")
|
|
280
|
+
all_files = []
|
|
281
|
+
for fp in glob.glob(os.path.join(folder_path, "**"), recursive=True):
|
|
282
|
+
if not os.path.isfile(fp):
|
|
283
|
+
continue
|
|
284
|
+
rel_path = os.path.relpath(fp, folder_path)
|
|
285
|
+
remote_file = os.path.join(remote_path, rel_path) if remote_path else rel_path
|
|
286
|
+
all_files.append((fp, remote_file))
|
|
287
|
+
|
|
288
|
+
if progress_bar:
|
|
289
|
+
progress_bar = tqdm(total=len(all_files), desc="Uploading files", unit="file")
|
|
290
|
+
for local_file, remote_path in sorted(all_files, key=lambda p: p[1]):
|
|
291
|
+
if progress_bar:
|
|
292
|
+
progress_bar.set_description(f"Uploading {local_file}")
|
|
293
|
+
self.upload_file(local_file, remote_path=remote_path, progress_bar=False)
|
|
294
|
+
if progress_bar:
|
|
295
|
+
progress_bar.update(1)
|
|
296
|
+
if progress_bar:
|
|
297
|
+
progress_bar.close()
|
|
298
|
+
|
|
268
299
|
def download_file(self, remote_path: str, file_path: Optional[str] = None) -> None:
|
|
269
300
|
"""Downloads a file from the Studio to a given target path."""
|
|
270
301
|
if file_path is None:
|
|
@@ -445,6 +476,7 @@ class Studio:
|
|
|
445
476
|
def _add_plugin(self, plugin_name: str) -> None:
|
|
446
477
|
"""Adds the just installed plugin to the internal list of plugins."""
|
|
447
478
|
from lightning_sdk.plugin import (
|
|
479
|
+
CustomPortPlugin,
|
|
448
480
|
InferenceServerPlugin,
|
|
449
481
|
JobsPlugin,
|
|
450
482
|
MultiMachineTrainingPlugin,
|
|
@@ -458,6 +490,7 @@ class Studio:
|
|
|
458
490
|
"jobs": JobsPlugin,
|
|
459
491
|
"multi-machine-training": MultiMachineTrainingPlugin,
|
|
460
492
|
"inference-server": InferenceServerPlugin,
|
|
493
|
+
"custom-port": CustomPortPlugin,
|
|
461
494
|
}.get(plugin_name, Plugin)
|
|
462
495
|
|
|
463
496
|
description = self._list_installed_plugins()[plugin_name]
|