lightning-sdk 0.1.32__py3-none-any.whl → 0.1.33__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lightning_sdk/__init__.py CHANGED
@@ -27,5 +27,5 @@ __all__ = [
27
27
  "AIHub",
28
28
  ]
29
29
 
30
- __version__ = "0.1.32"
30
+ __version__ = "0.1.33"
31
31
  _check_version_and_prompt_upgrade(__version__)
lightning_sdk/ai_hub.py CHANGED
@@ -1,6 +1,12 @@
1
- from typing import Dict, List
1
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
2
2
 
3
- from lightning_sdk.api.ai_hub_api import AIHubApi
3
+ from lightning_sdk.api import AIHubApi, UserApi
4
+ from lightning_sdk.lightning_cloud import login
5
+ from lightning_sdk.user import User
6
+ from lightning_sdk.utils.resolve import _resolve_org, _resolve_teamspace
7
+
8
+ if TYPE_CHECKING:
9
+ from lightning_sdk import Organization, Teamspace
4
10
 
5
11
 
6
12
  class AIHub:
@@ -13,10 +19,17 @@ class AIHub:
13
19
 
14
20
  def __init__(self) -> None:
15
21
  self._api = AIHubApi()
22
+ self._auth = None
23
+
24
+ def list_apis(self, search: Optional[str] = None) -> List[Dict[str, str]]:
25
+ """Get a list of AI Hub API templates.
16
26
 
17
- def list_apis(self) -> List[Dict[str, str]]:
18
- """Get a list of AI Hub API templates."""
19
- api_templates = self._api.list_apis()
27
+ Example:
28
+ api_hub = AIHub()
29
+ api_list = api_hub.list_apis(search="Llama")
30
+ """
31
+ search_query = search or ""
32
+ api_templates = self._api.list_apis(search_query=search_query)
20
33
  results = []
21
34
  for template in api_templates:
22
35
  result = {
@@ -24,6 +37,74 @@ class AIHub:
24
37
  "name": template.name,
25
38
  "description": template.description,
26
39
  "creator_username": template.creator_username,
40
+ "created_on": template.creation_timestamp.strftime("%Y-%m-%d %H:%M:%S")
41
+ if template.creation_timestamp
42
+ else None,
27
43
  }
28
44
  results.append(result)
29
45
  return results
46
+
47
+ def _authenticate(
48
+ self,
49
+ teamspace: Optional[Union[str, "Teamspace"]] = None,
50
+ org: Optional[Union[str, "Organization"]] = None,
51
+ user: Optional[Union[str, "User"]] = None,
52
+ ) -> "Teamspace":
53
+ if self._auth is None:
54
+ self._auth = login.Auth()
55
+ try:
56
+ self._auth.authenticate()
57
+ user = User(name=UserApi()._get_user_by_id(self._auth.user_id).username)
58
+ except ConnectionError as e:
59
+ raise e
60
+
61
+ org = _resolve_org(org)
62
+ teamspace = _resolve_teamspace(teamspace=teamspace, org=org, user=user if org is None else None)
63
+ if teamspace is None:
64
+ raise ValueError("You need to pass a teamspace or an org for your deployment.")
65
+ return teamspace
66
+
67
+ def deploy(
68
+ self,
69
+ api_id: str,
70
+ cluster: Optional[str] = None,
71
+ name: Optional[str] = None,
72
+ teamspace: Optional[Union[str, "Teamspace"]] = None,
73
+ org: Optional[Union[str, "Organization"]] = None,
74
+ **kwargs: Dict[str, Any],
75
+ ) -> Dict[str, Union[str, bool]]:
76
+ """Deploy an API from the AI Hub.
77
+
78
+ Example:
79
+ from lightning_sdk import AIHub
80
+ ai_hub = AIHub()
81
+ deployment = ai_hub.deploy("temp_01jc37n6qpqkdptjpyep0z06hy", batch_size=10)
82
+
83
+ Args:
84
+ api_id: The ID of the API you want to deploy.
85
+ cluster: The cluster where you want to deploy the API, such as "lightning-public-prod".
86
+ name: Name for the deployed API. Defaults to None.
87
+ teamspace: The team or group for deployment. Defaults to None.
88
+ org: The organization for deployment. Defaults to None.
89
+ **kwargs: Additional keyword arguments for deployment.
90
+
91
+ Returns:
92
+ A dictionary containing the name of the deployed API,
93
+ the URL to access it, and whether it is interruptible.
94
+
95
+ Raises:
96
+ ValueError: If a teamspace or organization is not provided.
97
+ ConnectionError: If there is an issue with logging in.
98
+ """
99
+ teamspace = self._authenticate(teamspace, org)
100
+ teamspace_id = teamspace.id
101
+
102
+ deployment = self._api.deploy_api(
103
+ template_id=api_id, cluster_id=cluster, project_id=teamspace_id, name=name, **kwargs
104
+ )
105
+ return {
106
+ "id": deployment.id,
107
+ "name": deployment.name,
108
+ "base_url": deployment.status.urls[0],
109
+ "interruptible": deployment.spec.spot,
110
+ }
@@ -1,5 +1,11 @@
1
- from typing import List
1
+ import re
2
+ from typing import List, Optional
2
3
 
4
+ from lightning_sdk.lightning_cloud.openapi.models import (
5
+ CreateDeploymentRequestDefinesASpecForTheJobThatAllowsForAutoscalingJobs,
6
+ V1Deployment,
7
+ V1ParameterizationSpec,
8
+ )
3
9
  from lightning_sdk.lightning_cloud.openapi.models.v1_deployment_template_gallery_response import (
4
10
  V1DeploymentTemplateGalleryResponse,
5
11
  )
@@ -8,8 +14,48 @@ from lightning_sdk.lightning_cloud.rest_client import LightningClient
8
14
 
9
15
  class AIHubApi:
10
16
  def __init__(self) -> None:
11
- self._client = LightningClient(max_tries=7)
17
+ self._client = LightningClient(max_tries=3)
12
18
 
13
- def list_apis(self) -> List[V1DeploymentTemplateGalleryResponse]:
19
+ def list_apis(self, search_query: str) -> List[V1DeploymentTemplateGalleryResponse]:
14
20
  kwargs = {"show_globally_visible": True}
15
- return self._client.deployment_templates_service_list_published_deployment_templates(**kwargs).templates
21
+ return self._client.deployment_templates_service_list_published_deployment_templates(
22
+ search_query=search_query, **kwargs
23
+ ).templates
24
+
25
+ @staticmethod
26
+ def _parse_and_update_args(cmd: str, **kwargs: dict) -> list:
27
+ """Parse the command and update the arguments with the provided kwargs.
28
+
29
+ >>> _parse_and_update_args("--arg1 1 --arg2=2", arg1=3)
30
+ ['--arg1 3']
31
+ """
32
+ keys = [key.lstrip("-") for key in re.findall(r"--\w+", cmd)]
33
+ arguments = {}
34
+ for key in keys:
35
+ if key in kwargs:
36
+ arguments[key] = kwargs[key]
37
+ return [f"--{k} {v}" for k, v in arguments.items()]
38
+
39
+ @staticmethod
40
+ def _resolve_api_arguments(parameter_spec: "V1ParameterizationSpec", **kwargs: dict) -> str:
41
+ return " ".join(AIHubApi._parse_and_update_args(parameter_spec.command, **kwargs))
42
+
43
+ def deploy_api(
44
+ self, template_id: str, project_id: str, cluster_id: str, name: Optional[str], **kwargs: dict
45
+ ) -> V1Deployment:
46
+ template = self._client.deployment_templates_service_get_deployment_template(template_id)
47
+ name = name or template.name
48
+ template.spec_v2.endpoint.id = None
49
+ command = self._resolve_api_arguments(template.parameter_spec, **kwargs)
50
+ template.spec_v2.job.command = command
51
+ return self._client.jobs_service_create_deployment(
52
+ project_id=project_id,
53
+ body=CreateDeploymentRequestDefinesASpecForTheJobThatAllowsForAutoscalingJobs(
54
+ autoscaling=template.spec_v2.autoscaling,
55
+ cluster_id=cluster_id,
56
+ endpoint=template.spec_v2.endpoint,
57
+ name=name,
58
+ replicas=0,
59
+ spec=template.spec_v2.job,
60
+ ),
61
+ )
@@ -5,7 +5,7 @@ import time
5
5
  import warnings
6
6
  import zipfile
7
7
  from threading import Event, Thread
8
- from typing import Any, Dict, Mapping, Optional, Tuple
8
+ from typing import Any, Dict, Mapping, Optional, Tuple, Union
9
9
 
10
10
  import backoff
11
11
  import requests
@@ -18,6 +18,7 @@ from lightning_sdk.api.utils import (
18
18
  _DummyBody,
19
19
  _DummyResponse,
20
20
  _FileUploader,
21
+ _machine_to_compute_name,
21
22
  _sanitize_studio_remote_path,
22
23
  )
23
24
  from lightning_sdk.api.utils import (
@@ -146,9 +147,11 @@ class StudioApi:
146
147
  """Retries checking the sync_in_progress value of the code status when there's an AttributeError."""
147
148
  return self.get_studio_status(studio_id, teamspace_id).in_use.sync_in_progress
148
149
 
149
- def start_studio(self, studio_id: str, teamspace_id: str, machine: Machine, interruptible: False) -> None:
150
+ def start_studio(
151
+ self, studio_id: str, teamspace_id: str, machine: Union[Machine, str], interruptible: False
152
+ ) -> None:
150
153
  """Start an existing Studio."""
151
- if machine == Machine.CPU_SMALL:
154
+ if _machine_to_compute_name(machine) == _machine_to_compute_name(Machine.CPU_SMALL):
152
155
  warnings.warn(
153
156
  f"{Machine.CPU_SMALL} is not a valid machine for starting a Studio. "
154
157
  "It is reserved for running jobs only. "
@@ -158,7 +161,7 @@ class StudioApi:
158
161
 
159
162
  self._client.cloud_space_service_start_cloud_space_instance(
160
163
  IdStartBody(
161
- compute_config=V1UserRequestedComputeConfig(name=_MACHINE_TO_COMPUTE_NAME[machine], spot=interruptible)
164
+ compute_config=V1UserRequestedComputeConfig(name=_machine_to_compute_name(machine), spot=interruptible)
162
165
  ),
163
166
  teamspace_id,
164
167
  studio_id,
@@ -204,9 +207,11 @@ class StudioApi:
204
207
  def _get_studio_instance_status_from_object(self, studio: V1CloudSpace) -> Optional[str]:
205
208
  return getattr(getattr(studio.code_status, "in_use", None), "phase", None)
206
209
 
207
- def _request_switch(self, studio_id: str, teamspace_id: str, machine: Machine, interruptible: bool) -> None:
210
+ def _request_switch(
211
+ self, studio_id: str, teamspace_id: str, machine: Union[Machine, str], interruptible: bool
212
+ ) -> None:
208
213
  """Switches given Studio to a new machine type."""
209
- if machine == Machine.CPU_SMALL:
214
+ if _machine_to_compute_name(machine) == _machine_to_compute_name(Machine.CPU_SMALL):
210
215
  warnings.warn(
211
216
  f"{Machine.CPU_SMALL} is not a valid machine for switching a Studio. "
212
217
  "It is reserved for running jobs only. "
@@ -214,7 +219,7 @@ class StudioApi:
214
219
  )
215
220
  machine = Machine.CPU
216
221
 
217
- compute_name = _MACHINE_TO_COMPUTE_NAME[machine]
222
+ compute_name = _machine_to_compute_name(machine)
218
223
  # TODO: UI sends disk size here, maybe we need to also?
219
224
  body = IdCodeconfigBody(compute_config=V1UserRequestedComputeConfig(name=compute_name, spot=interruptible))
220
225
  self._client.cloud_space_service_update_cloud_space_instance_config(
@@ -223,7 +228,9 @@ class StudioApi:
223
228
  body=body,
224
229
  )
225
230
 
226
- def switch_studio_machine(self, studio_id: str, teamspace_id: str, machine: Machine, interruptible: bool) -> None:
231
+ def switch_studio_machine(
232
+ self, studio_id: str, teamspace_id: str, machine: Union[Machine, str], interruptible: bool
233
+ ) -> None:
227
234
  """Switches given Studio to a new machine type."""
228
235
  self._request_switch(
229
236
  studio_id=studio_id, teamspace_id=teamspace_id, machine=machine, interruptible=interruptible
@@ -4,7 +4,7 @@ import os
4
4
  from concurrent.futures import ThreadPoolExecutor
5
5
  from functools import partial
6
6
  from pathlib import Path
7
- from typing import Any, Dict, List, Optional, Tuple
7
+ from typing import Any, Dict, List, Optional, Tuple, Union
8
8
 
9
9
  import backoff
10
10
  import requests
@@ -340,6 +340,13 @@ _MACHINE_TO_COMPUTE_NAME: Dict[Machine, str] = {
340
340
  Machine.H200_X_8: "p5e.48xlarge",
341
341
  }
342
342
 
343
+
344
+ def _machine_to_compute_name(machine: Union[Machine, str]) -> str:
345
+ if isinstance(machine, Machine):
346
+ return _MACHINE_TO_COMPUTE_NAME[machine]
347
+ return machine
348
+
349
+
343
350
  _COMPUTE_NAME_TO_MACHINE: Dict[str, Machine] = {v: k for k, v in _MACHINE_TO_COMPUTE_NAME.items()}
344
351
 
345
352
  _DEFAULT_CLOUD_URL = "https://lightning.ai:443"
@@ -45,6 +45,7 @@ class CreateDeploymentRequestDefinesASpecForTheJobThatAllowsForAutoscalingJobs(o
45
45
  'cluster_id': 'str',
46
46
  'endpoint': 'V1Endpoint',
47
47
  'name': 'str',
48
+ 'parameter_spec': 'V1ParameterizationSpec',
48
49
  'replicas': 'int',
49
50
  'spec': 'V1JobSpec',
50
51
  'strategy': 'V1DeploymentStrategy'
@@ -55,17 +56,19 @@ class CreateDeploymentRequestDefinesASpecForTheJobThatAllowsForAutoscalingJobs(o
55
56
  'cluster_id': 'clusterId',
56
57
  'endpoint': 'endpoint',
57
58
  'name': 'name',
59
+ 'parameter_spec': 'parameterSpec',
58
60
  'replicas': 'replicas',
59
61
  'spec': 'spec',
60
62
  'strategy': 'strategy'
61
63
  }
62
64
 
63
- def __init__(self, autoscaling: 'V1AutoscalingSpec' =None, cluster_id: 'str' =None, endpoint: 'V1Endpoint' =None, name: 'str' =None, replicas: 'int' =None, spec: 'V1JobSpec' =None, strategy: 'V1DeploymentStrategy' =None): # noqa: E501
65
+ def __init__(self, autoscaling: 'V1AutoscalingSpec' =None, cluster_id: 'str' =None, endpoint: 'V1Endpoint' =None, name: 'str' =None, parameter_spec: 'V1ParameterizationSpec' =None, replicas: 'int' =None, spec: 'V1JobSpec' =None, strategy: 'V1DeploymentStrategy' =None): # noqa: E501
64
66
  """CreateDeploymentRequestDefinesASpecForTheJobThatAllowsForAutoscalingJobs - a model defined in Swagger""" # noqa: E501
65
67
  self._autoscaling = None
66
68
  self._cluster_id = None
67
69
  self._endpoint = None
68
70
  self._name = None
71
+ self._parameter_spec = None
69
72
  self._replicas = None
70
73
  self._spec = None
71
74
  self._strategy = None
@@ -78,6 +81,8 @@ class CreateDeploymentRequestDefinesASpecForTheJobThatAllowsForAutoscalingJobs(o
78
81
  self.endpoint = endpoint
79
82
  if name is not None:
80
83
  self.name = name
84
+ if parameter_spec is not None:
85
+ self.parameter_spec = parameter_spec
81
86
  if replicas is not None:
82
87
  self.replicas = replicas
83
88
  if spec is not None:
@@ -169,6 +174,27 @@ class CreateDeploymentRequestDefinesASpecForTheJobThatAllowsForAutoscalingJobs(o
169
174
 
170
175
  self._name = name
171
176
 
177
+ @property
178
+ def parameter_spec(self) -> 'V1ParameterizationSpec':
179
+ """Gets the parameter_spec of this CreateDeploymentRequestDefinesASpecForTheJobThatAllowsForAutoscalingJobs. # noqa: E501
180
+
181
+
182
+ :return: The parameter_spec of this CreateDeploymentRequestDefinesASpecForTheJobThatAllowsForAutoscalingJobs. # noqa: E501
183
+ :rtype: V1ParameterizationSpec
184
+ """
185
+ return self._parameter_spec
186
+
187
+ @parameter_spec.setter
188
+ def parameter_spec(self, parameter_spec: 'V1ParameterizationSpec'):
189
+ """Sets the parameter_spec of this CreateDeploymentRequestDefinesASpecForTheJobThatAllowsForAutoscalingJobs.
190
+
191
+
192
+ :param parameter_spec: The parameter_spec of this CreateDeploymentRequestDefinesASpecForTheJobThatAllowsForAutoscalingJobs. # noqa: E501
193
+ :type: V1ParameterizationSpec
194
+ """
195
+
196
+ self._parameter_spec = parameter_spec
197
+
172
198
  @property
173
199
  def replicas(self) -> 'int':
174
200
  """Gets the replicas of this CreateDeploymentRequestDefinesASpecForTheJobThatAllowsForAutoscalingJobs. # noqa: E501
@@ -47,6 +47,7 @@ class DeploymentsIdBody(object):
47
47
  'endpoint': 'V1Endpoint',
48
48
  'is_published': 'bool',
49
49
  'name': 'str',
50
+ 'parameter_spec': 'V1ParameterizationSpec',
50
51
  'release_id': 'str',
51
52
  'replicas': 'int',
52
53
  'spec': 'V1JobSpec',
@@ -63,6 +64,7 @@ class DeploymentsIdBody(object):
63
64
  'endpoint': 'endpoint',
64
65
  'is_published': 'isPublished',
65
66
  'name': 'name',
67
+ 'parameter_spec': 'parameterSpec',
66
68
  'release_id': 'releaseId',
67
69
  'replicas': 'replicas',
68
70
  'spec': 'spec',
@@ -72,7 +74,7 @@ class DeploymentsIdBody(object):
72
74
  'user_id': 'userId'
73
75
  }
74
76
 
75
- def __init__(self, autoscaling: 'V1AutoscalingSpec' =None, created_at: 'datetime' =None, desired_state: 'V1DeploymentState' =None, endpoint: 'V1Endpoint' =None, is_published: 'bool' =None, name: 'str' =None, release_id: 'str' =None, replicas: 'int' =None, spec: 'V1JobSpec' =None, status: 'V1DeploymentStatus' =None, strategy: 'V1DeploymentStrategy' =None, updated_at: 'datetime' =None, user_id: 'str' =None): # noqa: E501
77
+ def __init__(self, autoscaling: 'V1AutoscalingSpec' =None, created_at: 'datetime' =None, desired_state: 'V1DeploymentState' =None, endpoint: 'V1Endpoint' =None, is_published: 'bool' =None, name: 'str' =None, parameter_spec: 'V1ParameterizationSpec' =None, release_id: 'str' =None, replicas: 'int' =None, spec: 'V1JobSpec' =None, status: 'V1DeploymentStatus' =None, strategy: 'V1DeploymentStrategy' =None, updated_at: 'datetime' =None, user_id: 'str' =None): # noqa: E501
76
78
  """DeploymentsIdBody - a model defined in Swagger""" # noqa: E501
77
79
  self._autoscaling = None
78
80
  self._created_at = None
@@ -80,6 +82,7 @@ class DeploymentsIdBody(object):
80
82
  self._endpoint = None
81
83
  self._is_published = None
82
84
  self._name = None
85
+ self._parameter_spec = None
83
86
  self._release_id = None
84
87
  self._replicas = None
85
88
  self._spec = None
@@ -100,6 +103,8 @@ class DeploymentsIdBody(object):
100
103
  self.is_published = is_published
101
104
  if name is not None:
102
105
  self.name = name
106
+ if parameter_spec is not None:
107
+ self.parameter_spec = parameter_spec
103
108
  if release_id is not None:
104
109
  self.release_id = release_id
105
110
  if replicas is not None:
@@ -241,6 +246,27 @@ class DeploymentsIdBody(object):
241
246
 
242
247
  self._name = name
243
248
 
249
+ @property
250
+ def parameter_spec(self) -> 'V1ParameterizationSpec':
251
+ """Gets the parameter_spec of this DeploymentsIdBody. # noqa: E501
252
+
253
+
254
+ :return: The parameter_spec of this DeploymentsIdBody. # noqa: E501
255
+ :rtype: V1ParameterizationSpec
256
+ """
257
+ return self._parameter_spec
258
+
259
+ @parameter_spec.setter
260
+ def parameter_spec(self, parameter_spec: 'V1ParameterizationSpec'):
261
+ """Sets the parameter_spec of this DeploymentsIdBody.
262
+
263
+
264
+ :param parameter_spec: The parameter_spec of this DeploymentsIdBody. # noqa: E501
265
+ :type: V1ParameterizationSpec
266
+ """
267
+
268
+ self._parameter_spec = parameter_spec
269
+
244
270
  @property
245
271
  def release_id(self) -> 'str':
246
272
  """Gets the release_id of this DeploymentsIdBody. # noqa: E501
@@ -48,6 +48,7 @@ class V1Deployment(object):
48
48
  'id': 'str',
49
49
  'is_published': 'bool',
50
50
  'name': 'str',
51
+ 'parameter_spec': 'V1ParameterizationSpec',
51
52
  'project_id': 'str',
52
53
  'release_id': 'str',
53
54
  'replicas': 'int',
@@ -66,6 +67,7 @@ class V1Deployment(object):
66
67
  'id': 'id',
67
68
  'is_published': 'isPublished',
68
69
  'name': 'name',
70
+ 'parameter_spec': 'parameterSpec',
69
71
  'project_id': 'projectId',
70
72
  'release_id': 'releaseId',
71
73
  'replicas': 'replicas',
@@ -76,7 +78,7 @@ class V1Deployment(object):
76
78
  'user_id': 'userId'
77
79
  }
78
80
 
79
- def __init__(self, autoscaling: 'V1AutoscalingSpec' =None, created_at: 'datetime' =None, desired_state: 'V1DeploymentState' =None, endpoint: 'V1Endpoint' =None, id: 'str' =None, is_published: 'bool' =None, name: 'str' =None, project_id: 'str' =None, release_id: 'str' =None, replicas: 'int' =None, spec: 'V1JobSpec' =None, status: 'V1DeploymentStatus' =None, strategy: 'V1DeploymentStrategy' =None, updated_at: 'datetime' =None, user_id: 'str' =None): # noqa: E501
81
+ def __init__(self, autoscaling: 'V1AutoscalingSpec' =None, created_at: 'datetime' =None, desired_state: 'V1DeploymentState' =None, endpoint: 'V1Endpoint' =None, id: 'str' =None, is_published: 'bool' =None, name: 'str' =None, parameter_spec: 'V1ParameterizationSpec' =None, project_id: 'str' =None, release_id: 'str' =None, replicas: 'int' =None, spec: 'V1JobSpec' =None, status: 'V1DeploymentStatus' =None, strategy: 'V1DeploymentStrategy' =None, updated_at: 'datetime' =None, user_id: 'str' =None): # noqa: E501
80
82
  """V1Deployment - a model defined in Swagger""" # noqa: E501
81
83
  self._autoscaling = None
82
84
  self._created_at = None
@@ -85,6 +87,7 @@ class V1Deployment(object):
85
87
  self._id = None
86
88
  self._is_published = None
87
89
  self._name = None
90
+ self._parameter_spec = None
88
91
  self._project_id = None
89
92
  self._release_id = None
90
93
  self._replicas = None
@@ -108,6 +111,8 @@ class V1Deployment(object):
108
111
  self.is_published = is_published
109
112
  if name is not None:
110
113
  self.name = name
114
+ if parameter_spec is not None:
115
+ self.parameter_spec = parameter_spec
111
116
  if project_id is not None:
112
117
  self.project_id = project_id
113
118
  if release_id is not None:
@@ -272,6 +277,27 @@ class V1Deployment(object):
272
277
 
273
278
  self._name = name
274
279
 
280
+ @property
281
+ def parameter_spec(self) -> 'V1ParameterizationSpec':
282
+ """Gets the parameter_spec of this V1Deployment. # noqa: E501
283
+
284
+
285
+ :return: The parameter_spec of this V1Deployment. # noqa: E501
286
+ :rtype: V1ParameterizationSpec
287
+ """
288
+ return self._parameter_spec
289
+
290
+ @parameter_spec.setter
291
+ def parameter_spec(self, parameter_spec: 'V1ParameterizationSpec'):
292
+ """Sets the parameter_spec of this V1Deployment.
293
+
294
+
295
+ :param parameter_spec: The parameter_spec of this V1Deployment. # noqa: E501
296
+ :type: V1ParameterizationSpec
297
+ """
298
+
299
+ self._parameter_spec = parameter_spec
300
+
275
301
  @property
276
302
  def project_id(self) -> 'str':
277
303
  """Gets the project_id of this V1Deployment. # noqa: E501
lightning_sdk/studio.py CHANGED
@@ -3,6 +3,7 @@ import warnings
3
3
  from typing import TYPE_CHECKING, Any, Mapping, Optional, Tuple, Union
4
4
 
5
5
  from lightning_sdk.api.studio_api import StudioApi
6
+ from lightning_sdk.api.utils import _machine_to_compute_name
6
7
  from lightning_sdk.constants import _LIGHTNING_DEBUG
7
8
  from lightning_sdk.machine import Machine
8
9
  from lightning_sdk.organization import Organization
@@ -145,14 +146,14 @@ class Studio:
145
146
  """Returns the cluster the Studio is running on."""
146
147
  return self._studio.cluster_id
147
148
 
148
- def start(self, machine: Machine = Machine.CPU, interruptible: bool = False) -> None:
149
+ def start(self, machine: Union[Machine, str] = Machine.CPU, interruptible: bool = False) -> None:
149
150
  """Starts a Studio on the specified machine type (default: CPU-4)."""
150
151
  status = self.status
151
152
  if status == Status.Running:
152
- curr_machine = self.machine
153
- if curr_machine != machine:
153
+ curr_machine = _machine_to_compute_name(self.machine) if self.machine is not None else None
154
+ if curr_machine != _machine_to_compute_name(machine):
154
155
  raise RuntimeError(
155
- f"Requested to start studio on {machine}, but studio is already running on {curr_machine}."
156
+ f"Requested to start studio on {machine}, but studio is already running on {self.machine}."
156
157
  " Consider switching instead!"
157
158
  )
158
159
  _logger.info(f"Studio {self.name} is already running")
@@ -180,7 +181,7 @@ class Studio:
180
181
  kwargs = self._studio_api.duplicate_studio(self._studio.id, self._teamspace.id, self._teamspace.id)
181
182
  return Studio(**kwargs)
182
183
 
183
- def switch_machine(self, machine: Machine, interruptible: bool = False) -> None:
184
+ def switch_machine(self, machine: Union[Machine, str], interruptible: bool = False) -> None:
184
185
  """Switches machine to the provided machine type/.
185
186
 
186
187
  Args:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: lightning_sdk
3
- Version: 0.1.32
3
+ Version: 0.1.33
4
4
  Summary: SDK to develop using Lightning AI Studios
5
5
  Author-email: Lightning-AI <justus@lightning.ai>
6
6
  License: MIT License
@@ -1,7 +1,7 @@
1
1
  docs/source/conf.py,sha256=r8yX20eC-4mHhMTd0SbQb5TlSWHhO6wnJ0VJ_FBFpag,13249
2
- lightning_sdk/__init__.py,sha256=R_77kIsdEkN0Lxnf5ETmShr62qDXfjG_gi2HIeT9LmA,925
2
+ lightning_sdk/__init__.py,sha256=rB9uPGPKZ4KsCOSgqGSglzmuOvrxBtIp_o_kdXHqlZM,925
3
3
  lightning_sdk/agents.py,sha256=ly6Ma1j0ZgGPFyvPvMN28JWiB9dATIstFa5XM8pMi6I,1577
4
- lightning_sdk/ai_hub.py,sha256=wVgLzqDzv0d3yWyoOEvE8epR97pWu-MoAYqysiO2evs,785
4
+ lightning_sdk/ai_hub.py,sha256=yjCkfmJRKI9zpIn3qUZ7J5C1lccTt6UnKD6iv0xXsX0,3966
5
5
  lightning_sdk/constants.py,sha256=ztl1PTUBULnqTf3DyKUSJaV_O20hNtUYT6XvAYIrmIk,749
6
6
  lightning_sdk/helpers.py,sha256=RnQwUquc_YPotjh6YXOoJvZs8krX_QFhd7kGv4U_spQ,1844
7
7
  lightning_sdk/machine.py,sha256=VdFXStR6ilYBEYuxgGWzcAw2TtW-nEQVsh6hz-2aaEw,750
@@ -9,19 +9,19 @@ lightning_sdk/organization.py,sha256=WCfzdgjtvY1_A07DnxOpp74V2JR2gQwtXbIEcFDnoVU
9
9
  lightning_sdk/owner.py,sha256=t5svD2it4C9pbSpVuG9WJL46CYi37JXNziwnXxhiU5U,1361
10
10
  lightning_sdk/plugin.py,sha256=2OgP8YtcLo2K1T0pypLGj3dkIaytTw6N11CHHQmRupY,13591
11
11
  lightning_sdk/status.py,sha256=kLDhN4-zdsGuZM577JMl1BbUIoF61bUOadW89ZAATFA,219
12
- lightning_sdk/studio.py,sha256=ivwRwPSJtve2iLOwjatXKCtNNt-bAVkhEpB-GRqsQz4,16588
12
+ lightning_sdk/studio.py,sha256=wvZrk07ZuG6Bf6c4XK8gMBumyiz_90JUScczdiWq1rU,16763
13
13
  lightning_sdk/teamspace.py,sha256=T3nxbZG1HtKu40F8gtpapqn-GVqSL6nBv7d32vmB_DY,10567
14
14
  lightning_sdk/user.py,sha256=vdn8pZqkAZO0-LoRsBdg0TckRKtd_H3QF4gpiZcl4iY,1130
15
15
  lightning_sdk/api/__init__.py,sha256=Qn2VVRvir_gO7w4yxGLkZY-R3T7kdiTPKgQ57BhIA9k,413
16
16
  lightning_sdk/api/agents_api.py,sha256=G47TbFo9kYqnBMqdw2RW-lfS1VAUBSXDmzs6fpIEMUs,4059
17
- lightning_sdk/api/ai_hub_api.py,sha256=DKVc80OY6Ehx1KoOEkHSnDZJxs6lh28C6I-4erJ2zew,572
17
+ lightning_sdk/api/ai_hub_api.py,sha256=Gf4-0FXQROpIXiuQipp-pjLFGNuU-0hft7V_rpxg3Hc,2479
18
18
  lightning_sdk/api/deployment_api.py,sha256=9HhOHz7ElmjgX88YkPUtETPgz_jMFkK1nTO18-jUniI,21505
19
19
  lightning_sdk/api/job_api.py,sha256=pdQ3R5x04fyC49UqggWCIuPDLCmLYrnq791Ij-3vMr0,8457
20
20
  lightning_sdk/api/org_api.py,sha256=Ze3z_ATVrukobujV5YdC42DKj45Vuwl7X52q_Vr-o3U,803
21
- lightning_sdk/api/studio_api.py,sha256=es_ZpFwSDotYgWvJ4MunY7qEpSqQLcA7amTK9fG3tII,25924
21
+ lightning_sdk/api/studio_api.py,sha256=ypgTAUJhwfdsEOJHfog4JsgZiLp9ZQA0KggSjLiI0ZA,26143
22
22
  lightning_sdk/api/teamspace_api.py,sha256=o-GBR3KLo298kDxO0myx-qlcCzSZnbR2OhZ73tt_5B8,9871
23
23
  lightning_sdk/api/user_api.py,sha256=sL7RIjjtmZmvCZWx7BBZslhj1BeNh4Idn-RVcdmf7M0,2598
24
- lightning_sdk/api/utils.py,sha256=Pnze_iy14cScQdOs8oELLXP-9WncpyigVEaohiT7uLM,20380
24
+ lightning_sdk/api/utils.py,sha256=eHxq0dU_e9Pq5XjPlp7sNJ5_N3d44EtHKpbEiPPMkOQ,20562
25
25
  lightning_sdk/cli/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
26
  lightning_sdk/cli/download.py,sha256=b6dpUMLhro6p0y6CleCHxggQnZlZMxCq0bjuf8B6_eA,6191
27
27
  lightning_sdk/cli/entrypoint.py,sha256=DD9rbmNqH6uE0IKMzfOenVymyrvfTMYi_zDPHr4Pk5E,1258
@@ -110,10 +110,10 @@ lightning_sdk/lightning_cloud/openapi/models/command_argument_command_argument_t
110
110
  lightning_sdk/lightning_cloud/openapi/models/conversations_id_body.py,sha256=blW66WBRE-qMCkzLCOrx4EqmM1vI42O3MvP9rIlx_tE,3651
111
111
  lightning_sdk/lightning_cloud/openapi/models/create.py,sha256=xrtxd9B0VkqfdszT5GXOmfD5PLmvPrbxTd2Ktjq55eE,14373
112
112
  lightning_sdk/lightning_cloud/openapi/models/create_checkout_session_request_wallet_type.py,sha256=6mUpEXvoZxyrnOvfT9PfF34E3OGvRMzMfLFndUe-jKU,3217
113
- lightning_sdk/lightning_cloud/openapi/models/create_deployment_request_defines_a_spec_for_the_job_that_allows_for_autoscaling_jobs.py,sha256=wGwtlmnWMhWMyJ2opvaV8llc7_tKPprpNjSOQTJeyVU,9970
113
+ lightning_sdk/lightning_cloud/openapi/models/create_deployment_request_defines_a_spec_for_the_job_that_allows_for_autoscaling_jobs.py,sha256=tbgFc6advfcyZCmeTiGzPGdASHbOzV9jxUa3DPgaapA,11124
114
114
  lightning_sdk/lightning_cloud/openapi/models/data_connection_mount_data_connection_mount_copy_status.py,sha256=ytC9VwBzQMvpOSZJ78uekfrMmwGRaTOclX5XC3a-dMs,3343
115
115
  lightning_sdk/lightning_cloud/openapi/models/datasets_id_body.py,sha256=OcNzTWEbLGF36eY-yImNCmTwY4ZMX2P8yPuHAvc6KSk,5533
116
- lightning_sdk/lightning_cloud/openapi/models/deployments_id_body.py,sha256=OQJ_v3nlNxgPEs6A6JE4uSMv1qZsZqJfRGtSjo16vRM,12741
116
+ lightning_sdk/lightning_cloud/openapi/models/deployments_id_body.py,sha256=CzAP4ulw4Q7nBFyLU3xHckqDLDKY7TRAB5ZaM64mV8o,13675
117
117
  lightning_sdk/lightning_cloud/openapi/models/deploymenttemplates_id_body.py,sha256=2TNCtt0YINExIsa6GYNNps5ton8QN-PPWSo6juG-oRo,14223
118
118
  lightning_sdk/lightning_cloud/openapi/models/endpoints_id_body.py,sha256=-Nv94Q8XoAT3tfYVhcjutpihfoNS2b2STbTHJr1gomo,12309
119
119
  lightning_sdk/lightning_cloud/openapi/models/experiment_name_variant_name_body.py,sha256=f-pBOnWAd03BreivwAA0D1P1wEBGq2jaC1nK-GtGxVA,3910
@@ -384,7 +384,7 @@ lightning_sdk/lightning_cloud/openapi/models/v1_delete_studio_job_response.py,sh
384
384
  lightning_sdk/lightning_cloud/openapi/models/v1_delete_user_slurm_job_response.py,sha256=kmy5RJKy63bRzBcWdMk00-vT3m2lVs05cic7jfU1JyQ,3058
385
385
  lightning_sdk/lightning_cloud/openapi/models/v1_dependency_cache_state.py,sha256=VL2MY_XvHWAd8toG-myqZlhuoEwjAIX1cemNSXxqk-k,3210
386
386
  lightning_sdk/lightning_cloud/openapi/models/v1_dependency_file_info.py,sha256=Fyh7eRcxFJ8Y4LHA3uJe_D5F9RhOxT4ChllHph_ryfQ,4571
387
- lightning_sdk/lightning_cloud/openapi/models/v1_deployment.py,sha256=7938YMxE3gyAJ6TyQogPiLu5PCwGKRPPJF74KG6n7lM,13732
387
+ lightning_sdk/lightning_cloud/openapi/models/v1_deployment.py,sha256=gcZLLsjCX21_IanX_pmHC2TkDhFoH8feGKAd9sXhbwo,14646
388
388
  lightning_sdk/lightning_cloud/openapi/models/v1_deployment_event.py,sha256=XDpq8DM8sN2UKCIlBWGEk-vzjGQGNoc0KjyRcCzDqZ4,9994
389
389
  lightning_sdk/lightning_cloud/openapi/models/v1_deployment_event_type.py,sha256=uuogrGRymVVSz-OcYoJ7eZyFj3POXY22EAQJkXpTVm4,3197
390
390
  lightning_sdk/lightning_cloud/openapi/models/v1_deployment_metrics.py,sha256=xHhh0-aPxmyyH3g1Ks2h2oFnEy-CtZjDWH68jjetPEQ,3611
@@ -808,9 +808,9 @@ lightning_sdk/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSu
808
808
  lightning_sdk/utils/dynamic.py,sha256=glUTO1JC9APtQ6Gr9SO02a3zr56-sPAXM5C3NrTpgyQ,1959
809
809
  lightning_sdk/utils/enum.py,sha256=h2JRzqoBcSlUdanFHmkj_j5DleBHAu1esQYUsdNI-hU,4106
810
810
  lightning_sdk/utils/resolve.py,sha256=gU3MSko9Y7rE4jcnVwstNBaW83OFnSgvM-N44Ibrc_A,5148
811
- lightning_sdk-0.1.32.dist-info/LICENSE,sha256=uFIuZwj5z-4TeF2UuacPZ1o17HkvKObT8fY50qN84sg,1064
812
- lightning_sdk-0.1.32.dist-info/METADATA,sha256=vxqkGNSCH82QlLfXB4ne_BdEdQWjqRpAZDlD02jAb3M,3941
813
- lightning_sdk-0.1.32.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
814
- lightning_sdk-0.1.32.dist-info/entry_points.txt,sha256=msB9PJWIJ784dX-OP8by51d4IbKYH3Fj1vCuA9oXjHY,68
815
- lightning_sdk-0.1.32.dist-info/top_level.txt,sha256=ps8doKILFXmN7F1mHncShmnQoTxKBRPIcchC8TpoBw4,19
816
- lightning_sdk-0.1.32.dist-info/RECORD,,
811
+ lightning_sdk-0.1.33.dist-info/LICENSE,sha256=uFIuZwj5z-4TeF2UuacPZ1o17HkvKObT8fY50qN84sg,1064
812
+ lightning_sdk-0.1.33.dist-info/METADATA,sha256=Q20ldoup46qvs8rCqTg3mCuO0e4si6qBOiXyWMwrbvs,3941
813
+ lightning_sdk-0.1.33.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
814
+ lightning_sdk-0.1.33.dist-info/entry_points.txt,sha256=msB9PJWIJ784dX-OP8by51d4IbKYH3Fj1vCuA9oXjHY,68
815
+ lightning_sdk-0.1.33.dist-info/top_level.txt,sha256=ps8doKILFXmN7F1mHncShmnQoTxKBRPIcchC8TpoBw4,19
816
+ lightning_sdk-0.1.33.dist-info/RECORD,,