lightning-sdk 0.2.14__py3-none-any.whl → 0.2.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lightning_sdk/__init__.py +1 -1
- lightning_sdk/api/base_studio_api.py +79 -0
- lightning_sdk/api/cluster_api.py +83 -1
- lightning_sdk/api/license_api.py +48 -0
- lightning_sdk/api/llm_api.py +73 -12
- lightning_sdk/api/studio_api.py +50 -1
- lightning_sdk/api/teamspace_api.py +127 -1
- lightning_sdk/api/utils.py +4 -0
- lightning_sdk/base_studio.py +83 -0
- lightning_sdk/cli/create.py +21 -1
- lightning_sdk/cli/delete.py +6 -8
- lightning_sdk/cli/deploy/__init__.py +0 -0
- lightning_sdk/cli/deploy/_auth.py +189 -0
- lightning_sdk/cli/deploy/devbox.py +157 -0
- lightning_sdk/cli/{serve.py → deploy/serve.py} +22 -281
- lightning_sdk/cli/download.py +69 -16
- lightning_sdk/cli/entrypoint.py +1 -1
- lightning_sdk/cli/open.py +21 -2
- lightning_sdk/cli/start.py +12 -3
- lightning_sdk/cli/teamspace_menu.py +9 -1
- lightning_sdk/cli/upload.py +2 -5
- lightning_sdk/lightning_cloud/openapi/__init__.py +29 -0
- lightning_sdk/lightning_cloud/openapi/api/__init__.py +1 -0
- lightning_sdk/lightning_cloud/openapi/api/assistants_service_api.py +121 -0
- lightning_sdk/lightning_cloud/openapi/api/billing_service_api.py +9 -1
- lightning_sdk/lightning_cloud/openapi/api/cloud_space_service_api.py +226 -0
- lightning_sdk/lightning_cloud/openapi/api/cluster_service_api.py +105 -0
- lightning_sdk/lightning_cloud/openapi/api/file_system_service_api.py +178 -0
- lightning_sdk/lightning_cloud/openapi/api/jobs_service_api.py +984 -101
- lightning_sdk/lightning_cloud/openapi/api/product_license_service_api.py +525 -0
- lightning_sdk/lightning_cloud/openapi/api/storage_service_api.py +93 -0
- lightning_sdk/lightning_cloud/openapi/configuration.py +1 -1
- lightning_sdk/lightning_cloud/openapi/models/__init__.py +28 -0
- lightning_sdk/lightning_cloud/openapi/models/assistant_id_conversations_body.py +79 -1
- lightning_sdk/lightning_cloud/openapi/models/cloudspaces_id_body.py +53 -1
- lightning_sdk/lightning_cloud/openapi/models/deployment_id_alertingpolicies_body.py +331 -0
- lightning_sdk/lightning_cloud/openapi/models/deployment_id_alertingpolicies_body1.py +305 -0
- lightning_sdk/lightning_cloud/openapi/models/deployments_id_body.py +53 -1
- lightning_sdk/lightning_cloud/openapi/models/endpoints_id_body.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/model_id_versions_body.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/models_id_body.py +123 -0
- lightning_sdk/lightning_cloud/openapi/models/orgs_id_body.py +183 -1
- lightning_sdk/lightning_cloud/openapi/models/pipelines_id_body.py +6 -6
- lightning_sdk/lightning_cloud/openapi/models/project_id_cloudspaces_body.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/project_id_storage_body.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/projects_id_body.py +107 -3
- lightning_sdk/lightning_cloud/openapi/models/storage_complete_body.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/update.py +79 -1
- lightning_sdk/lightning_cloud/openapi/models/uploads_upload_id_body1.py +55 -3
- lightning_sdk/lightning_cloud/openapi/models/v1_aws_direct_v1.py +53 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_cloud_provider.py +3 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space.py +79 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space_environment_config.py +123 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space_environment_template_config.py +79 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space_environment_type.py +104 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space_source_type.py +103 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_cloudflare_v1.py +66 -66
- lightning_sdk/lightning_cloud/openapi/models/v1_cluster_spec.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_cluster_tagging_options.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_cluster_upload.py +149 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_complete_upload.py +55 -3
- lightning_sdk/lightning_cloud/openapi/models/v1_conversation.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_create_cloud_space_environment_template_request.py +79 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_delete_deployment_alerting_policy_response.py +175 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_deployment.py +53 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_deployment_alerting_event.py +487 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_deployment_alerting_policy.py +383 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_deployment_alerting_policy_frequency.py +105 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_deployment_alerting_policy_operation.py +105 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_deployment_alerting_policy_severity.py +106 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_deployment_alerting_policy_type.py +111 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_deployment_api.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_deployment_state.py +4 -4
- lightning_sdk/lightning_cloud/openapi/models/v1_endpoint.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_external_search_user.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_ge_list_deployment_routing_telemetry_response.py +123 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_get_cloud_space_instance_open_ports_response.py +123 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_get_deployment_routing_telemetry_content_response.py +123 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_get_job_stats_response.py +53 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_get_organization_storage_metadata_response.py +331 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_get_project_balance_response.py +1 -27
- lightning_sdk/lightning_cloud/openapi/models/v1_google_cloud_direct_v1.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_job_type.py +1 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_list_deployment_alerting_events_response.py +123 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_list_deployment_alerting_policies_response.py +175 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_list_product_licenses_response.py +123 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_managed_model.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_membership.py +43 -17
- lightning_sdk/lightning_cloud/openapi/models/v1_modify_filesystem_volume_response.py +97 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_organization.py +183 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_pipeline.py +6 -6
- lightning_sdk/lightning_cloud/openapi/models/v1_pipeline_state.py +111 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_presigned_url.py +53 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_product_license.py +409 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_product_license_check_response.py +123 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_project.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_project_membership.py +43 -17
- lightning_sdk/lightning_cloud/openapi/models/v1_project_settings.py +107 -3
- lightning_sdk/lightning_cloud/openapi/models/v1_project_storage.py +53 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_r2_data_connection.py +53 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_routing_telemetry.py +253 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_secret_type.py +1 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_server_alert_type.py +2 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_sleep_server_response.py +97 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_trigger_filesystem_upgrade_response.py +123 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_upload_project_artifact_response.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_usage_report.py +79 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_user_features.py +347 -113
- lightning_sdk/lightning_cloud/openapi/models/v1_user_requested_compute_config.py +27 -1
- lightning_sdk/lightning_cloud/rest_client.py +4 -0
- lightning_sdk/llm/llm.py +132 -40
- lightning_sdk/services/__init__.py +1 -1
- lightning_sdk/services/license.py +236 -0
- lightning_sdk/studio.py +62 -1
- lightning_sdk/teamspace.py +68 -0
- {lightning_sdk-0.2.14.dist-info → lightning_sdk-0.2.16.dist-info}/METADATA +1 -1
- {lightning_sdk-0.2.14.dist-info → lightning_sdk-0.2.16.dist-info}/RECORD +122 -86
- /lightning_sdk/services/{finetune/__init__.py → finetune_llm.py} +0 -0
- {lightning_sdk-0.2.14.dist-info → lightning_sdk-0.2.16.dist-info}/LICENSE +0 -0
- {lightning_sdk-0.2.14.dist-info → lightning_sdk-0.2.16.dist-info}/WHEEL +0 -0
- {lightning_sdk-0.2.14.dist-info → lightning_sdk-0.2.16.dist-info}/entry_points.txt +0 -0
- {lightning_sdk-0.2.14.dist-info → lightning_sdk-0.2.16.dist-info}/top_level.txt +0 -0
lightning_sdk/__init__.py
CHANGED
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
from typing import Any, List, Optional
|
|
2
|
+
|
|
3
|
+
from lightning_sdk.lightning_cloud.openapi.models.update import Update as BaseStudioUpdateBody
|
|
4
|
+
from lightning_sdk.lightning_cloud.openapi.models.v1_cloud_space_environment_template import (
|
|
5
|
+
V1CloudSpaceEnvironmentTemplate,
|
|
6
|
+
)
|
|
7
|
+
from lightning_sdk.lightning_cloud.openapi.models.v1_cloud_space_environment_type import V1CloudSpaceEnvironmentType
|
|
8
|
+
from lightning_sdk.lightning_cloud.rest_client import LightningClient
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class BaseStudioApi:
|
|
12
|
+
def __init__(self) -> None:
|
|
13
|
+
self._client = LightningClient(retry=False, max_tries=0)
|
|
14
|
+
|
|
15
|
+
def get_base_studio(self, base_studio_id: str, org_id: str) -> V1CloudSpaceEnvironmentTemplate:
|
|
16
|
+
"""Retrieve the base studio by its ID."""
|
|
17
|
+
try:
|
|
18
|
+
return self._client.cloud_space_environment_template_service_get_cloud_space_environment_template(
|
|
19
|
+
base_studio_id, org_id=org_id
|
|
20
|
+
)
|
|
21
|
+
except ValueError as e:
|
|
22
|
+
raise ValueError(f"Base studio {base_studio_id} does not exist") from e
|
|
23
|
+
|
|
24
|
+
def get_all_base_studios(self, org_id: str) -> List[V1CloudSpaceEnvironmentTemplate]:
|
|
25
|
+
"""Retrieve all base studios for a given organization."""
|
|
26
|
+
return self._client.cloud_space_environment_template_service_list_cloud_space_environment_templates(
|
|
27
|
+
org_id=org_id
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
def update_base_studio(
|
|
31
|
+
self,
|
|
32
|
+
base_studio_id: str,
|
|
33
|
+
org_id: str,
|
|
34
|
+
name: Optional[str] = None,
|
|
35
|
+
allowed_machines: Optional[List[str]] = None,
|
|
36
|
+
default_machine: Optional[str] = None,
|
|
37
|
+
disabled: Optional[bool] = None,
|
|
38
|
+
environment_type: Optional[V1CloudSpaceEnvironmentType] = None,
|
|
39
|
+
machine_image_version: Optional[str] = None,
|
|
40
|
+
setup_script_text: Optional[str] = None,
|
|
41
|
+
) -> V1CloudSpaceEnvironmentTemplate:
|
|
42
|
+
base_studio = self.get_base_studio(base_studio_id, org_id)
|
|
43
|
+
|
|
44
|
+
# Get the current configuration for the base studio
|
|
45
|
+
update_body = BaseStudioUpdateBody(
|
|
46
|
+
org_id=base_studio.org_id,
|
|
47
|
+
name=base_studio.name,
|
|
48
|
+
allowed_machines=base_studio.config.allowed_machines,
|
|
49
|
+
default_machine=base_studio.config.default_machine,
|
|
50
|
+
environment_type=base_studio.config.environment_type,
|
|
51
|
+
machine_image_version=base_studio.config.machine_image_version,
|
|
52
|
+
setup_script_text=base_studio.config.setup_script_text,
|
|
53
|
+
disabled=base_studio.disabled,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
# Apply changes only if the new value is not None
|
|
57
|
+
apply_change(update_body, "name", name)
|
|
58
|
+
apply_change(update_body, "allowed_machines", allowed_machines)
|
|
59
|
+
apply_change(update_body, "default_machine", default_machine)
|
|
60
|
+
apply_change(update_body, "environment_type", environment_type)
|
|
61
|
+
apply_change(update_body, "machine_image_version", machine_image_version)
|
|
62
|
+
apply_change(update_body, "setup_script_text", setup_script_text)
|
|
63
|
+
apply_change(update_body, "disabled", disabled)
|
|
64
|
+
|
|
65
|
+
return self._client.cloud_space_environment_template_service_update_cloud_space_environment_template(
|
|
66
|
+
id=base_studio_id,
|
|
67
|
+
body=update_body,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def apply_change(spec: Any, key: str, value: Any) -> bool:
|
|
72
|
+
if value is None:
|
|
73
|
+
return False
|
|
74
|
+
|
|
75
|
+
if getattr(spec, key) != value:
|
|
76
|
+
setattr(spec, key, value)
|
|
77
|
+
return True
|
|
78
|
+
|
|
79
|
+
return False
|
lightning_sdk/api/cluster_api.py
CHANGED
|
@@ -1,4 +1,11 @@
|
|
|
1
|
-
from
|
|
1
|
+
from typing import Dict, List, Optional
|
|
2
|
+
|
|
3
|
+
from lightning_sdk.lightning_cloud.openapi import (
|
|
4
|
+
Externalv1Cluster,
|
|
5
|
+
V1CloudProvider,
|
|
6
|
+
V1ClusterType,
|
|
7
|
+
V1ListClusterAcceleratorsResponse,
|
|
8
|
+
)
|
|
2
9
|
from lightning_sdk.lightning_cloud.rest_client import LightningClient
|
|
3
10
|
|
|
4
11
|
|
|
@@ -20,3 +27,78 @@ class ClusterApi:
|
|
|
20
27
|
if not res:
|
|
21
28
|
raise ValueError(f"Cluster {cluster_id} does not exist")
|
|
22
29
|
return res
|
|
30
|
+
|
|
31
|
+
def list_cluster_accelerators(self, cluster_id: str, org_id: str) -> V1ListClusterAcceleratorsResponse:
|
|
32
|
+
"""Lists the accelerators for a given cluster.
|
|
33
|
+
|
|
34
|
+
:param cluster_id: cluster ID test
|
|
35
|
+
:param project_id: the project the cluster is supposed to be associated with
|
|
36
|
+
:param org_id: The owning org of this cluster
|
|
37
|
+
"""
|
|
38
|
+
res = self._client.cluster_service_list_cluster_accelerators(
|
|
39
|
+
id=cluster_id,
|
|
40
|
+
org_id=org_id,
|
|
41
|
+
)
|
|
42
|
+
if not res:
|
|
43
|
+
raise ValueError(f"Cluster {cluster_id} does not exist")
|
|
44
|
+
return res
|
|
45
|
+
|
|
46
|
+
def list_global_clusters(self, project_id: str, org_id: str) -> List[Externalv1Cluster]:
|
|
47
|
+
"""Lists the accelerators for a given project.
|
|
48
|
+
|
|
49
|
+
:param project_id: project ID test
|
|
50
|
+
:param org_id: The owning org of this project
|
|
51
|
+
"""
|
|
52
|
+
res = self._client.cluster_service_list_clusters(
|
|
53
|
+
project_id=project_id,
|
|
54
|
+
org_id=org_id,
|
|
55
|
+
)
|
|
56
|
+
if not res:
|
|
57
|
+
raise ValueError(f"Project {project_id} does not exist")
|
|
58
|
+
filtered_clusters = filter(lambda x: x.spec.cluster_type == V1ClusterType.GLOBAL, res.clusters)
|
|
59
|
+
return list(filtered_clusters)
|
|
60
|
+
|
|
61
|
+
def get_cluster_provider_mapping(self, project_id: str, org_id: str) -> Dict[V1CloudProvider, str]:
|
|
62
|
+
"""Gets the cluster provider mapping."""
|
|
63
|
+
res = self.list_global_clusters(
|
|
64
|
+
project_id=project_id,
|
|
65
|
+
org_id=org_id,
|
|
66
|
+
)
|
|
67
|
+
return {self._get_cluster_provider(cluster): cluster.id for cluster in res}
|
|
68
|
+
|
|
69
|
+
def _get_cluster_provider(self, cluster: Optional[Externalv1Cluster]) -> V1CloudProvider:
|
|
70
|
+
"""Determines the cloud provider based on the cluster configuration.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
cluster: An optional Externalv1Cluster object containing cluster specifications
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
V1CloudProvider: The determined cloud provider, defaults to AWS if no match is found
|
|
77
|
+
"""
|
|
78
|
+
if not cluster:
|
|
79
|
+
return V1CloudProvider.AWS
|
|
80
|
+
|
|
81
|
+
if (
|
|
82
|
+
cluster.spec
|
|
83
|
+
and cluster.spec.driver
|
|
84
|
+
and cluster.spec.driver in [V1CloudProvider.LIGHTNING, V1CloudProvider.DGX]
|
|
85
|
+
):
|
|
86
|
+
return cluster.spec.driver
|
|
87
|
+
|
|
88
|
+
if cluster.spec:
|
|
89
|
+
if cluster.spec.aws_v1:
|
|
90
|
+
return V1CloudProvider.AWS
|
|
91
|
+
if cluster.spec.google_cloud_v1:
|
|
92
|
+
return V1CloudProvider.GCP
|
|
93
|
+
if cluster.spec.lambda_labs_v1:
|
|
94
|
+
return V1CloudProvider.LAMBDA_LABS
|
|
95
|
+
if cluster.spec.vultr_v1:
|
|
96
|
+
return V1CloudProvider.VULTR
|
|
97
|
+
if cluster.spec.slurm_v1:
|
|
98
|
+
return V1CloudProvider.SLURM
|
|
99
|
+
if cluster.spec.voltage_park_v1:
|
|
100
|
+
return V1CloudProvider.VOLTAGE_PARK
|
|
101
|
+
if cluster.spec.nebius_v1:
|
|
102
|
+
return V1CloudProvider.NEBIUS
|
|
103
|
+
|
|
104
|
+
return V1CloudProvider.AWS
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from lightning_sdk.lightning_cloud.rest_client import LightningClient
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class LicenseApi:
|
|
7
|
+
def __init__(self) -> None:
|
|
8
|
+
self._client = LightningClient(retry=False, max_tries=0)
|
|
9
|
+
|
|
10
|
+
def valid_license(
|
|
11
|
+
self,
|
|
12
|
+
license_key: str,
|
|
13
|
+
product_name: str,
|
|
14
|
+
product_version: Optional[str] = None,
|
|
15
|
+
product_type: str = "package",
|
|
16
|
+
) -> bool:
|
|
17
|
+
"""Check if the license key is valid.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
license_key: The license key to check.
|
|
21
|
+
product_name: The name of the product.
|
|
22
|
+
product_version: The version of the product.
|
|
23
|
+
product_type: The type of the product. Default is "package".
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
True if the license key is valid, False otherwise.
|
|
27
|
+
"""
|
|
28
|
+
response, code, _ = self._client.product_license_service_validate_product_license_with_http_info(
|
|
29
|
+
license_key=license_key,
|
|
30
|
+
product_name=product_name,
|
|
31
|
+
product_version=product_version,
|
|
32
|
+
product_type=product_type,
|
|
33
|
+
)
|
|
34
|
+
if code != 200:
|
|
35
|
+
raise ConnectionError(f"Failed to validate license key: {code} - {response}")
|
|
36
|
+
return response.valid
|
|
37
|
+
|
|
38
|
+
def list_user_licenses(self, user_id: str) -> list:
|
|
39
|
+
"""List all licenses for a user.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
user_id: The ID of the user.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
A list of licenses for the user.
|
|
46
|
+
"""
|
|
47
|
+
response = self._client.product_license_service_list_user_licenses(user_id=user_id)
|
|
48
|
+
return response.licenses
|
lightning_sdk/api/llm_api.py
CHANGED
|
@@ -1,6 +1,12 @@
|
|
|
1
|
-
|
|
1
|
+
import base64
|
|
2
|
+
import json
|
|
3
|
+
from typing import Dict, Generator, List, Optional, Union
|
|
4
|
+
|
|
5
|
+
from pip._vendor.urllib3 import HTTPResponse
|
|
2
6
|
|
|
3
7
|
from lightning_sdk.lightning_cloud.openapi.models.v1_conversation_response_chunk import V1ConversationResponseChunk
|
|
8
|
+
from lightning_sdk.lightning_cloud.openapi.models.v1_response_choice import V1ResponseChoice
|
|
9
|
+
from lightning_sdk.lightning_cloud.openapi.models.v1_response_choice_delta import V1ResponseChoiceDelta
|
|
4
10
|
from lightning_sdk.lightning_cloud.rest_client import LightningClient
|
|
5
11
|
|
|
6
12
|
|
|
@@ -20,30 +26,85 @@ class LLMApi:
|
|
|
20
26
|
result = self._client.assistants_service_list_assistants(user_id=user_id)
|
|
21
27
|
return result.assistants
|
|
22
28
|
|
|
29
|
+
def _stream_chat_response(self, result: HTTPResponse) -> Generator[V1ConversationResponseChunk, None, None]:
|
|
30
|
+
for line in result.stream():
|
|
31
|
+
decoded_lines = line.decode("utf-8").strip()
|
|
32
|
+
for decoded_line in decoded_lines.splitlines():
|
|
33
|
+
try:
|
|
34
|
+
payload = json.loads(decoded_line)
|
|
35
|
+
result_data = payload.get("result", {})
|
|
36
|
+
|
|
37
|
+
choices = []
|
|
38
|
+
for choice in result_data.get("choices", []):
|
|
39
|
+
delta = choice.get("delta", {})
|
|
40
|
+
choices.append(
|
|
41
|
+
V1ResponseChoice(
|
|
42
|
+
delta=V1ResponseChoiceDelta(**delta),
|
|
43
|
+
finish_reason=choice.get("finishReason"),
|
|
44
|
+
index=choice.get("index"),
|
|
45
|
+
)
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
yield V1ConversationResponseChunk(
|
|
49
|
+
choices=choices,
|
|
50
|
+
conversation_id=result_data.get("conversationId"),
|
|
51
|
+
executable=result_data.get("executable"),
|
|
52
|
+
id=result_data.get("id"),
|
|
53
|
+
throughput=result_data.get("throughput"),
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
except json.JSONDecodeError:
|
|
57
|
+
print("Error decoding JSON:", decoded_line)
|
|
58
|
+
|
|
59
|
+
def _encode_image_bytes_to_data_url(self, image: str, mime_type: str = "image/jpeg") -> str:
|
|
60
|
+
with open(image, "rb") as image_file:
|
|
61
|
+
b64 = base64.b64encode(image_file.read()).decode("utf-8")
|
|
62
|
+
return f"data:{mime_type};base64,{b64}"
|
|
63
|
+
|
|
23
64
|
def start_conversation(
|
|
24
65
|
self,
|
|
25
66
|
prompt: str,
|
|
26
67
|
system_prompt: Optional[str],
|
|
27
|
-
max_completion_tokens:
|
|
68
|
+
max_completion_tokens: int,
|
|
28
69
|
assistant_id: str,
|
|
29
|
-
|
|
30
|
-
|
|
70
|
+
images: Optional[List[str]] = None,
|
|
71
|
+
conversation_id: Optional[str] = None,
|
|
72
|
+
billing_project_id: Optional[str] = None,
|
|
73
|
+
name: Optional[str] = None,
|
|
74
|
+
metadata: Optional[Dict[str, str]] = None,
|
|
75
|
+
stream: bool = False,
|
|
76
|
+
) -> Union[V1ConversationResponseChunk, Generator[V1ConversationResponseChunk, None, None]]:
|
|
31
77
|
body = {
|
|
32
78
|
"message": {
|
|
33
79
|
"author": {"role": "user"},
|
|
34
80
|
"content": [
|
|
35
|
-
{
|
|
36
|
-
"contentType": "text",
|
|
37
|
-
"parts": [prompt],
|
|
38
|
-
}
|
|
81
|
+
{"contentType": "text", "parts": [prompt]},
|
|
39
82
|
],
|
|
40
83
|
},
|
|
41
84
|
"max_completion_tokens": max_completion_tokens,
|
|
85
|
+
"conversation_id": conversation_id,
|
|
86
|
+
"billing_project_id": billing_project_id,
|
|
87
|
+
"name": name,
|
|
88
|
+
"stream": stream,
|
|
89
|
+
"metadata": metadata or {},
|
|
42
90
|
}
|
|
43
|
-
if
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
91
|
+
if images:
|
|
92
|
+
for image in images:
|
|
93
|
+
url = image
|
|
94
|
+
if not image.startswith("http"):
|
|
95
|
+
url = self._encode_image_bytes_to_data_url(image)
|
|
96
|
+
|
|
97
|
+
body["message"]["content"].append(
|
|
98
|
+
{
|
|
99
|
+
"contentType": "image",
|
|
100
|
+
"parts": [url],
|
|
101
|
+
}
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
result = self._client.assistants_service_start_conversation(body, assistant_id, _preload_content=not stream)
|
|
105
|
+
if not stream:
|
|
106
|
+
return result.result
|
|
107
|
+
return self._stream_chat_response(result)
|
|
47
108
|
|
|
48
109
|
def list_conversations(self, assistant_id: str) -> List[str]:
|
|
49
110
|
result = self._client.assistants_service_list_conversations(assistant_id)
|
lightning_sdk/api/studio_api.py
CHANGED
|
@@ -5,7 +5,7 @@ import time
|
|
|
5
5
|
import warnings
|
|
6
6
|
import zipfile
|
|
7
7
|
from threading import Event, Thread
|
|
8
|
-
from typing import Any, Dict, Mapping, Optional, Tuple, Union
|
|
8
|
+
from typing import Any, Dict, Generator, Mapping, Optional, Tuple, Union
|
|
9
9
|
|
|
10
10
|
import backoff
|
|
11
11
|
import requests
|
|
@@ -35,6 +35,7 @@ from lightning_sdk.lightning_cloud.openapi import (
|
|
|
35
35
|
V1CloudSpace,
|
|
36
36
|
V1CloudSpaceInstanceConfig,
|
|
37
37
|
V1CloudSpaceSeedFile,
|
|
38
|
+
V1CloudSpaceSourceType,
|
|
38
39
|
V1CloudSpaceState,
|
|
39
40
|
V1EndpointType,
|
|
40
41
|
V1GetCloudSpaceInstanceStatusResponse,
|
|
@@ -110,6 +111,7 @@ class StudioApi:
|
|
|
110
111
|
name: str,
|
|
111
112
|
teamspace_id: str,
|
|
112
113
|
cloud_account: Optional[str] = None,
|
|
114
|
+
source: Optional[V1CloudSpaceSourceType] = None,
|
|
113
115
|
) -> V1CloudSpace:
|
|
114
116
|
"""Create a Studio with a given name in a given Teamspace on a possibly given cloud_account."""
|
|
115
117
|
body = ProjectIdCloudspacesBody(
|
|
@@ -117,6 +119,7 @@ class StudioApi:
|
|
|
117
119
|
name=name,
|
|
118
120
|
display_name=name,
|
|
119
121
|
seed_files=[V1CloudSpaceSeedFile(path="main.py", contents="print('Hello, Lightning World!')\n")],
|
|
122
|
+
source=source,
|
|
120
123
|
)
|
|
121
124
|
studio = self._client.cloud_space_service_create_cloud_space(body, teamspace_id)
|
|
122
125
|
|
|
@@ -285,6 +288,52 @@ class StudioApi:
|
|
|
285
288
|
for response in responses:
|
|
286
289
|
yield response.result
|
|
287
290
|
|
|
291
|
+
def run_studio_commands_and_yield(
|
|
292
|
+
self, studio_id: str, teamspace_id: str, *commands: str, timeout: float, check_interval: float
|
|
293
|
+
) -> Generator[Tuple[str, int], None, None]:
|
|
294
|
+
"""Run given commands in a given Studio and yield the output and exit code for the given timeout.
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
timeout: wait for this many seconds for the command to finish.
|
|
298
|
+
"""
|
|
299
|
+
response_submit = self._client.cloud_space_service_execute_command_in_cloud_space(
|
|
300
|
+
IdExecuteBody1("; ".join(commands), detached=True),
|
|
301
|
+
project_id=teamspace_id,
|
|
302
|
+
id=studio_id,
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
if not response_submit:
|
|
306
|
+
raise RuntimeError("Unable to submit command")
|
|
307
|
+
|
|
308
|
+
if response_submit.session_name == "":
|
|
309
|
+
raise RuntimeError("The session name should be defined.")
|
|
310
|
+
|
|
311
|
+
start_time = time.time()
|
|
312
|
+
exit_code = None
|
|
313
|
+
while True:
|
|
314
|
+
for resp in self._get_detached_command_status(
|
|
315
|
+
studio_id=studio_id,
|
|
316
|
+
teamspace_id=teamspace_id,
|
|
317
|
+
session_id=response_submit.session_name,
|
|
318
|
+
):
|
|
319
|
+
if time.time() - start_time >= timeout:
|
|
320
|
+
return
|
|
321
|
+
|
|
322
|
+
if resp.exit_code == -1:
|
|
323
|
+
break
|
|
324
|
+
|
|
325
|
+
if exit_code is None:
|
|
326
|
+
exit_code = resp.exit_code
|
|
327
|
+
|
|
328
|
+
elif exit_code != resp.exit_code:
|
|
329
|
+
raise RuntimeError("Cannot determine exit code")
|
|
330
|
+
|
|
331
|
+
if resp.exit_code is not None and resp.exit_code != 0:
|
|
332
|
+
raise RuntimeError(f"Command failed with exit code {resp.exit_code}. Output: {resp.output}")
|
|
333
|
+
|
|
334
|
+
yield resp.output, exit_code
|
|
335
|
+
time.sleep(check_interval)
|
|
336
|
+
|
|
288
337
|
def run_studio_commands(self, studio_id: str, teamspace_id: str, *commands: str) -> Tuple[str, int]:
|
|
289
338
|
"""Run given commands in a given Studio."""
|
|
290
339
|
response_submit = self._client.cloud_space_service_execute_command_in_cloud_space(
|
|
@@ -1,10 +1,20 @@
|
|
|
1
1
|
import os
|
|
2
|
+
import tempfile
|
|
3
|
+
import zipfile
|
|
2
4
|
from pathlib import Path
|
|
3
5
|
from typing import Dict, List, Optional, Tuple
|
|
4
6
|
|
|
7
|
+
import requests
|
|
5
8
|
from tqdm.auto import tqdm
|
|
6
9
|
|
|
7
|
-
from lightning_sdk.api.utils import
|
|
10
|
+
from lightning_sdk.api.utils import (
|
|
11
|
+
_download_model_files,
|
|
12
|
+
_DummyBody,
|
|
13
|
+
_FileUploader,
|
|
14
|
+
_get_model_version,
|
|
15
|
+
_ModelFileUploader,
|
|
16
|
+
_resolve_teamspace_remote_path,
|
|
17
|
+
)
|
|
8
18
|
from lightning_sdk.lightning_cloud.login import Auth
|
|
9
19
|
from lightning_sdk.lightning_cloud.openapi import (
|
|
10
20
|
Externalv1LightningappInstance,
|
|
@@ -17,6 +27,7 @@ from lightning_sdk.lightning_cloud.openapi import (
|
|
|
17
27
|
V1ClusterAccelerator,
|
|
18
28
|
V1Endpoint,
|
|
19
29
|
V1Job,
|
|
30
|
+
V1LoginRequest,
|
|
20
31
|
V1Model,
|
|
21
32
|
V1ModelVersionArchive,
|
|
22
33
|
V1MultiMachineJob,
|
|
@@ -331,3 +342,118 @@ class TeamspaceApi:
|
|
|
331
342
|
model_id = self.get_model(teamspace_id=teamspace_id, model_name=model_name).id
|
|
332
343
|
response = self.models_api.models_store_list_model_versions(project_id=teamspace_id, model_id=model_id)
|
|
333
344
|
return response.versions
|
|
345
|
+
|
|
346
|
+
def upload_file(
|
|
347
|
+
self,
|
|
348
|
+
teamspace_id: str,
|
|
349
|
+
cloud_account: str,
|
|
350
|
+
file_path: str,
|
|
351
|
+
remote_path: str,
|
|
352
|
+
progress_bar: bool,
|
|
353
|
+
) -> None:
|
|
354
|
+
"""Uploads file to given remote path in the Teamspace drive."""
|
|
355
|
+
_FileUploader(
|
|
356
|
+
client=self._client,
|
|
357
|
+
teamspace_id=teamspace_id,
|
|
358
|
+
cloud_account=cloud_account,
|
|
359
|
+
file_path=file_path,
|
|
360
|
+
remote_path=_resolve_teamspace_remote_path(remote_path),
|
|
361
|
+
progress_bar=progress_bar,
|
|
362
|
+
)()
|
|
363
|
+
|
|
364
|
+
def download_file(
|
|
365
|
+
self,
|
|
366
|
+
path: str,
|
|
367
|
+
target_path: str,
|
|
368
|
+
teamspace_id: str,
|
|
369
|
+
cloud_account: str,
|
|
370
|
+
progress_bar: bool = True,
|
|
371
|
+
) -> None:
|
|
372
|
+
"""Downloads a given file in Teamspace drive to a target location."""
|
|
373
|
+
# TODO: Update this endpoint to permit basic auth
|
|
374
|
+
auth = Auth()
|
|
375
|
+
auth.authenticate()
|
|
376
|
+
token = self._client.auth_service_login(V1LoginRequest(auth.api_key)).token
|
|
377
|
+
|
|
378
|
+
query_params = {
|
|
379
|
+
"clusterId": cloud_account,
|
|
380
|
+
"key": _resolve_teamspace_remote_path(path),
|
|
381
|
+
"token": token,
|
|
382
|
+
}
|
|
383
|
+
|
|
384
|
+
r = requests.get(
|
|
385
|
+
f"{self._client.api_client.configuration.host}/v1/projects/{teamspace_id}/artifacts/download",
|
|
386
|
+
params=query_params,
|
|
387
|
+
stream=True,
|
|
388
|
+
)
|
|
389
|
+
total_length = int(r.headers.get("content-length"))
|
|
390
|
+
|
|
391
|
+
if progress_bar:
|
|
392
|
+
pbar = tqdm(
|
|
393
|
+
desc=f"Downloading {os.path.split(path)[1]}",
|
|
394
|
+
total=total_length,
|
|
395
|
+
unit="B",
|
|
396
|
+
unit_scale=True,
|
|
397
|
+
unit_divisor=1000,
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
pbar_update = pbar.update
|
|
401
|
+
else:
|
|
402
|
+
pbar_update = lambda x: None
|
|
403
|
+
|
|
404
|
+
target_dir = os.path.split(target_path)[0]
|
|
405
|
+
if target_dir:
|
|
406
|
+
os.makedirs(target_dir, exist_ok=True)
|
|
407
|
+
with open(target_path, "wb") as f:
|
|
408
|
+
for chunk in r.iter_content(chunk_size=4096 * 8):
|
|
409
|
+
f.write(chunk)
|
|
410
|
+
pbar_update(len(chunk))
|
|
411
|
+
|
|
412
|
+
def download_folder(
|
|
413
|
+
self,
|
|
414
|
+
path: str,
|
|
415
|
+
target_path: str,
|
|
416
|
+
teamspace_id: str,
|
|
417
|
+
cloud_account: str,
|
|
418
|
+
progress_bar: bool = True,
|
|
419
|
+
) -> None:
|
|
420
|
+
"""Downloads a given folder from Teamspace drive to a target location."""
|
|
421
|
+
# TODO: Update this endpoint to permit basic auth
|
|
422
|
+
auth = Auth()
|
|
423
|
+
auth.authenticate()
|
|
424
|
+
token = self._client.auth_service_login(V1LoginRequest(auth.api_key)).token
|
|
425
|
+
|
|
426
|
+
query_params = {
|
|
427
|
+
"clusterId": cloud_account,
|
|
428
|
+
"prefix": _resolve_teamspace_remote_path(path),
|
|
429
|
+
"token": token,
|
|
430
|
+
}
|
|
431
|
+
|
|
432
|
+
r = requests.get(
|
|
433
|
+
f"{self._client.api_client.configuration.host}/v1/projects/{teamspace_id}/artifacts/download",
|
|
434
|
+
params=query_params,
|
|
435
|
+
stream=True,
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
if progress_bar:
|
|
439
|
+
pbar = tqdm(
|
|
440
|
+
desc=f"Downloading {os.path.split(path)[1]}",
|
|
441
|
+
unit="B",
|
|
442
|
+
unit_scale=True,
|
|
443
|
+
unit_divisor=1000,
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
pbar_update = pbar.update
|
|
447
|
+
else:
|
|
448
|
+
pbar_update = lambda x: None
|
|
449
|
+
|
|
450
|
+
if target_path:
|
|
451
|
+
os.makedirs(target_path, exist_ok=True)
|
|
452
|
+
|
|
453
|
+
with tempfile.TemporaryFile() as f:
|
|
454
|
+
for chunk in r.iter_content(chunk_size=4096 * 8):
|
|
455
|
+
f.write(chunk)
|
|
456
|
+
pbar_update(len(chunk))
|
|
457
|
+
|
|
458
|
+
with zipfile.ZipFile(f) as z:
|
|
459
|
+
z.extractall(target_path)
|
lightning_sdk/api/utils.py
CHANGED
|
@@ -355,6 +355,10 @@ def _sanitize_studio_remote_path(path: str, studio_id: str) -> str:
|
|
|
355
355
|
return f"/cloudspaces/{studio_id}/code/content/{path.replace('/teamspace/studios/this_studio/', '')}"
|
|
356
356
|
|
|
357
357
|
|
|
358
|
+
def _resolve_teamspace_remote_path(path: str) -> str:
|
|
359
|
+
return f"/Uploads/{path.replace('/teamspace/studios/this_studio/', '')}"
|
|
360
|
+
|
|
361
|
+
|
|
358
362
|
_DOWNLOAD_REQUEST_CHUNK_SIZE = 10 * _BYTES_PER_MB
|
|
359
363
|
_DOWNLOAD_MIN_CHUNK_SIZE = 100 * _BYTES_PER_KB
|
|
360
364
|
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
from typing import List, Optional, Union
|
|
2
|
+
|
|
3
|
+
from lightning_sdk.api.base_studio_api import BaseStudioApi
|
|
4
|
+
from lightning_sdk.api.user_api import UserApi
|
|
5
|
+
from lightning_sdk.lightning_cloud import login
|
|
6
|
+
from lightning_sdk.lightning_cloud.openapi.models.v1_cloud_space_environment_template import (
|
|
7
|
+
V1CloudSpaceEnvironmentTemplate,
|
|
8
|
+
)
|
|
9
|
+
from lightning_sdk.lightning_cloud.openapi.models.v1_cloud_space_environment_type import V1CloudSpaceEnvironmentType
|
|
10
|
+
from lightning_sdk.organization import Organization
|
|
11
|
+
from lightning_sdk.user import User
|
|
12
|
+
from lightning_sdk.utils.resolve import _resolve_org, _resolve_user
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class BaseStudio:
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
name: Optional[str] = None,
|
|
19
|
+
org: Optional[Union[str, Organization]] = None,
|
|
20
|
+
user: Optional[Union[str, User]] = None,
|
|
21
|
+
) -> None:
|
|
22
|
+
"""Initializes the BaseStudio instance with organization and user information.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
org (Optional[Union[str, Organization]]): The organization for the base studio. If not provided,
|
|
26
|
+
it will be resolved through the authentication process.
|
|
27
|
+
user (Optional[Union[str, User]]): The user for the base studio. If not provided, it will be resolved
|
|
28
|
+
through the authentication process.
|
|
29
|
+
|
|
30
|
+
Raises:
|
|
31
|
+
ConnectionError: If there is an issue with the authentication process.
|
|
32
|
+
"""
|
|
33
|
+
self._auth = login.Auth()
|
|
34
|
+
self._user = None
|
|
35
|
+
|
|
36
|
+
try:
|
|
37
|
+
self._auth.authenticate()
|
|
38
|
+
if user is None:
|
|
39
|
+
self._user = User(name=UserApi()._get_user_by_id(self._auth.user_id).username)
|
|
40
|
+
except ConnectionError as e:
|
|
41
|
+
raise e
|
|
42
|
+
|
|
43
|
+
self._user = _resolve_user(self._user or user)
|
|
44
|
+
self._org = _resolve_org(org)
|
|
45
|
+
|
|
46
|
+
self._base_studio_api = BaseStudioApi()
|
|
47
|
+
|
|
48
|
+
if name is not None:
|
|
49
|
+
base_studio = self._base_studio_api.get_base_studio(name, self._org.id)
|
|
50
|
+
|
|
51
|
+
if base_studio is None:
|
|
52
|
+
raise ValueError(f"Base studio with name {name} does not exist in organization {self._org.name}")
|
|
53
|
+
self._base_studio = base_studio
|
|
54
|
+
|
|
55
|
+
def update(
|
|
56
|
+
self,
|
|
57
|
+
name: Optional[str] = None,
|
|
58
|
+
allowed_machines: Optional[List[str]] = None,
|
|
59
|
+
default_machine: Optional[str] = None,
|
|
60
|
+
disabled: Optional[bool] = None,
|
|
61
|
+
environment_type: Optional[V1CloudSpaceEnvironmentType] = None,
|
|
62
|
+
machine_image_version: Optional[str] = None,
|
|
63
|
+
setup_script_text: Optional[str] = None,
|
|
64
|
+
) -> V1CloudSpaceEnvironmentTemplate:
|
|
65
|
+
self._base_studio = self._base_studio_api.update_base_studio(
|
|
66
|
+
self._base_studio.id,
|
|
67
|
+
self._org.id,
|
|
68
|
+
name=name,
|
|
69
|
+
allowed_machines=allowed_machines,
|
|
70
|
+
default_machine=default_machine,
|
|
71
|
+
environment_type=environment_type,
|
|
72
|
+
machine_image_version=machine_image_version,
|
|
73
|
+
setup_script_text=setup_script_text,
|
|
74
|
+
disabled=disabled,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
def list(self) -> List[V1CloudSpaceEnvironmentTemplate]:
|
|
78
|
+
"""List all base studios in the organization.
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
List[V1CloudSpaceEnvironmentTemplate]: A list of base studio templates.
|
|
82
|
+
"""
|
|
83
|
+
return self._base_studio_api.get_all_base_studios(self._org.id)
|