lightning-sdk 2025.7.17__py3-none-any.whl → 2025.7.30rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lightning_sdk/__init__.py +3 -2
- lightning_sdk/api/cloud_account_api.py +204 -0
- lightning_sdk/api/deployment_api.py +11 -0
- lightning_sdk/api/job_api.py +82 -10
- lightning_sdk/api/llm_api.py +1 -1
- lightning_sdk/api/mmt_api.py +44 -5
- lightning_sdk/api/pipeline_api.py +4 -3
- lightning_sdk/api/studio_api.py +51 -8
- lightning_sdk/api/utils.py +6 -2
- lightning_sdk/cli/clusters_menu.py +3 -3
- lightning_sdk/cli/create.py +25 -11
- lightning_sdk/cli/deploy/_auth.py +19 -3
- lightning_sdk/cli/deploy/serve.py +21 -5
- lightning_sdk/cli/download.py +25 -1
- lightning_sdk/cli/entrypoint.py +4 -2
- lightning_sdk/cli/list.py +5 -1
- lightning_sdk/cli/run.py +3 -1
- lightning_sdk/cli/start.py +40 -8
- lightning_sdk/cli/switch.py +3 -1
- lightning_sdk/deployment/deployment.py +8 -0
- lightning_sdk/job/base.py +27 -3
- lightning_sdk/job/job.py +28 -4
- lightning_sdk/job/v1.py +10 -1
- lightning_sdk/job/v2.py +22 -2
- lightning_sdk/job/work.py +5 -1
- lightning_sdk/lightning_cloud/openapi/__init__.py +14 -1
- lightning_sdk/lightning_cloud/openapi/api/assistants_service_api.py +428 -0
- lightning_sdk/lightning_cloud/openapi/api/billing_service_api.py +153 -48
- lightning_sdk/lightning_cloud/openapi/api/cloudy_service_api.py +295 -0
- lightning_sdk/lightning_cloud/openapi/api/cluster_service_api.py +93 -0
- lightning_sdk/lightning_cloud/openapi/models/__init__.py +14 -1
- lightning_sdk/lightning_cloud/openapi/models/agentmanagedendpoints_id_body.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/blogposts_id_body.py +53 -1
- lightning_sdk/lightning_cloud/openapi/models/conversations_id_body1.py +123 -0
- lightning_sdk/lightning_cloud/openapi/models/messages_id_body.py +123 -0
- lightning_sdk/lightning_cloud/openapi/models/metricsstream_id_body.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/project_id_schedules_body.py +81 -3
- lightning_sdk/lightning_cloud/openapi/models/schedules_id_body.py +79 -1
- lightning_sdk/lightning_cloud/openapi/models/user_id_upgradetrigger_body.py +201 -0
- lightning_sdk/lightning_cloud/openapi/models/user_user_id_body.py +201 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_billing_subscription.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_blog_post.py +53 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_cloudy_settings.py +227 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_cluster_spec.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_conversation.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_conversation_response_chunk.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_create_billing_upgrade_trigger_record_response.py +97 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_create_blog_post_request.py +53 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_create_checkout_session_request.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_create_subscription_checkout_session_request.py +55 -3
- lightning_sdk/lightning_cloud/openapi/models/v1_function_call.py +149 -0
- lightning_sdk/lightning_cloud/openapi/models/{v1_get_clickhouse_assistant_session_daily_aggregated_response.py → v1_get_assistant_session_daily_aggregated_response.py} +22 -22
- lightning_sdk/lightning_cloud/openapi/models/v1_get_cluster_health_response.py +149 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_get_user_response.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_job_spec.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_kubernetes_direct_v1.py +105 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_like_status.py +104 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_list_published_managed_endpoints_response.py +123 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_managed_endpoint.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_managed_model.py +95 -17
- lightning_sdk/lightning_cloud/openapi/models/v1_message.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_quote_subscription_response.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_resource_visibility.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_response_choice.py +29 -3
- lightning_sdk/lightning_cloud/openapi/models/v1_schedule.py +79 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_service_health.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_slurm_v1.py +79 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_slurm_v1_status.py +79 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_tool_call.py +175 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_update_conversation_like_response.py +149 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_update_conversation_message_like_response.py +149 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_user_features.py +79 -313
- lightning_sdk/lightning_cloud/openapi/models/v1_volume_state.py +1 -0
- lightning_sdk/llm/llm.py +69 -11
- lightning_sdk/llm/public_assistants.json +32 -8
- lightning_sdk/machine.py +151 -43
- lightning_sdk/mmt/base.py +20 -2
- lightning_sdk/mmt/mmt.py +25 -3
- lightning_sdk/mmt/v1.py +7 -1
- lightning_sdk/mmt/v2.py +27 -3
- lightning_sdk/models.py +1 -1
- lightning_sdk/organization.py +4 -0
- lightning_sdk/pipeline/pipeline.py +16 -5
- lightning_sdk/pipeline/printer.py +5 -3
- lightning_sdk/pipeline/schedule.py +844 -1
- lightning_sdk/pipeline/steps.py +19 -4
- lightning_sdk/sandbox.py +4 -1
- lightning_sdk/serve.py +2 -0
- lightning_sdk/studio.py +91 -44
- lightning_sdk/teamspace.py +19 -10
- lightning_sdk/utils/resolve.py +37 -2
- {lightning_sdk-2025.7.17.dist-info → lightning_sdk-2025.7.30rc0.dist-info}/METADATA +7 -5
- {lightning_sdk-2025.7.17.dist-info → lightning_sdk-2025.7.30rc0.dist-info}/RECORD +98 -85
- lightning_sdk/api/cluster_api.py +0 -119
- /lightning_sdk/cli/{inspect.py → inspection.py} +0 -0
- {lightning_sdk-2025.7.17.dist-info → lightning_sdk-2025.7.30rc0.dist-info}/LICENSE +0 -0
- {lightning_sdk-2025.7.17.dist-info → lightning_sdk-2025.7.30rc0.dist-info}/WHEEL +0 -0
- {lightning_sdk-2025.7.17.dist-info → lightning_sdk-2025.7.30rc0.dist-info}/entry_points.txt +0 -0
- {lightning_sdk-2025.7.17.dist-info → lightning_sdk-2025.7.30rc0.dist-info}/top_level.txt +0 -0
lightning_sdk/__init__.py
CHANGED
|
@@ -4,7 +4,7 @@ from lightning_sdk.constants import __GLOBAL_LIGHTNING_UNIQUE_IDS_STORE__ # noq
|
|
|
4
4
|
from lightning_sdk.deployment import Deployment
|
|
5
5
|
from lightning_sdk.helpers import _check_version_and_prompt_upgrade, _set_tqdm_envvars_noninteractive
|
|
6
6
|
from lightning_sdk.job import Job
|
|
7
|
-
from lightning_sdk.machine import Machine
|
|
7
|
+
from lightning_sdk.machine import CloudProvider, Machine
|
|
8
8
|
from lightning_sdk.mmt import MMT
|
|
9
9
|
from lightning_sdk.organization import Organization
|
|
10
10
|
from lightning_sdk.plugin import JobsPlugin, MultiMachineTrainingPlugin, Plugin, SlurmJobsPlugin
|
|
@@ -16,6 +16,7 @@ from lightning_sdk.user import User
|
|
|
16
16
|
__all__ = [
|
|
17
17
|
"AIHub",
|
|
18
18
|
"Agent",
|
|
19
|
+
"CloudProvider",
|
|
19
20
|
"Deployment",
|
|
20
21
|
"Job",
|
|
21
22
|
"JobsPlugin",
|
|
@@ -31,6 +32,6 @@ __all__ = [
|
|
|
31
32
|
"User",
|
|
32
33
|
]
|
|
33
34
|
|
|
34
|
-
__version__ = "2025.07.
|
|
35
|
+
__version__ = "2025.07.30rc0"
|
|
35
36
|
_check_version_and_prompt_upgrade(__version__)
|
|
36
37
|
_set_tqdm_envvars_noninteractive()
|
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
from functools import lru_cache
|
|
2
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
|
3
|
+
|
|
4
|
+
from lightning_sdk.lightning_cloud.openapi import (
|
|
5
|
+
Externalv1Cluster,
|
|
6
|
+
V1CloudProvider,
|
|
7
|
+
V1ClusterType,
|
|
8
|
+
V1ExternalCluster,
|
|
9
|
+
V1ListClusterAcceleratorsResponse,
|
|
10
|
+
V1ListDefaultClusterAcceleratorsResponse,
|
|
11
|
+
)
|
|
12
|
+
from lightning_sdk.lightning_cloud.rest_client import LightningClient
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from lightning_sdk.machine import CloudProvider
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class CloudAccountApi:
|
|
19
|
+
"""Internal API client for API requests to cluster endpoints."""
|
|
20
|
+
|
|
21
|
+
def __init__(self) -> None:
|
|
22
|
+
self._client = LightningClient(max_tries=7)
|
|
23
|
+
|
|
24
|
+
def get_cloud_account(self, cloud_account_id: str, teamspace_id: str, org_id: str) -> Externalv1Cluster:
|
|
25
|
+
"""Gets the cluster from given params cluster_id, project_id and owner.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
cloud_account_id: the cloud account to get
|
|
29
|
+
teamspace_id: the teamspace the cloud_account is supposed to be associated with
|
|
30
|
+
org_id: The owning org of this teamspace
|
|
31
|
+
|
|
32
|
+
"""
|
|
33
|
+
res = self._client.cluster_service_get_cluster(id=cloud_account_id, org_id=org_id, project_id=teamspace_id)
|
|
34
|
+
if not res:
|
|
35
|
+
raise ValueError(f"CloudAccount {cloud_account_id} does not exist")
|
|
36
|
+
return res
|
|
37
|
+
|
|
38
|
+
@lru_cache(maxsize=None) # noqa: B019
|
|
39
|
+
def list_cloud_accounts(self, teamspace_id: str) -> List[V1ExternalCluster]:
|
|
40
|
+
"""Lists the cloud accounts for a given teamspace.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
teamspace_id: The teamspace to list cloud accounts for
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
A list of cloud accounts
|
|
47
|
+
"""
|
|
48
|
+
res_project = self._client.cluster_service_list_project_clusters(
|
|
49
|
+
project_id=teamspace_id,
|
|
50
|
+
)
|
|
51
|
+
res_global = self._client.cluster_service_list_clusters(
|
|
52
|
+
project_id=teamspace_id,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
# can't use set here because the cloud_accounts are not hashable
|
|
56
|
+
cloud_accounts = []
|
|
57
|
+
cloud_account_ids = []
|
|
58
|
+
for cloud_account in res_project.clusters + res_global.clusters:
|
|
59
|
+
if cloud_account.id not in cloud_account_ids:
|
|
60
|
+
cloud_accounts.append(cloud_account)
|
|
61
|
+
cloud_account_ids.append(cloud_account.id)
|
|
62
|
+
|
|
63
|
+
return cloud_accounts
|
|
64
|
+
|
|
65
|
+
def get_cloud_account_non_org(self, teamspace_id: str, cloud_account_id: str) -> Optional[V1ExternalCluster]:
|
|
66
|
+
for cluster in self.list_cloud_accounts(teamspace_id=teamspace_id):
|
|
67
|
+
if cluster.id == cloud_account_id:
|
|
68
|
+
return cluster
|
|
69
|
+
|
|
70
|
+
return None
|
|
71
|
+
|
|
72
|
+
@lru_cache(maxsize=None) # noqa: B019
|
|
73
|
+
def list_cloud_account_accelerators(
|
|
74
|
+
self,
|
|
75
|
+
teamspace_id: str,
|
|
76
|
+
cloud_account_id: str,
|
|
77
|
+
org_id: str,
|
|
78
|
+
) -> Union[V1ListClusterAcceleratorsResponse, V1ListDefaultClusterAcceleratorsResponse]:
|
|
79
|
+
"""Lists the accelerators for a given cloud account.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
cloud_account_id: cluster ID to list accelerators for
|
|
83
|
+
"""
|
|
84
|
+
# map cloud_account to provider
|
|
85
|
+
cloud_provider = None
|
|
86
|
+
is_default = True
|
|
87
|
+
for cloud_account in self.list_cloud_accounts(teamspace_id=teamspace_id):
|
|
88
|
+
if cloud_account.id == cloud_account_id:
|
|
89
|
+
is_default = cloud_account.spec.cluster_type == V1ClusterType.GLOBAL
|
|
90
|
+
cloud_provider = self._get_cloud_account_provider(cloud_account)
|
|
91
|
+
break
|
|
92
|
+
|
|
93
|
+
if cloud_provider is None:
|
|
94
|
+
raise ValueError(
|
|
95
|
+
f"Cloud Account {cloud_account_id} is not a default cloud account. Are you in the correct teamspace?"
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
if is_default:
|
|
99
|
+
res = self._list_default_cluster_accelerators(teamspace_id=teamspace_id, cloud_provider=cloud_provider)
|
|
100
|
+
else:
|
|
101
|
+
res = self._client.cluster_service_list_cluster_accelerators(
|
|
102
|
+
id=cloud_account_id,
|
|
103
|
+
org_id=org_id,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
if not res:
|
|
107
|
+
raise ValueError(f"CloudAccount {cloud_account_id} does not exist")
|
|
108
|
+
return res
|
|
109
|
+
|
|
110
|
+
def _list_default_cluster_accelerators(
|
|
111
|
+
self, teamspace_id: str, cloud_provider: Union[str, "CloudProvider"]
|
|
112
|
+
) -> V1ListDefaultClusterAcceleratorsResponse:
|
|
113
|
+
return self._client.cluster_service_list_default_cluster_accelerators(
|
|
114
|
+
project_id=teamspace_id, cloud_provider=str(cloud_provider)
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
@lru_cache(maxsize=None) # noqa: B019
|
|
118
|
+
def list_global_cloud_accounts(self, teamspace_id: str) -> List[V1ExternalCluster]:
|
|
119
|
+
"""Lists the accelerators for a given teamspace.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
teamspace_id: id of the teamspace to get the associated cloud_accounts for
|
|
123
|
+
"""
|
|
124
|
+
cloud_accounts = self.list_cloud_accounts(teamspace_id=teamspace_id)
|
|
125
|
+
if not cloud_accounts:
|
|
126
|
+
raise ValueError(f"Teamspace {teamspace_id} does not exist")
|
|
127
|
+
filtered_cloud_accounts = filter(lambda x: x.spec.cluster_type == V1ClusterType.GLOBAL, cloud_accounts)
|
|
128
|
+
# TODO: remove aggregate filter once finished
|
|
129
|
+
filtered_cloud_accounts = filter(
|
|
130
|
+
lambda x: x.spec.driver != V1CloudProvider.LIGHTNING_AGGREGATE, filtered_cloud_accounts
|
|
131
|
+
)
|
|
132
|
+
return list(filtered_cloud_accounts)
|
|
133
|
+
|
|
134
|
+
def get_cloud_account_provider_mapping(self, teamspace_id: str) -> Dict["CloudProvider", str]:
|
|
135
|
+
"""Gets the cloud account <-> provider mapping."""
|
|
136
|
+
res = self.list_global_cloud_accounts(teamspace_id=teamspace_id)
|
|
137
|
+
return {self._get_cloud_account_provider(cloud_account): cloud_account.id for cloud_account in res}
|
|
138
|
+
|
|
139
|
+
@staticmethod
|
|
140
|
+
def _get_cloud_account_provider(cloud_account: Optional[V1ExternalCluster]) -> "CloudProvider":
|
|
141
|
+
"""Determines the cloud provider based on the cloud_account configuration.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
cloud_account: An optional Externalv1Cluster object containing cluster specifications
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
CloudProvider: The determined cloud provider, defaults to AWS if no match is found
|
|
148
|
+
"""
|
|
149
|
+
from lightning_sdk.machine import CloudProvider
|
|
150
|
+
|
|
151
|
+
if not cloud_account:
|
|
152
|
+
return CloudProvider.AWS
|
|
153
|
+
|
|
154
|
+
if cloud_account.spec and cloud_account.spec.driver:
|
|
155
|
+
if cloud_account.spec.driver == V1CloudProvider.LIGHTNING:
|
|
156
|
+
return CloudProvider.LIGHTNING
|
|
157
|
+
|
|
158
|
+
if cloud_account.spec.driver == V1CloudProvider.DGX:
|
|
159
|
+
return CloudProvider.DGX
|
|
160
|
+
|
|
161
|
+
if cloud_account.spec:
|
|
162
|
+
if cloud_account.spec.aws_v1:
|
|
163
|
+
return CloudProvider.AWS
|
|
164
|
+
if cloud_account.spec.google_cloud_v1:
|
|
165
|
+
return CloudProvider.GCP
|
|
166
|
+
if cloud_account.spec.lambda_labs_v1:
|
|
167
|
+
return CloudProvider.LAMBDA_LABS
|
|
168
|
+
if cloud_account.spec.vultr_v1:
|
|
169
|
+
return CloudProvider.VULTR
|
|
170
|
+
if cloud_account.spec.voltage_park_v1:
|
|
171
|
+
return CloudProvider.VOLTAGE_PARK
|
|
172
|
+
if cloud_account.spec.nebius_v1:
|
|
173
|
+
return CloudProvider.NEBIUS
|
|
174
|
+
|
|
175
|
+
return CloudProvider.AWS
|
|
176
|
+
|
|
177
|
+
def resolve_cloud_account(
|
|
178
|
+
self,
|
|
179
|
+
teamspace_id: str,
|
|
180
|
+
cloud_account: Optional[str],
|
|
181
|
+
cloud_provider: Optional[Union["CloudProvider", str]],
|
|
182
|
+
default_cloud_account: Optional[str],
|
|
183
|
+
) -> Optional[str]:
|
|
184
|
+
if cloud_account:
|
|
185
|
+
if cloud_provider:
|
|
186
|
+
cloud_account_resp = self.get_cloud_account_non_org(teamspace_id, cloud_account)
|
|
187
|
+
cloud_provider_resp = self._get_cloud_account_provider(cloud_account_resp)
|
|
188
|
+
if cloud_provider_resp != cloud_provider:
|
|
189
|
+
raise RuntimeError(
|
|
190
|
+
f"Specified both cloud_provider ({cloud_provider}) and "
|
|
191
|
+
"cloud_account ({cloud_account} has cloud provider {cloud_provider_resp}) which don't match!"
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
return cloud_account
|
|
195
|
+
|
|
196
|
+
if cloud_provider:
|
|
197
|
+
cloud_account_mapping = self.get_cloud_account_provider_mapping(teamspace_id=teamspace_id)
|
|
198
|
+
if cloud_provider and cloud_provider in cloud_account_mapping:
|
|
199
|
+
return cloud_account_mapping[cloud_provider]
|
|
200
|
+
|
|
201
|
+
if default_cloud_account:
|
|
202
|
+
return default_cloud_account
|
|
203
|
+
|
|
204
|
+
return None
|
|
@@ -264,6 +264,7 @@ class DeploymentApi:
|
|
|
264
264
|
custom_domain: Optional[str] = None,
|
|
265
265
|
quantity: Optional[int] = None,
|
|
266
266
|
include_credentials: Optional[bool] = None,
|
|
267
|
+
max_runtime: Optional[int] = None,
|
|
267
268
|
) -> V1Deployment:
|
|
268
269
|
# Update the deployment in place
|
|
269
270
|
|
|
@@ -291,6 +292,9 @@ class DeploymentApi:
|
|
|
291
292
|
requires_release |= apply_change(deployment.spec, "spot", spot)
|
|
292
293
|
requires_release |= apply_change(deployment.spec, "quantity", quantity)
|
|
293
294
|
requires_release |= apply_change(deployment.spec, "include_credentials", include_credentials)
|
|
295
|
+
requires_release |= apply_change(
|
|
296
|
+
deployment.spec, "requested_run_duration_seconds", str(max_runtime) if max_runtime is not None else None
|
|
297
|
+
)
|
|
294
298
|
|
|
295
299
|
if requires_release:
|
|
296
300
|
if deployment.strategy is None:
|
|
@@ -569,6 +573,7 @@ def to_spec(
|
|
|
569
573
|
quantity: Optional[int] = None,
|
|
570
574
|
include_credentials: Optional[bool] = None,
|
|
571
575
|
cloudspace_id: Optional[None] = None,
|
|
576
|
+
max_runtime: Optional[int] = None,
|
|
572
577
|
) -> V1JobSpec:
|
|
573
578
|
if cloud_account is None:
|
|
574
579
|
raise ValueError("The cloud account should be defined.")
|
|
@@ -585,6 +590,11 @@ def to_spec(
|
|
|
585
590
|
if command is None and cloudspace_id is not None:
|
|
586
591
|
raise ValueError("The command should be defined.")
|
|
587
592
|
|
|
593
|
+
# need to go via kwargs for typing compatibility since autogenerated apis accept None but aren't typed with None
|
|
594
|
+
optional_spec_kwargs = {}
|
|
595
|
+
if max_runtime:
|
|
596
|
+
optional_spec_kwargs["requested_run_duration_seconds"] = str(max_runtime)
|
|
597
|
+
|
|
588
598
|
return V1JobSpec(
|
|
589
599
|
cluster_id=cloud_account,
|
|
590
600
|
command=command,
|
|
@@ -597,6 +607,7 @@ def to_spec(
|
|
|
597
607
|
quantity=quantity,
|
|
598
608
|
include_credentials=include_credentials,
|
|
599
609
|
cloudspace_id=cloudspace_id,
|
|
610
|
+
**optional_spec_kwargs,
|
|
600
611
|
)
|
|
601
612
|
|
|
602
613
|
|
lightning_sdk/api/job_api.py
CHANGED
|
@@ -17,7 +17,7 @@ from lightning_sdk.lightning_cloud.openapi import (
|
|
|
17
17
|
JobsIdBody1,
|
|
18
18
|
ProjectIdJobsBody,
|
|
19
19
|
V1CloudSpace,
|
|
20
|
-
|
|
20
|
+
V1ClusterAccelerator,
|
|
21
21
|
V1DownloadJobLogsResponse,
|
|
22
22
|
V1DownloadLightningappInstanceLogsResponse,
|
|
23
23
|
V1EnvVar,
|
|
@@ -94,15 +94,48 @@ class JobApiV1:
|
|
|
94
94
|
def get_work(self, job_id: str, teamspace_id: str, work_id: str) -> Externalv1Lightningwork:
|
|
95
95
|
return self._client.lightningwork_service_get_lightningwork(project_id=teamspace_id, app_id=job_id, id=work_id)
|
|
96
96
|
|
|
97
|
-
def get_machine_from_work(self, work: Externalv1Lightningwork) -> Machine:
|
|
97
|
+
def get_machine_from_work(self, work: Externalv1Lightningwork, org_id: str) -> Machine:
|
|
98
98
|
spec: V1LightningworkSpec = work.spec
|
|
99
99
|
# prefer user-requested config if specified
|
|
100
100
|
user_requested_compute_config: V1UserRequestedComputeConfig = spec.user_requested_compute_config
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
101
|
+
accelerators = self._get_machines_for_cloud_account(
|
|
102
|
+
teamspace_id=work.project_id,
|
|
103
|
+
cloud_account_id=spec.cluster_id,
|
|
104
|
+
org_id=org_id,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
identifier = None
|
|
108
|
+
|
|
109
|
+
if user_requested_compute_config and user_requested_compute_config.name:
|
|
110
|
+
identifier = user_requested_compute_config.name
|
|
111
|
+
else:
|
|
112
|
+
identifier = spec.compute_config.instance_type
|
|
113
|
+
|
|
114
|
+
for accelerator in accelerators:
|
|
115
|
+
if identifier in (
|
|
116
|
+
accelerator.slug,
|
|
117
|
+
accelerator.slug_multi_cloud,
|
|
118
|
+
accelerator.instance_id,
|
|
119
|
+
):
|
|
120
|
+
return Machine.from_str(accelerator.slug_multi_cloud)
|
|
121
|
+
|
|
122
|
+
return Machine.from_str(identifier)
|
|
104
123
|
|
|
105
|
-
|
|
124
|
+
def _get_machines_for_cloud_account(
|
|
125
|
+
self, teamspace_id: str, cloud_account_id: str, org_id: str
|
|
126
|
+
) -> List[V1ClusterAccelerator]:
|
|
127
|
+
from lightning_sdk.api.cloud_account_api import CloudAccountApi
|
|
128
|
+
|
|
129
|
+
cloud_account_api = CloudAccountApi()
|
|
130
|
+
accelerators = cloud_account_api.list_cloud_account_accelerators(
|
|
131
|
+
teamspace_id=teamspace_id,
|
|
132
|
+
cloud_account_id=cloud_account_id,
|
|
133
|
+
org_id=org_id,
|
|
134
|
+
)
|
|
135
|
+
if not accelerators:
|
|
136
|
+
return []
|
|
137
|
+
|
|
138
|
+
return list(filter(lambda acc: acc.enabled, accelerators.accelerator))
|
|
106
139
|
|
|
107
140
|
def get_studio_name(self, job: Externalv1LightningappInstance) -> str:
|
|
108
141
|
cs: V1CloudSpace = self._client.cloud_space_service_get_cloud_space(
|
|
@@ -215,6 +248,7 @@ class JobApiV2:
|
|
|
215
248
|
path_mappings: Optional[Dict[str, str]],
|
|
216
249
|
artifacts_local: Optional[str], # deprecated in favor of path_mappings
|
|
217
250
|
artifacts_remote: Optional[str], # deprecated in favor of path_mappings
|
|
251
|
+
max_runtime: Optional[int] = None,
|
|
218
252
|
) -> V1Job:
|
|
219
253
|
body = self._create_job_body(
|
|
220
254
|
name=name,
|
|
@@ -231,6 +265,7 @@ class JobApiV2:
|
|
|
231
265
|
path_mappings=path_mappings,
|
|
232
266
|
artifacts_local=artifacts_local,
|
|
233
267
|
artifacts_remote=artifacts_remote,
|
|
268
|
+
max_runtime=max_runtime,
|
|
234
269
|
)
|
|
235
270
|
|
|
236
271
|
job: V1Job = self._client.jobs_service_create_job(project_id=teamspace_id, body=body)
|
|
@@ -252,6 +287,7 @@ class JobApiV2:
|
|
|
252
287
|
path_mappings: Optional[Dict[str, str]],
|
|
253
288
|
artifacts_local: Optional[str], # deprecated in favor of path_mappings
|
|
254
289
|
artifacts_remote: Optional[str], # deprecated in favor of path_mappings)
|
|
290
|
+
max_runtime: Optional[int] = None,
|
|
255
291
|
) -> ProjectIdJobsBody:
|
|
256
292
|
env_vars = []
|
|
257
293
|
if env is not None:
|
|
@@ -268,6 +304,11 @@ class JobApiV2:
|
|
|
268
304
|
artifacts_remote=artifacts_remote,
|
|
269
305
|
)
|
|
270
306
|
|
|
307
|
+
# need to go via kwargs for typing compatibility since autogenerated apis accept None but aren't typed with None
|
|
308
|
+
optional_spec_kwargs = {}
|
|
309
|
+
if max_runtime:
|
|
310
|
+
optional_spec_kwargs["requested_run_duration_seconds"] = str(max_runtime)
|
|
311
|
+
|
|
271
312
|
spec = V1JobSpec(
|
|
272
313
|
cloudspace_id=studio_id or "",
|
|
273
314
|
cluster_id=cloud_account or "",
|
|
@@ -281,6 +322,7 @@ class JobApiV2:
|
|
|
281
322
|
image_cluster_credentials=cloud_account_auth,
|
|
282
323
|
image_secret_ref=image_credentials or "",
|
|
283
324
|
path_mappings=path_mappings_list,
|
|
325
|
+
**optional_spec_kwargs,
|
|
284
326
|
)
|
|
285
327
|
return ProjectIdJobsBody(name=name, spec=spec)
|
|
286
328
|
|
|
@@ -371,11 +413,41 @@ class JobApiV2:
|
|
|
371
413
|
return Status.Stopping
|
|
372
414
|
return Status.Pending
|
|
373
415
|
|
|
374
|
-
def _get_job_machine_from_spec(self, spec: V1JobSpec) -> "Machine":
|
|
375
|
-
|
|
376
|
-
|
|
416
|
+
def _get_job_machine_from_spec(self, spec: V1JobSpec, teamspace_id: str, org_id: str) -> "Machine":
|
|
417
|
+
accelerators = self._get_machines_for_cloud_account(
|
|
418
|
+
teamspace_id=teamspace_id,
|
|
419
|
+
cloud_account_id=spec.cluster_id,
|
|
420
|
+
org_id=org_id,
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
for accelerator in accelerators:
|
|
424
|
+
possible_identifiers = (
|
|
425
|
+
accelerator.slug,
|
|
426
|
+
accelerator.slug_multi_cloud,
|
|
427
|
+
accelerator.instance_id,
|
|
428
|
+
)
|
|
429
|
+
if (spec.instance_name and spec.instance_name in possible_identifiers) or (
|
|
430
|
+
spec.instance_type and spec.instance_type in possible_identifiers
|
|
431
|
+
):
|
|
432
|
+
return Machine.from_str(accelerator.slug_multi_cloud)
|
|
433
|
+
|
|
434
|
+
return Machine.from_str(spec.instance_name or spec.instance_type)
|
|
435
|
+
|
|
436
|
+
def _get_machines_for_cloud_account(
|
|
437
|
+
self, teamspace_id: str, cloud_account_id: str, org_id: str
|
|
438
|
+
) -> List[V1ClusterAccelerator]:
|
|
439
|
+
from lightning_sdk.api.cloud_account_api import CloudAccountApi
|
|
440
|
+
|
|
441
|
+
cloud_account_api = CloudAccountApi()
|
|
442
|
+
accelerators = cloud_account_api.list_cloud_account_accelerators(
|
|
443
|
+
teamspace_id=teamspace_id,
|
|
444
|
+
cloud_account_id=cloud_account_id,
|
|
445
|
+
org_id=org_id,
|
|
446
|
+
)
|
|
447
|
+
if not accelerators:
|
|
448
|
+
return []
|
|
377
449
|
|
|
378
|
-
return
|
|
450
|
+
return list(filter(lambda acc: acc.enabled, accelerators.accelerator))
|
|
379
451
|
|
|
380
452
|
def get_total_cost(self, job: V1Job) -> float:
|
|
381
453
|
return job.total_cost
|
lightning_sdk/api/llm_api.py
CHANGED
|
@@ -98,7 +98,7 @@ class LLMApi:
|
|
|
98
98
|
{"contentType": "text", "parts": [prompt]},
|
|
99
99
|
],
|
|
100
100
|
},
|
|
101
|
-
"
|
|
101
|
+
"max_tokens": max_completion_tokens,
|
|
102
102
|
"conversation_id": conversation_id,
|
|
103
103
|
"billing_project_id": billing_project_id,
|
|
104
104
|
"name": name,
|
lightning_sdk/api/mmt_api.py
CHANGED
|
@@ -2,7 +2,7 @@ import json
|
|
|
2
2
|
import time
|
|
3
3
|
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
|
4
4
|
|
|
5
|
-
from lightning_sdk.api.job_api import JobApiV1
|
|
5
|
+
from lightning_sdk.api.job_api import JobApiV1, V1ClusterAccelerator
|
|
6
6
|
from lightning_sdk.api.utils import (
|
|
7
7
|
_create_app,
|
|
8
8
|
_machine_to_compute_name,
|
|
@@ -87,6 +87,7 @@ class MMTApiV2:
|
|
|
87
87
|
path_mappings: Optional[Dict[str, str]],
|
|
88
88
|
artifacts_local: Optional[str], # deprecated in favor of path_mappings
|
|
89
89
|
artifacts_remote: Optional[str], # deprecated in favor of path_mappings
|
|
90
|
+
max_runtime: Optional[int],
|
|
90
91
|
) -> V1MultiMachineJob:
|
|
91
92
|
body = self._create_mmt_body(
|
|
92
93
|
name=name,
|
|
@@ -104,6 +105,7 @@ class MMTApiV2:
|
|
|
104
105
|
path_mappings=path_mappings,
|
|
105
106
|
artifacts_local=artifacts_local, # deprecated in favor of path_mappings
|
|
106
107
|
artifacts_remote=artifacts_remote, # deprecated in favor of path_mappings
|
|
108
|
+
max_runtime=max_runtime,
|
|
107
109
|
)
|
|
108
110
|
|
|
109
111
|
job: V1MultiMachineJob = self._client.jobs_service_create_multi_machine_job(project_id=teamspace_id, body=body)
|
|
@@ -126,6 +128,7 @@ class MMTApiV2:
|
|
|
126
128
|
path_mappings: Optional[Dict[str, str]],
|
|
127
129
|
artifacts_local: Optional[str], # deprecated in favor of path_mappings
|
|
128
130
|
artifacts_remote: Optional[str], # deprecated in favor of path_mappings
|
|
131
|
+
max_runtime: Optional[int] = None,
|
|
129
132
|
) -> ProjectIdMultimachinejobsBody:
|
|
130
133
|
env_vars = []
|
|
131
134
|
if env is not None:
|
|
@@ -142,6 +145,11 @@ class MMTApiV2:
|
|
|
142
145
|
artifacts_remote=artifacts_remote,
|
|
143
146
|
)
|
|
144
147
|
|
|
148
|
+
# need to go via kwargs for typing compatibility since autogenerated apis accept None but aren't typed with None
|
|
149
|
+
optional_spec_kwargs = {}
|
|
150
|
+
if max_runtime:
|
|
151
|
+
optional_spec_kwargs["requested_run_duration_seconds"] = str(max_runtime)
|
|
152
|
+
|
|
145
153
|
spec = V1JobSpec(
|
|
146
154
|
cloudspace_id=studio_id or "",
|
|
147
155
|
cluster_id=cloud_account or "",
|
|
@@ -155,6 +163,7 @@ class MMTApiV2:
|
|
|
155
163
|
image_cluster_credentials=cloud_account_auth,
|
|
156
164
|
image_secret_ref=image_credentials or "",
|
|
157
165
|
path_mappings=path_mappings_list,
|
|
166
|
+
**optional_spec_kwargs,
|
|
158
167
|
)
|
|
159
168
|
return ProjectIdMultimachinejobsBody(
|
|
160
169
|
name=name, spec=spec, cluster_id=cloud_account or "", machines=num_machines
|
|
@@ -238,11 +247,41 @@ class MMTApiV2:
|
|
|
238
247
|
def get_command(self, job: V1MultiMachineJob) -> str:
|
|
239
248
|
return job.spec.command
|
|
240
249
|
|
|
241
|
-
def _get_job_machine_from_spec(self, spec: V1JobSpec) -> "Machine":
|
|
242
|
-
|
|
243
|
-
|
|
250
|
+
def _get_job_machine_from_spec(self, spec: V1JobSpec, teamspace_id: str, org_id: str) -> "Machine":
|
|
251
|
+
accelerators = self._get_machines_for_cloud_account(
|
|
252
|
+
teamspace_id=teamspace_id,
|
|
253
|
+
cloud_account_id=spec.cluster_id,
|
|
254
|
+
org_id=org_id,
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
for accelerator in accelerators:
|
|
258
|
+
possible_identifiers = (
|
|
259
|
+
accelerator.slug,
|
|
260
|
+
accelerator.slug_multi_cloud,
|
|
261
|
+
accelerator.instance_id,
|
|
262
|
+
)
|
|
263
|
+
if (spec.instance_name and spec.instance_name in possible_identifiers) or (
|
|
264
|
+
spec.instance_type and spec.instance_type in possible_identifiers
|
|
265
|
+
):
|
|
266
|
+
return Machine.from_str(accelerator.slug_multi_cloud)
|
|
267
|
+
|
|
268
|
+
return Machine.from_str(spec.instance_name or spec.instance_type)
|
|
269
|
+
|
|
270
|
+
def _get_machines_for_cloud_account(
|
|
271
|
+
self, teamspace_id: str, cloud_account_id: str, org_id: str
|
|
272
|
+
) -> List[V1ClusterAccelerator]:
|
|
273
|
+
from lightning_sdk.api.cloud_account_api import CloudAccountApi
|
|
274
|
+
|
|
275
|
+
cloud_account_api = CloudAccountApi()
|
|
276
|
+
accelerators = cloud_account_api.list_cloud_account_accelerators(
|
|
277
|
+
teamspace_id=teamspace_id,
|
|
278
|
+
cloud_account_id=cloud_account_id,
|
|
279
|
+
org_id=org_id,
|
|
280
|
+
)
|
|
281
|
+
if not accelerators:
|
|
282
|
+
return []
|
|
244
283
|
|
|
245
|
-
return
|
|
284
|
+
return list(filter(lambda acc: acc.enabled, accelerators.accelerator))
|
|
246
285
|
|
|
247
286
|
def get_total_cost(self, job: V1MultiMachineJob) -> float:
|
|
248
287
|
return job.total_cost
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from typing import TYPE_CHECKING, List, Optional, Union
|
|
2
2
|
|
|
3
|
-
from lightning_sdk.api.
|
|
3
|
+
from lightning_sdk.api.cloud_account_api import CloudAccountApi
|
|
4
4
|
from lightning_sdk.lightning_cloud.openapi.models import (
|
|
5
5
|
ProjectIdPipelinesBody,
|
|
6
6
|
ProjectIdSchedulesBody,
|
|
@@ -23,7 +23,7 @@ class PipelineApi:
|
|
|
23
23
|
|
|
24
24
|
def __init__(self) -> None:
|
|
25
25
|
self._client = LightningClient(max_tries=0, retry=False)
|
|
26
|
-
self.
|
|
26
|
+
self._cloud_account_api = CloudAccountApi()
|
|
27
27
|
|
|
28
28
|
def get_pipeline_by_id(self, project_id: str, pipeline_id_or_name: str) -> Optional[V1Pipeline]:
|
|
29
29
|
if pipeline_id_or_name.startswith("pip_"):
|
|
@@ -75,6 +75,7 @@ class PipelineApi:
|
|
|
75
75
|
resource_id=pipeline.id,
|
|
76
76
|
parent_resource_id=parent_pipeline_id or "",
|
|
77
77
|
resource_type=V1ScheduleResourceType.PIPELINE,
|
|
78
|
+
timezone=schedule.timezone,
|
|
78
79
|
)
|
|
79
80
|
|
|
80
81
|
self._client.schedules_service_create_schedule(body, teamspace.id)
|
|
@@ -97,7 +98,7 @@ class PipelineApi:
|
|
|
97
98
|
|
|
98
99
|
from lightning_sdk.pipeline.utils import _get_cloud_account
|
|
99
100
|
|
|
100
|
-
clusters = self.
|
|
101
|
+
clusters = self._cloud_account_api.list_cloud_accounts(teamspace_id=teamspace.id)
|
|
101
102
|
|
|
102
103
|
selected_cluster = None
|
|
103
104
|
selected_cluster_id = _get_cloud_account(steps)
|